Skip to content

Commit

Permalink
topdown/buitins: io.jwt.encode_sign uses BuiltinContext random source (
Browse files Browse the repository at this point in the history
…#3738)

It didn't before, so we had not much control over the entropy that is getting into the signature for ecdsa.

This is useful if you want reproducible outcomes over multiple policy evaluations, such as in testing.

Signed-off-by: Stephan Renatus <[email protected]>
  • Loading branch information
srenatus authored Aug 19, 2021
1 parent f2db31d commit e732b0b
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 41 deletions.
7 changes: 6 additions & 1 deletion bundle/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package bundle

import (
"crypto/rand"
"encoding/json"
"fmt"

Expand Down Expand Up @@ -76,7 +77,11 @@ func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, k
return "", err
}

token, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(sc.Algorithm), privateKey, hdr)
token, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(sc.Algorithm),
privateKey,
hdr,
rand.Reader)
if err != nil {
return "", err
}
Expand Down
17 changes: 14 additions & 3 deletions internal/jwx/jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ package jws

import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/json"
"io"
"strings"

"github.com/open-policy-agent/opa/internal/jwx/jwa"
Expand All @@ -37,7 +39,7 @@ import (
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte, rnd io.Reader) ([]byte, error) {
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
signingInput := strings.Join(
Expand All @@ -50,7 +52,14 @@ func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hd
if err != nil {
return nil, errors.Wrap(err, `failed to create signer`)
}
signature, err := signer.Sign([]byte(signingInput), key)

var signature []byte
switch s := signer.(type) {
case *sign.ECDSASigner:
signature, err = s.SignWithRand([]byte(signingInput), key, rnd)
default:
signature, err = signer.Sign([]byte(signingInput), key)
}
if err != nil {
return nil, errors.Wrap(err, `failed to sign Payload`)
}
Expand Down Expand Up @@ -81,7 +90,9 @@ func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{})
if err != nil {
return nil, errors.Wrap(err, `failed to marshal Headers`)
}
return SignLiteral(payload, alg, key, hdrBuf)
// NOTE(sr): we don't use SignWithOption -- if we did, this rand.Reader
// should come from the BuiltinContext's Seed, too.
return SignLiteral(payload, alg, key, hdrBuf, rand.Reader)
}

// Verify checks if the given JWS message is verifiable using `alg` and `key`.
Expand Down
6 changes: 3 additions & 3 deletions internal/jwx/jws/jws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestEncode(t *testing.T) {
t.Fatal("Failed to parse key")
}
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes)
jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes, rand.Reader)
if err != nil {
t.Fatal("Failed to sign message")
}
Expand Down Expand Up @@ -599,13 +599,13 @@ func TestSignErrors(t *testing.T) {
}
})
t.Run("Invalid signature algorithm", func(t *testing.T) {
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"))
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader)
if err == nil {
t.Fatal("JWS signing should have failed")
}
})
t.Run("Invalid signature algorithm", func(t *testing.T) {
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"))
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader)
if err == nil {
t.Fatal("JWS signing should have failed")
}
Expand Down
16 changes: 11 additions & 5 deletions internal/jwx/jws/sign/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"io"

"github.com/open-policy-agent/opa/internal/jwx/jwa"

Expand All @@ -25,7 +26,7 @@ func init() {
}

func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey, rnd io.Reader) ([]byte, error) {
curveBits := key.Curve.Params().BitSize
keyBytes := curveBits / 8
// Curve bits do not need to be a multiple of 8.
Expand All @@ -34,7 +35,7 @@ func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
}
h := hash.New()
h.Write(payload)
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
r, s, err := ecdsa.Sign(rnd, key, h.Sum(nil))
if err != nil {
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
}
Expand Down Expand Up @@ -69,8 +70,9 @@ func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}

// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
// SignWithRand signs payload with a ECDSA private key and a provided randomness
// source (such as `rand.Reader`).
func (s ECDSASigner) SignWithRand(payload []byte, key interface{}, r io.Reader) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
Expand All @@ -79,6 +81,10 @@ func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if !ok {
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
}
return s.sign(payload, privateKey, r)
}

return s.sign(payload, privateKey)
// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
return s.SignWithRand(payload, key, rand.Reader)
}
3 changes: 2 additions & 1 deletion internal/jwx/jws/sign/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sign
import (
"crypto/ecdsa"
"crypto/rsa"
"io"

"github.com/open-policy-agent/opa/internal/jwx/jwa"
)
Expand All @@ -28,7 +29,7 @@ type RSASigner struct {
sign rsaSignFunc
}

type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey, io.Reader) ([]byte, error)

// ECDSASigner uses crypto/ecdsa to sign the payloads.
type ECDSASigner struct {
Expand Down
6 changes: 5 additions & 1 deletion plugins/rest/rest_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ func (ap *oauth2ClientCredentialsAuthPlugin) createAuthJWT(claims map[string]int
jwsHeaders = []byte(fmt.Sprintf(`{"typ":"JWT","alg":"%s"}`, ap.signingKey.Algorithm))
}

jwsCompact, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(ap.signingKey.Algorithm), signingKey, jwsHeaders)
jwsCompact, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(ap.signingKey.Algorithm),
signingKey,
jwsHeaders,
rand.Reader)
if err != nil {
return nil, err
}
Expand Down
51 changes: 24 additions & 27 deletions topdown/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,69 +823,66 @@ func (header *tokenHeader) valid() bool {
return true
}

func commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc string) (ast.Value, error) {
func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc string, iter func(*ast.Term) error) error {

keys, err := jwk.ParseString(jwkSrc)
if err != nil {
return nil, err
return err
}
key, err := keys.Keys[0].Materialize()
if err != nil {
return nil, err
return err
}
if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() {
return nil, fmt.Errorf("JWK derived key type and keyType parameter do not match")
return fmt.Errorf("JWK derived key type and keyType parameter do not match")
}

standardHeaders := &jws.StandardHeaders{}
jwsHeaders := []byte(inputHeaders)
err = json.Unmarshal(jwsHeaders, standardHeaders)
if err != nil {
return nil, err
return err
}
alg := standardHeaders.GetAlgorithm()

if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
return nil, fmt.Errorf("type is JWT but payload is not JSON")
return fmt.Errorf("type is JWT but payload is not JSON")
}

// process payload and sign
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders)
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders, bctx.Seed)
if err != nil {
return nil, err
return err
}
return ast.String(jwsCompact), nil
return iter(ast.StringTerm(string(jwsCompact)))

}

func builtinJWTEncodeSign(a ast.Value, b ast.Value, c ast.Value) (ast.Value, error) {

jwkSrc := c.String()
func builtinJWTEncodeSign(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {

inputHeaders := a.String()

jwsPayload := b.String()

return commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc)
inputHeaders := args[0].String()
jwsPayload := args[1].String()
jwkSrc := args[2].String()
return commonBuiltinJWTEncodeSign(bctx, inputHeaders, jwsPayload, jwkSrc, iter)

}

func builtinJWTEncodeSignRaw(a ast.Value, b ast.Value, c ast.Value) (ast.Value, error) {
func builtinJWTEncodeSignRaw(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {

jwkSrc, err := builtins.StringOperand(c, 1)
jwkSrc, err := builtins.StringOperand(args[2].Value, 3)
if err != nil {
return nil, err
return err
}
inputHeaders, err := builtins.StringOperand(a, 1)
inputHeaders, err := builtins.StringOperand(args[0].Value, 1)
if err != nil {
return nil, err
return err
}
jwsPayload, err := builtins.StringOperand(b, 1)
jwsPayload, err := builtins.StringOperand(args[1].Value, 2)
if err != nil {
return nil, err
return err
}
return commonBuiltinJWTEncodeSign(string(inputHeaders), string(jwsPayload), string(jwkSrc))
return commonBuiltinJWTEncodeSign(bctx, string(inputHeaders), string(jwsPayload), string(jwkSrc), iter)
}

// Implements full JWT decoding, validation and verification.
Expand Down Expand Up @@ -1111,6 +1108,6 @@ func init() {
RegisterBuiltinFunc(ast.JWTVerifyHS384.Name, builtinJWTVerifyHS384)
RegisterBuiltinFunc(ast.JWTVerifyHS512.Name, builtinJWTVerifyHS512)
RegisterBuiltinFunc(ast.JWTDecodeVerify.Name, builtinJWTDecodeVerify)
RegisterFunctionalBuiltin3(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw)
RegisterFunctionalBuiltin3(ast.JWTEncodeSign.Name, builtinJWTEncodeSign)
RegisterBuiltinFunc(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw)
RegisterBuiltinFunc(ast.JWTEncodeSign.Name, builtinJWTEncodeSign)
}
47 changes: 47 additions & 0 deletions topdown/tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,50 @@ func TestTopDownJWTEncodeSignES512(t *testing.T) {
t.Fatal("Failed to verify message")
}
}

// NOTE(sr): The stdlib ecdsa package will randomly read 1 byte from the source
// and discard it: so passing a fixed-seed `rand.New(rand.Source(seed))` via
// `rego.WithSeed` will not do the trick, the output would still randomly be
// one of two possible signatures. To fix that for testing, we're reaching
// deeper here, and use a "constant number generator". It doesn't matter if the
// first byte is discarded, the second one looks just the same.
type cng struct{}

func (*cng) Read(p []byte) (int, error) {
for i := range p {
p[i] = 4
}
return len(p), nil
}

func TestTopdownJWTEncodeSignECWithSeedReturnsSameSignature(t *testing.T) {
query := `io.jwt.encode_sign({"alg": "ES256"},{"pay": "load"},
{"kty":"EC",
"crv":"P-256",
"x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU",
"y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0",
"d":"jpsQnnGQmL-YBIffH1136cspYG6-0iY7X1fCE9-E9LI"
}, x)`
encodedSigned := "eyJhbGciOiAiRVMyNTYifQ.eyJwYXkiOiAibG9hZCJ9.-LoHxtbT8t_TnqlLyONI4BtjvfkySO8TcoCFENqTTH2AKxvn29nAjxOdlbY-0EKVM2nJ4ukCx4IGtZtuwXr0VQ"

for i := 0; i < 10; i++ {
q := NewQuery(ast.MustParseBody(query)).
WithSeed(&cng{}).
WithStrictBuiltinErrors(true).
WithCompiler(ast.NewCompiler())

qrs, err := q.Run(context.Background())
if err != nil {
t.Fatal(err)
} else if len(qrs) != 1 {
t.Fatal("expected exactly one result but got:", qrs)
}

if exp, act := 1, len(qrs); exp != act {
t.Fatalf("expected %d results, got %d", exp, act)
}
if exp, act := ast.String(encodedSigned), qrs[0][ast.Var("x")].Value; !exp.Equal(act) {
t.Fatalf("unexpected result: want %v, got %v", exp, act)
}
}
}

0 comments on commit e732b0b

Please sign in to comment.