diff --git a/bundle/sign.go b/bundle/sign.go index 13d6667f58..cf9a3e183a 100644 --- a/bundle/sign.go +++ b/bundle/sign.go @@ -6,6 +6,7 @@ package bundle import ( + "crypto/rand" "encoding/json" "fmt" @@ -76,7 +77,11 @@ func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, k return "", err } - token, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(sc.Algorithm), privateKey, hdr) + token, err := jws.SignLiteral(payload, + jwa.SignatureAlgorithm(sc.Algorithm), + privateKey, + hdr, + rand.Reader) if err != nil { return "", err } diff --git a/internal/jwx/jws/jws.go b/internal/jwx/jws/jws.go index 6fca28d23c..bfa498bb0f 100644 --- a/internal/jwx/jws/jws.go +++ b/internal/jwx/jws/jws.go @@ -21,8 +21,10 @@ package jws import ( "bytes" + "crypto/rand" "encoding/base64" "encoding/json" + "io" "strings" "github.com/open-policy-agent/opa/internal/jwx/jwa" @@ -37,7 +39,7 @@ import ( // it in compact serialization format. In this format you may NOT use // multiple signers. // -func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) { +func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte, rnd io.Reader) ([]byte, error) { encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf) encodedPayload := base64.RawURLEncoding.EncodeToString(payload) signingInput := strings.Join( @@ -50,7 +52,14 @@ func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hd if err != nil { return nil, errors.Wrap(err, `failed to create signer`) } - signature, err := signer.Sign([]byte(signingInput), key) + + var signature []byte + switch s := signer.(type) { + case *sign.ECDSASigner: + signature, err = s.SignWithRand([]byte(signingInput), key, rnd) + default: + signature, err = signer.Sign([]byte(signingInput), key) + } if err != nil { return nil, errors.Wrap(err, `failed to sign Payload`) } @@ -81,7 +90,9 @@ func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{}) if err != nil { return nil, errors.Wrap(err, `failed to marshal Headers`) } - return SignLiteral(payload, alg, key, hdrBuf) + // NOTE(sr): we don't use SignWithOption -- if we did, this rand.Reader + // should come from the BuiltinContext's Seed, too. + return SignLiteral(payload, alg, key, hdrBuf, rand.Reader) } // Verify checks if the given JWS message is verifiable using `alg` and `key`. diff --git a/internal/jwx/jws/jws_test.go b/internal/jwx/jws/jws_test.go index b5df98b31e..815210db86 100644 --- a/internal/jwx/jws/jws_test.go +++ b/internal/jwx/jws/jws_test.go @@ -261,7 +261,7 @@ func TestEncode(t *testing.T) { t.Fatal("Failed to parse key") } var jwsCompact []byte - jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes) + jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes, rand.Reader) if err != nil { t.Fatal("Failed to sign message") } @@ -599,13 +599,13 @@ func TestSignErrors(t *testing.T) { } }) t.Run("Invalid signature algorithm", func(t *testing.T) { - _, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header")) + _, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader) if err == nil { t.Fatal("JWS signing should have failed") } }) t.Run("Invalid signature algorithm", func(t *testing.T) { - _, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header")) + _, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader) if err == nil { t.Fatal("JWS signing should have failed") } diff --git a/internal/jwx/jws/sign/ecdsa.go b/internal/jwx/jws/sign/ecdsa.go index 7023906806..62af72b6c1 100644 --- a/internal/jwx/jws/sign/ecdsa.go +++ b/internal/jwx/jws/sign/ecdsa.go @@ -4,6 +4,7 @@ import ( "crypto" "crypto/ecdsa" "crypto/rand" + "io" "github.com/open-policy-agent/opa/internal/jwx/jwa" @@ -25,7 +26,7 @@ func init() { } func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc { - return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) { + return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey, rnd io.Reader) ([]byte, error) { curveBits := key.Curve.Params().BitSize keyBytes := curveBits / 8 // Curve bits do not need to be a multiple of 8. @@ -34,7 +35,7 @@ func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc { } h := hash.New() h.Write(payload) - r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil)) + r, s, err := ecdsa.Sign(rnd, key, h.Sum(nil)) if err != nil { return nil, errors.Wrap(err, "failed to sign payload using ecdsa") } @@ -69,8 +70,9 @@ func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm { return s.alg } -// Sign signs payload with a ECDSA private key -func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) { +// SignWithRand signs payload with a ECDSA private key and a provided randomness +// source (such as `rand.Reader`). +func (s ECDSASigner) SignWithRand(payload []byte, key interface{}, r io.Reader) ([]byte, error) { if key == nil { return nil, errors.New(`missing private key while signing payload`) } @@ -79,6 +81,10 @@ func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) { if !ok { return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key) } + return s.sign(payload, privateKey, r) +} - return s.sign(payload, privateKey) +// Sign signs payload with a ECDSA private key +func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) { + return s.SignWithRand(payload, key, rand.Reader) } diff --git a/internal/jwx/jws/sign/interface.go b/internal/jwx/jws/sign/interface.go index 42a10c42e4..2ef2bee486 100644 --- a/internal/jwx/jws/sign/interface.go +++ b/internal/jwx/jws/sign/interface.go @@ -3,6 +3,7 @@ package sign import ( "crypto/ecdsa" "crypto/rsa" + "io" "github.com/open-policy-agent/opa/internal/jwx/jwa" ) @@ -28,7 +29,7 @@ type RSASigner struct { sign rsaSignFunc } -type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error) +type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey, io.Reader) ([]byte, error) // ECDSASigner uses crypto/ecdsa to sign the payloads. type ECDSASigner struct { diff --git a/plugins/rest/rest_auth.go b/plugins/rest/rest_auth.go index 5151af439c..f6b5168371 100644 --- a/plugins/rest/rest_auth.go +++ b/plugins/rest/rest_auth.go @@ -219,7 +219,11 @@ func (ap *oauth2ClientCredentialsAuthPlugin) createAuthJWT(claims map[string]int jwsHeaders = []byte(fmt.Sprintf(`{"typ":"JWT","alg":"%s"}`, ap.signingKey.Algorithm)) } - jwsCompact, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(ap.signingKey.Algorithm), signingKey, jwsHeaders) + jwsCompact, err := jws.SignLiteral(payload, + jwa.SignatureAlgorithm(ap.signingKey.Algorithm), + signingKey, + jwsHeaders, + rand.Reader) if err != nil { return nil, err } diff --git a/topdown/tokens.go b/topdown/tokens.go index 49045e32df..64828831fd 100644 --- a/topdown/tokens.go +++ b/topdown/tokens.go @@ -823,69 +823,66 @@ func (header *tokenHeader) valid() bool { return true } -func commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc string) (ast.Value, error) { +func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc string, iter func(*ast.Term) error) error { keys, err := jwk.ParseString(jwkSrc) if err != nil { - return nil, err + return err } key, err := keys.Keys[0].Materialize() if err != nil { - return nil, err + return err } if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() { - return nil, fmt.Errorf("JWK derived key type and keyType parameter do not match") + return fmt.Errorf("JWK derived key type and keyType parameter do not match") } standardHeaders := &jws.StandardHeaders{} jwsHeaders := []byte(inputHeaders) err = json.Unmarshal(jwsHeaders, standardHeaders) if err != nil { - return nil, err + return err } alg := standardHeaders.GetAlgorithm() if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) { - return nil, fmt.Errorf("type is JWT but payload is not JSON") + return fmt.Errorf("type is JWT but payload is not JSON") } // process payload and sign var jwsCompact []byte - jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders) + jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders, bctx.Seed) if err != nil { - return nil, err + return err } - return ast.String(jwsCompact), nil + return iter(ast.StringTerm(string(jwsCompact))) } -func builtinJWTEncodeSign(a ast.Value, b ast.Value, c ast.Value) (ast.Value, error) { - - jwkSrc := c.String() +func builtinJWTEncodeSign(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - inputHeaders := a.String() - - jwsPayload := b.String() - - return commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc) + inputHeaders := args[0].String() + jwsPayload := args[1].String() + jwkSrc := args[2].String() + return commonBuiltinJWTEncodeSign(bctx, inputHeaders, jwsPayload, jwkSrc, iter) } -func builtinJWTEncodeSignRaw(a ast.Value, b ast.Value, c ast.Value) (ast.Value, error) { +func builtinJWTEncodeSignRaw(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { - jwkSrc, err := builtins.StringOperand(c, 1) + jwkSrc, err := builtins.StringOperand(args[2].Value, 3) if err != nil { - return nil, err + return err } - inputHeaders, err := builtins.StringOperand(a, 1) + inputHeaders, err := builtins.StringOperand(args[0].Value, 1) if err != nil { - return nil, err + return err } - jwsPayload, err := builtins.StringOperand(b, 1) + jwsPayload, err := builtins.StringOperand(args[1].Value, 2) if err != nil { - return nil, err + return err } - return commonBuiltinJWTEncodeSign(string(inputHeaders), string(jwsPayload), string(jwkSrc)) + return commonBuiltinJWTEncodeSign(bctx, string(inputHeaders), string(jwsPayload), string(jwkSrc), iter) } // Implements full JWT decoding, validation and verification. @@ -1111,6 +1108,6 @@ func init() { RegisterBuiltinFunc(ast.JWTVerifyHS384.Name, builtinJWTVerifyHS384) RegisterBuiltinFunc(ast.JWTVerifyHS512.Name, builtinJWTVerifyHS512) RegisterBuiltinFunc(ast.JWTDecodeVerify.Name, builtinJWTDecodeVerify) - RegisterFunctionalBuiltin3(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw) - RegisterFunctionalBuiltin3(ast.JWTEncodeSign.Name, builtinJWTEncodeSign) + RegisterBuiltinFunc(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw) + RegisterBuiltinFunc(ast.JWTEncodeSign.Name, builtinJWTEncodeSign) } diff --git a/topdown/tokens_test.go b/topdown/tokens_test.go index 305b48b8a0..48c2d08e8f 100644 --- a/topdown/tokens_test.go +++ b/topdown/tokens_test.go @@ -445,3 +445,50 @@ func TestTopDownJWTEncodeSignES512(t *testing.T) { t.Fatal("Failed to verify message") } } + +// NOTE(sr): The stdlib ecdsa package will randomly read 1 byte from the source +// and discard it: so passing a fixed-seed `rand.New(rand.Source(seed))` via +// `rego.WithSeed` will not do the trick, the output would still randomly be +// one of two possible signatures. To fix that for testing, we're reaching +// deeper here, and use a "constant number generator". It doesn't matter if the +// first byte is discarded, the second one looks just the same. +type cng struct{} + +func (*cng) Read(p []byte) (int, error) { + for i := range p { + p[i] = 4 + } + return len(p), nil +} + +func TestTopdownJWTEncodeSignECWithSeedReturnsSameSignature(t *testing.T) { + query := `io.jwt.encode_sign({"alg": "ES256"},{"pay": "load"}, + {"kty":"EC", + "crv":"P-256", + "x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU", + "y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0", + "d":"jpsQnnGQmL-YBIffH1136cspYG6-0iY7X1fCE9-E9LI" + }, x)` + encodedSigned := "eyJhbGciOiAiRVMyNTYifQ.eyJwYXkiOiAibG9hZCJ9.-LoHxtbT8t_TnqlLyONI4BtjvfkySO8TcoCFENqTTH2AKxvn29nAjxOdlbY-0EKVM2nJ4ukCx4IGtZtuwXr0VQ" + + for i := 0; i < 10; i++ { + q := NewQuery(ast.MustParseBody(query)). + WithSeed(&cng{}). + WithStrictBuiltinErrors(true). + WithCompiler(ast.NewCompiler()) + + qrs, err := q.Run(context.Background()) + if err != nil { + t.Fatal(err) + } else if len(qrs) != 1 { + t.Fatal("expected exactly one result but got:", qrs) + } + + if exp, act := 1, len(qrs); exp != act { + t.Fatalf("expected %d results, got %d", exp, act) + } + if exp, act := ast.String(encodedSigned), qrs[0][ast.Var("x")].Value; !exp.Equal(act) { + t.Fatalf("unexpected result: want %v, got %v", exp, act) + } + } +}