Skip to content

Commit

Permalink
fix: resolve comments
Browse files Browse the repository at this point in the history
Signed-off-by: Junjie Gao <[email protected]>
  • Loading branch information
JeyJeyGao committed Aug 3, 2023
1 parent 643f388 commit bc18cae
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 47 deletions.
77 changes: 44 additions & 33 deletions internal/encoding/asn1/asn1.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,16 @@ package asn1
import (
"bytes"
"encoding/asn1"
"math"
)

// Common errors
var (
ErrBytesAtTheEnd = asn1.StructuralError{Msg: "invalid bytes at the end of the BER data"}
ErrEarlyEOF = asn1.SyntaxError{Msg: "early EOF"}
ErrInvalidBerData = asn1.StructuralError{Msg: "invalid BER data"}
ErrInvalidOffset = asn1.StructuralError{Msg: "invalid offset"}
ErrInvalidSlice = asn1.StructuralError{Msg: "invalid slice"}
ErrUnsupportedLen = asn1.StructuralError{Msg: "length method not supported"}
ErrUnsupportedIndefinedLen = asn1.StructuralError{Msg: "indefinite length not supported"}
ErrBytesAtTheEnd = asn1.StructuralError{Msg: "invalid bytes at the end of the BER data"}
ErrEarlyEOF = asn1.SyntaxError{Msg: "early EOF"}
ErrInvalidBERData = asn1.StructuralError{Msg: "invalid BER data"}
ErrUnsupportedLength = asn1.StructuralError{Msg: "length method not supported"}
ErrUnsupportedIndefinedLength = asn1.StructuralError{Msg: "indefinite length not supported"}
)

// value represents an ASN.1 value.
Expand Down Expand Up @@ -60,17 +59,16 @@ func decode(r []byte) (value, error) {
identifier []byte
contentLen int
berValueLen int
isPrimitive bool
err error
)
// prepare the first value
identifier, contentLen, _, isPrimitive, r, err = decodeMetadata(r)
identifier, contentLen, _, r, err = decodeMetadata(r)
if err != nil {
return nil, err
}

// primitive value
if isPrimitive {
if isPrimitive(identifier[0]) {
if contentLen != len(r) {
return nil, ErrBytesAtTheEnd
}
Expand All @@ -93,7 +91,7 @@ func decode(r []byte) (value, error) {
v := valueStack[stackLen-1]

if v.expectedLen < 0 {
return nil, ErrInvalidBerData
return nil, ErrInvalidBERData
}

if v.expectedLen == 0 {
Expand All @@ -108,11 +106,11 @@ func decode(r []byte) (value, error) {
}

for v.expectedLen > 0 {
identifier, contentLen, berValueLen, isPrimitive, r, err = decodeMetadata(r)
identifier, contentLen, berValueLen, r, err = decodeMetadata(r)
if err != nil {
return nil, err
}
if isPrimitive {
if isPrimitive(identifier[0]) {
// primitive value
pv := primitiveValue{
identifier: identifier,
Expand Down Expand Up @@ -142,32 +140,35 @@ func decode(r []byte) (value, error) {
return rootConstructed, nil
}

func decodeMetadata(r []byte) ([]byte, int, int, bool, []byte, error) {
func decodeMetadata(r []byte) ([]byte, int, int, []byte, error) {
length := len(r)
identifier, r, err := decodeIdentifier(r)
if err != nil {
return nil, 0, 0, false, nil, err
return nil, 0, 0, nil, err
}
contentLen, r, err := decodeLen(r)
contentLen, r, err := decodeLength(r)
if err != nil {
return nil, 0, 0, false, nil, err
return nil, 0, 0, nil, err
}

if contentLen > len(r) {
return nil, 0, 0, false, nil, ErrEarlyEOF
return nil, 0, 0, nil, ErrEarlyEOF
}
isPrimitive := identifier[0]&0x20 == 0
metadataLen := length - len(r)
berValueLen := metadataLen + contentLen
return identifier, contentLen, berValueLen, isPrimitive, r, nil
return identifier, contentLen, berValueLen, r, nil
}

// decodeIdentifier decodes decodeIdentifier octets.
//
// r is the input byte slice.
// The first return value is the identifier octets.
// The second return value is the subsequent value after the identifiers octets.
func decodeIdentifier(r []byte) ([]byte, []byte, error) {
offset := 0
if len(r) < 1 {
return nil, nil, ErrEarlyEOF
}
offset := 0
b := r[offset]
offset++

Expand All @@ -180,41 +181,51 @@ func decodeIdentifier(r []byte) ([]byte, []byte, error) {
return r[:offset], r[offset:], nil
}

// decodeLen decodes length octets.
// decodeLength decodes length octets.
// Indefinite length is not supported
func decodeLen(r []byte) (int, []byte, error) {
//
// r is the input byte slice.
// The first return value is the length.
// The second return value is the subsequent value after the length octets.
func decodeLength(r []byte) (int, []byte, error) {
offset := 0
if len(r) < 1 {
return 0, nil, ErrEarlyEOF
}
b := r[offset]
offset++
switch {
case b < 0x80:

if b < 0x80 {
// short form
return int(b), r[offset:], nil
case b == 0x80:
} else if b == 0x80 {
// Indefinite-length method is not supported.
return 0, nil, ErrUnsupportedIndefinedLen
return 0, nil, ErrUnsupportedIndefinedLength
}

// long form
n := int(b & 0x7f)
if n > 4 {
// length must fit the memory space of the int type.
return 0, nil, ErrUnsupportedLen
return 0, nil, ErrUnsupportedLength
}
var length int
var length uint64
for i := 0; i < n; i++ {
if offset >= len(r) {
return 0, nil, ErrEarlyEOF
}
length = (length << 8) | int(r[offset])
length = (length << 8) | uint64(r[offset])
offset++
}
if length < 0 {
if length > uint64(math.MaxInt64) {
// double check in case that length is over 31 bits.
return 0, nil, ErrUnsupportedLen
return 0, nil, ErrUnsupportedLength
}
return length, r[offset:], nil
return int(length), r[offset:], nil
}

// isPrimitive returns true if the first identifier octet is marked
// as primitive.
func isPrimitive(b byte) bool {
return b&0x20 == 0
}
12 changes: 6 additions & 6 deletions internal/encoding/asn1/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@
package asn1

import (
"bytes"
"io"
)

// encodeLen encodes length octets in DER.
func encodeLen(w *bytes.Buffer, length int) error {
// encodeLength encodes length octets in DER.
func encodeLength(w io.ByteWriter, length int) error {
// DER restriction: short form must be used for length less than 128
if length < 0x80 {
return w.WriteByte(byte(length))
}

// DER restriction: long form must be encoded in the minimum number of octets
lengthSize := encodedLenSize(length)
lengthSize := encodedLengthSize(length)
err := w.WriteByte(0x80 | byte(lengthSize-1))
if err != nil {
return err
Expand All @@ -38,8 +38,8 @@ func encodeLen(w *bytes.Buffer, length int) error {
return nil
}

// encodedLenSize gives the number of octets used for encoding the length.
func encodedLenSize(length int) int {
// encodedLengthSize gives the number of octets used for encoding the length.
func encodedLengthSize(length int) int {
if length < 0x80 {
return 1
}
Expand Down
12 changes: 6 additions & 6 deletions internal/encoding/asn1/constructed.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
package asn1

import "bytes"

// Copyright The Notary Project Authors.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,6 +11,10 @@ import "bytes"
// See the License for the specific language governing permissions and
// limitations under the License.

package asn1

import "bytes"

// constructedValue represents a value in constructed encoding.
type constructedValue struct {
identifier []byte
Expand All @@ -29,7 +29,7 @@ func (v constructedValue) Encode(w *bytes.Buffer) error {
if err != nil {
return err
}
if err = encodeLen(w, v.length); err != nil {
if err = encodeLength(w, v.length); err != nil {
return err
}
for _, value := range v.members {
Expand All @@ -42,5 +42,5 @@ func (v constructedValue) Encode(w *bytes.Buffer) error {

// EncodedLen returns the length in bytes of the encoded data.
func (v constructedValue) EncodedLen() int {
return len(v.identifier) + encodedLenSize(v.length) + v.length
return len(v.identifier) + encodedLengthSize(v.length) + v.length
}
4 changes: 2 additions & 2 deletions internal/encoding/asn1/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (v primitiveValue) Encode(w *bytes.Buffer) error {
if err != nil {
return err
}
if err = encodeLen(w, len(v.content)); err != nil {
if err = encodeLength(w, len(v.content)); err != nil {
return err
}
_, err = w.Write(v.content)
Expand All @@ -36,5 +36,5 @@ func (v primitiveValue) Encode(w *bytes.Buffer) error {

// EncodedLen returns the length in bytes of the encoded data.
func (v primitiveValue) EncodedLen() int {
return len(v.identifier) + encodedLenSize(len(v.content)) + len(v.content)
return len(v.identifier) + encodedLengthSize(len(v.content)) + len(v.content)
}

0 comments on commit bc18cae

Please sign in to comment.