Skip to content

Commit

Permalink
refactor: add common logic to base envelope
Browse files Browse the repository at this point in the history
Signed-off-by: Binbin Li <[email protected]>
  • Loading branch information
binbin-li committed Aug 12, 2022
1 parent ab66bb3 commit 44d50bd
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 23 deletions.
14 changes: 14 additions & 0 deletions signature/algorithm.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package signature

import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
Expand Down Expand Up @@ -35,6 +36,19 @@ type KeySpec struct {
Size int
}

// Hash returns the hash function of the algorithm
func (alg Algorithm) Hash() crypto.Hash {
switch alg {
case AlgorithmPS256, AlgorithmES256:
return crypto.SHA256
case AlgorithmPS384, AlgorithmES384:
return crypto.SHA384
case AlgorithmPS512, AlgorithmES512:
return crypto.SHA512
}
return 0
}

// ExtractKeySpec extracts keySpec from the signing certificate
func ExtractKeySpec(signingCert *x509.Certificate) (KeySpec, error) {
switch key := signingCert.PublicKey.(type) {
Expand Down
17 changes: 14 additions & 3 deletions signature/envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type Envelope interface {
// NewEnvelopeFunc defines a function to create a new Envelope
type NewEnvelopeFunc func() Envelope

// ParseEnvelopeFunc defines a function to create a new Envelope with given
// ParseEnvelopeFunc defines a function to create a new Envelope with given
// envelope bytes
type ParseEnvelopeFunc func([]byte) (Envelope, error)

Expand All @@ -30,14 +30,25 @@ func RegisterEnvelopeType(mediaType string, newFunc NewEnvelopeFunc, parseFunc P
if newFunc == nil || parseFunc == nil {
return fmt.Errorf("required functions not provided")
}

envelopeFuncs[mediaType] = envelopeFunc{
newFunc: newFunc,
parseFunc: parseFunc,
}
return nil
}

// RegisteredEnvelopeTypes lists registered envelope media types.
func RegisteredEnvelopeTypes() []string {
types := []string{}

for envelopeType := range envelopeFuncs {
types = append(types, envelopeType)
}

return types
}

// NewEnvelope returns an envelope of given media type
func NewEnvelope(mediaType string) (Envelope, error) {
envelopeFunc, ok := envelopeFuncs[mediaType]
Expand All @@ -54,4 +65,4 @@ func ParseEnvelope(mediaType string, envelopeBytes []byte) (Envelope, error) {
return nil, fmt.Errorf("envelope is not set for type: %s", mediaType)
}
return envelopeFunc.parseFunc(envelopeBytes)
}
}
42 changes: 36 additions & 6 deletions signature/internal/base/envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,45 @@ func (e *Envelope) Verify() (*signature.Payload, *signature.SignerInfo, error) {
return nil, nil, &signature.MalformedSignatureError{}
}

return e.Envelope.Verify()
if err := e.validatePayload(); err != nil {
return nil, nil, err
}

payload, signerInfo, err := e.Envelope.Verify()
if err != nil {
return nil, nil, err
}

signingTime := signerInfo.SignedAttributes.SigningTime

if err = validateCertificateChain(
signerInfo.CertificateChain,
signingTime,
signerInfo.SignatureAlgorithm,
); err != nil {
return nil, nil, err
}

if err = validateSigningTime(signingTime, signerInfo.SignedAttributes.Expiry); err != nil {
return nil, nil, err
}
return payload, signerInfo, nil
}

// Payload returns the payload to be signed.
func (e *Envelope) Payload() (*signature.Payload, error) {
if len(e.Raw) == 0 {
return nil, &signature.MalformedSignatureError{Msg: "raw signature is empty"}
}
return e.Envelope.Payload()
payload, err := e.Envelope.Payload()
if err != nil {
return nil, err
}

if err = validatePayload(payload); err != nil {
return nil, err
}
return payload, nil
}

// SignerInfo returns information about the Signature envelope.
Expand Down Expand Up @@ -140,16 +170,16 @@ func validateSigningTime(signingTime, expireTime time.Time) error {

// validatePayload performs validation of the payload.
func validatePayload(payload *signature.Payload) error {
if len(payload.Content) == 0 {
return &signature.MalformedSignatureError{Msg: "content not present"}
}

if payload.ContentType != signature.MediaTypePayloadV1 {
return &signature.MalformedSignatureError{
Msg: fmt.Sprintf("payload content type: {%s} not supported", payload.ContentType),
}
}

if len(payload.Content) == 0 {
return &signature.MalformedSignatureError{Msg: "content not present"}
}

return nil
}

Expand Down
2 changes: 1 addition & 1 deletion signature/jws/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func (s *JwsSigner) Sign(digest []byte) ([]byte, error) {
if err != nil {
return nil, err
}
hasher := hash(keySpec.SignatureAlgorithm())
hasher := keySpec.SignatureAlgorithm().Hash()
h := hasher.New()
h.Write(digest)
hash := h.Sum(nil)
Expand Down
13 changes: 0 additions & 13 deletions signature/jws/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,6 @@ func mergeMaps(maps ...map[string]interface{}) map[string]interface{} {
return result
}

func hash(algorithm signature.Algorithm) crypto.Hash {
var hash crypto.Hash
switch algorithm {
case signature.AlgorithmPS256, signature.AlgorithmES256:
hash = crypto.SHA256
case signature.AlgorithmPS384, signature.AlgorithmES384:
hash = crypto.SHA384
case signature.AlgorithmPS512, signature.AlgorithmES512:
hash = crypto.SHA512
}
return hash
}

// getSigningMethod picks up a recommended algorithm for given public keys.
func getSigningMethod(key crypto.PublicKey) (jwt.SigningMethod, error) {
switch key := key.(type) {
Expand Down

0 comments on commit 44d50bd

Please sign in to comment.