Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charithabandi committed Sep 26, 2023
1 parent e8c5e91 commit b9d101e
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 39 deletions.
2 changes: 1 addition & 1 deletion pkg/validators/mgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func (vm *ValidatorMgr) Finalize(ctx context.Context) []*Validator {
for candidate, join := range vm.candidates {
if join.votes() < join.requiredVotes() {

if vm.lastBlockHeight >= join.expiresAt {
if join.expiresAt != -1 && vm.lastBlockHeight >= join.expiresAt {
// Join request expired
delete(vm.candidates, candidate)
if err := vm.db.DeleteJoinRequest(ctx, join.pubkey); err != nil {
Expand Down
7 changes: 1 addition & 6 deletions pkg/validators/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@ func newValidatorStore(ctx context.Context, datastore Datastore, log log.Logger)
log: log,
}

// err := ar.initTables(ctx)
// if err != nil {
// return nil, fmt.Errorf("failed to initialize tables: %w", err)
// }

err := ar.databaseUpgrade(ctx)
if err != nil {
return nil, fmt.Errorf("failed to upgrade database: %w", err)
return nil, fmt.Errorf("failed to initialize database at version %d due to error: %w", valStoreVersion, err)
}

// err = ar.prepareStatements()
Expand Down
26 changes: 14 additions & 12 deletions pkg/validators/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,21 @@ func (vs *validatorStore) CheckVersion(ctx context.Context) (int, upgradeAction,
_, valErr := vs.currentValidators(ctx)

if versionErr != nil && valErr != nil {
// Fresh db, do regular init
// Fresh db, do regular initialization at valStoreVersion
return valStoreVersion, upgradeActionNone, nil
} else if versionErr != nil && valErr == nil {
// Legacy db
// Legacy db without version tracking
return 0, upgradeActionLegacy, nil
} else if versionErr == nil && valErr == nil {
// both tables exist
if version == valStoreVersion {
// Nothing to do
// DB on the latest version
return version, upgradeActionNone, nil
} else if version < valStoreVersion {
// Run DB migrations
// DB on previous version, Run DB migrations
return version, upgradeActionRunMigrations, nil
} else if version > valStoreVersion {
// Error
// Invalid DB version, return error
return version, upgradeActionNone, fmt.Errorf("validator store version %d is newer than the current version %d", version, valStoreVersion)
}
} else {
Expand All @@ -68,12 +68,12 @@ func (vs *validatorStore) CheckVersion(ctx context.Context) (int, upgradeAction,

func (vs *validatorStore) databaseUpgrade(ctx context.Context) error {
version, action, err := vs.CheckVersion(ctx)
vs.log.Info("databaseUpgrade", zap.String("version", fmt.Sprintf("%d", version)), zap.String("action", upgradeActionString(action)), zap.Error(err))

if err != nil {
return err
}

vs.log.Info("databaseUpgrade", zap.String("version", fmt.Sprintf("%d", version)), zap.String("action", upgradeActionString(action)), zap.Error(err))

switch action {
case upgradeActionNone:
return vs.initTables(ctx)
Expand All @@ -92,11 +92,13 @@ func (vs *validatorStore) runMigrations(ctx context.Context, version int) error
if err := vs.upgradeValidatorsDB_0_1(ctx); err != nil {
return err
}
// fallthrough
default:
fallthrough
case valStoreVersion:
vs.log.Info("databaseUpgrade: completed successfully")
return nil
default:
return fmt.Errorf("unknown version: %d", version)
}
return nil
}

/*
Expand All @@ -117,11 +119,11 @@ func (vs *validatorStore) upgradeValidatorsDB_0_1(ctx context.Context) error {
}

if err := vs.db.Execute(ctx, "ALTER TABLE join_reqs ADD COLUMN expiresAt INTEGER;", nil); err != nil {
return fmt.Errorf("failed to upgrade validators db from version 0 to 1: %w", err)
return fmt.Errorf("failed to add expiresAt column to join_reqs table: %w", err)
}

if err := vs.db.Execute(ctx, "UPDATE join_reqs SET expiresAt = -1;", nil); err != nil {
return fmt.Errorf("failed to upgrade validators db from version 0 to 1: %w", err)
return fmt.Errorf("failed to set indefinite join expiry for existing join requests: %w", err)
}
return nil
}
151 changes: 131 additions & 20 deletions pkg/validators/upgrade_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package validators

import (
"bytes"
"context"
"os"
"testing"
Expand All @@ -9,9 +10,10 @@ import (
sqlTesting "github.com/kwilteam/kwil-db/pkg/sql/testing"
)

func setup(srcfile string) {
func setup(srcFile string) {
// Copies the db file to tmp
os.MkdirAll("tmp", os.ModePerm)
bts, err := os.ReadFile(srcfile)
bts, err := os.ReadFile(srcFile)
if err != nil {
panic(err)
}
Expand All @@ -21,8 +23,15 @@ func setup(srcfile string) {
panic(err)
}
}

func TestValidatorStoreUpgradeLegacyToV1(t *testing.T) {
setup("./test_data/version0.sqlite")

/*
Open Version 0 DB. It contains:
- 1 validator
- 3 join requests
*/
ds, td, err := sqlTesting.OpenTestDB("validator_db")
if err != nil {
t.Fatal(err)
Expand All @@ -31,12 +40,13 @@ func TestValidatorStoreUpgradeLegacyToV1(t *testing.T) {
ctx := context.Background()
logger := log.NewStdOut(log.DebugLevel)

// vs
// validator store
vs := &validatorStore{
db: ds,
log: logger,
}

// Verify validator count is 1
results, err := vs.db.Query(ctx, "SELECT COUNT(*) FROM validators", nil)
if err != nil {
t.Fatal(err)
Expand All @@ -45,7 +55,11 @@ func TestValidatorStoreUpgradeLegacyToV1(t *testing.T) {
t.Fatalf("Expected 1 result, got %d", len(results))
}

// CheckVersion
/*
CheckVersion and Upgrade Action to take:
- Version: 0
- Action: upgradeActionLegacy
*/
version, action, err := vs.CheckVersion(ctx)
if err != nil {
t.Fatal(err)
Expand All @@ -54,22 +68,22 @@ func TestValidatorStoreUpgradeLegacyToV1(t *testing.T) {
t.Fatalf("Expected version 0, got %d", version)
}
if action != upgradeActionLegacy {
t.Fatalf("Expected action %d, got %d", upgradeActionLegacy, action)
t.Fatalf("Expected action %s, got %s", upgradeActionString(upgradeActionLegacy), upgradeActionString(action))
}

// Get JoinRequest entries
// Expect failure as expiresAt column doesn't exist in legacy code
_, err = vs.ActiveVotes(ctx)
if err == nil {
t.Fatal(err)
}

// Upgrade
// Upgrade DB to version 1
err = vs.databaseUpgrade(ctx)
if err != nil {
t.Fatal(err)
}

// Check Version Table
// Check Version Table to ensure version is 1
version, err = vs.currentVersion(ctx)
if err != nil {
t.Fatal(err)
Expand All @@ -78,19 +92,27 @@ func TestValidatorStoreUpgradeLegacyToV1(t *testing.T) {
t.Fatalf("Expected version %d, got %d", valStoreVersion, version)
}

// Get JoinRequest entries
// Expect 3 join requests
votes, err := vs.ActiveVotes(ctx)
if err != nil {
t.Fatal(err)
}
if len(votes) != 3 {
t.Fatalf("Starting votes not empty (%d)", len(votes))
t.Fatalf("Mismatch in join_requests (%d)", len(votes))
}

testValidatorJoins(t, vs, ctx)

}

func TestValidatorStoreUpgradeV1(t *testing.T) {
setup("./test_data/version1.sqlite")

/*
Open Version 0 DB. It contains:
- 1 validator
- 3 join requests
*/
ds, td, err := sqlTesting.OpenTestDB("validator_db")
if err != nil {
t.Fatal(err)
Expand All @@ -99,12 +121,13 @@ func TestValidatorStoreUpgradeV1(t *testing.T) {
ctx := context.Background()
logger := log.NewStdOut(log.DebugLevel)

// vs
// validator store
vs := &validatorStore{
db: ds,
log: logger,
}

// Verify validator count is 1
results, err := vs.db.Query(ctx, "SELECT COUNT(*) FROM validators", nil)
if err != nil {
t.Fatal(err)
Expand All @@ -113,19 +136,23 @@ func TestValidatorStoreUpgradeV1(t *testing.T) {
t.Fatalf("Expected 1 result, got %d", len(results))
}

// CheckVersion
version, action, err := vs.CheckVersion(ctx)
/*
CheckVersion and Upgrade Action to take:
- Version: 1
- Action: upgradeActionNone
*/
versionPre, action, err := vs.CheckVersion(ctx)
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatalf("Expected version 0, got %d", version)
if versionPre != 1 {
t.Fatalf("Expected version 0, got %d", versionPre)
}
if action != upgradeActionNone {
t.Fatalf("Expected action %d, got %d", upgradeActionLegacy, action)
t.Fatalf("Expected action %s, got %s", upgradeActionString(upgradeActionNone), upgradeActionString(action))
}

// Get JoinRequest entries
// Three entries in join_reqs table with expiresAt column
votes, err := vs.ActiveVotes(ctx)
if err != nil {
t.Fatal(err)
Expand All @@ -141,12 +168,13 @@ func TestValidatorStoreUpgradeV1(t *testing.T) {
}

// Check Version Table
version, err = vs.currentVersion(ctx)
versionPost, err := vs.currentVersion(ctx)
if err != nil {
t.Fatal(err)
}
if version != valStoreVersion {
t.Fatalf("Expected version %d, got %d", valStoreVersion, version)
// Version should be 1, no upgrade
if versionPost != versionPre {
t.Fatalf("Expected version %d, got %d", versionPre, versionPost)
}

// Get JoinRequest entries
Expand All @@ -157,4 +185,87 @@ func TestValidatorStoreUpgradeV1(t *testing.T) {
if len(votes) != 3 {
t.Fatalf("Starting votes not empty (%d)", len(votes))
}

testValidatorJoins(t, vs, ctx)
}

// Verify validator joins
func testValidatorJoins(t *testing.T, vs *validatorStore, ctx context.Context) {

// Add 2 validators
validators := make([]*Validator, 2)
for i := range validators {
validators[i] = newValidator()
err := vs.AddValidator(ctx, validators[i].PubKey, validators[i].Power)
if err != nil {
t.Fatal(err)
}
}

// Validator count = 3
curVals, err := vs.CurrentValidators(ctx)
if err != nil {
t.Fatal(err)
}
if len(curVals) != 3 {
t.Fatalf("Expected 3 validators, got %d", len(curVals))
}

// Validators public keys
approvers := make([][]byte, len(curVals))
for i, vi := range curVals {
approvers[i] = vi.PubKey
}

// Create a Joiner node and Initiate a join request
joiner := newValidator()
expiresAt := int64(3)
err = vs.StartJoinRequest(ctx, joiner.PubKey, approvers, 1, expiresAt)
if err != nil {
t.Fatal(err)
}

// Approve the join request
ApproveJoinRequest(t, vs, ctx, joiner, approvers[0], 1)

// Approve the join request
ApproveJoinRequest(t, vs, ctx, joiner, approvers[1], 2)
}

func ApproveJoinRequest(t *testing.T, vs *validatorStore, ctx context.Context, joiner *Validator, approver []byte, approvalCnt int) {
// Approve the join request
err := vs.AddApproval(ctx, joiner.PubKey, approver)
if err != nil {
t.Fatal("Unable to add approval", err)
}

join_stats, err := vs.ActiveVotes(ctx)
if err != nil {
t.Fatal(err)
}

// 4 join requests
if len(join_stats) != 4 {
t.Fatalf("Expected 4 join requests, got %d", len(join_stats))
}

// # of approvals = 1
count := 0
for _, v := range join_stats {
if bytes.Equal(v.Candidate, joiner.PubKey) {
if v.ExpiresAt != 3 {
t.Fatalf("Expected expiresAt 3, got %d", v.ExpiresAt)
}

for _, approved := range v.Approved {
if approved {
count++
}
}
break
}
}
if count != approvalCnt {
t.Fatalf("Expected approval cnt %d, got %d", approvalCnt, count)
}
}

0 comments on commit b9d101e

Please sign in to comment.