Skip to content

Commit

Permalink
better abstracted away the authenticators, as suggested by Jon
Browse files Browse the repository at this point in the history
  • Loading branch information
brennanjl committed Oct 3, 2023
1 parent e110a4e commit 4ec61cb
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 52 deletions.
7 changes: 1 addition & 6 deletions internal/app/kwild/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,9 @@ func buildEngine(d *coreDependencies, a *sessions.AtomicCommitter) *engine.Engin
log: *d.log.Named("sqlite-committable"),
}

addressFuncs := make(map[string]engine.AddressFunc)
for _, a := range auth.ListAuthenticators() {
addressFuncs[a.Name] = a.Authenticator.Address
}

e, err := engine.Open(d.ctx, d.opener,
sqlCommitRegister,
addressFuncs,
auth.GetAddress,
engine.WithLogger(*d.log.Named("engine")),
engine.WithExtensions(adaptExtensions(extensions)),
)
Expand Down
14 changes: 12 additions & 2 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func RegisterAuthenticator(name string, auth Authenticator) error {
return nil
}

// GetAuthenticator returns an authenticator by the name it was registered with
func GetAuthenticator(name string) (Authenticator, error) {
// getAuthenticator returns an authenticator by the name it was registered with
func getAuthenticator(name string) (Authenticator, error) {
name = strings.ToLower(name)
auth, ok := registeredAuthenticators[name]
if !ok {
Expand All @@ -40,6 +40,16 @@ func GetAuthenticator(name string) (Authenticator, error) {
return auth, nil
}

// GetAddress returns an address from a public key and authenticator type
func GetAddress(authType string, sender []byte) (string, error) {
auth, err := getAuthenticator(authType)
if err != nil {
return "", err
}

return auth.Address(sender)
}

// ListAuthenticators returns a list of registered authenticators
func ListAuthenticators() []struct {
Name string
Expand Down
8 changes: 2 additions & 6 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,12 @@ func Test_Auth(t *testing.T) {
sig, err := tc.signer.Sign(msg)
assert.NoError(t, err)

// get the authenticator
authenticator, err := auth.GetAuthenticator(sig.Type)
assert.NoError(t, err)

// verify the signature
err = authenticator.Verify(tc.signer.PublicKey(), msg, sig.Signature)
err = sig.Verify(tc.signer.PublicKey(), msg)
assert.NoError(t, err)

// check the address
address, err := authenticator.Address(tc.signer.PublicKey())
address, err := auth.GetAddress(sig.Type, tc.signer.PublicKey())
assert.NoError(t, err)

assert.Equal(t, tc.address, address)
Expand Down
10 changes: 10 additions & 0 deletions pkg/auth/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ type Signature struct {
Type string `json:"signature_type"`
}

// Verify verifies the signature against the given message and public key.
func (s *Signature) Verify(senderPubKey, msg []byte) error {
a, err := getAuthenticator(s.Type)
if err != nil {
return err
}

return a.Verify(senderPubKey, msg, s.Signature)
}

// Signer is an interface for something that can sign messages.
// It returns signatures with a designated AuthType, which should
// be used to determine how to verify the signature.
Expand Down
9 changes: 4 additions & 5 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,21 @@ type Engine struct {

// commitRegister is the commit register that is used to register commits
commitRegister CommitRegister

// addressFuncs is a map of address functions that are used to generate addresses from public keys
addressFuncs map[string]AddressFunc
// addresser takes in an address type and a public key and returns an address
addresser Addresser
}

// Open opens a new engine with the provided options.
// It will also open any stored datasets.
func Open(ctx context.Context, dbOpener sql.Opener, commitRegister CommitRegister, addressFuncs map[string]AddressFunc, opts ...EngineOpt) (*Engine, error) {
func Open(ctx context.Context, dbOpener sql.Opener, commitRegister CommitRegister, addresser Addresser, opts ...EngineOpt) (*Engine, error) {
e := &Engine{
name: masterDBName,
log: log.NewNoOp(),
datasets: make(map[string]Dataset),
extensions: make(map[string]ExtensionInitializer),
opener: dbOpener,
commitRegister: commitRegister,
addressFuncs: addressFuncs,
addresser: addresser,
}

for _, opt := range opts {
Expand Down
12 changes: 3 additions & 9 deletions pkg/engine/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package engine

import (
"context"
"fmt"
"io"

"github.com/kwilteam/kwil-db/pkg/engine/dataset"
Expand Down Expand Up @@ -53,12 +52,7 @@ func (e *Engine) newDatasetUser(u *types.User) (*datasetUser, error) {
return nil, err
}

addrFunc, ok := e.addressFuncs[u.AuthType]
if !ok {
return nil, fmt.Errorf("unknown auth type %s", u.AuthType)
}

addr, err := addrFunc(u.PublicKey)
addr, err := e.addresser(u.AuthType, u.PublicKey)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -90,5 +84,5 @@ func (u *datasetUser) Address() string {
return u.address
}

// AddressFunc is a function that takes a public key and returns an address
type AddressFunc func([]byte) (string, error)
// Addresser is a function that takes an address type and a public key and returns an address
type Addresser func(addressType string, pubkey []byte) (string, error)
7 changes: 1 addition & 6 deletions pkg/engine/testing/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ import (
func NewTestEngine(ctx context.Context, ec engine.CommitRegister, opts ...engine.EngineOpt) (*engine.Engine, func() error, error) {
opener := newTestDBOpener()

addressFuncs := make(map[string]engine.AddressFunc)
for _, a := range auth.ListAuthenticators() {
addressFuncs[a.Name] = a.Authenticator.Address
}

e, err := engine.Open(ctx, opener, ec, addressFuncs,
e, err := engine.Open(ctx, opener, ec, auth.GetAddress,
opts...,
)
if err != nil {
Expand Down
7 changes: 1 addition & 6 deletions pkg/sql/sqlite/functions/addresses/addresses.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,7 @@ func addressFunc(ctx sqlite.Context, args []sqlite.Value) (sqlite.Value, error)
return raiseErr(addressFuncName, fmt.Errorf("failed to read public key identifier: %w", err))
}

authenticator, err := auth.GetAuthenticator(ident.AuthType)
if err != nil {
return raiseErr(addressFuncName, fmt.Errorf("failed to get authenticator: %w", err))
}

address, err := authenticator.Address(ident.PublicKey)
address, err := auth.GetAddress(ident.AuthType, ident.PublicKey)
if err != nil {
return raiseErr(addressFuncName, fmt.Errorf("failed to get address: %w", err))
}
Expand Down
7 changes: 1 addition & 6 deletions pkg/transactions/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,5 @@ func (s *CallMessage) Verify() error {
return err
}

authenticator, err := auth.GetAuthenticator(s.Signature.Type)
if err != nil {
return err
}

return authenticator.Verify(s.Sender, msg, s.Signature.Signature)
return s.Signature.Verify(s.Sender, msg)
}
7 changes: 1 addition & 6 deletions pkg/transactions/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,7 @@ func (t *Transaction) Verify() error {
return err
}

authenticator, err := auth.GetAuthenticator(t.Signature.Type)
if err != nil {
return err
}

return authenticator.Verify(t.Sender, msg, t.Signature.Signature)
return t.Signature.Verify(t.Sender, msg)
}

// Sign signs transaction body with given signer.
Expand Down

0 comments on commit 4ec61cb

Please sign in to comment.