Skip to content

Commit

Permalink
JWT Version Checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mkysel committed Jan 3, 2025
1 parent a1f6462 commit af31869
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 18 deletions.
1 change: 1 addition & 0 deletions cmd/replication/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ func main() {
dbInstance,
blockchainPublisher,
fmt.Sprintf("0.0.0.0:%d", options.API.Port),
Commit,
)
if err != nil {
log.Fatal("initializing server", zap.Error(err))
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require (
)

require (
github.com/Masterminds/semver/v3 v3.1.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
)
Expand Down
40 changes: 40 additions & 0 deletions pkg/authn/claims.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package authn

import (
"fmt"
"github.com/Masterminds/semver/v3"
"github.com/golang-jwt/jwt/v5"
)

const (
// XMTPD_COMPATIBLE_VERSION_CONSTRAINT major or minor version bumps indicate backwards incompatible changes
XMTPD_COMPATIBLE_VERSION_CONSTRAINT = "~ 0.1.3"
)

type XmtpdClaims struct {
Version *string `json:"version,omitempty"`
jwt.RegisteredClaims
}

func ValidateVersionClaimIsCompatible(claims *XmtpdClaims) error {
if claims.Version == nil || *claims.Version == "" {
return nil
}

c, err := semver.NewConstraint(XMTPD_COMPATIBLE_VERSION_CONSTRAINT)
if err != nil {
return err
}

v, err := semver.NewVersion(*claims.Version)

if err != nil {
return err
}

if ok := c.Check(v); !ok {
return fmt.Errorf("version %s is not compatible", *claims.Version)
}

return nil
}
22 changes: 15 additions & 7 deletions pkg/authn/tokenFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,33 @@ const (
type TokenFactory struct {
privateKey *ecdsa.PrivateKey
nodeID uint32
version string
}

func NewTokenFactory(privateKey *ecdsa.PrivateKey, nodeID uint32) *TokenFactory {
func NewTokenFactory(privateKey *ecdsa.PrivateKey, nodeID uint32, version string) *TokenFactory {
return &TokenFactory{
privateKey: privateKey,
nodeID: nodeID,
version: version,
}
}

func (f *TokenFactory) CreateToken(forNodeID uint32) (*Token, error) {
now := time.Now()
expiresAt := now.Add(TOKEN_DURATION)

token := jwt.NewWithClaims(&SigningMethodSecp256k1{}, &jwt.RegisteredClaims{
Subject: strconv.Itoa(int(f.nodeID)),
Audience: []string{strconv.Itoa(int(forNodeID))},
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
})
claims := &XmtpdClaims{
Version: &f.version,
RegisteredClaims: jwt.RegisteredClaims{
Subject: strconv.Itoa(int(f.nodeID)),
Audience: []string{strconv.Itoa(int(forNodeID))},
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(now),
},
}

// Create a new token with custom claims
token := jwt.NewWithClaims(&SigningMethodSecp256k1{}, claims)

signedString, err := token.SignedString(f.privateKey)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/authn/tokenFactory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func TestTokenFactory(t *testing.T) {
privateKey := testutils.RandomPrivateKey(t)
factory := NewTokenFactory(privateKey, 100)
factory := NewTokenFactory(privateKey, 100, "")

token, err := factory.CreateToken(200)
require.NoError(t, err)
Expand Down
26 changes: 24 additions & 2 deletions pkg/authn/verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@ func NewRegistryVerifier(registry registry.NodeRegistry, myNodeID uint32) *Regis
func (v *RegistryVerifier) Verify(tokenString string) error {
var token *jwt.Token
var err error
if token, err = jwt.Parse(tokenString, v.getMatchingPublicKey); err != nil {

if token, err = jwt.ParseWithClaims(
tokenString,
&XmtpdClaims{},
v.getMatchingPublicKey,
); err != nil {
return err
}

if err = v.validateAudience(token); err != nil {
return err
}
Expand All @@ -42,6 +46,10 @@ func (v *RegistryVerifier) Verify(tokenString string) error {
return err
}

if err = v.validateClaims(token); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -85,6 +93,20 @@ func (v *RegistryVerifier) validateAudience(token *jwt.Token) error {
return fmt.Errorf("could not find node ID in audience %v", audience)
}

func (v *RegistryVerifier) validateClaims(token *jwt.Token) error {
claims, ok := token.Claims.(*XmtpdClaims)
if !ok {
return fmt.Errorf("invalid token claims type")
}

// Check if the token is valid
if !token.Valid {
return fmt.Errorf("invalid token")
}

return ValidateVersionClaimIsCompatible(claims)
}

// Parse the subject claim of the JWT and return the node ID as a uint32
func getSubjectNodeId(token *jwt.Token) (uint32, error) {
subject, err := token.Claims.GetSubject()
Expand Down
8 changes: 4 additions & 4 deletions pkg/authn/verifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func buildJwt(
func TestVerifier(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), "")

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
Expand All @@ -79,7 +79,7 @@ func TestVerifier(t *testing.T) {
func TestWrongAudience(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), "")

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(&registry.Node{
Expand All @@ -97,7 +97,7 @@ func TestWrongAudience(t *testing.T) {
func TestUnknownNode(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), "")

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))
nodeRegistry.EXPECT().GetNode(uint32(SIGNER_NODE_ID)).Return(nil, errors.New("node not found"))
Expand All @@ -112,7 +112,7 @@ func TestUnknownNode(t *testing.T) {
func TestWrongPublicKey(t *testing.T) {
signerPrivateKey := testutils.RandomPrivateKey(t)

tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID))
tokenFactory := NewTokenFactory(signerPrivateKey, uint32(SIGNER_NODE_ID), "")

verifier, nodeRegistry := buildVerifier(t, uint32(VERIFIER_NODE_ID))

Expand Down
2 changes: 1 addition & 1 deletion pkg/interceptors/client/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestAuthInterceptor(t *testing.T) {
privateKey := testutils.RandomPrivateKey(t)
myNodeID := uint32(100)
targetNodeID := uint32(200)
tokenFactory := authn.NewTokenFactory(privateKey, myNodeID)
tokenFactory := authn.NewTokenFactory(privateKey, myNodeID, "")
interceptor := NewAuthInterceptor(tokenFactory, targetNodeID)
token, err := interceptor.getToken()
require.NoError(t, err)
Expand Down
3 changes: 2 additions & 1 deletion pkg/registrant/registrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func NewRegistrant(
db *queries.Queries,
nodeRegistry registry.NodeRegistry,
privateKeyString string,
version string,
) (*Registrant, error) {
privateKey, err := utils.ParseEcdsaPrivateKey(privateKeyString)
if err != nil {
Expand All @@ -47,7 +48,7 @@ func NewRegistrant(
return nil, err
}

tokenFactory := authn.NewTokenFactory(privateKey, record.NodeID)
tokenFactory := authn.NewTokenFactory(privateKey, record.NodeID, version)

log.Info(
"Registrant identified",
Expand Down
10 changes: 10 additions & 0 deletions pkg/registrant/registrant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type deps struct {
privKey1Str string
privKey2 *ecdsa.PrivateKey
privKey3 *ecdsa.PrivateKey
version string
}

func setup(t *testing.T) (deps, func()) {
Expand All @@ -54,6 +55,7 @@ func setup(t *testing.T) (deps, func()) {
privKey1Str: privKey1Str,
privKey2: privKey2,
privKey3: privKey3,
version: "",
}, dbCleanup
}

Expand All @@ -70,6 +72,7 @@ func setupWithRegistrant(t *testing.T) (deps, *registrant.Registrant, func()) {
deps.db,
deps.registry,
deps.privKey1Str,
deps.version,
)
require.NoError(t, err)

Expand All @@ -86,6 +89,7 @@ func TestNewRegistrantBadPrivateKey(t *testing.T) {
deps.db,
deps.registry,
"badkey",
deps.version,
)
require.ErrorContains(t, err, "parse")
}
Expand All @@ -105,6 +109,7 @@ func TestNewRegistrantNotInRegistry(t *testing.T) {
deps.db,
deps.registry,
deps.privKey1Str,
deps.version,
)
require.ErrorContains(t, err, "registry")
}
Expand All @@ -125,6 +130,7 @@ func TestNewRegistrantNewDatabase(t *testing.T) {
deps.db,
deps.registry,
deps.privKey1Str,
deps.version,
)
require.NoError(t, err)
}
Expand Down Expand Up @@ -152,6 +158,7 @@ func TestNewRegistrantExistingDatabase(t *testing.T) {
deps.db,
deps.registry,
deps.privKey1Str,
deps.version,
)
require.NoError(t, err)
}
Expand Down Expand Up @@ -179,6 +186,7 @@ func TestNewRegistrantMismatchingDatabaseNodeId(t *testing.T) {
deps.db,
deps.registry,
deps.privKey1Str,
deps.version,
)
require.ErrorContains(t, err, "does not match")
}
Expand Down Expand Up @@ -206,6 +214,7 @@ func TestNewRegistrantMismatchingDatabasePublicKey(t *testing.T) {
deps.db,
deps.registry,
deps.privKey1Str,
deps.version,
)
require.ErrorContains(t, err, "does not match")
}
Expand All @@ -224,6 +233,7 @@ func TestNewRegistrantPrivateKeyNo0x(t *testing.T) {
deps.db,
deps.registry,
utils.HexEncode(crypto.FromECDSA(deps.privKey1)),
deps.version,
)
require.NoError(t, err)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func NewReplicationServer(
writerDB *sql.DB,
blockchainPublisher blockchain.IBlockchainPublisher,
listenAddress string,
version string,
) (*ReplicationServer, error) {
var err error

Expand Down Expand Up @@ -89,6 +90,7 @@ func NewReplicationServer(
queries.New(writerDB),
nodeRegistry,
options.Signer.PrivateKey,
version,
)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func NewTestServer(
//Payer: config.PayerOptions{
// Enable: true,
//},
}, registry, db, messagePublisher, fmt.Sprintf("localhost:%d", port))
}, registry, db, messagePublisher, fmt.Sprintf("localhost:%d", port), "")
require.NoError(t, err)

return server
Expand Down
9 changes: 8 additions & 1 deletion pkg/testutils/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ func NewTestAPIServer(t *testing.T) (*api.ApiServer, *sql.DB, func()) {
mockRegistry.EXPECT().GetNodes().Return([]registry.Node{
{NodeID: 100, SigningKey: &privKey.PublicKey},
}, nil)
registrant, err := registrant.NewRegistrant(ctx, log, queries.New(db), mockRegistry, privKeyStr)
registrant, err := registrant.NewRegistrant(
ctx,
log,
queries.New(db),
mockRegistry,
privKeyStr,
"",
)
require.NoError(t, err)
mockMessagePublisher := blockchain.NewMockIBlockchainPublisher(t)

Expand Down

0 comments on commit af31869

Please sign in to comment.