Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webrtc: run onDone callback immediately on close #2729

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ require (
github.com/pion/datachannel v1.5.5
github.com/pion/ice/v2 v2.3.11
github.com/pion/logging v0.2.2
github.com/pion/sctp v1.8.9
github.com/pion/stun v0.6.1
github.com/pion/webrtc/v3 v3.2.23
github.com/prometheus/client_golang v1.18.0
Expand Down Expand Up @@ -105,7 +106,6 @@ require (
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.13 // indirect
github.com/pion/rtp v1.8.3 // indirect
github.com/pion/sctp v1.8.9 // indirect
github.com/pion/sdp/v3 v3.0.6 // indirect
github.com/pion/srtp/v2 v2.0.18 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
Expand Down
50 changes: 22 additions & 28 deletions p2p/transport/webrtc/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ type stream struct {
// SetReadDeadline
// See: https://github.com/pion/sctp/pull/290
controlMessageReaderEndTime time.Time
controlMessageReaderDone sync.WaitGroup

onDoneOnce sync.Once
onDone func()
id uint16 // for logging purposes
dataChannel *datachannel.DataChannel
Expand All @@ -113,8 +113,6 @@ func newStream(
dataChannel: rwc.(*datachannel.DataChannel),
onDone: onDone,
}
// released when the controlMessageReader goroutine exits
s.controlMessageReaderDone.Add(1)
s.dataChannel.SetBufferedAmountLowThreshold(sendBufferLowThreshold)
s.dataChannel.OnBufferedAmountLow(func() {
s.notifyWriteStateChanged()
Expand All @@ -130,7 +128,7 @@ func (s *stream) Close() error {
if isClosed {
return nil
}

defer s.cleanup()
closeWriteErr := s.CloseWrite()
closeReadErr := s.CloseRead()
if closeWriteErr != nil || closeReadErr != nil {
Expand All @@ -142,10 +140,6 @@ func (s *stream) Close() error {
if s.controlMessageReaderEndTime.IsZero() {
s.controlMessageReaderEndTime = time.Now().Add(maxFINACKWait)
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
go func() {
s.controlMessageReaderDone.Wait()
s.cleanup()
}()
}
s.mx.Unlock()
return nil
Expand Down Expand Up @@ -222,17 +216,10 @@ func (s *stream) spawnControlMessageReader() {
s.controlMessageReaderOnce.Do(func() {
// Spawn a goroutine to ensure that we're not holding any locks
go func() {
defer s.controlMessageReaderDone.Done()
// cleanup the sctp deadline timer goroutine
defer s.setDataChannelReadDeadline(time.Time{})

setDeadline := func() bool {
if s.controlMessageReaderEndTime.IsZero() || time.Now().Before(s.controlMessageReaderEndTime) {
s.setDataChannelReadDeadline(s.controlMessageReaderEndTime)
return true
}
return false
}
defer s.dataChannel.Close()

// Unblock any Read call waiting on reader.ReadMsg
s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour))
Expand All @@ -251,12 +238,22 @@ func (s *stream) spawnControlMessageReader() {
s.processIncomingFlag(s.nextMessage.Flag)
s.nextMessage = nil
}
for s.closeForShutdownErr == nil &&
s.sendState != sendStateDataReceived && s.sendState != sendStateReset {
var msg pb.Message
if !setDeadline() {
var msg pb.Message
for {
// Connection closed. No need to cleanup the data channel.
if s.closeForShutdownErr != nil {
return
}
// Write half of the stream completed.
if s.sendState == sendStateDataReceived || s.sendState == sendStateReset {
return
}
// FIN_ACK wait deadling exceeded.
if !s.controlMessageReaderEndTime.IsZero() && time.Now().After(s.controlMessageReaderEndTime) {
return
}

s.setDataChannelReadDeadline(s.controlMessageReaderEndTime)
s.mx.Unlock()
err := s.reader.ReadMsg(&msg)
s.mx.Lock()
Expand All @@ -276,12 +273,9 @@ func (s *stream) spawnControlMessageReader() {
}

func (s *stream) cleanup() {
// Even if we close the datachannel pion keeps a reference to the datachannel around.
// Remove the onBufferedAmountLow callback to ensure that we at least garbage collect
// memory we allocated for this stream.
s.dataChannel.OnBufferedAmountLow(nil)
s.dataChannel.Close()
if s.onDone != nil {
s.onDone()
}
s.onDoneOnce.Do(func() {
if s.onDone != nil {
s.onDone()
}
})
}
109 changes: 72 additions & 37 deletions p2p/transport/webrtc/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ import (

"github.com/libp2p/go-libp2p/p2p/transport/webrtc/pb"
"github.com/libp2p/go-msgio/pbio"
"google.golang.org/protobuf/proto"

"github.com/libp2p/go-libp2p/core/network"

"github.com/pion/datachannel"
"github.com/pion/sctp"
"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -98,6 +100,50 @@ func getDetachedDataChannels(t *testing.T) (detachedChan, detachedChan) {
return <-answerChan, <-offerRWCChan
}

// assertDataChannelOpen checks if the datachannel is open.
// It sends empty messages on the data channel to check if the channel is still open.
// The control message reader goroutine depends on exclusive access to datachannel.Read
// so we have to depend on Write to determine whether the channel has been closed.
func assertDataChannelOpen(t *testing.T, dc *datachannel.DataChannel) {
t.Helper()
emptyMsg := &pb.Message{}
msg, err := proto.Marshal(emptyMsg)
if err != nil {
t.Fatal("unexpected mashalling error", err)
}
for i := 0; i < 3; i++ {
_, err := dc.Write(msg)
if err != nil {
t.Fatal("unexpected write err: ", err)
}
time.Sleep(50 * time.Millisecond)
}
}

// assertDataChannelClosed checks if the datachannel is closed.
// It sends empty messages on the data channel to check if the channel has been closed.
// The control message reader goroutine depends on exclusive access to datachannel.Read
// so we have to depend on Write to determine whether the channel has been closed.
func assertDataChannelClosed(t *testing.T, dc *datachannel.DataChannel) {
t.Helper()
emptyMsg := &pb.Message{}
msg, err := proto.Marshal(emptyMsg)
if err != nil {
t.Fatal("unexpected mashalling error", err)
}
for i := 0; i < 5; i++ {
_, err := dc.Write(msg)
if err != nil {
if errors.Is(err, sctp.ErrStreamClosed) {
return
} else {
t.Fatal("unexpected write err: ", err)
}
}
time.Sleep(50 * time.Millisecond)
}
}

func TestStreamSimpleReadWriteClose(t *testing.T) {
client, server := getDetachedDataChannels(t)

Expand Down Expand Up @@ -357,27 +403,22 @@ func TestStreamCloseAfterFINACK(t *testing.T) {
serverStr := newStream(server.dc, server.rwc, func() {})

go func() {
done <- true
err := clientStr.Close()
assert.NoError(t, err)
}()
<-done

select {
case <-done:
t.Fatalf("Close should not have completed without processing FIN_ACK")
case <-time.After(200 * time.Millisecond):
t.Fatalf("Close should signal OnDone immediately")
}

// Reading FIN_ACK on server should trigger data channel close on the client
b := make([]byte, 1)
_, err := serverStr.Read(b)
require.Error(t, err)
require.ErrorIs(t, err, io.EOF)
select {
case <-done:
case <-time.After(3 * time.Second):
t.Errorf("Close should have completed")
}
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
}

// TestStreamFinAckAfterStopSending tests that FIN_ACK is sent even after the write half
Expand All @@ -400,8 +441,8 @@ func TestStreamFinAckAfterStopSending(t *testing.T) {

select {
case <-done:
t.Errorf("Close should not have completed without processing FIN_ACK")
case <-time.After(500 * time.Millisecond):
t.Errorf("Close should signal onDone immediately")
}

// serverStr has write half closed and read half open
Expand All @@ -410,11 +451,8 @@ func TestStreamFinAckAfterStopSending(t *testing.T) {
_, err := serverStr.Read(b)
require.NoError(t, err)
serverStr.Close() // Sends stop_sending, fin
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatalf("Close should have completed")
}
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
}

func TestStreamConcurrentClose(t *testing.T) {
Expand Down Expand Up @@ -446,26 +484,35 @@ func TestStreamConcurrentClose(t *testing.T) {
case <-time.After(2 * time.Second):
t.Fatalf("concurrent close should succeed quickly")
}

// Wait for FIN_ACK AND datachannel close
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))

}

func TestStreamResetAfterClose(t *testing.T) {
client, _ := getDetachedDataChannels(t)
client, server := getDetachedDataChannels(t)

done := make(chan bool, 2)
clientStr := newStream(client.dc, client.rwc, func() { done <- true })
clientStr.Close()

select {
case <-done:
t.Fatalf("Close shouldn't run cleanup immediately")
case <-time.After(500 * time.Millisecond):
t.Fatalf("Close should run cleanup immediately")
}

// The server data channel should still be open
assertDataChannelOpen(t, server.rwc.(*datachannel.DataChannel))
clientStr.Reset()
// Reset closes the datachannels
assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("Reset should run callback immediately")
t.Fatalf("onDone should not be called twice")
case <-time.After(50 * time.Millisecond):
}
}

Expand All @@ -478,30 +525,18 @@ func TestStreamDataChannelCloseOnFINACK(t *testing.T) {
clientStr.Close()

select {
case <-done:
t.Fatalf("Close shouldn't run cleanup immediately")
case <-time.After(500 * time.Millisecond):
t.Fatalf("Close should run cleanup immediately")
case <-done:
}

// sending FIN_ACK closes the datachannel
serverWriter := pbio.NewDelimitedWriter(server.rwc)
err := serverWriter.WriteMsg(&pb.Message{Flag: pb.Message_FIN_ACK.Enum()})
require.NoError(t, err)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatalf("Callback should be run on reading FIN_ACK")
}
b := make([]byte, 100)
N := 0
for {
n, err := server.rwc.Read(b)
N += n
if err != nil {
require.ErrorIs(t, err, io.EOF)
break
}
}
require.Less(t, N, 10)

assertDataChannelClosed(t, server.rwc.(*datachannel.DataChannel))
assertDataChannelClosed(t, client.rwc.(*datachannel.DataChannel))
}

func TestStreamChunking(t *testing.T) {
Expand Down
14 changes: 6 additions & 8 deletions p2p/transport/webrtc/stream_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ func (s *stream) cancelWrite() error {
return nil
}
s.sendState = sendStateReset
// Remove reference to this stream from data channel
s.dataChannel.OnBufferedAmountLow(nil)
s.notifyWriteStateChanged()
if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}); err != nil {
return err
}
return nil
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()})
}

func (s *stream) CloseWrite() error {
Expand All @@ -144,11 +143,10 @@ func (s *stream) CloseWrite() error {
return nil
}
s.sendState = sendStateDataSent
// Remove reference to this stream from data channel
s.dataChannel.OnBufferedAmountLow(nil)
s.notifyWriteStateChanged()
if err := s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()}); err != nil {
return err
}
return nil
return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_FIN.Enum()})
}

func (s *stream) notifyWriteStateChanged() {
Expand Down
Loading