Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topdown/buitins: io.jwt.encode_sign uses BuiltinContext random source #3738

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion bundle/sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package bundle

import (
"crypto/rand"
"encoding/json"
"fmt"

Expand Down Expand Up @@ -76,7 +77,11 @@ func (*DefaultSigner) GenerateSignedToken(files []FileInfo, sc *SigningConfig, k
return "", err
}

token, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(sc.Algorithm), privateKey, hdr)
token, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(sc.Algorithm),
privateKey,
hdr,
rand.Reader)
if err != nil {
return "", err
}
Expand Down
17 changes: 14 additions & 3 deletions internal/jwx/jws/jws.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ package jws

import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/json"
"io"
"strings"

"github.com/open-policy-agent/opa/internal/jwx/jwa"
Expand All @@ -37,7 +39,7 @@ import (
// it in compact serialization format. In this format you may NOT use
// multiple signers.
//
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte) ([]byte, error) {
func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hdrBuf []byte, rnd io.Reader) ([]byte, error) {
encodedHdr := base64.RawURLEncoding.EncodeToString(hdrBuf)
encodedPayload := base64.RawURLEncoding.EncodeToString(payload)
signingInput := strings.Join(
Expand All @@ -50,7 +52,14 @@ func SignLiteral(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, hd
if err != nil {
return nil, errors.Wrap(err, `failed to create signer`)
}
signature, err := signer.Sign([]byte(signingInput), key)

var signature []byte
switch s := signer.(type) {
case *sign.ECDSASigner:
signature, err = s.SignWithRand([]byte(signingInput), key, rnd)
default:
signature, err = signer.Sign([]byte(signingInput), key)
}
if err != nil {
return nil, errors.Wrap(err, `failed to sign Payload`)
}
Expand Down Expand Up @@ -81,7 +90,9 @@ func SignWithOption(payload []byte, alg jwa.SignatureAlgorithm, key interface{})
if err != nil {
return nil, errors.Wrap(err, `failed to marshal Headers`)
}
return SignLiteral(payload, alg, key, hdrBuf)
// NOTE(sr): we don't use SignWithOption -- if we did, this rand.Reader
// should come from the BuiltinContext's Seed, too.
return SignLiteral(payload, alg, key, hdrBuf, rand.Reader)
}

// Verify checks if the given JWS message is verifiable using `alg` and `key`.
Expand Down
6 changes: 3 additions & 3 deletions internal/jwx/jws/jws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ func TestEncode(t *testing.T) {
t.Fatal("Failed to parse key")
}
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes)
jwsCompact, err = jws.SignLiteral([]byte(examplePayload), alg, key, hdrBytes, rand.Reader)
if err != nil {
t.Fatal("Failed to sign message")
}
Expand Down Expand Up @@ -599,13 +599,13 @@ func TestSignErrors(t *testing.T) {
}
})
t.Run("Invalid signature algorithm", func(t *testing.T) {
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"))
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader)
if err == nil {
t.Fatal("JWS signing should have failed")
}
})
t.Run("Invalid signature algorithm", func(t *testing.T) {
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"))
_, err := jws.SignLiteral([]byte("payload"), jwa.SignatureAlgorithm("dummy"), nil, []byte("header"), rand.Reader)
if err == nil {
t.Fatal("JWS signing should have failed")
}
Expand Down
16 changes: 11 additions & 5 deletions internal/jwx/jws/sign/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rand"
"io"

"github.com/open-policy-agent/opa/internal/jwx/jwa"

Expand All @@ -25,7 +26,7 @@ func init() {
}

func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey) ([]byte, error) {
return ecdsaSignFunc(func(payload []byte, key *ecdsa.PrivateKey, rnd io.Reader) ([]byte, error) {
curveBits := key.Curve.Params().BitSize
keyBytes := curveBits / 8
// Curve bits do not need to be a multiple of 8.
Expand All @@ -34,7 +35,7 @@ func makeECDSASignFunc(hash crypto.Hash) ecdsaSignFunc {
}
h := hash.New()
h.Write(payload)
r, s, err := ecdsa.Sign(rand.Reader, key, h.Sum(nil))
r, s, err := ecdsa.Sign(rnd, key, h.Sum(nil))
if err != nil {
return nil, errors.Wrap(err, "failed to sign payload using ecdsa")
}
Expand Down Expand Up @@ -69,8 +70,9 @@ func (s ECDSASigner) Algorithm() jwa.SignatureAlgorithm {
return s.alg
}

// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
// SignWithRand signs payload with a ECDSA private key and a provided randomness
// source (such as `rand.Reader`).
func (s ECDSASigner) SignWithRand(payload []byte, key interface{}, r io.Reader) ([]byte, error) {
if key == nil {
return nil, errors.New(`missing private key while signing payload`)
}
Expand All @@ -79,6 +81,10 @@ func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
if !ok {
return nil, errors.Errorf(`invalid key type %T. *ecdsa.PrivateKey is required`, key)
}
return s.sign(payload, privateKey, r)
}

return s.sign(payload, privateKey)
// Sign signs payload with a ECDSA private key
func (s ECDSASigner) Sign(payload []byte, key interface{}) ([]byte, error) {
return s.SignWithRand(payload, key, rand.Reader)
}
3 changes: 2 additions & 1 deletion internal/jwx/jws/sign/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sign
import (
"crypto/ecdsa"
"crypto/rsa"
"io"

"github.com/open-policy-agent/opa/internal/jwx/jwa"
)
Expand All @@ -28,7 +29,7 @@ type RSASigner struct {
sign rsaSignFunc
}

type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey) ([]byte, error)
type ecdsaSignFunc func([]byte, *ecdsa.PrivateKey, io.Reader) ([]byte, error)

// ECDSASigner uses crypto/ecdsa to sign the payloads.
type ECDSASigner struct {
Expand Down
6 changes: 5 additions & 1 deletion plugins/rest/rest_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ func (ap *oauth2ClientCredentialsAuthPlugin) createAuthJWT(claims map[string]int
jwsHeaders = []byte(fmt.Sprintf(`{"typ":"JWT","alg":"%s"}`, ap.signingKey.Algorithm))
}

jwsCompact, err := jws.SignLiteral(payload, jwa.SignatureAlgorithm(ap.signingKey.Algorithm), signingKey, jwsHeaders)
jwsCompact, err := jws.SignLiteral(payload,
jwa.SignatureAlgorithm(ap.signingKey.Algorithm),
signingKey,
jwsHeaders,
rand.Reader)
if err != nil {
return nil, err
}
Expand Down
51 changes: 24 additions & 27 deletions topdown/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,69 +823,66 @@ func (header *tokenHeader) valid() bool {
return true
}

func commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc string) (ast.Value, error) {
func commonBuiltinJWTEncodeSign(bctx BuiltinContext, inputHeaders, jwsPayload, jwkSrc string, iter func(*ast.Term) error) error {

keys, err := jwk.ParseString(jwkSrc)
if err != nil {
return nil, err
return err
}
key, err := keys.Keys[0].Materialize()
if err != nil {
return nil, err
return err
}
if jwk.GetKeyTypeFromKey(key) != keys.Keys[0].GetKeyType() {
return nil, fmt.Errorf("JWK derived key type and keyType parameter do not match")
return fmt.Errorf("JWK derived key type and keyType parameter do not match")
}

standardHeaders := &jws.StandardHeaders{}
jwsHeaders := []byte(inputHeaders)
err = json.Unmarshal(jwsHeaders, standardHeaders)
if err != nil {
return nil, err
return err
}
alg := standardHeaders.GetAlgorithm()

if (standardHeaders.Type == "" || standardHeaders.Type == headerJwt) && !json.Valid([]byte(jwsPayload)) {
return nil, fmt.Errorf("type is JWT but payload is not JSON")
return fmt.Errorf("type is JWT but payload is not JSON")
}

// process payload and sign
var jwsCompact []byte
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders)
jwsCompact, err = jws.SignLiteral([]byte(jwsPayload), alg, key, jwsHeaders, bctx.Seed)
if err != nil {
return nil, err
return err
}
return ast.String(jwsCompact), nil
return iter(ast.StringTerm(string(jwsCompact)))

}

func builtinJWTEncodeSign(a ast.Value, b ast.Value, c ast.Value) (ast.Value, error) {

jwkSrc := c.String()
func builtinJWTEncodeSign(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {

inputHeaders := a.String()

jwsPayload := b.String()

return commonBuiltinJWTEncodeSign(inputHeaders, jwsPayload, jwkSrc)
inputHeaders := args[0].String()
jwsPayload := args[1].String()
jwkSrc := args[2].String()
return commonBuiltinJWTEncodeSign(bctx, inputHeaders, jwsPayload, jwkSrc, iter)

}

func builtinJWTEncodeSignRaw(a ast.Value, b ast.Value, c ast.Value) (ast.Value, error) {
func builtinJWTEncodeSignRaw(bctx BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {

jwkSrc, err := builtins.StringOperand(c, 1)
jwkSrc, err := builtins.StringOperand(args[2].Value, 3)
if err != nil {
return nil, err
return err
}
inputHeaders, err := builtins.StringOperand(a, 1)
inputHeaders, err := builtins.StringOperand(args[0].Value, 1)
if err != nil {
return nil, err
return err
}
jwsPayload, err := builtins.StringOperand(b, 1)
jwsPayload, err := builtins.StringOperand(args[1].Value, 2)
if err != nil {
return nil, err
return err
}
return commonBuiltinJWTEncodeSign(string(inputHeaders), string(jwsPayload), string(jwkSrc))
return commonBuiltinJWTEncodeSign(bctx, string(inputHeaders), string(jwsPayload), string(jwkSrc), iter)
}

// Implements full JWT decoding, validation and verification.
Expand Down Expand Up @@ -1111,6 +1108,6 @@ func init() {
RegisterBuiltinFunc(ast.JWTVerifyHS384.Name, builtinJWTVerifyHS384)
RegisterBuiltinFunc(ast.JWTVerifyHS512.Name, builtinJWTVerifyHS512)
RegisterBuiltinFunc(ast.JWTDecodeVerify.Name, builtinJWTDecodeVerify)
RegisterFunctionalBuiltin3(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw)
RegisterFunctionalBuiltin3(ast.JWTEncodeSign.Name, builtinJWTEncodeSign)
RegisterBuiltinFunc(ast.JWTEncodeSignRaw.Name, builtinJWTEncodeSignRaw)
RegisterBuiltinFunc(ast.JWTEncodeSign.Name, builtinJWTEncodeSign)
}
47 changes: 47 additions & 0 deletions topdown/tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,50 @@ func TestTopDownJWTEncodeSignES512(t *testing.T) {
t.Fatal("Failed to verify message")
}
}

// NOTE(sr): The stdlib ecdsa package will randomly read 1 byte from the source
// and discard it: so passing a fixed-seed `rand.New(rand.Source(seed))` via
// `rego.WithSeed` will not do the trick, the output would still randomly be
// one of two possible signatures. To fix that for testing, we're reaching
// deeper here, and use a "constant number generator". It doesn't matter if the
// first byte is discarded, the second one looks just the same.
type cng struct{}

func (*cng) Read(p []byte) (int, error) {
for i := range p {
p[i] = 4
}
return len(p), nil
}

func TestTopdownJWTEncodeSignECWithSeedReturnsSameSignature(t *testing.T) {
query := `io.jwt.encode_sign({"alg": "ES256"},{"pay": "load"},
{"kty":"EC",
"crv":"P-256",
"x":"f83OJ3D2xF1Bg8vub9tLe1gHMzV76e8Tus9uPHvRVEU",
"y":"x_FEzRu9m36HLN_tue659LNpXW6pCyStikYjKIWI5a0",
"d":"jpsQnnGQmL-YBIffH1136cspYG6-0iY7X1fCE9-E9LI"
}, x)`
encodedSigned := "eyJhbGciOiAiRVMyNTYifQ.eyJwYXkiOiAibG9hZCJ9.-LoHxtbT8t_TnqlLyONI4BtjvfkySO8TcoCFENqTTH2AKxvn29nAjxOdlbY-0EKVM2nJ4ukCx4IGtZtuwXr0VQ"

for i := 0; i < 10; i++ {
q := NewQuery(ast.MustParseBody(query)).
WithSeed(&cng{}).
WithStrictBuiltinErrors(true).
WithCompiler(ast.NewCompiler())

qrs, err := q.Run(context.Background())
if err != nil {
t.Fatal(err)
} else if len(qrs) != 1 {
t.Fatal("expected exactly one result but got:", qrs)
}

if exp, act := 1, len(qrs); exp != act {
t.Fatalf("expected %d results, got %d", exp, act)
}
if exp, act := ast.String(encodedSigned), qrs[0][ast.Var("x")].Value; !exp.Equal(act) {
t.Fatalf("unexpected result: want %v, got %v", exp, act)
}
}
}