From bc18cae59daa3700b49503c46bd6555634c599e5 Mon Sep 17 00:00:00 2001 From: Junjie Gao Date: Thu, 3 Aug 2023 16:14:57 +0800 Subject: [PATCH] fix: resolve comments Signed-off-by: Junjie Gao --- internal/encoding/asn1/asn1.go | 77 +++++++++++++++------------ internal/encoding/asn1/common.go | 12 ++--- internal/encoding/asn1/constructed.go | 12 ++--- internal/encoding/asn1/primitive.go | 4 +- 4 files changed, 58 insertions(+), 47 deletions(-) diff --git a/internal/encoding/asn1/asn1.go b/internal/encoding/asn1/asn1.go index 3fac5e3b..1d27cc5f 100644 --- a/internal/encoding/asn1/asn1.go +++ b/internal/encoding/asn1/asn1.go @@ -19,17 +19,16 @@ package asn1 import ( "bytes" "encoding/asn1" + "math" ) // Common errors var ( - ErrBytesAtTheEnd = asn1.StructuralError{Msg: "invalid bytes at the end of the BER data"} - ErrEarlyEOF = asn1.SyntaxError{Msg: "early EOF"} - ErrInvalidBerData = asn1.StructuralError{Msg: "invalid BER data"} - ErrInvalidOffset = asn1.StructuralError{Msg: "invalid offset"} - ErrInvalidSlice = asn1.StructuralError{Msg: "invalid slice"} - ErrUnsupportedLen = asn1.StructuralError{Msg: "length method not supported"} - ErrUnsupportedIndefinedLen = asn1.StructuralError{Msg: "indefinite length not supported"} + ErrBytesAtTheEnd = asn1.StructuralError{Msg: "invalid bytes at the end of the BER data"} + ErrEarlyEOF = asn1.SyntaxError{Msg: "early EOF"} + ErrInvalidBERData = asn1.StructuralError{Msg: "invalid BER data"} + ErrUnsupportedLength = asn1.StructuralError{Msg: "length method not supported"} + ErrUnsupportedIndefinedLength = asn1.StructuralError{Msg: "indefinite length not supported"} ) // value represents an ASN.1 value. @@ -60,17 +59,16 @@ func decode(r []byte) (value, error) { identifier []byte contentLen int berValueLen int - isPrimitive bool err error ) // prepare the first value - identifier, contentLen, _, isPrimitive, r, err = decodeMetadata(r) + identifier, contentLen, _, r, err = decodeMetadata(r) if err != nil { return nil, err } // primitive value - if isPrimitive { + if isPrimitive(identifier[0]) { if contentLen != len(r) { return nil, ErrBytesAtTheEnd } @@ -93,7 +91,7 @@ func decode(r []byte) (value, error) { v := valueStack[stackLen-1] if v.expectedLen < 0 { - return nil, ErrInvalidBerData + return nil, ErrInvalidBERData } if v.expectedLen == 0 { @@ -108,11 +106,11 @@ func decode(r []byte) (value, error) { } for v.expectedLen > 0 { - identifier, contentLen, berValueLen, isPrimitive, r, err = decodeMetadata(r) + identifier, contentLen, berValueLen, r, err = decodeMetadata(r) if err != nil { return nil, err } - if isPrimitive { + if isPrimitive(identifier[0]) { // primitive value pv := primitiveValue{ identifier: identifier, @@ -142,32 +140,35 @@ func decode(r []byte) (value, error) { return rootConstructed, nil } -func decodeMetadata(r []byte) ([]byte, int, int, bool, []byte, error) { +func decodeMetadata(r []byte) ([]byte, int, int, []byte, error) { length := len(r) identifier, r, err := decodeIdentifier(r) if err != nil { - return nil, 0, 0, false, nil, err + return nil, 0, 0, nil, err } - contentLen, r, err := decodeLen(r) + contentLen, r, err := decodeLength(r) if err != nil { - return nil, 0, 0, false, nil, err + return nil, 0, 0, nil, err } if contentLen > len(r) { - return nil, 0, 0, false, nil, ErrEarlyEOF + return nil, 0, 0, nil, ErrEarlyEOF } - isPrimitive := identifier[0]&0x20 == 0 metadataLen := length - len(r) berValueLen := metadataLen + contentLen - return identifier, contentLen, berValueLen, isPrimitive, r, nil + return identifier, contentLen, berValueLen, r, nil } // decodeIdentifier decodes decodeIdentifier octets. +// +// r is the input byte slice. +// The first return value is the identifier octets. +// The second return value is the subsequent value after the identifiers octets. func decodeIdentifier(r []byte) ([]byte, []byte, error) { - offset := 0 if len(r) < 1 { return nil, nil, ErrEarlyEOF } + offset := 0 b := r[offset] offset++ @@ -180,41 +181,51 @@ func decodeIdentifier(r []byte) ([]byte, []byte, error) { return r[:offset], r[offset:], nil } -// decodeLen decodes length octets. +// decodeLength decodes length octets. // Indefinite length is not supported -func decodeLen(r []byte) (int, []byte, error) { +// +// r is the input byte slice. +// The first return value is the length. +// The second return value is the subsequent value after the length octets. +func decodeLength(r []byte) (int, []byte, error) { offset := 0 if len(r) < 1 { return 0, nil, ErrEarlyEOF } b := r[offset] offset++ - switch { - case b < 0x80: + + if b < 0x80 { // short form return int(b), r[offset:], nil - case b == 0x80: + } else if b == 0x80 { // Indefinite-length method is not supported. - return 0, nil, ErrUnsupportedIndefinedLen + return 0, nil, ErrUnsupportedIndefinedLength } // long form n := int(b & 0x7f) if n > 4 { // length must fit the memory space of the int type. - return 0, nil, ErrUnsupportedLen + return 0, nil, ErrUnsupportedLength } - var length int + var length uint64 for i := 0; i < n; i++ { if offset >= len(r) { return 0, nil, ErrEarlyEOF } - length = (length << 8) | int(r[offset]) + length = (length << 8) | uint64(r[offset]) offset++ } - if length < 0 { + if length > uint64(math.MaxInt64) { // double check in case that length is over 31 bits. - return 0, nil, ErrUnsupportedLen + return 0, nil, ErrUnsupportedLength } - return length, r[offset:], nil + return int(length), r[offset:], nil +} + +// isPrimitive returns true if the first identifier octet is marked +// as primitive. +func isPrimitive(b byte) bool { + return b&0x20 == 0 } diff --git a/internal/encoding/asn1/common.go b/internal/encoding/asn1/common.go index b84f24dc..eb93ba78 100644 --- a/internal/encoding/asn1/common.go +++ b/internal/encoding/asn1/common.go @@ -14,18 +14,18 @@ package asn1 import ( - "bytes" + "io" ) -// encodeLen encodes length octets in DER. -func encodeLen(w *bytes.Buffer, length int) error { +// encodeLength encodes length octets in DER. +func encodeLength(w io.ByteWriter, length int) error { // DER restriction: short form must be used for length less than 128 if length < 0x80 { return w.WriteByte(byte(length)) } // DER restriction: long form must be encoded in the minimum number of octets - lengthSize := encodedLenSize(length) + lengthSize := encodedLengthSize(length) err := w.WriteByte(0x80 | byte(lengthSize-1)) if err != nil { return err @@ -38,8 +38,8 @@ func encodeLen(w *bytes.Buffer, length int) error { return nil } -// encodedLenSize gives the number of octets used for encoding the length. -func encodedLenSize(length int) int { +// encodedLengthSize gives the number of octets used for encoding the length. +func encodedLengthSize(length int) int { if length < 0x80 { return 1 } diff --git a/internal/encoding/asn1/constructed.go b/internal/encoding/asn1/constructed.go index 440e2692..eeac2b93 100644 --- a/internal/encoding/asn1/constructed.go +++ b/internal/encoding/asn1/constructed.go @@ -1,7 +1,3 @@ -package asn1 - -import "bytes" - // Copyright The Notary Project Authors. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,6 +11,10 @@ import "bytes" // See the License for the specific language governing permissions and // limitations under the License. +package asn1 + +import "bytes" + // constructedValue represents a value in constructed encoding. type constructedValue struct { identifier []byte @@ -29,7 +29,7 @@ func (v constructedValue) Encode(w *bytes.Buffer) error { if err != nil { return err } - if err = encodeLen(w, v.length); err != nil { + if err = encodeLength(w, v.length); err != nil { return err } for _, value := range v.members { @@ -42,5 +42,5 @@ func (v constructedValue) Encode(w *bytes.Buffer) error { // EncodedLen returns the length in bytes of the encoded data. func (v constructedValue) EncodedLen() int { - return len(v.identifier) + encodedLenSize(v.length) + v.length + return len(v.identifier) + encodedLengthSize(v.length) + v.length } diff --git a/internal/encoding/asn1/primitive.go b/internal/encoding/asn1/primitive.go index c7390815..a2e9115b 100644 --- a/internal/encoding/asn1/primitive.go +++ b/internal/encoding/asn1/primitive.go @@ -27,7 +27,7 @@ func (v primitiveValue) Encode(w *bytes.Buffer) error { if err != nil { return err } - if err = encodeLen(w, len(v.content)); err != nil { + if err = encodeLength(w, len(v.content)); err != nil { return err } _, err = w.Write(v.content) @@ -36,5 +36,5 @@ func (v primitiveValue) Encode(w *bytes.Buffer) error { // EncodedLen returns the length in bytes of the encoded data. func (v primitiveValue) EncodedLen() int { - return len(v.identifier) + encodedLenSize(len(v.content)) + len(v.content) + return len(v.identifier) + encodedLengthSize(len(v.content)) + len(v.content) }