diff --git a/go.mod b/go.mod index acdd10f75..146f2f123 100644 --- a/go.mod +++ b/go.mod @@ -185,7 +185,7 @@ require ( github.com/btcsuite/btcd/btcutil v1.1.3 // indirect github.com/cenkalti/backoff/v4 v4.2.0 // indirect github.com/chzyer/readline v1.5.0 // indirect - github.com/cometbft/cometbft v0.37.1 + github.com/cometbft/cometbft v0.37.2 github.com/containerd/containerd v1.6.19 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect diff --git a/go.sum b/go.sum index cbf0f3901..7927038aa 100644 --- a/go.sum +++ b/go.sum @@ -199,6 +199,8 @@ github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4 h1:/inchEIKaYC1Akx+H+gqO04wryn5h75LSazbRlnya1k= github.com/cometbft/cometbft v0.37.1 h1:KLxkQTK2hICXYq21U2hn1W5hOVYUdQgDQ1uB+90xPIg= github.com/cometbft/cometbft v0.37.1/go.mod h1:Y2MMMN//O5K4YKd8ze4r9jmk4Y7h0ajqILXbH5JQFVs= +github.com/cometbft/cometbft v0.37.2 h1:XB0yyHGT0lwmJlFmM4+rsRnczPlHoAKFX6K8Zgc2/Jc= +github.com/cometbft/cometbft v0.37.2/go.mod h1:Y2MMMN//O5K4YKd8ze4r9jmk4Y7h0ajqILXbH5JQFVs= github.com/cometbft/cometbft-db v0.7.0 h1:uBjbrBx4QzU0zOEnU8KxoDl18dMNgDh+zZRUE0ucsbo= github.com/cometbft/cometbft-db v0.7.0/go.mod h1:yiKJIm2WKrt6x8Cyxtq9YTEcIMPcEe4XPxhgX59Fzf0= github.com/compose-spec/compose-go v1.13.4 h1:O6xAsPqaY1s9KXteiO7wRCDTJLahv1XP/z/eUO9EfbI= diff --git a/internal/app/kwild/config/variables.go b/internal/app/kwild/config/variables.go index 85df06dcd..3a129ead7 100644 --- a/internal/app/kwild/config/variables.go +++ b/internal/app/kwild/config/variables.go @@ -1,18 +1,16 @@ package config import ( - "crypto/ecdsa" "fmt" "os" "strings" + "github.com/kwilteam/kwil-db/pkg/crypto" "github.com/kwilteam/kwil-db/pkg/log" "github.com/kwilteam/kwil-db/pkg/config" - cmtCrypto "github.com/cometbft/cometbft/crypto" "github.com/cstockton/go-conv" - "github.com/ethereum/go-ethereum/crypto" ) const ( @@ -22,13 +20,12 @@ const ( type KwildConfig struct { GrpcListenAddress string HttpListenAddress string - PrivateKey *ecdsa.PrivateKey + PrivateKey *crypto.Ed25519PrivateKey SqliteFilePath string Log log.Config ExtensionEndpoints []string ArweaveConfig ArweaveConfig BcRpcUrl string - BCPrivateKey cmtCrypto.PrivKey WithoutGasCosts bool WithoutNonces bool } @@ -60,7 +57,7 @@ var ( Setter: func(val any) (any, error) { if val == nil { fmt.Println("no private key provided, generating a new one...") - return crypto.GenerateKey() + return crypto.GenerateEd25519Key() } strVal, err := conv.String(val) @@ -68,7 +65,7 @@ var ( return nil, err } - return crypto.HexToECDSA(strVal) + return crypto.Ed25519PrivateKeyFromHex(strVal) }, } diff --git a/internal/app/kwild/server/root.go b/internal/app/kwild/server/root.go index ff0bab183..54c62487d 100644 --- a/internal/app/kwild/server/root.go +++ b/internal/app/kwild/server/root.go @@ -15,18 +15,20 @@ import ( "github.com/kwilteam/kwil-db/internal/pkg/healthcheck" simple_checker "github.com/kwilteam/kwil-db/internal/pkg/healthcheck/simple-checker" "github.com/kwilteam/kwil-db/pkg/abci" + "github.com/kwilteam/kwil-db/pkg/abci/cometbft" "github.com/kwilteam/kwil-db/pkg/balances" "github.com/kwilteam/kwil-db/pkg/engine" "github.com/kwilteam/kwil-db/pkg/grpc/gateway" "github.com/kwilteam/kwil-db/pkg/grpc/gateway/middleware/cors" grpc "github.com/kwilteam/kwil-db/pkg/grpc/server" + "github.com/kwilteam/kwil-db/pkg/kv/badger" "github.com/kwilteam/kwil-db/pkg/log" "github.com/kwilteam/kwil-db/pkg/modules/datasets" "github.com/kwilteam/kwil-db/pkg/modules/validators" "github.com/kwilteam/kwil-db/pkg/sql" vmgr "github.com/kwilteam/kwil-db/pkg/validators" - // CometBFT + abciTypes "github.com/cometbft/cometbft/abci/types" cmtcfg "github.com/cometbft/cometbft/config" cmtflags "github.com/cometbft/cometbft/libs/cli/flags" cmtlog "github.com/cometbft/cometbft/libs/log" @@ -261,6 +263,32 @@ func buildCometBftClient(cometBftNode *nm.Node) *cmtlocal.Local { return cmtlocal.New(cometBftNode) } +func buildCometNode(d *coreDependencies, abciApp abciTypes.Application) *cometbft.CometBftNode { + // TODO: a lot of the filepaths and logging here are hardcoded. This should be cleaned up + // with a config + + // for now, I'm just using a KV store for my atomic commit. This probably is not ideal; a file may be better + // I'm simply using this because we know it fsyncs the data to disk + db, err := badger.NewBadgerDB("abci/signing", &badger.Options{ + GuaranteeFSync: true, + }) + if err != nil { + failBuild(err, "failed to build comet node") + } + + readWriter := &atomicReadWriter{ + kv: db, + key: []byte("az"), // any key here will work + } + + node, err := cometbft.NewCometBftNode(abciApp, d.cfg.PrivateKey.Bytes(), readWriter, "abci/data", "debug") + if err != nil { + failBuild(err, "failed to build comet node") + } + + return node +} + // TODO: clean this up --> @jchappelow // it seems some of this should be handled in ABCI package if we do not provide it as a package func newCometNode(app *abci.AbciApp, cfg *config.KwildConfig) (*nm.Node, error) { diff --git a/internal/app/kwild/server/utils.go b/internal/app/kwild/server/utils.go index d6eb93d4a..599d000dc 100644 --- a/internal/app/kwild/server/utils.go +++ b/internal/app/kwild/server/utils.go @@ -9,6 +9,7 @@ import ( cmttypes "github.com/cometbft/cometbft/types" "github.com/kwilteam/kwil-db/pkg/engine" "github.com/kwilteam/kwil-db/pkg/extensions" + "github.com/kwilteam/kwil-db/pkg/kv" "github.com/kwilteam/kwil-db/pkg/log" "github.com/kwilteam/kwil-db/pkg/sql" "github.com/kwilteam/kwil-db/pkg/sql/client" @@ -106,3 +107,27 @@ func (wc *wrappedCometBFTClient) BroadcastTxAsync(ctx context.Context, tx *trans return err } + +// atomicReadWriter implements the CometBFt AtomicReadWriter interface. +// This should probably be done with a file instead of a KV store, +// but we already have a good implementation of an atomic KV store. +type atomicReadWriter struct { + kv kv.KVStore + key []byte +} + +func (a *atomicReadWriter) Read() ([]byte, error) { + res, err := a.kv.Get(a.key) + if err == kv.ErrKeyNotFound { + return nil, nil + } + if err != nil { + return nil, err + } + + return res, nil +} + +func (a *atomicReadWriter) Write(val []byte) error { + return a.kv.Set(a.key, val) +} diff --git a/internal/pkg/nodecfg/generate.go b/internal/pkg/nodecfg/generate.go index d8374b90e..4103a6d08 100644 --- a/internal/pkg/nodecfg/generate.go +++ b/internal/pkg/nodecfg/generate.go @@ -9,8 +9,9 @@ import ( cmtCfg "github.com/cometbft/cometbft/config" cmtos "github.com/cometbft/cometbft/libs/os" - cmtrand "github.com/cometbft/cometbft/libs/rand" "github.com/cometbft/cometbft/p2p" + + cmtrand "github.com/cometbft/cometbft/libs/rand" "github.com/cometbft/cometbft/privval" "github.com/cometbft/cometbft/types" cmttime "github.com/cometbft/cometbft/types/time" @@ -42,6 +43,8 @@ type TestnetGenerateConfig struct { P2pPort int } +// TODO: if we use our own keys for cosmos, this will not work +// privval.LoadFilePV will need to be replacew with something else func GenerateNodeConfig(genCfg *NodeGenerateConfig) error { cfg := cmtCfg.DefaultConfig() cfg.SetRoot(genCfg.HomeDir) @@ -210,6 +213,8 @@ func GenerateTestnetConfig(genCfg *TestnetGenerateConfig) error { return nil } +// TODO: we definitely want to get rid of this, or at least make it more understandable / move it +// It generates private keys, which we should not leave up to Comet func initFilesWithConfig(cfg *cmtCfg.Config) error { // private validator privValKeyFile := cfg.PrivValidatorKeyFile() diff --git a/pkg/abci/abci.go b/pkg/abci/abci.go index 46edcd87c..56e982f9d 100644 --- a/pkg/abci/abci.go +++ b/pkg/abci/abci.go @@ -37,6 +37,10 @@ func (fe FatalError) String() string { } func newFatalError(method string, request fmt.Stringer, message string) FatalError { + if request == nil { + request = nilStringer{} + } + return FatalError{ AppMethod: method, Request: request, @@ -44,9 +48,10 @@ func newFatalError(method string, request fmt.Stringer, message string) FatalErr } } -type appState struct { // TODO - prevBlockHeight int64 - prevAppHash []byte +type nilStringer struct{} + +func (ds nilStringer) String() string { + return "no message" } func NewAbciApp(database DatasetsModule, vldtrs ValidatorModule, kv KVStore, committer AtomicCommitter, opts ...AbciOpt) *AbciApp { @@ -58,6 +63,9 @@ func NewAbciApp(database DatasetsModule, vldtrs ValidatorModule, kv KVStore, com kv: kv, }, + valAddrToKey: make(map[string][]byte), + valUpdates: make([]*validators.Validator, 0), + log: log.NewNoOp(), commitWaiter: sync.WaitGroup{}, @@ -111,6 +119,8 @@ type AbciApp struct { applicationVersion uint64 } +var _ abciTypes.Application = &AbciApp{} + func (a *AbciApp) ApplySnapshotChunk(p0 abciTypes.RequestApplySnapshotChunk) abciTypes.ResponseApplySnapshotChunk { return abciTypes.ResponseApplySnapshotChunk{} } @@ -126,8 +136,7 @@ func (a *AbciApp) BeginBlock(req abciTypes.RequestBeginBlock) abciTypes.Response err := a.committer.Begin(context.Background()) if err != nil { - a.log.Error("failed to begin atomic commit", zap.Error(err)) - return abciTypes.ResponseBeginBlock{} + panic(newFatalError("BeginBlock", &req, err.Error())) } // Punish bad validators. @@ -142,12 +151,12 @@ func (a *AbciApp) BeginBlock(req abciTypes.RequestBeginBlock) abciTypes.Response // This is why we need the addr=>pubkey map. Why, comet, why? pubkey, ok := a.valAddrToKey[addr] if !ok { - panic(fmt.Sprintf("unknown validator address %v", addr)) + panic(newFatalError("BeginBlock", &req, fmt.Sprintf("unknown validator address %v", addr))) } const punishDelta = 1 newPower := ev.Validator.Power - punishDelta if err = a.validators.Punish(context.Background(), pubkey, newPower); err != nil { - panic(fmt.Sprintf("failed to punish validator %v: %v", addr, err)) + panic(newFatalError("BeginBlock", &req, fmt.Sprintf("failed to punish validator %v", addr))) } } @@ -171,13 +180,6 @@ func (a *AbciApp) CheckTx(incoming abciTypes.RequestCheckTx) abciTypes.ResponseC return abciTypes.ResponseCheckTx{Code: 0} } -// pubkeys in event attributes returned to comet as strings are base64 encoded, -// apparently. -// TODO: move this somewhere else in the file -func encodeBase64(b []byte) string { - return base64.StdEncoding.EncodeToString(b) -} - func (a *AbciApp) DeliverTx(req abciTypes.RequestDeliverTx) abciTypes.ResponseDeliverTx { ctx := context.Background() @@ -354,32 +356,28 @@ func (a *AbciApp) Commit() abciTypes.ResponseCommit { // generate the unique id for all changes occurred thus far id, err := a.committer.ID(ctx) if err != nil { - a.log.Error("failed to get committer id", zap.Error(err)) - return abciTypes.ResponseCommit{} + panic(newFatalError("Commit", nil, fmt.Sprintf("failed to get commit id: %v", err))) } appHash, err := a.createNewAppHash(ctx, id) if err != nil { - a.log.Error("failed to create new app hash", zap.Error(err)) - return abciTypes.ResponseCommit{} + panic(newFatalError("Commit", nil, fmt.Sprintf("failed to create new app hash: %v", err))) } err = a.metadataStore.IncrementBlockHeight(ctx) if err != nil { - a.log.Error("failed to increment block height", zap.Error(err)) - return abciTypes.ResponseCommit{} + panic(newFatalError("Commit", nil, fmt.Sprintf("failed to increment block height: %v", err))) } err = a.committer.Commit(ctx, func(err error) { if err != nil { - a.log.Error("failed to apply atomic commit", zap.Error(err)) + panic(newFatalError("Commit", nil, fmt.Sprintf("failed to commit atomic commit: %v", err))) } a.commitWaiter.Done() }) if err != nil { - a.log.Error("failed to commit atomic commit", zap.Error(err)) - return abciTypes.ResponseCommit{} + panic(newFatalError("Commit", nil, fmt.Sprintf("failed to commit atomic commit: %v", err))) } // Update the validator address=>pubkey map used by Penalize. @@ -420,19 +418,12 @@ func (a *AbciApp) Info(p0 abciTypes.RequestInfo) abciTypes.ResponseInfo { height, err := a.metadataStore.GetBlockHeight(ctx) if err != nil { - a.log.Error("failed to get block height", zap.Error(err)) - return abciTypes.ResponseInfo{ - AppVersion: a.applicationVersion, - } + panic(newFatalError("Info", &p0, fmt.Sprintf("failed to get block height: %v", err))) } appHash, err := a.metadataStore.GetAppHash(ctx) if err != nil { - a.log.Error("failed to get app hash", zap.Error(err)) - return abciTypes.ResponseInfo{ - LastBlockHeight: height, - AppVersion: a.applicationVersion, - } + panic(newFatalError("Info", &p0, fmt.Sprintf("failed to get app hash: %v", err))) } return abciTypes.ResponseInfo{ @@ -472,11 +463,10 @@ func (a *AbciApp) InitChain(p0 abciTypes.RequestInitChain) abciTypes.ResponseIni apphash, err := a.metadataStore.GetAppHash(ctx) if err != nil { - a.log.Error("failed to get app hash", zap.Error(err)) - + panic(fmt.Sprintf("failed to get app hash: %v", err)) // TODO: should we initialize with a genesis hash instead if it fails // TODO: apparently InitChain is only genesis, so yes it should only be genesis hash - apphash = []byte{} + // in fact, I don't think we should be getting it from this store at all } return abciTypes.ResponseInitChain{ @@ -536,8 +526,8 @@ func convertArgs(args [][]string) [][]any { } var ( - appHashKey = []byte("appHash") - blockHeightKey = []byte("blockHeight") + appHashKey = []byte("a") + blockHeightKey = []byte("b") ) type metadataStore struct { @@ -576,3 +566,9 @@ func (m *metadataStore) IncrementBlockHeight(ctx context.Context) error { return m.SetBlockHeight(ctx, height+1) } + +// pubkeys in event attributes returned to comet as strings are base64 encoded, +// apparently. +func encodeBase64(b []byte) string { + return base64.StdEncoding.EncodeToString(b) +} diff --git a/pkg/abci/cometbft/comet_test.go b/pkg/abci/cometbft/comet_test.go deleted file mode 100644 index 2a92e6eef..000000000 --- a/pkg/abci/cometbft/comet_test.go +++ /dev/null @@ -1,6 +0,0 @@ -package cometbft_test - -import "testing" - -// TODO: delete this -func Test_CometBFT(t *testing.T) {} diff --git a/pkg/abci/cometbft/implementations.go b/pkg/abci/cometbft/implementations.go deleted file mode 100644 index eaf0f8414..000000000 --- a/pkg/abci/cometbft/implementations.go +++ /dev/null @@ -1,44 +0,0 @@ -package cometbft - -import ( - "github.com/cometbft/cometbft/crypto" - "github.com/cometbft/cometbft/privval" -) - -// this file contains some basic implementations of cometbft validator interfaces. -// Some of these are actually hard to implement (like PrivKey), because some basic digging -// reveals that the internals of CometBFT might be tied to their own implementations -// I am just including these here to organize my thoughts and requirements - -// this is just a placeholder -type CometBftPrivateKey struct { -} - -func (c *CometBftPrivateKey) Bytes() []byte { - panic("TODO") -} - -func (c *CometBftPrivateKey) Equals(p0 crypto.PrivKey) bool { - panic("TODO") -} - -func (c *CometBftPrivateKey) PubKey() crypto.PubKey { - panic("TODO") -} - -func (c *CometBftPrivateKey) Sign(msg []byte) ([]byte, error) { - panic("TODO") -} - -// this Type seems to be the big hold up, since it essentially couples their internal implementations -func (c *CometBftPrivateKey) Type() string { - panic("TODO") -} - -// newPrivateValidator creates a new private validator with the given private key -// we don't need the filepaths, they are only used for the Save() method which is only used -// in testing and cometBFTs Cobra Commands. Save() is not included in the interface required -// by NewNode() -func newPrivateValidator(pk *CometBftPrivateKey) *privval.FilePV { - return privval.NewFilePV(pk, "", "") // save is not called, so we don't need to worry about the file paths -} diff --git a/pkg/abci/cometbft/node.go b/pkg/abci/cometbft/node.go index 83dcde87f..1234c32d4 100644 --- a/pkg/abci/cometbft/node.go +++ b/pkg/abci/cometbft/node.go @@ -6,56 +6,42 @@ import ( abciTypes "github.com/cometbft/cometbft/abci/types" cometConfig "github.com/cometbft/cometbft/config" + cometEd25519 "github.com/cometbft/cometbft/crypto/ed25519" cometFlags "github.com/cometbft/cometbft/libs/cli/flags" cometLog "github.com/cometbft/cometbft/libs/log" cometNodes "github.com/cometbft/cometbft/node" "github.com/cometbft/cometbft/p2p" "github.com/cometbft/cometbft/proxy" + "github.com/kwilteam/kwil-db/pkg/abci/cometbft/privval" ) type CometBftNode struct { node *cometNodes.Node } -type Config struct { - // directory is the path to where files should be read and written for cometbft - Directory string - - // LogLevel is the log level for cometbft - LogLevel string -} - -func NewCometBftNode(app abciTypes.Application, config *Config) (*CometBftNode, error) { - conf := cometConfig.DefaultConfig().SetRoot(config.Directory) +// NewCometBftNode creates a new CometBFT node. +// I don't love this constructor function signature; I can definitely make it better +func NewCometBftNode(app abciTypes.Application, privateKey []byte, atomicStore privval.AtomicReadWriter, directory string, logLevel string) (*CometBftNode, error) { + conf := cometConfig.DefaultConfig().SetRoot(directory) logger := cometLog.NewTMLogger(cometLog.NewSyncWriter(os.Stdout)) - logger, err := cometFlags.ParseLogLevel(conf.LogLevel, logger, config.LogLevel) + logger, err := cometFlags.ParseLogLevel(conf.LogLevel, logger, logLevel) if err != nil { return nil, fmt.Errorf("failed to parse log level: %v", err) } - privKey := &CometBftPrivateKey{} + privateValidator, err := privval.NewValidatorSigner(privateKey, atomicStore) + if err != nil { + return nil, fmt.Errorf("failed to create private validator: %v", err) + } node, err := cometNodes.NewNode( conf, - - // ideally this takes our own custom implementation, since CometBFT - // only supports signers created from files, or remote signers that need a connection - newPrivateValidator(privKey), - - // ideally we can use our own custom implementation. - // CometBFT supports both ED25519 and SECP256K1, but the default seems to be ED25519. - // either translating our ED25519 and SECP256K1 keys to CometBFT's format, or - // creating our own CometBet PrivKey implementation is ideal. - // It does seems that internally, CometBFT is tied to their own implementations of PrivKey, - // but I am not certain + privateValidator, &p2p.NodeKey{ - PrivKey: &CometBftPrivateKey{}, + PrivKey: cometEd25519.PrivKey(privateKey), }, proxy.NewLocalClientCreator(app), cometNodes.DefaultGenesisDocProviderFunc(conf), - - // There coukd be a good reason to switch this with our own implementation, - // seems lower priority than others though cometNodes.DefaultDBProvider, cometNodes.DefaultMetricsProvider(conf.Instrumentation), logger, @@ -68,3 +54,13 @@ func NewCometBftNode(app abciTypes.Application, config *Config) (*CometBftNode, node: node, }, nil } + +// Start starts the CometBFT node. +func (n *CometBftNode) Start() error { + return n.node.Start() +} + +// Stop stops the CometBFT node. +func (n *CometBftNode) Stop() error { + return n.node.Stop() +} diff --git a/pkg/abci/cometbft/privvalidator.go b/pkg/abci/cometbft/privval/privvalidator.go similarity index 71% rename from pkg/abci/cometbft/privvalidator.go rename to pkg/abci/cometbft/privval/privvalidator.go index 15ca9ff31..6e206b48b 100644 --- a/pkg/abci/cometbft/privvalidator.go +++ b/pkg/abci/cometbft/privval/privvalidator.go @@ -1,4 +1,4 @@ -package cometbft +package privval /* Much of the code in this package is inspired or pulled directly from cometbft/privval, @@ -9,6 +9,7 @@ package cometbft import ( "bytes" + "encoding/json" "errors" "fmt" "time" @@ -16,27 +17,25 @@ import ( cometEd25519 "github.com/cometbft/cometbft/crypto/ed25519" cmtbytes "github.com/cometbft/cometbft/libs/bytes" "github.com/cometbft/cometbft/libs/protoio" - cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + tendermintTypes "github.com/cometbft/cometbft/proto/tendermint/types" cmttime "github.com/cometbft/cometbft/types/time" "github.com/cosmos/gogoproto/proto" "github.com/cometbft/cometbft/crypto" "github.com/cometbft/cometbft/types" - "github.com/kwilteam/kwil-db/pkg/kv" - numBytes "github.com/kwilteam/kwil-db/pkg/utils/numbers/bytes" ) // NewValidatorSigner returns a new ValidatorSigner // it takes in an ed25519 key, and a keyvalue store // the key values store should NOT be atomically committed with other KV // stores. Instead, it should simply fsync after every write/commit -func NewValidatorSigner(ed25519Key []byte, kvStore AtomicKV) (*ValidatorSigner, error) { +func NewValidatorSigner(ed25519Key []byte, storer AtomicReadWriter) (*ValidatorSigner, error) { if len(ed25519Key) != cometEd25519.PrivateKeySize { return nil, fmt.Errorf("invalid private key size. received: %d, expected: %d", len(ed25519Key), cometEd25519.PrivateKeySize) } - lss := &lastSignState{ - kv: kvStore, + lss := &LastSignState{ + storer: storer, } err := lss.loadLatest() if err != nil { @@ -57,7 +56,7 @@ type ValidatorSigner struct { privateKey crypto.PrivKey // lastSignedState is the most recent signature made by this validator - lastSignedState *lastSignState + lastSignedState *LastSignState } var _ types.PrivValidator = (*ValidatorSigner)(nil) @@ -71,7 +70,7 @@ func (v *ValidatorSigner) GetPubKey() (crypto.PubKey, error) { // SignProposal signs a proposal message // It is part of the cometTypes.PrivValidator interface -func (v *ValidatorSigner) SignProposal(chainID string, proposal *cmtproto.Proposal) error { +func (v *ValidatorSigner) SignProposal(chainID string, proposal *tendermintTypes.Proposal) error { height, round, step := proposal.Height, proposal.Round, stepPropose sameHRS, err := v.lastSignedState.checkHRS(height, round, step) @@ -93,8 +92,9 @@ func (v *ValidatorSigner) SignProposal(chainID string, proposal *cmtproto.Propos proposal.Signature = v.lastSignedState.Signature proposal.Timestamp = timestamp } else { - return fmt.Errorf("proposal sign bytes differ from last sign bytes") + err = fmt.Errorf("proposal sign bytes differ from last sign bytes") } + return err } // Sign the proposal @@ -111,8 +111,8 @@ func (v *ValidatorSigner) SignProposal(chainID string, proposal *cmtproto.Propos // SignVote signs a vote message // It is part of the cometTypes.PrivValidator interface -func (v *ValidatorSigner) SignVote(chainID string, vote *cmtproto.Vote) error { - height, round, step := vote.Height, vote.Round, voteToStep(vote) +func (v *ValidatorSigner) SignVote(chainID string, vote *tendermintTypes.Vote) error { + height, round, step := vote.Height, vote.Round, VoteToStep(vote) sameHRS, err := v.lastSignedState.checkHRS(height, round, step) if err != nil { @@ -175,110 +175,56 @@ func (v *ValidatorSigner) signAndPersist(height int64, round int32, step int8, s // made by this validator. It is atomically committed to disk // before it is used for anything else, and can be reloaded in case // of a crash -type lastSignState struct { +type LastSignState struct { // Height is the height of the block that the message was signed for - Height int64 + Height int64 `json:"height"` // Round is the consensus round that the message was signed for // CometBFT can have an arbitrary number of rounds per height - Round int32 + Round int32 `json:"round"` // Step is the consensus step that the message was signed for // e.g. propose, prevote, precommit - Step int8 + Step int8 `json:"step"` // Signature is the signature generated by the validator - Signature []byte + Signature []byte `json:"signature"` // SignBytes is the bytes that were signed by the validator - SignBytes cmtbytes.HexBytes + SignBytes cmtbytes.HexBytes `json:"sign_bytes"` - // kv is the keyvalue store that this lastSignState is persisted to - kv AtomicKV + // storer is the store that this lastSignState is persisted to + storer AtomicReadWriter } // store stores the lastSignState to the given KV store // it is atomic, and will only commit if all writes succeed -func (l *lastSignState) store() error { - tx, err := l.kv.BeginTransaction() +func (l *LastSignState) store() error { + bts, err := json.Marshal(l) if err != nil { return err } - defer tx.Discard() - - err = errors.Join( - tx.Set(heightKey, numBytes.Int64ToBytes(l.Height)), - tx.Set(roundKey, numBytes.Int32ToBytes(l.Round)), - tx.Set(stepKey, []byte{byte(l.Step)}), - tx.Set(signatureKey, l.Signature), - tx.Set(signBytesKey, l.SignBytes.Bytes()), - ) - if err != nil { - return err - } - - return tx.Commit() + return l.storer.Write(bts) } // loadLatest loads the latest lastSignState from the given KV store // if none exists, it sets all fields to zero values -func (l *lastSignState) loadLatest() (err error) { - // if we encounter a key not found error, we set all fields to zero values - // if not, then there is a KV store issue, and we do not want to sign for risk - // of double signing - defer func() { - if err == kv.ErrKeyNotFound { - l.setZero() - err = nil - } - }() - - tx, err := l.kv.BeginTransaction() - if err != nil { - return err - } - defer tx.Discard() - - height, err := tx.Get(heightKey) - if err != nil { - return err - } - - round, err := tx.Get(roundKey) - if err != nil { - return err - } - - step, err := tx.Get(stepKey) - if err != nil { - return err - } - // guard against index out of range panic - if len(step) != 1 { - return fmt.Errorf("invalid step length. received: %d, expected: 1", len(step)) - } - - signature, err := tx.Get(signatureKey) +func (l *LastSignState) loadLatest() (err error) { + bts, err := l.storer.Read() if err != nil { return err } - signBytes, err := tx.Get(signBytesKey) - if err != nil { - return err + if bts == nil { + l.setZero() + return nil } - l.Height = numBytes.BytesToInt64(height) - l.Round = numBytes.BytesToInt32(round) - l.Step = int8(step[0]) - l.Signature = signature - l.SignBytes = cmtbytes.HexBytes(signBytes) - - return nil + return json.Unmarshal(bts, l) } // setZero sets all fields to zero values -func (l *lastSignState) setZero() { +func (l *LastSignState) setZero() { l.Height = 0 l.Round = 0 l.Step = 0 @@ -286,23 +232,23 @@ func (l *lastSignState) setZero() { l.SignBytes = nil } -// checkHRS checks that the given height, round, and step match the lastSignState -// if they do not, it returns an error -func (lss *lastSignState) checkHRS(height int64, round int32, step int8) (bool, error) { +// checkHRS checks that the given height, round, and step match the lastSignState. +func (lss *LastSignState) checkHRS(height int64, round int32, step int8) (bool, error) { if lss.Height > height { - return false, fmt.Errorf("height regression. Got %v, last height %v", height, lss.Height) + return false, fmt.Errorf("%w: height regression. Got %v, last height %v", ErrHeightRegression, height, lss.Height) } if lss.Height == height { if lss.Round > round { - return false, fmt.Errorf("round regression at height %v. Got %v, last round %v", height, round, lss.Round) + return false, fmt.Errorf("%w: round regression at height %v. Got %v, last round %v", ErrRoundRegression, height, round, lss.Round) } if lss.Round == round { if lss.Step > step { return false, fmt.Errorf( - "step regression at height %v round %v. Got %v, last step %v", + "%w: step regression at height %v round %v. Got %v, last step %v", + ErrStepRegression, height, round, step, @@ -311,7 +257,7 @@ func (lss *lastSignState) checkHRS(height int64, round int32, step int8) (bool, } else if lss.Step == step { if lss.SignBytes != nil { if lss.Signature == nil { - panic("pv: Signature is nil but SignBytes is not!") + return false, fmt.Errorf("%w: Signature is nil but SignBytes is not", ErrNilSignature) } return true, nil } @@ -322,27 +268,12 @@ func (lss *lastSignState) checkHRS(height int64, round int32, step int8) (bool, return false, nil } -// AtomicKV is an interface for a keyvalue store. -// This should be durable, and fsync (or equivalent) after every write/commit -type AtomicKV interface { - // BeginTransaction starts a new transaction - BeginTransaction() (kv.Transaction, error) -} - -var ( - heightKey = []byte{byte(10)} - roundKey = []byte{byte(20)} - stepKey = []byte{byte(30)} - signatureKey = []byte{byte(40)} - signBytesKey = []byte{byte(50)} -) - // Returns the timestamp from the lastSignBytes. // Returns true if the only difference in the votes is their timestamp. // Performs these checks on the canonical votes (excluding the vote extension // and vote extension signatures). func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) { - var lastVote, newVote cmtproto.CanonicalVote + var lastVote, newVote tendermintTypes.CanonicalVote if err := protoio.UnmarshalDelimited(lastSignBytes, &lastVote); err != nil { panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into vote: %v", err)) } @@ -362,7 +293,7 @@ func checkVotesOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.T // returns the timestamp from the lastSignBytes. // returns true if the only difference in the proposals is their timestamp func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (time.Time, bool) { - var lastProposal, newProposal cmtproto.CanonicalProposal + var lastProposal, newProposal tendermintTypes.CanonicalProposal if err := protoio.UnmarshalDelimited(lastSignBytes, &lastProposal); err != nil { panic(fmt.Sprintf("LastSignBytes cannot be unmarshalled into proposal: %v", err)) } @@ -379,12 +310,13 @@ func checkProposalsOnlyDifferByTimestamp(lastSignBytes, newSignBytes []byte) (ti return lastTime, proto.Equal(&newProposal, &lastProposal) } +// this should be unexported, but is needed for testing // A vote is either stepPrevote or stepPrecommit. -func voteToStep(vote *cmtproto.Vote) int8 { +func VoteToStep(vote *tendermintTypes.Vote) int8 { switch vote.Type { - case cmtproto.PrevoteType: + case tendermintTypes.PrevoteType: return stepPrevote - case cmtproto.PrecommitType: + case tendermintTypes.PrecommitType: return stepPrecommit default: panic(fmt.Sprintf("Unknown vote type: %v", vote.Type)) @@ -397,3 +329,20 @@ const ( stepPrevote int8 = 2 stepPrecommit int8 = 3 ) + +// AtomicReadWriter is an interface for any store +// that can atomically read and write to a persistent store +type AtomicReadWriter interface { + // Write should overwrite the current value with the given value + Write([]byte) error + // Read should return the current value + // if the value is empty, it should return empty bytes and no error + Read() ([]byte, error) +} + +var ( + ErrHeightRegression = errors.New("height regression") + ErrRoundRegression = errors.New("round regression") + ErrStepRegression = errors.New("step regression") + ErrNilSignature = errors.New("signature is nil") +) diff --git a/pkg/abci/cometbft/privval/privvalidator_test.go b/pkg/abci/cometbft/privval/privvalidator_test.go new file mode 100644 index 000000000..a6712e5c3 --- /dev/null +++ b/pkg/abci/cometbft/privval/privvalidator_test.go @@ -0,0 +1,456 @@ +package privval_test + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "testing" + + "github.com/aws/smithy-go/time" + cmtproto "github.com/cometbft/cometbft/proto/tendermint/types" + "github.com/cometbft/cometbft/types" + "github.com/kwilteam/kwil-db/pkg/abci/cometbft/privval" + "github.com/stretchr/testify/assert" +) + +const defaultChainID = "test-chain" +const defaultPrivateKey = "7c67e60fce0c403ff40193a3128e5f3d8c2139aed36d76d7b5f1e70ec19c43f00aa611bf555596912bc6f9a9f169f8785918e7bab9924001895798ff13f05842" + +func Test_PrivValidatorVote(t *testing.T) { + type testCase struct { + // name is the name of the test case. + name string + // lastSigned is the last signed vote. + // it can be nil. + lastSigned *cmtproto.Vote + // vote is the vote to sign. + vote *cmtproto.Vote + // secondVote is the second vote to sign. + // it can be nil. + secondVote *cmtproto.Vote + + // chainid is the default chain ID to use. + chainID string + + // privKey is the private key to use. + privKey string + + // err is the expected error. + // if nil, no error is expected. + err error + + // after is a function to run after the test case. + // it can be nil. + after func(t *testing.T, tc *testCase) + } + + tests := []testCase{ + { + name: "signing a vote with no other votes signed", + vote: testVote(), + chainID: defaultChainID, + privKey: defaultPrivateKey, + }, + { + name: "signing two separate votes, validly", + vote: testVote(height(1)), + secondVote: testVote(height(2)), + chainID: defaultChainID, + privKey: defaultPrivateKey, + }, + { + name: "signing a vote with a different previous vote signed", + lastSigned: testVote(height(1)), + vote: testVote(height(2)), + chainID: defaultChainID, + privKey: defaultPrivateKey, + }, + { + name: "signing the same vote despite it being signed already, first vote is last signed", + lastSigned: testVote(signed("sig")), + vote: testVote(), + chainID: defaultChainID, + privKey: defaultPrivateKey, + after: func(t *testing.T, tc *testCase) { + // it should have the same signature as the last signed vote. + assert.Equal(t, tc.lastSigned.Signature, tc.vote.Signature) + }, + }, + { + name: "signing same vote twice, with different timestamps", + lastSigned: testVote(signed("sig"), timestamped(100)), + vote: testVote(timestamped(200)), + chainID: defaultChainID, + privKey: defaultPrivateKey, + after: func(t *testing.T, tc *testCase) { + // it should have the same signature as the last signed vote. + assert.Equal(t, tc.lastSigned.Signature, tc.vote.Signature) + assert.Equal(t, tc.lastSigned.Timestamp, tc.vote.Timestamp) + }, + }, + { + name: "test height regression", + lastSigned: testVote(height(100)), + vote: testVote(height(99)), + chainID: defaultChainID, + privKey: defaultPrivateKey, + err: privval.ErrHeightRegression, + }, + { + name: "test round regression", + lastSigned: testVote(round(100)), + vote: testVote(round(99)), + chainID: defaultChainID, + privKey: defaultPrivateKey, + err: privval.ErrRoundRegression, + }, + { + name: "test step regression", + lastSigned: testVote(step(cmtproto.PrecommitType)), + vote: testVote(step(cmtproto.PrevoteType)), + chainID: defaultChainID, + privKey: defaultPrivateKey, + err: privval.ErrStepRegression, + }, + } + + // test cases for signing a vote. + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // outer function to catch any returns errors + err := func() error { + privKeyBts, err := hex.DecodeString(tc.privKey) + if err != nil { + return err + } + + store := newMockStore() + if tc.lastSigned != nil { + if err := setKeys(store, tc.lastSigned.Height, tc.lastSigned.Round, privval.VoteToStep(tc.lastSigned), tc.lastSigned.Signature, types.VoteSignBytes(tc.chainID, tc.lastSigned)); err != nil { + return err + } + } + + privVal, err := privval.NewValidatorSigner(privKeyBts, store) + if err != nil { + return err + } + + // sign the vote + err = privVal.SignVote(tc.chainID, tc.vote) + if err != nil { + return err + } + + assert.NotNil(t, tc.vote.Signature) + + if tc.secondVote != nil { + // sign the second vote + err = privVal.SignVote(tc.chainID, tc.secondVote) + if err != nil { + return err + } + + assert.Equal(t, tc.vote.Timestamp, tc.secondVote.Timestamp) + } + + return nil + }() + if err != nil { + if tc.err == nil { + t.Fatalf("unexpected error: %v", err) + } + + assert.ErrorIs(t, err, tc.err) + return + } + if tc.err != nil { + t.Fatalf("expected error: %v", tc.err) + } + + if tc.after != nil { + tc.after(t, &tc) + } + }) + } +} + +func Test_Proposals(t *testing.T) { + type testCase struct { + // name is the name of the test case. + name string + // lastSigned is the last signed vote. + // it can be nil. + lastSigned *cmtproto.Proposal + // vote is the vote to sign. + vote *cmtproto.Proposal + // secondVote is the second vote to sign. + // it can be nil. + secondVote *cmtproto.Proposal + + // err is the expected error. + // if nil, no error is expected. + err error + + // after is a function to run after the test case. + // it can be nil. + after func(t *testing.T, tc *testCase) + } + + tests := []testCase{ + { + name: "signing a vote with no other votes signed", + vote: testProposal(), + }, + { + name: "signing two separate votes, validly", + vote: testProposal(height(1)), + secondVote: testProposal(height(2)), + }, + { + name: "signing a vote with a different previous vote signed", + lastSigned: testProposal(height(1)), + vote: testProposal(height(2)), + }, + { + name: "signing the same vote despite it being signed already, first vote is last signed", + lastSigned: testProposal(signed("sig")), + vote: testProposal(), + + after: func(t *testing.T, tc *testCase) { + // it should have the same signature as the last signed vote. + assert.Equal(t, tc.lastSigned.Signature, tc.vote.Signature) + }, + }, + { + name: "signing same vote twice, with different timestamps", + lastSigned: testProposal(signed("sig"), timestamped(100)), + vote: testProposal(timestamped(200)), + + after: func(t *testing.T, tc *testCase) { + // it should have the same signature as the last signed vote. + assert.Equal(t, tc.lastSigned.Signature, tc.vote.Signature) + assert.Equal(t, tc.lastSigned.Timestamp, tc.vote.Timestamp) + }, + }, + { + name: "test height regression", + lastSigned: testProposal(height(100)), + vote: testProposal(height(99)), + + err: privval.ErrHeightRegression, + }, + { + name: "test round regression", + lastSigned: testProposal(round(100)), + vote: testProposal(round(99)), + + err: privval.ErrRoundRegression, + }, + } + + // test cases for signing a vote. + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // outer function to catch any returns errors + err := func() error { + privKeyBts, err := hex.DecodeString(defaultPrivateKey) + if err != nil { + return err + } + + store := newMockStore() + if tc.lastSigned != nil { + if err := setKeys(store, tc.lastSigned.Height, tc.lastSigned.Round, 1, tc.lastSigned.Signature, types.ProposalSignBytes(defaultChainID, tc.lastSigned)); err != nil { + return err + } + } + + privVal, err := privval.NewValidatorSigner(privKeyBts, store) + if err != nil { + return err + } + + // sign the vote + err = privVal.SignProposal(defaultChainID, tc.vote) + if err != nil { + return err + } + + assert.NotNil(t, tc.vote.Signature) + + if tc.secondVote != nil { + // sign the second vote + err = privVal.SignProposal(defaultChainID, tc.secondVote) + if err != nil { + return err + } + + assert.Equal(t, tc.vote.Timestamp, tc.secondVote.Timestamp) + } + + return nil + }() + if err != nil { + if tc.err == nil { + t.Fatalf("unexpected error: %v", err) + } + + assert.ErrorIs(t, err, tc.err) + return + } + if tc.err != nil { + t.Fatalf("expected error: %v", tc.err) + } + + if tc.after != nil { + tc.after(t, &tc) + } + }) + } +} + +// setKeys from vote sets the keys in the AtomicKV from the vote. +// this is useful for testing starting up the atomic KV with an existing vote. +func setKeys(store privval.AtomicReadWriter, ht int64, rnd int32, stp int8, signature []byte, signBytes []byte) error { + latest := privval.LastSignState{ + Height: ht, + Round: rnd, + Step: stp, + Signature: signature, + SignBytes: signBytes, + } + + latestBts, err := json.Marshal(latest) + if err != nil { + return err + } + + return store.Write(latestBts) +} + +// mockStore implements AtomicReadWriter +type mockStore struct { + latest []byte +} + +func newMockStore() *mockStore { + return &mockStore{ + latest: nil, + } +} + +// testVote is a valid vote for height 1 +func testVote(opts ...testVotOpt) *cmtproto.Vote { + options := defaultOptions() + for _, opt := range opts { + opt(options) + } + + return &cmtproto.Vote{ + Type: options.step, + Height: options.height, + Round: options.round, + BlockID: cmtproto.BlockID{ + Hash: hash("hash1"), + PartSetHeader: cmtproto.PartSetHeader{ + Total: 1, + Hash: hash("hash12"), + }, + }, + Timestamp: time.ParseEpochSeconds(options.timestamp), + ValidatorAddress: []byte("validator1"), + ValidatorIndex: 1, + Signature: options.signature, + } +} + +func testProposal(opts ...testVotOpt) *cmtproto.Proposal { + options := defaultOptions() + for _, opt := range opts { + opt(options) + } + if options.step != 1 { + panic("cannot create proposal with step != 1") + } + + return &cmtproto.Proposal{ + Type: 1, + Height: options.height, + Round: options.round, + PolRound: 0, + BlockID: cmtproto.BlockID{ + Hash: hash("hash1"), + PartSetHeader: cmtproto.PartSetHeader{ + Total: 1, + Hash: hash("hash12"), + }, + }, + Timestamp: time.ParseEpochSeconds(options.timestamp), + Signature: options.signature, + } +} + +func defaultOptions() *testVoteOptions { + return &testVoteOptions{ + timestamp: 500, + height: 10, + round: 0, + step: cmtproto.PrevoteType, + } +} + +type testVoteOptions struct { + timestamp float64 + signature []byte + height int64 + round int32 + step cmtproto.SignedMsgType +} + +type testVotOpt func(*testVoteOptions) + +func timestamped(ts float64) testVotOpt { + return func(opts *testVoteOptions) { + opts.timestamp = ts + } +} + +func signed(sig string) testVotOpt { + return func(opts *testVoteOptions) { + opts.signature = []byte(sig) + } +} + +func height(h int64) testVotOpt { + return func(opts *testVoteOptions) { + opts.height = h + } +} + +func round(r int32) testVotOpt { + return func(opts *testVoteOptions) { + opts.round = r + } +} + +func step(s cmtproto.SignedMsgType) testVotOpt { + return func(opts *testVoteOptions) { + opts.step = s + } +} + +func (m *mockStore) Read() ([]byte, error) { + return m.latest, nil +} + +func (m *mockStore) Write(p0 []byte) error { + m.latest = p0 + return nil +} + +func hash(s string) []byte { + hasher := sha256.New() + hasher.Write([]byte(s)) + return hasher.Sum(nil) +} diff --git a/pkg/abci/cometbft/privvalidator_test.go b/pkg/abci/cometbft/privvalidator_test.go deleted file mode 100644 index 12fa3f0a3..000000000 --- a/pkg/abci/cometbft/privvalidator_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package cometbft_test - -import "testing" - -func Test_PrivValidator(t *testing.T) { - type testCase struct { - name string - } -} diff --git a/pkg/abci/utils.go b/pkg/abci/utils.go index f5d3339b4..7c10bb734 100644 --- a/pkg/abci/utils.go +++ b/pkg/abci/utils.go @@ -83,6 +83,8 @@ func ResetState(dbDir string) error { return nil } +// TODO: we will have to get rid of this if we use our own private keys from CometBFT +// Resetting the privValStateFile is ok, however we don't persist Comet's private key func ResetFilePV(privValKeyFile, privValStateFile string) { if _, err := os.Stat(privValKeyFile); err == nil { pv := privval.LoadFilePVEmptyState(privValKeyFile, privValStateFile) diff --git a/pkg/crypto/ed25519.go b/pkg/crypto/ed25519.go index 138a36065..a80f2e4cc 100644 --- a/pkg/crypto/ed25519.go +++ b/pkg/crypto/ed25519.go @@ -96,3 +96,15 @@ func (s Ed25519Address) String() string { // TODO: need an address format return hex.EncodeToString(s.Bytes()) } + +// GenerateEd25519Key generates a new ed25519 key pair. +func GenerateEd25519Key() (*Ed25519PrivateKey, error) { + _, priv, err := ed25519.GenerateKey(nil) + if err != nil { + return nil, err + } + + return &Ed25519PrivateKey{ + key: priv, + }, nil +} diff --git a/pkg/crypto/ed25519_test.go b/pkg/crypto/ed25519_test.go index 5483839c4..5d5322095 100644 --- a/pkg/crypto/ed25519_test.go +++ b/pkg/crypto/ed25519_test.go @@ -105,3 +105,12 @@ func TestEd25519PublicKey_Address(t *testing.T) { eq := pubKey.Address().String() == "0aa611bf555596912bc6f9a9f169f8785918e7ba" assert.True(t, eq, "mismatch address") } + +func Test_GenerateEd25518PrivateKey(t *testing.T) { + pk, err := crypto.GenerateEd25519Key() + require.NoError(t, err, "error generate key") + + if len(pk.Bytes()) != 64 { + t.Errorf("invalid private key length: %d", len(pk.Bytes())) + } +} diff --git a/pkg/engine/db/db.go b/pkg/engine/db/db.go index 55a08ef4a..7dcc6e632 100644 --- a/pkg/engine/db/db.go +++ b/pkg/engine/db/db.go @@ -40,7 +40,7 @@ func (d *DB) Prepare(ctx context.Context, query string) (*PreparedStatement, err return nil, err } - err = sqlanalyzer.ApplyRules(ast, sqlanalyzer.DeterministicAggregates, &sqlanalyzer.RuleMetadata{ + err = sqlanalyzer.ApplyRules(ast, sqlanalyzer.AllRules, &sqlanalyzer.RuleMetadata{ Tables: tables, }) if err != nil { diff --git a/pkg/kv/atomic/kv_test.go b/pkg/kv/atomic/kv_test.go index 0fbac1ca0..ec9186f21 100644 --- a/pkg/kv/atomic/kv_test.go +++ b/pkg/kv/atomic/kv_test.go @@ -12,6 +12,7 @@ import ( // tests basic KV functionality; anything that is not the sessions.Committable interface func Test_BasicKV(t *testing.T) { + dbType := kvTesting.TestKVFlagInMemory type testCase struct { name string testFunc func(t *testing.T, db *atomic.AtomicKV) @@ -146,7 +147,7 @@ func Test_BasicKV(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - db, td, err := kvTesting.OpenTestKv("kv_test", kvTesting.TestKVFlagInMemory) + db, td, err := kvTesting.OpenTestKv("kv_test", dbType) if err != nil { t.Fatal(err) } diff --git a/pkg/kv/atomic/testing/testing.go b/pkg/kv/atomic/testing/testing.go index fdc90e7ef..e0870633f 100644 --- a/pkg/kv/atomic/testing/testing.go +++ b/pkg/kv/atomic/testing/testing.go @@ -1,13 +1,11 @@ package testing import ( - "errors" "fmt" - "os" - "github.com/kwilteam/kwil-db/pkg/kv" "github.com/kwilteam/kwil-db/pkg/kv/atomic" - "github.com/kwilteam/kwil-db/pkg/kv/badger" + badgerTesting "github.com/kwilteam/kwil-db/pkg/kv/badger/testing" + kvTesting "github.com/kwilteam/kwil-db/pkg/kv/testing" ) type TestKVFlag uint8 @@ -17,13 +15,10 @@ const ( TestKVFlagBadger ) -const defaultPath = "./tmp/" - // OpenTestKv opens a new test kv store // It returns a teardown function. If a teardown // function is not necessary, it does nothing func OpenTestKv(name string, flag TestKVFlag) (*atomic.AtomicKV, func() error, error) { - directory := fmt.Sprintf("%s%s", defaultPath, name) switch flag { case TestKVFlagInMemory: @@ -31,125 +26,17 @@ func OpenTestKv(name string, flag TestKVFlag) (*atomic.AtomicKV, func() error, e return nil } - db, err := atomic.NewAtomicKV(newMemoryKV()) + db, err := atomic.NewAtomicKV(kvTesting.NewMemoryKV()) return db, fn, err case TestKVFlagBadger: - badgerDB, err := badger.NewBadgerDB(directory) + badgerDB, td, err := badgerTesting.NewTestBadgerDB(name, nil) if err != nil { return nil, nil, err } - fn := func() error { - return errors.Join( - badgerDB.Close(), - os.RemoveAll(directory), - ) - } - db, err := atomic.NewAtomicKV(badgerDB) - return db, fn, err + return db, td, err default: return nil, nil, fmt.Errorf("unknown flag: %d", flag) } } - -func newMemoryKV() *MemoryKV { - return &MemoryKV{ - values: make(map[string][]byte), - } -} - -type MemoryKV struct { - values map[string][]byte -} - -func (m *MemoryKV) BeginTransaction() kv.Transaction { - - return &MemoryTransaction{ - kv: m, - currentTx: make(map[string][]byte), - currentDeletes: make(map[string]struct{}), - } -} - -func (m *MemoryKV) Delete(key []byte) error { - _, ok := m.values[string(key)] - if !ok { - return kv.ErrKeyNotFound - } - - delete(m.values, string(key)) - - return nil -} - -func (m *MemoryKV) Get(key []byte) ([]byte, error) { - val, ok := m.values[string(key)] - if !ok { - return nil, kv.ErrKeyNotFound - } - - return val, nil -} - -func (m *MemoryKV) Set(key []byte, value []byte) error { - m.values[string(key)] = value - - return nil -} - -type MemoryTransaction struct { - currentTx map[string][]byte - currentDeletes map[string]struct{} - kv *MemoryKV -} - -func (m *MemoryTransaction) Commit() error { - for k, v := range m.currentTx { - m.kv.values[k] = v - } - - for k := range m.currentDeletes { - delete(m.kv.values, k) - } - - m.currentTx = nil - - return nil -} - -func (m *MemoryTransaction) Delete(key []byte) error { - _, err := m.Get(key) - if err != nil { - return err - } - - m.currentDeletes[string(key)] = struct{}{} - - return nil -} - -func (m *MemoryTransaction) Discard() { - m.currentTx = nil - m.currentDeletes = nil -} - -func (m *MemoryTransaction) Get(key []byte) ([]byte, error) { - val, ok := m.currentTx[string(key)] - if ok { - return val, nil - } - - val, ok = m.kv.values[string(key)] - if ok { - return val, nil - } - - return nil, kv.ErrKeyNotFound -} - -func (m *MemoryTransaction) Set(key []byte, value []byte) error { - m.currentTx[string(key)] = value - - return nil -} diff --git a/pkg/kv/badger/db.go b/pkg/kv/badger/db.go index 04532fb65..7f6a19737 100644 --- a/pkg/kv/badger/db.go +++ b/pkg/kv/badger/db.go @@ -9,13 +9,19 @@ import ( // NewBadgerDB creates a new BadgerDB. // It takes a path, like path/to/db, where the database will be stored. -func NewBadgerDB(path string) (*BadgerDB, error) { - opts := badger.DefaultOptions(path) - opts.Logger = nil - db, err := badger.Open(opts) +func NewBadgerDB(path string, options *Options) (*BadgerDB, error) { + badgerOpts := badger.DefaultOptions(path) + + if options != nil { + options.apply(&badgerOpts) + } + + badgerOpts.Logger = nil + db, err := badger.Open(badgerOpts) if err != nil { return nil, err } + return &BadgerDB{db: db}, nil } @@ -114,3 +120,18 @@ func (t *Transaction) Get(key []byte) ([]byte, error) { val, err = item.ValueCopy(nil) return val, err } + +// Options are options for the BadgerDB. +// These get translated into Badger's options. +// We provide this abstraction layer since Badger has a lot of options, +// and I don't want future users of this to worry about all of them. +type Options struct { + // GuaranteeFSync guarantees that all writes to the wal are fsynced before + // attemtping to be written to the LSM tree. + GuaranteeFSync bool +} + +// apply applies the options to the badger options. +func (o *Options) apply(opts *badger.Options) { + opts.SyncWrites = o.GuaranteeFSync +} diff --git a/pkg/kv/badger/db_test.go b/pkg/kv/badger/db_test.go new file mode 100644 index 000000000..b8e80f22f --- /dev/null +++ b/pkg/kv/badger/db_test.go @@ -0,0 +1,26 @@ +package badger_test + +import ( + "testing" + + badgerTesting "github.com/kwilteam/kwil-db/pkg/kv/badger/testing" +) + +// testing double write does not produce an error +func Test_BadgerKV(t *testing.T) { + db, td, err := badgerTesting.NewTestBadgerDB("test", nil) + if err != nil { + t.Fatal(err) + } + defer td() + + err = db.Set([]byte("key"), []byte("value")) + if err != nil { + t.Fatal(err) + } + + err = db.Set([]byte("key"), []byte("value2")) + if err != nil { + t.Fatal(err) + } +} diff --git a/pkg/kv/badger/testing/db.go b/pkg/kv/badger/testing/db.go new file mode 100644 index 000000000..63b1d0cff --- /dev/null +++ b/pkg/kv/badger/testing/db.go @@ -0,0 +1,31 @@ +package testing + +import ( + "errors" + "fmt" + "os" + + "github.com/kwilteam/kwil-db/pkg/kv/badger" +) + +const defaultPath = "./tmp/" + +// NewTestBadgerDB returns a new badger db for testing +// it also returns a teardown function, which will remove +// the db and the directory +func NewTestBadgerDB(name string, options *badger.Options) (*badger.BadgerDB, func() error, error) { + directory := fmt.Sprintf("%s%s", defaultPath, name) + + db, err := badger.NewBadgerDB(directory, nil) + if err != nil { + return nil, nil, err + } + + fn := func() error { + return errors.Join( + db.Close(), + os.RemoveAll(defaultPath), + ) + } + return db, fn, err +} diff --git a/pkg/kv/testing/memory.go b/pkg/kv/testing/memory.go new file mode 100644 index 000000000..63dbc0e61 --- /dev/null +++ b/pkg/kv/testing/memory.go @@ -0,0 +1,106 @@ +package testing + +import "github.com/kwilteam/kwil-db/pkg/kv" + +func NewMemoryKV() *MemoryKV { + return &MemoryKV{ + values: make(map[string][]byte), + } +} + +type MemoryKV struct { + values map[string][]byte +} + +var _ kv.KVStore = (*MemoryKV)(nil) + +func (m *MemoryKV) BeginTransaction() kv.Transaction { + + return &MemoryTransaction{ + kv: m, + currentTx: make(map[string][]byte), + currentDeletes: make(map[string]struct{}), + } +} + +func (m *MemoryKV) Delete(key []byte) error { + _, ok := m.values[string(key)] + if !ok { + return kv.ErrKeyNotFound + } + + delete(m.values, string(key)) + + return nil +} + +func (m *MemoryKV) Get(key []byte) ([]byte, error) { + val, ok := m.values[string(key)] + if !ok { + return nil, kv.ErrKeyNotFound + } + + return val, nil +} + +func (m *MemoryKV) Set(key []byte, value []byte) error { + m.values[string(key)] = value + + return nil +} + +type MemoryTransaction struct { + currentTx map[string][]byte + currentDeletes map[string]struct{} + kv *MemoryKV +} + +func (m *MemoryTransaction) Commit() error { + for k, v := range m.currentTx { + m.kv.values[k] = v + } + + for k := range m.currentDeletes { + delete(m.kv.values, k) + } + + m.currentTx = nil + + return nil +} + +func (m *MemoryTransaction) Delete(key []byte) error { + _, err := m.Get(key) + if err != nil { + return err + } + + m.currentDeletes[string(key)] = struct{}{} + + return nil +} + +func (m *MemoryTransaction) Discard() { + m.currentTx = nil + m.currentDeletes = nil +} + +func (m *MemoryTransaction) Get(key []byte) ([]byte, error) { + val, ok := m.currentTx[string(key)] + if ok { + return val, nil + } + + val, ok = m.kv.values[string(key)] + if ok { + return val, nil + } + + return nil, kv.ErrKeyNotFound +} + +func (m *MemoryTransaction) Set(key []byte, value []byte) error { + m.currentTx[string(key)] = value + + return nil +} diff --git a/pkg/sessions/session.go b/pkg/sessions/session.go index 5c71da8af..ebc5f910f 100644 --- a/pkg/sessions/session.go +++ b/pkg/sessions/session.go @@ -71,7 +71,7 @@ func NewAtomicCommitter(ctx context.Context, committables map[string]Committable opt(a) } - err := a.flushWal(ctx) + err := a.applyWal(ctx) if err != nil { return nil, err } @@ -95,7 +95,7 @@ func (a *AtomicCommitter) Begin(ctx context.Context) (err error) { // Commit commits the atomic session. // It aggregates all commit ids from the committables and returns them as a single Sha256 hash. -// It can be given a callback function to handle any errors that occur during the apply phase (which procedes asynchronously) after this function returns. +// It can be given a callback function to handle any errors that occur during the apply phase (which proceeds asynchronously) after this function returns. func (a *AtomicCommitter) Commit(ctx context.Context, applyCallback func(error)) (err error) { a.mu.Lock() defer a.mu.Unlock() @@ -209,11 +209,11 @@ func (a *AtomicCommitter) handleErr(ctx context.Context, err *error) { } } -// flushWal will try to apply all changes in the WAL to the committables. +// applyWal will try to apply all changes in the WAL to the committables. // If the wal does not contain a commit record, it will delete all changes in the WAL. // If the wal contains a commit record, it will apply all changes in the WAL to the committables. // If the wal contains a commit record, but the commit fails, it will return an error. -func (a *AtomicCommitter) flushWal(ctx context.Context) (err error) { +func (a *AtomicCommitter) applyWal(ctx context.Context) (err error) { beginRecord, err := a.wal.ReadNext(ctx) if err == io.EOF { return a.wal.Truncate(ctx)