Skip to content

Commit

Permalink
topdown/io.jwt.encode_sign: use randomness from bctx.Seed
Browse files Browse the repository at this point in the history
Signed-off-by: Stephan Renatus <[email protected]>
  • Loading branch information
srenatus committed Aug 16, 2021
1 parent ad89378 commit 033feae
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 15 deletions.
7 changes: 6 additions & 1 deletion bundle/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package bundle

import (
"crypto/rand"
"encoding/json"
"fmt"

Expand Down Expand Up @@ -76,7 +77,11 @@ func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, k
return "", err
}

token, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(sc.Algorithm), privateKey, hdr)
token, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(sc.Algorithm),
privateKey,
hdr,
rand.Reader)
if err != nil {
return "", err
}
Expand Down
17 changes: 14 additions & 3 deletions internal/jwx/jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ package jws

import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/json"
"io"
"strings"

"github.com/open-policy-agent/opa/internal/jwx/jwa"
Expand All @@ -37,7 +39,7 @@ import (
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte, rnd io.Reader) ([]byte, error) {
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
signingInput := strings.Join(
Expand All @@ -50,7 +52,14 @@ func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hd
if err != nil {
return nil, errors.Wrap(err, `failed to create signer`)
}
signature, err := signer.Sign([]byte(signingInput), key)

var signature []byte
switch s := signer.(type) {
case *sign.ECDSASigner:
signature, err = s.SignWithRand([]byte(signingInput), key, rnd)
default:
signature, err = signer.Sign([]byte(signingInput), key)
}
if err != nil {
return nil, errors.Wrap(err, `failed to sign Payload`)
}
Expand Down Expand Up @@ -81,7 +90,9 @@ func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{})
if err != nil {
return nil, errors.Wrap(err, `failed to marshal Headers`)
}
return SignLiteral(payload, alg, key, hdrBuf)
// NOTE(sr): we don't use SignWithOption -- if we did, this rand.Reader
// should come from the BuiltinContext's Seed, too.
return SignLiteral(payload, alg, key, hdrBuf, rand.Reader)
}

// Verify checks if the given JWS message is verifiable using `alg` and `key`.
Expand Down
6 changes: 3 additions & 3 deletions internal/jwx/jws/jws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestEncode(t *testing.T) {
t.Fatal("Failed to parse key")
}
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes)
jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes, rand.Reader)
if err != nil {
t.Fatal("Failed to sign message")
}
Expand Down Expand Up @@ -599,13 +599,13 @@ func TestSignErrors(t *testing.T) {
}
})
t.Run("Invalid signature algorithm", func(t *testing.T) {
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"))
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader)
if err == nil {
t.Fatal("JWS signing should have failed")
}
})
t.Run("Invalid signature algorithm", func(t *testing.T) {
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"))
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader)
if err == nil {
t.Fatal("JWS signing should have failed")
}
Expand Down
16 changes: 11 additions & 5 deletions internal/jwx/jws/sign/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"io"

"github.com/open-policy-agent/opa/internal/jwx/jwa"

Expand All @@ -25,7 +26,7 @@ func init() {
}

func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey, rnd io.Reader) ([]byte, error) {
curveBits := key.Curve.Params().BitSize
keyBytes := curveBits / 8
// Curve bits do not need to be a multiple of 8.
Expand All @@ -34,7 +35,7 @@ func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
}
h := hash.New()
h.Write(payload)
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
r, s, err := ecdsa.Sign(rnd, key, h.Sum(nil))
if err != nil {
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
}
Expand Down Expand Up @@ -69,8 +70,9 @@ func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}

// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
// SignWithRand signs payload with a ECDSA private key and a provided randomness
// source (such as `rand.Reader`).
func (s ECDSASigner) SignWithRand(payload []byte, key interface{}, r io.Reader) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
Expand All @@ -79,6 +81,10 @@ func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if !ok {
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
}
return s.sign(payload, privateKey, r)
}

return s.sign(payload, privateKey)
// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
return s.SignWithRand(payload, key, rand.Reader)
}
3 changes: 2 additions & 1 deletion internal/jwx/jws/sign/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sign
import (
"crypto/ecdsa"
"crypto/rsa"
"io"

"github.com/open-policy-agent/opa/internal/jwx/jwa"
)
Expand All @@ -28,7 +29,7 @@ type RSASigner struct {
sign rsaSignFunc
}

type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey, io.Reader) ([]byte, error)

// ECDSASigner uses crypto/ecdsa to sign the payloads.
type ECDSASigner struct {
Expand Down
6 changes: 5 additions & 1 deletion plugins/rest/rest_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ func (ap *oauth2ClientCredentialsAuthPlugin) createAuthJWT(claims map[string]int
jwsHeaders = []byte(fmt.Sprintf(`{"typ":"JWT","alg":"%s"}`, ap.signingKey.Algorithm))
}

jwsCompact, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(ap.signingKey.Algorithm), signingKey, jwsHeaders)
jwsCompact, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(ap.signingKey.Algorithm),
signingKey,
jwsHeaders,
rand.Reader)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion topdown/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,7 @@ func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, j

// process payload and sign
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders)
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders, bctx.Seed)
if err != nil {
return err
}
Expand Down
47 changes: 47 additions & 0 deletions topdown/tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,50 @@ func TestTopDownJWTEncodeSignES512(t *testing.T) {
t.Fatal("Failed to verify message")
}
}

// NOTE(sr): The stdlib ecdsa package will randomly read 1 byte from the source
// and discard it: so passing a fixed-seed `rand.New(rand.Source(seed))` via
// `rego.WithSeed` will not do the trick, the output would still randomly be
// one of two possible signatures. To fix that for testing, we're reaching
// deeper here, and use a "constant number generator". It doesn't matter if the
// first byte is discarded, the second one looks just the same.
type cng struct{}

func (*cng) Read(p []byte) (int, error) {
for i := range p {
p[i] = 4
}
return len(p), nil
}

func TestTopdownJWTEncodeSignECWithSeedReturnsSameSignature(t *testing.T) {
query := `io.jwt.encode_sign({"alg": "ES256"},{"pay": "load"},
{"kty":"EC",
"crv":"P-256",
"x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU",
"y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0",
"d":"jpsQnnGQmL-YBIffH1136cspYG6-0iY7X1fCE9-E9LI"
}, x)`
encodedSigned := "eyJhbGciOiAiRVMyNTYifQ.eyJwYXkiOiAibG9hZCJ9.-LoHxtbT8t_TnqlLyONI4BtjvfkySO8TcoCFENqTTH2AKxvn29nAjxOdlbY-0EKVM2nJ4ukCx4IGtZtuwXr0VQ"

for i := 0; i < 10; i++ {
q := NewQuery(ast.MustParseBody(query)).
WithSeed(&cng{}).
WithStrictBuiltinErrors(true).
WithCompiler(ast.NewCompiler())

qrs, err := q.Run(context.Background())
if err != nil {
t.Fatal(err)
} else if len(qrs) != 1 {
t.Fatal("expected exactly one result but got:", qrs)
}

if exp, act := 1, len(qrs); exp != act {
t.Fatalf("expected %d results, got %d", exp, act)
}
if exp, act := ast.String(encodedSigned), qrs[0][ast.Var("x")].Value; !exp.Equal(act) {
t.Fatalf("unexpected result: want %v, got %v", exp, act)
}
}
}

0 comments on commit 033feae

Please sign in to comment.