Skip to content

Commit

Permalink
Merge pull request #3138 from lucas-clemente/pmtud-after-handshake-co…
Browse files Browse the repository at this point in the history
…nfirmation

only start PMTUD after handshake confirmation
  • Loading branch information
marten-seemann authored Apr 2, 2021
2 parents 37a3938 + 91629cd commit 7f55fc7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 28 deletions.
61 changes: 33 additions & 28 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -747,10 +747,10 @@ func (s *session) maybeResetTimer() {
} else {
deadline = s.idleTimeoutStartTime().Add(s.idleTimeout)
}
if !s.config.DisablePathMTUDiscovery {
if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() {
deadline = utils.MinTime(deadline, probeTime)
}
}
if s.handshakeConfirmed && !s.config.DisablePathMTUDiscovery {
if probeTime := s.mtuDiscoverer.NextProbeTime(); !probeTime.IsZero() {
deadline = utils.MinTime(deadline, probeTime)
}
}

Expand Down Expand Up @@ -782,30 +782,13 @@ func (s *session) handleHandshakeComplete() {
s.connIDManager.SetHandshakeComplete()
s.connIDGenerator.SetHandshakeComplete()

if !s.config.DisablePathMTUDiscovery {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize)
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
getMaxPacketSize(s.conn.RemoteAddr()),
maxPacketSize,
func(size protocol.ByteCount) {
s.sentPacketHandler.SetMaxDatagramSize(size)
s.packer.SetMaxPacketSize(size)
},
)
}

if s.perspective == protocol.PerspectiveClient {
s.applyTransportParameters()
return
}

s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
s.handleHandshakeConfirmed()

ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil {
s.closeLocal(err)
Expand All @@ -821,10 +804,32 @@ func (s *session) handleHandshakeComplete() {
s.closeLocal(err)
}
s.queueControlFrame(&wire.NewTokenFrame{Token: token})
s.cryptoStreamHandler.SetHandshakeConfirmed()
s.queueControlFrame(&wire.HandshakeDoneFrame{})
}

func (s *session) handleHandshakeConfirmed() {
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed()

if !s.config.DisablePathMTUDiscovery {
maxPacketSize := s.peerParams.MaxUDPPayloadSize
if maxPacketSize == 0 {
maxPacketSize = protocol.MaxByteCount
}
maxPacketSize = utils.MinByteCount(maxPacketSize, protocol.MaxPacketBufferSize)
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
getMaxPacketSize(s.conn.RemoteAddr()),
maxPacketSize,
func(size protocol.ByteCount) {
s.sentPacketHandler.SetMaxDatagramSize(size)
s.packer.SetMaxPacketSize(size)
},
)
}
}

func (s *session) handlePacketImpl(rp *receivedPacket) bool {
if wire.IsVersionNegotiationPacket(rp.data) {
s.handleVersionNegotiationPacket(rp)
Expand Down Expand Up @@ -1349,9 +1354,9 @@ func (s *session) handleHandshakeDoneFrame() error {
if s.perspective == protocol.PerspectiveServer {
return qerr.NewError(qerr.ProtocolViolation, "received a HANDSHAKE_DONE frame")
}
s.handshakeConfirmed = true
s.sentPacketHandler.SetHandshakeConfirmed()
s.cryptoStreamHandler.SetHandshakeConfirmed()
if !s.handshakeConfirmed {
s.handleHandshakeConfirmed()
}
return nil
}

Expand Down Expand Up @@ -1716,7 +1721,7 @@ func (s *session) sendPacket() (bool, error) {
s.sendQueue.Send(packet.buffer)
return true, nil
}
if !s.config.DisablePathMTUDiscovery && s.handshakeComplete && s.mtuDiscoverer.ShouldSendProbe(now) {
if !s.config.DisablePathMTUDiscovery && s.mtuDiscoverer.ShouldSendProbe(now) {
packet, err := s.packer.PackMTUProbePacket(s.mtuDiscoverer.GetPing())
if err != nil {
return false, err
Expand Down
1 change: 1 addition & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2555,6 +2555,7 @@ var _ = Describe("Client Session", func() {
})

It("handles HANDSHAKE_DONE frames", func() {
sess.peerParams = &wire.TransportParameters{}
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sess.sentPacketHandler = sph
sph.EXPECT().SetHandshakeConfirmed()
Expand Down

0 comments on commit 7f55fc7

Please sign in to comment.