Skip to content

Commit

Permalink
Merge pull request #2903 from lucas-clemente/fix-packet-number-decoding
Browse files Browse the repository at this point in the history
fix decoding of packet numbers in different packet number spaces
  • Loading branch information
marten-seemann authored Dec 4, 2020
2 parents baa0e42 + 9533420 commit d35cce1
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 57 deletions.
10 changes: 9 additions & 1 deletion internal/handshake/aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qtls"
"github.com/lucas-clemente/quic-go/internal/utils"
)

func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) cipher.AEAD {
Expand Down Expand Up @@ -50,6 +51,7 @@ func (s *longHeaderSealer) Overhead() int {
type longHeaderOpener struct {
aead cipher.AEAD
headerProtector headerProtector
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)

// use a single slice to avoid allocations
nonceBuf []byte
Expand All @@ -65,12 +67,18 @@ func newLongHeaderOpener(aead cipher.AEAD, headerProtector headerProtector) Long
}
}

func (o *longHeaderOpener) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, o.highestRcvdPN, wirePN)
}

func (o *longHeaderOpener) Open(dst, src []byte, pn protocol.PacketNumber, ad []byte) ([]byte, error) {
binary.BigEndian.PutUint64(o.nonceBuf[len(o.nonceBuf)-8:], uint64(pn))
// The AEAD we're using here will be the qtls.aeadAESGCM13.
// It uses the nonce provided here and XOR it with the IV.
dec, err := o.aead.Open(dst, o.nonceBuf, src, ad)
if err != nil {
if err == nil {
o.highestRcvdPN = utils.MaxPacketNumber(o.highestRcvdPN, pn)
} else {
err = ErrDecryptionFailed
}
return dec, err
Expand Down
16 changes: 16 additions & 0 deletions internal/handshake/aead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ var _ = Describe("Long Header AEAD", func() {
_, err := opener.Open(nil, encrypted, 0x42, ad)
Expect(err).To(MatchError(ErrDecryptionFailed))
})

It("decodes the packet number", func() {
sealer, opener := getSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted, 0x1337, ad)
Expect(err).ToNot(HaveOccurred())
Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
})

It("ignores packets it can't decrypt for packet number derivation", func() {
sealer, opener := getSealerAndOpener()
encrypted := sealer.Seal(nil, msg, 0x1337, ad)
_, err := opener.Open(nil, encrypted[:len(encrypted)-1], 0x1337, ad)
Expect(err).To(HaveOccurred())
Expect(opener.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
})
})

Context("header encryption", func() {
Expand Down
2 changes: 2 additions & 0 deletions internal/handshake/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ type headerDecryptor interface {
// LongHeaderOpener opens a long header packet
type LongHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) ([]byte, error)
}

// ShortHeaderOpener opens a short header packet
type ShortHeaderOpener interface {
headerDecryptor
DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber
Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, associatedData []byte) ([]byte, error)
}

Expand Down
8 changes: 8 additions & 0 deletions internal/handshake/updatable_aead.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type updatableAEAD struct {

firstRcvdWithCurrentKey protocol.PacketNumber
firstSentWithCurrentKey protocol.PacketNumber
highestRcvdPN protocol.PacketNumber // highest packet number received (which could be successfully unprotected)
numRcvdWithCurrentKey uint64
numSentWithCurrentKey uint64
rcvAEAD cipher.AEAD
Expand Down Expand Up @@ -153,6 +154,10 @@ func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSu
}
}

func (a *updatableAEAD) DecodePacketNumber(wirePN protocol.PacketNumber, wirePNLen protocol.PacketNumberLen) protocol.PacketNumber {
return protocol.DecodePacketNumber(wirePNLen, a.highestRcvdPN, wirePN)
}

func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.PacketNumber, kp protocol.KeyPhaseBit, ad []byte) ([]byte, error) {
dec, err := a.open(dst, src, rcvTime, pn, kp, ad)
if err == ErrDecryptionFailed {
Expand All @@ -161,6 +166,9 @@ func (a *updatableAEAD) Open(dst, src []byte, rcvTime time.Time, pn protocol.Pac
return nil, qerr.AEADLimitReached
}
}
if err == nil {
a.highestRcvdPN = utils.MaxPacketNumber(a.highestRcvdPN, pn)
}
return dec, err
}

Expand Down
14 changes: 14 additions & 0 deletions internal/handshake/updatable_aead_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ var _ = Describe("Updatable AEAD", func() {
Expect(err).To(MatchError(ErrDecryptionFailed))
})

It("decodes the packet number", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted, time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).ToNot(HaveOccurred())
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x1338))
})

It("ignores packets it can't decrypt for packet number derivation", func() {
encrypted := server.Seal(nil, msg, 0x1337, ad)
_, err := client.Open(nil, encrypted[:len(encrypted)-1], time.Now(), 0x1337, protocol.KeyPhaseZero, ad)
Expect(err).To(HaveOccurred())
Expect(client.DecodePacketNumber(0x38, protocol.PacketNumberLen1)).To(BeEquivalentTo(0x38))
})

It("returns an AEAD_LIMIT_REACHED error when reaching the AEAD limit", func() {
client.invalidPacketLimit = 10
for i := 0; i < 9; i++ {
Expand Down
14 changes: 14 additions & 0 deletions internal/mocks/long_header_opener.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions internal/mocks/short_header_opener.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 2 additions & 11 deletions packet_unpacker.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/internal/wire"
)

Expand Down Expand Up @@ -43,8 +42,6 @@ type unpackedPacket struct {
type packetUnpacker struct {
cs handshake.CryptoSetup

largestRcvdPacketNumber protocol.PacketNumber

version protocol.VersionNumber
}

Expand Down Expand Up @@ -111,9 +108,6 @@ func (u *packetUnpacker) Unpack(hdr *wire.Header, rcvTime time.Time, data []byte
}
}

// Only do this after decrypting, so we are sure the packet is not attacker-controlled
u.largestRcvdPacketNumber = utils.MaxPacketNumber(u.largestRcvdPacketNumber, extHdr.PacketNumber)

return &unpackedPacket{
hdr: extHdr,
packetNumber: extHdr.PacketNumber,
Expand All @@ -131,6 +125,7 @@ func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpene
return nil, nil, parseErr
}
extHdrLen := extHdr.ParsedLen()
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
Expand All @@ -154,6 +149,7 @@ func (u *packetUnpacker) unpackShortHeaderPacket(
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr
}
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
extHdrLen := extHdr.ParsedLen()
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], rcvTime, extHdr.PacketNumber, extHdr.KeyPhase, data[:extHdrLen])
if err != nil {
Expand All @@ -171,11 +167,6 @@ func (u *packetUnpacker) unpackHeader(hd headerDecryptor, hdr *wire.Header, data
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
extHdr.PacketNumber = protocol.DecodePacketNumber(
extHdr.PacketNumberLen,
u.largestRcvdPacketNumber,
extHdr.PacketNumber,
)
return extHdr, err
}

Expand Down
91 changes: 46 additions & 45 deletions packet_unpacker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"errors"
"time"

"github.com/lucas-clemente/quic-go/internal/qerr"

"github.com/golang/mock/gomock"
"github.com/lucas-clemente/quic-go/internal/handshake"
"github.com/lucas-clemente/quic-go/internal/mocks"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/qerr"
"github.com/lucas-clemente/quic-go/internal/wire"

"github.com/golang/mock/gomock"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)
Expand Down Expand Up @@ -97,9 +97,12 @@ var _ = Describe("Packet Unpacker", func() {
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetInitialOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
gomock.InOrder(
cs.EXPECT().GetInitialOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(2), protocol.PacketNumberLen3).Return(protocol.PacketNumber(1234)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(1234), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.EncryptionInitial))
Expand All @@ -115,20 +118,45 @@ var _ = Describe("Packet Unpacker", func() {
DestConnectionID: connID,
Version: version,
},
PacketNumber: 2,
PacketNumberLen: 3,
PacketNumber: 20,
PacketNumberLen: 2,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().Get0RTTOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), payload, extHdr.PacketNumber, hdrRaw).Return([]byte("decrypted"), nil)
gomock.InOrder(
cs.EXPECT().Get0RTTOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(20), protocol.PacketNumberLen2).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, protocol.PacketNumber(321), hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption0RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
})

It("opens short header packets", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
KeyPhase: protocol.KeyPhaseOne,
PacketNumber: 99,
PacketNumberLen: protocol.PacketNumberLen4,
}
hdr, hdrRaw := getHeader(extHdr)
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
now := time.Now()
gomock.InOrder(
cs.EXPECT().Get1RTTOpener().Return(opener, nil),
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any()),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(99), protocol.PacketNumberLen4).Return(protocol.PacketNumber(321)),
opener.EXPECT().Open(gomock.Any(), payload, now, protocol.PacketNumber(321), protocol.KeyPhaseOne, hdrRaw).Return([]byte("decrypted"), nil),
)
packet, err := unpacker.Unpack(hdr, now, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.encryptionLevel).To(Equal(protocol.Encryption1RTT))
Expect(packet.data).To(Equal([]byte("decrypted")))
})

It("returns the error when getting the sealer fails", func() {
extHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
Expand Down Expand Up @@ -157,6 +185,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, qerr.CryptoBufferExceeded)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(qerr.CryptoBufferExceeded))
Expand All @@ -178,6 +207,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockLongHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().GetHandshakeOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
Expand All @@ -194,6 +224,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return([]byte("payload"), nil)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(wire.ErrInvalidReservedBits))
Expand All @@ -210,6 +241,7 @@ var _ = Describe("Packet Unpacker", func() {
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
cs.EXPECT().Get1RTTOpener().Return(opener, nil)
opener.EXPECT().DecodePacketNumber(gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, handshake.ErrDecryptionFailed)
_, err := unpacker.Unpack(hdr, time.Now(), append(hdrRaw, payload...))
Expect(err).To(MatchError(handshake.ErrDecryptionFailed))
Expand Down Expand Up @@ -247,46 +279,15 @@ var _ = Describe("Packet Unpacker", func() {
pnBytes[i] ^= 0xff // invert the packet number bytes
}
}),
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), origHdrRaw).Return([]byte{0}, nil),
opener.EXPECT().DecodePacketNumber(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2).Return(protocol.PacketNumber(0x7331)),
opener.EXPECT().Open(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x7331), origHdrRaw).Return([]byte{0}, nil),
)
data := hdrRaw
for i := 1; i <= 100; i++ {
data = append(data, uint8(i))
}
packet, err := unpacker.Unpack(hdr, time.Now(), data)
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
})

It("decodes the packet number", func() {
rcvTime := time.Now().Add(-time.Hour)
firstHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x1337,
PacketNumberLen: 2,
KeyPhase: protocol.KeyPhaseOne,
}
opener := mocks.NewMockShortHeaderOpener(mockCtrl)
cs.EXPECT().Get1RTTOpener().Return(opener, nil).Times(2)
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, firstHdr.PacketNumber, protocol.KeyPhaseOne, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw := getHeader(firstHdr)
packet, err := unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1337)))
// the real packet number is 0x1338, but only the last byte is sent
secondHdr := &wire.ExtendedHeader{
Header: wire.Header{DestConnectionID: connID},
PacketNumber: 0x38,
PacketNumberLen: 1,
KeyPhase: protocol.KeyPhaseZero,
}
// expect the call with the decoded packet number
opener.EXPECT().DecryptHeader(gomock.Any(), gomock.Any(), gomock.Any())
opener.EXPECT().Open(gomock.Any(), gomock.Any(), rcvTime, protocol.PacketNumber(0x1338), protocol.KeyPhaseZero, gomock.Any()).Return([]byte{0}, nil)
hdr, hdrRaw = getHeader(secondHdr)
packet, err = unpacker.Unpack(hdr, rcvTime, append(hdrRaw, payload...))
Expect(err).ToNot(HaveOccurred())
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x1338)))
Expect(packet.packetNumber).To(Equal(protocol.PacketNumber(0x7331)))
})
})

0 comments on commit d35cce1

Please sign in to comment.