diff --git a/internal/encoding/asn1/asn1.go b/internal/encoding/asn1/asn1.go new file mode 100644 index 00000000..73e435ab --- /dev/null +++ b/internal/encoding/asn1/asn1.go @@ -0,0 +1,139 @@ +// Package asn1 decodes BER-encoded ASN.1 data structures and encodes in DER. +// Note: DER is a subset of BER. +// Reference: http://luca.ntop.org/Teaching/Appunti/asn1.html +package asn1 + +import ( + "bytes" + "encoding/asn1" + "io" +) + +// Common errors +var ( + ErrEarlyEOF = asn1.SyntaxError{Msg: "early EOF"} + ErrExpectConstructed = asn1.SyntaxError{Msg: "constructed value expected"} + ErrExpectPrimitive = asn1.SyntaxError{Msg: "primitive value expected"} + ErrUnsupportedLength = asn1.StructuralError{Msg: "length method not supported"} + ErrInvalidSlice = asn1.StructuralError{Msg: "invalid slice"} + ErrInvalidOffset = asn1.StructuralError{Msg: "invalid offset"} +) + +// Value represents an ASN.1 value. +type Value interface { + // Encode encodes the value to the value writer in DER. + Encode(ValueWriter) error + + // EncodedLen returns the length in bytes of the encoded data. + EncodedLen() int +} + +// ConvertToDER converts BER-encoded ASN.1 data structures to DER-encoded. +func ConvertToDER(ber []byte) ([]byte, error) { + v, err := decode(newReadOnlySlice(ber)) + if err != nil { + return nil, err + } + buf := bytes.NewBuffer(make([]byte, 0, v.EncodedLen())) + if err = v.Encode(buf); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// decode decodes BER-encoded ASN.1 data structures. +func decode(r ReadOnlySlice) (Value, error) { + identifier, isPrimitiveValue, err := identifierValue(r) + if err != nil { + return nil, err + } + expectedLength, err := decodeLength(r) + if err != nil { + return nil, err + } + content, err := r.Slice(r.Offset(), r.Offset()+expectedLength) + if err != nil { + return nil, err + } + if err = r.Seek(r.Offset() + expectedLength); err != nil { + return nil, err + } + + if isPrimitiveValue { + return newPrimitiveValue(identifier, content) + } + return newConstructedValue(identifier, expectedLength, content) +} + +// identifierValue decodes identifierValue octets. +func identifierValue(r ReadOnlySlice) (ReadOnlySlice, bool, error) { + b, err := r.ReadByte() + if err != nil { + return nil, false, err + } + isPrimitiveValue := isPrimitive(b) + + tagBytesCount := 1 + // high-tag-number form + if b&0x1f == 0x1f { + for { + b, err = r.ReadByte() + if err != nil { + return nil, false, err + } + tagBytesCount++ + if b&0x80 != 0 { + break + } + } + } + + identifier, err := r.Slice(r.Offset()-tagBytesCount, r.Offset()) + if err != nil { + return nil, false, err + } + return identifier, isPrimitiveValue, nil +} + +// isPrimitive checks the primitive flag in the identifier. +// Returns true if the value is primitive. +func isPrimitive(identifier byte) bool { + return identifier&0x20 == 0 +} + +// decodeLength decodes length octets. +// Indefinite length is not supported +func decodeLength(r io.ByteReader) (int, error) { + b, err := r.ReadByte() + if err != nil { + return 0, err + } + switch { + case b < 0x80: + // short form + return int(b), nil + case b == 0x80: + // Indefinite-length method is not supported. + return 0, ErrUnsupportedLength + } + + // long form + n := int(b & 0x7f) + if n > 4 { + // length must fit the memory space of the int type. + return 0, ErrUnsupportedLength + } + var length int + for i := 0; i < n; i++ { + b, err = r.ReadByte() + if err != nil { + return 0, err + } + length = (length << 8) | int(b) + } + if length < 0 { + // double check in case that length is over 31 bits. + return 0, ErrUnsupportedLength + } + return length, nil +} diff --git a/internal/encoding/asn1/asn1_test.go b/internal/encoding/asn1/asn1_test.go new file mode 100644 index 00000000..c2fa90d1 --- /dev/null +++ b/internal/encoding/asn1/asn1_test.go @@ -0,0 +1,68 @@ +package asn1 + +import ( + "encoding/asn1" + "reflect" + "testing" +) + +func TestConvertToDER(t *testing.T) { + type data struct { + Type asn1.ObjectIdentifier + Value []byte + } + + want := data{ + Type: asn1.ObjectIdentifier{2, 16, 840, 1, 101, 3, 4, 2, 1}, + Value: []byte{ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + }, + } + + ber := []byte{ + // Constructed value + 0x30, + // Constructed value length + 0x2e, + + // Type identifier + 0x06, + // Type length + 0x09, + // Type content + 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, + + // Value identifier + 0x04, + // Value length in BER + 0x81, 0x20, + // Value content + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, + 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, + 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + } + + der, err := ConvertToDER(ber) + if err != nil { + t.Errorf("ConvertToDER() error = %v", err) + return + } + + var got data + rest, err := asn1.Unmarshal(der, &got) + if err != nil { + t.Errorf("Failed to decode converted data: %v", err) + return + } + if len(rest) > 0 { + t.Errorf("Unexpected rest data: %v", rest) + return + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got = %v, want %v", got, want) + } +} diff --git a/internal/encoding/asn1/common.go b/internal/encoding/asn1/common.go new file mode 100644 index 00000000..dbdc9912 --- /dev/null +++ b/internal/encoding/asn1/common.go @@ -0,0 +1,37 @@ +package asn1 + +import "io" + +// encodeLength encodes length octets in DER. +func encodeLength(w io.ByteWriter, length int) error { + // DER restriction: short form must be used for length less than 128 + if length < 0x80 { + return w.WriteByte(byte(length)) + } + + // DER restriction: long form must be encoded in the minimum number of octets + lengthSize := encodedLengthSize(length) + err := w.WriteByte(0x80 | byte(lengthSize-1)) + if err != nil { + return err + } + for i := lengthSize - 1; i > 0; i-- { + if err = w.WriteByte(byte(length >> (8 * (i - 1)))); err != nil { + return err + } + } + return nil +} + +// encodedLengthSize gives the number of octets used for encoding the length. +func encodedLengthSize(length int) int { + if length < 0x80 { + return 1 + } + + lengthSize := 1 + for ; length > 0; lengthSize++ { + length >>= 8 + } + return lengthSize +} diff --git a/internal/encoding/asn1/constructed.go b/internal/encoding/asn1/constructed.go new file mode 100644 index 00000000..cb559e18 --- /dev/null +++ b/internal/encoding/asn1/constructed.go @@ -0,0 +1,50 @@ +package asn1 + +// ConstructedValue represents a value in constructed encoding. +type ConstructedValue struct { + identifier ReadOnlySlice + length int + members []Value +} + +// newConstructedValue builds the constructed value. +func newConstructedValue(identifier ReadOnlySlice, expectedLength int, content ReadOnlySlice) (Value, error) { + var members []Value + encodedLength := 0 + for content.Offset() < content.Length() { + value, err := decode(content) + if err != nil { + return nil, err + } + members = append(members, value) + encodedLength += value.EncodedLen() + } + + return ConstructedValue{ + identifier: identifier, + length: encodedLength, + members: members, + }, nil +} + +// Encode encodes the constructed value to the value writer in DER. +func (v ConstructedValue) Encode(w ValueWriter) error { + _, err := w.ReadFrom(v.identifier) + if err != nil { + return err + } + if err = encodeLength(w, v.length); err != nil { + return err + } + for _, value := range v.members { + if err = value.Encode(w); err != nil { + return err + } + } + return nil +} + +// EncodedLen returns the length in bytes of the encoded data. +func (v ConstructedValue) EncodedLen() int { + return v.identifier.Length() + encodedLengthSize(v.length) + v.length +} diff --git a/internal/encoding/asn1/io.go b/internal/encoding/asn1/io.go new file mode 100644 index 00000000..737235fa --- /dev/null +++ b/internal/encoding/asn1/io.go @@ -0,0 +1,9 @@ +package asn1 + +import "io" + +// ValueWriter is the interface for writing a value. +type ValueWriter interface { + io.ReaderFrom + io.ByteWriter +} diff --git a/internal/encoding/asn1/primitive.go b/internal/encoding/asn1/primitive.go new file mode 100644 index 00000000..b19cb9fc --- /dev/null +++ b/internal/encoding/asn1/primitive.go @@ -0,0 +1,33 @@ +package asn1 + +// PrimitiveValue represents a value in primitive encoding. +type PrimitiveValue struct { + identifier ReadOnlySlice + content ReadOnlySlice +} + +// newPrimitiveValue builds the primitive value. +func newPrimitiveValue(identifier ReadOnlySlice, content ReadOnlySlice) (Value, error) { + return PrimitiveValue{ + identifier: identifier, + content: content, + }, nil +} + +// Encode encodes the primitive value to the value writer in DER. +func (v PrimitiveValue) Encode(w ValueWriter) error { + _, err := w.ReadFrom(v.identifier) + if err != nil { + return err + } + if err = encodeLength(w, v.content.Length()); err != nil { + return err + } + _, err = w.ReadFrom(v.content) + return err +} + +// EncodedLen returns the length in bytes of the encoded data. +func (v PrimitiveValue) EncodedLen() int { + return v.identifier.Length() + encodedLengthSize(v.content.Length()) + v.content.Length() +} diff --git a/internal/encoding/asn1/slice.go b/internal/encoding/asn1/slice.go new file mode 100644 index 00000000..af124559 --- /dev/null +++ b/internal/encoding/asn1/slice.go @@ -0,0 +1,67 @@ +package asn1 + +import "io" + +type ReadOnlySlice interface { + io.ByteReader + io.Reader + Length() int + Offset() int + Seek(offset int) error + Slice(begin int, end int) (ReadOnlySlice, error) +} + +type readOnlySlice struct { + data []byte + offset int +} + +func newReadOnlySlice(data []byte) ReadOnlySlice { + return &readOnlySlice{ + data: data, + offset: 0, + } +} + +func (r *readOnlySlice) ReadByte() (byte, error) { + if r.offset >= len(r.data) { + return 0, ErrEarlyEOF + } + defer func() { r.offset++ }() + return r.data[r.offset], nil +} + +func (r *readOnlySlice) Read(p []byte) (int, error) { + if r.offset >= len(r.data) { + return 0, io.EOF + } + n := copy(p, r.data[r.offset:]) + r.offset += n + return n, nil +} + +func (r *readOnlySlice) Length() int { + return len(r.data) +} + +func (r *readOnlySlice) Offset() int { + return r.offset +} + +func (r *readOnlySlice) Seek(offset int) error { + if offset < 0 || offset > len(r.data) { + return ErrInvalidOffset + } + r.offset = offset + return nil +} + +func (r *readOnlySlice) Slice(begin int, end int) (ReadOnlySlice, error) { + if begin < 0 || end < 0 || begin > end || end > len(r.data) { + return nil, ErrInvalidSlice + } + return &readOnlySlice{ + data: r.data[begin:end], + offset: 0, + }, nil +}