From aa8b049e706bbde04393f60121f390b4a5cdbb47 Mon Sep 17 00:00:00 2001 From: sukun Date: Mon, 19 Aug 2024 19:00:22 +0530 Subject: [PATCH 1/5] error-codes: add implementation for quic, yamux, websockets, webrtc --- core/network/conn.go | 29 +++ core/network/mux.go | 41 ++++ core/network/stream.go | 4 + go.mod | 2 +- go.sum | 4 +- p2p/muxer/testsuite/mux.go | 3 +- p2p/muxer/yamux/conn.go | 8 +- p2p/muxer/yamux/stream.go | 37 +++- p2p/net/connmgr/connmgr_test.go | 1 + p2p/net/mock/mock_conn.go | 4 + p2p/net/mock/mock_stream.go | 4 + p2p/net/swarm/swarm.go | 8 + p2p/net/swarm/swarm_conn.go | 23 +- p2p/net/swarm/swarm_stream.go | 11 + p2p/net/swarm/swarm_test.go | 2 +- p2p/net/upgrader/conn.go | 7 + p2p/protocol/circuitv2/relay/relay_test.go | 9 +- p2p/test/transport/transport_test.go | 232 +++++++++++++++++++++ p2p/transport/quic/conn.go | 19 +- p2p/transport/quic/conn_test.go | 42 ++++ p2p/transport/quic/listener_test.go | 5 +- p2p/transport/quic/stream.go | 46 +++- p2p/transport/tcp/metrics_unix_test.go | 2 +- p2p/transport/webrtc/connection.go | 4 + p2p/transport/webrtc/pb/message.pb.go | 35 ++-- p2p/transport/webrtc/pb/message.proto | 2 + p2p/transport/webrtc/stream.go | 32 ++- p2p/transport/webrtc/stream_read.go | 16 +- p2p/transport/webrtc/stream_write.go | 10 +- p2p/transport/websocket/conn.go | 8 + p2p/transport/webtransport/conn.go | 4 + p2p/transport/webtransport/stream.go | 4 + test-plans/go.mod | 2 +- test-plans/go.sum | 4 +- 34 files changed, 592 insertions(+), 72 deletions(-) diff --git a/core/network/conn.go b/core/network/conn.go index 3be8cb0d69..64aee09ffb 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -2,6 +2,7 @@ package network import ( "context" + "fmt" "io" ic "github.com/libp2p/go-libp2p/core/crypto" @@ -11,6 +12,29 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +type ConnErrorCode uint32 + +type ConnError struct { + Remote bool + ErrorCode ConnErrorCode + TransportError error +} + +func (c *ConnError) Error() string { + side := "local" + if c.Remote { + side = "remote" + } + if c.TransportError != nil { + return fmt.Sprintf("connection closed (%s): code: %d: transport error: %s", side, c.ErrorCode, c.TransportError) + } + return fmt.Sprintf("connection closed (%s): code: %d", side, c.ErrorCode) +} + +func (c *ConnError) Unwrap() error { + return c.TransportError +} + // Conn is a connection to a remote peer. It multiplexes streams. // Usually there is no need to use a Conn directly, but it may // be useful to get information about the peer on the other side: @@ -24,6 +48,11 @@ type Conn interface { ConnStat ConnScoper + // CloseWithError closes the connection with errCode. The errCode is sent to the + // peer on a best effort basis. For transports that do not support sending error + // codes on connection close, the behavior is identical to calling Close. + CloseWithError(errCode ConnErrorCode) error + // ID returns an identifier that uniquely identifies this Conn within this // host, during this run. Connection IDs may repeat across restarts. ID() string diff --git a/core/network/mux.go b/core/network/mux.go index d12e2ea34b..34d4035429 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -3,6 +3,7 @@ package network import ( "context" "errors" + "fmt" "io" "net" "time" @@ -11,6 +12,33 @@ import ( // ErrReset is returned when reading or writing on a reset stream. var ErrReset = errors.New("stream reset") +type StreamErrorCode uint32 + +type StreamError struct { + ErrorCode StreamErrorCode + Remote bool + TransportError error +} + +func (s *StreamError) Error() string { + side := "local" + if s.Remote { + side = "remote" + } + if s.TransportError != nil { + return fmt.Sprintf("stream reset (%s): code: %d: transport error: %s", side, s.ErrorCode, s.TransportError) + } + return fmt.Sprintf("stream reset (%s): code: %d", side, s.ErrorCode) +} + +func (s *StreamError) Is(target error) bool { + return target == ErrReset +} + +func (s *StreamError) Unwrap() error { + return s.TransportError +} + // MuxedStream is a bidirectional io pipe within a connection. type MuxedStream interface { io.Reader @@ -61,6 +89,13 @@ type MuxedStream interface { SetWriteDeadline(time.Time) error } +type ResetWithErrorer interface { + // ResetWithError closes both ends of the stream with errCode. The errCode is sent + // to the peer on a best effort basis. For transports that do not support sending + // error codes to remote peer, the behavior is identical to calling Reset + ResetWithError(errCode StreamErrorCode) error +} + // MuxedConn represents a connection to a remote peer that has been // extended to support stream multiplexing. // @@ -86,6 +121,12 @@ type MuxedConn interface { AcceptStream() (MuxedStream, error) } +type CloseWithErrorer interface { + // CloseWithError closes the connection with errCode. The errCode is sent + // to the peer. + CloseWithError(errCode ConnErrorCode) error +} + // Multiplexer wraps a net.Conn with a stream multiplexing // implementation and returns a MuxedConn that supports opening // multiple streams over the underlying net.Conn diff --git a/core/network/stream.go b/core/network/stream.go index 62e230034c..f2b6cbcb88 100644 --- a/core/network/stream.go +++ b/core/network/stream.go @@ -27,4 +27,8 @@ type Stream interface { // Scope returns the user's view of this stream's resource scope Scope() StreamScope + + // ResetWithError closes both ends of the stream with errCode. The errCode is sent + // to the peer. + ResetWithError(errCode StreamErrorCode) error } diff --git a/go.mod b/go.mod index c6f3f2b324..b36b99798c 100644 --- a/go.mod +++ b/go.mod @@ -32,7 +32,7 @@ require ( github.com/libp2p/go-nat v0.2.0 github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-reuseport v0.4.0 - github.com/libp2p/go-yamux/v4 v4.0.1 + github.com/libp2p/go-yamux/v4 v4.0.2-0.20241120100319-39abe7ed206a github.com/libp2p/zeroconf/v2 v2.2.0 github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b diff --git a/go.sum b/go.sum index 4650eef1f0..d2c99d8702 100644 --- a/go.sum +++ b/go.sum @@ -194,8 +194,8 @@ github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9t github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= -github.com/libp2p/go-yamux/v4 v4.0.1 h1:FfDR4S1wj6Bw2Pqbc8Uz7pCxeRBPbwsBbEdfwiCypkQ= -github.com/libp2p/go-yamux/v4 v4.0.1/go.mod h1:NWjl8ZTLOGlozrXSOZ/HlfG++39iKNnM5wwmtQP1YB4= +github.com/libp2p/go-yamux/v4 v4.0.2-0.20241120100319-39abe7ed206a h1:zc7jPWFFQibZbACDyQdEAWg7yG/fjx5Jmg6djtpjKog= +github.com/libp2p/go-yamux/v4 v4.0.2-0.20241120100319-39abe7ed206a/go.mod h1:PGP+3py2ZWDKABvqstBZtMnixEHNC7U/odnGylzur5o= github.com/libp2p/zeroconf/v2 v2.2.0 h1:Cup06Jv6u81HLhIj1KasuNM/RHHrJ8T7wOTS4+Tv53Q= github.com/libp2p/zeroconf/v2 v2.2.0/go.mod h1:fuJqLnUwZTshS3U/bMRJ3+ow/v9oid1n0DmyYyNO1Xs= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= diff --git a/p2p/muxer/testsuite/mux.go b/p2p/muxer/testsuite/mux.go index 5b47117fd6..93d24785ea 100644 --- a/p2p/muxer/testsuite/mux.go +++ b/p2p/muxer/testsuite/mux.go @@ -4,6 +4,7 @@ import ( "bytes" "context" crand "crypto/rand" + "errors" "fmt" "io" mrand "math/rand" @@ -462,7 +463,7 @@ func SubtestStreamReset(t *testing.T, tr network.Multiplexer) { time.Sleep(time.Millisecond * 50) _, err = s.Write([]byte("foo")) - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Error("should have been stream reset") } s.Close() diff --git a/p2p/muxer/yamux/conn.go b/p2p/muxer/yamux/conn.go index 40c4af4052..4531771842 100644 --- a/p2p/muxer/yamux/conn.go +++ b/p2p/muxer/yamux/conn.go @@ -23,6 +23,10 @@ func (c *conn) Close() error { return c.yamux().Close() } +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.yamux().CloseWithError(uint32(errCode)) +} + // IsClosed checks if yamux.Session is in closed state. func (c *conn) IsClosed() bool { return c.yamux().IsClosed() @@ -32,7 +36,7 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { s, err := c.yamux().OpenStream(ctx) if err != nil { - return nil, err + return nil, parseResetError(err) } return (*stream)(s), nil @@ -41,7 +45,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { s, err := c.yamux().AcceptStream() - return (*stream)(s), err + return (*stream)(s), parseResetError(err) } func (c *conn) yamux() *yamux.Session { diff --git a/p2p/muxer/yamux/stream.go b/p2p/muxer/yamux/stream.go index b50bc0bb87..f2151b8658 100644 --- a/p2p/muxer/yamux/stream.go +++ b/p2p/muxer/yamux/stream.go @@ -1,6 +1,8 @@ package yamux import ( + "errors" + "fmt" "time" "github.com/libp2p/go-libp2p/core/network" @@ -13,22 +15,33 @@ type stream yamux.Stream var _ network.MuxedStream = &stream{} -func (s *stream) Read(b []byte) (n int, err error) { - n, err = s.yamux().Read(b) - if err == yamux.ErrStreamReset { - err = network.ErrReset +func parseResetError(err error) error { + if err == nil { + return err + } + se := &yamux.StreamError{} + if errors.As(err, &se) { + return &network.StreamError{Remote: se.Remote, ErrorCode: network.StreamErrorCode(se.ErrorCode)} } + ce := &yamux.GoAwayError{} + if errors.As(err, &ce) { + return &network.ConnError{Remote: ce.Remote, ErrorCode: network.ConnErrorCode(ce.ErrorCode)} + } + // TODO: How should we handle resets for reason other than a remote error + if errors.Is(err, yamux.ErrStreamReset) { + return fmt.Errorf("%w: %w", network.ErrReset, err) + } + return err +} - return n, err +func (s *stream) Read(b []byte) (n int, err error) { + n, err = s.yamux().Read(b) + return n, parseResetError(err) } func (s *stream) Write(b []byte) (n int, err error) { n, err = s.yamux().Write(b) - if err == yamux.ErrStreamReset { - err = network.ErrReset - } - - return n, err + return n, parseResetError(err) } func (s *stream) Close() error { @@ -39,6 +52,10 @@ func (s *stream) Reset() error { return s.yamux().Reset() } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + return s.yamux().ResetWithError(uint32(errCode)) +} + func (s *stream) CloseRead() error { return s.yamux().CloseRead() } diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 2c657255f0..5955265f9b 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -794,6 +794,7 @@ type mockConn struct { } func (m mockConn) Close() error { panic("implement me") } +func (m mockConn) CloseWithError(errCode network.ConnErrorCode) error { panic("implement me") } func (m mockConn) LocalPeer() peer.ID { panic("implement me") } func (m mockConn) RemotePeer() peer.ID { panic("implement me") } func (m mockConn) RemotePublicKey() crypto.PubKey { panic("implement me") } diff --git a/p2p/net/mock/mock_conn.go b/p2p/net/mock/mock_conn.go index 8c3dc87299..fc4e0ad670 100644 --- a/p2p/net/mock/mock_conn.go +++ b/p2p/net/mock/mock_conn.go @@ -185,3 +185,7 @@ func (c *conn) Stat() network.ConnStats { func (c *conn) Scope() network.ConnScope { return &network.NullScope{} } + +func (c *conn) CloseWithError(_ network.ConnErrorCode) error { + return c.Close() +} diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index c85cca544d..3ba29ddd80 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -144,6 +144,10 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + panic("not implemented") +} + func (s *stream) teardown() { // at this point, no streams are writing. s.conn.removeStream(s) diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index ef1fc2a2b3..12e83d38ef 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -840,6 +840,14 @@ func (c connWithMetrics) Close() error { return c.CapableConn.Close() } +func (c connWithMetrics) CloseWithError(errCode network.ConnErrorCode) error { + c.metricsTracer.ClosedConnection(c.dir, time.Since(c.opened), c.ConnState(), c.LocalMultiaddr()) + if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return c.CapableConn.Close() +} + func (c connWithMetrics) Stat() network.ConnStats { if cs, ok := c.CapableConn.(network.ConnStat); ok { return cs.Stat() diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index 5fd41c8d9f..b7cc46fb71 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -58,11 +58,20 @@ func (c *Conn) ID() string { // open notifications must finish before we can fire off the close // notifications). func (c *Conn) Close() error { - c.closeOnce.Do(c.doClose) + c.closeOnce.Do(func() { + c.doClose(0) + }) return c.err } -func (c *Conn) doClose() { +func (c *Conn) CloseWithError(errCode network.ConnErrorCode) error { + c.closeOnce.Do(func() { + c.doClose(errCode) + }) + return c.err +} + +func (c *Conn) doClose(errCode network.ConnErrorCode) { c.swarm.removeConn(c) // Prevent new streams from opening. @@ -71,7 +80,15 @@ func (c *Conn) doClose() { c.streams.m = nil c.streams.Unlock() - c.err = c.conn.Close() + if errCode != 0 { + if ce, ok := c.conn.(network.CloseWithErrorer); ok { + c.err = ce.CloseWithError(errCode) + } else { + c.err = c.conn.Close() + } + } else { + c.err = c.conn.Close() + } // Send the connectedness event after closing the connection. // This ensures that both remote connection close and local connection diff --git a/p2p/net/swarm/swarm_stream.go b/p2p/net/swarm/swarm_stream.go index b7846adec2..437921aaff 100644 --- a/p2p/net/swarm/swarm_stream.go +++ b/p2p/net/swarm/swarm_stream.go @@ -91,6 +91,17 @@ func (s *Stream) Reset() error { return err } +func (s *Stream) ResetWithError(errCode network.StreamErrorCode) error { + var err error + if se, ok := s.stream.(network.ResetWithErrorer); ok { + err = se.ResetWithError(errCode) + } else { + err = s.stream.Reset() + } + s.closeAndRemoveStream() + return err +} + func (s *Stream) closeAndRemoveStream() { s.closeMx.Lock() defer s.closeMx.Unlock() diff --git a/p2p/net/swarm/swarm_test.go b/p2p/net/swarm/swarm_test.go index 3d92690b98..496236f826 100644 --- a/p2p/net/swarm/swarm_test.go +++ b/p2p/net/swarm/swarm_test.go @@ -538,7 +538,7 @@ func TestResourceManagerAcceptStream(t *testing.T) { if err == nil { _, err = str.Read([]byte{0}) } - require.EqualError(t, err, "stream reset") + require.ErrorContains(t, err, "stream reset") } func TestListenCloseCount(t *testing.T) { diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 1c23a01aed..18e1e6a931 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -63,3 +63,10 @@ func (t *transportConn) ConnState() network.ConnectionState { UsedEarlyMuxerNegotiation: t.usedEarlyMuxerNegotiation, } } + +func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error { + if ce, ok := t.MuxedConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return t.Close() +} diff --git a/p2p/protocol/circuitv2/relay/relay_test.go b/p2p/protocol/circuitv2/relay/relay_test.go index f6b63e32de..228f2d3b4a 100644 --- a/p2p/protocol/circuitv2/relay/relay_test.go +++ b/p2p/protocol/circuitv2/relay/relay_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/rand" + "errors" "fmt" "io" "testing" @@ -267,12 +268,12 @@ func TestRelayLimitTime(t *testing.T) { if n > 0 { t.Fatalf("expected to write 0 bytes, wrote %d", n) } - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Fatalf("expected reset, but got %s", err) } err = <-rch - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Fatalf("expected reset, but got %s", err) } } @@ -300,7 +301,7 @@ func TestRelayLimitData(t *testing.T) { } n, err := s.Read(buf) - if err != network.ErrReset { + if !errors.Is(err, network.ErrReset) { t.Fatalf("expected reset but got %s", err) } rch <- n @@ -308,6 +309,7 @@ func TestRelayLimitData(t *testing.T) { rc := relay.DefaultResources() rc.Limit.Duration = time.Second + // Due to yamux framing, 4 blocks of 1024 bytes will exceed the data limit rc.Limit.Data = 4096 r, err := relay.New(hosts[1], relay.WithResources(rc)) @@ -367,7 +369,6 @@ func TestRelayLimitData(t *testing.T) { t.Fatalf("expected to read %d bytes but read %d", len(buf), n) } } - buf = make([]byte, 4096) if _, err := rand.Read(buf); err != nil { t.Fatal(err) diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 60f8ca0c06..195eb95886 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -35,6 +35,7 @@ import ( "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -830,3 +831,234 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) { }) } } + +// TestStreamErrorCode tests correctness for resetting stream with errors +func TestStreamErrorCode(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.StreamErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, code) + require.Equal(t, se.Remote, remote) + return + } + t.Fatal("expected network.StreamError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + _, err = s.Read(b) + fmt.Println(err) + errCh <- err + + _, err = s.Write(b) + fmt.Println(err) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.ResetWithError(42) + require.NoError(t, err) + + _, err = s.Read(buf) + checkError(err, 42, false) + + _, err = s.Write(buf) + checkError(err, 42, false) + + err = <-errCh // read error + checkError(err, 42, true) + + err = <-errCh // write error + checkError(err, 42, true) + }) + } +} + +// TestStreamErrorCodeConnClosed tests correctness for resetting stream with errors +func TestStreamErrorCodeConnClosed(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.ConnErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + ce := &network.ConnError{} + if errors.As(err, &ce) { + require.Equal(t, code, ce.ErrorCode) + require.Equal(t, remote, ce.Remote) + return + } + t.Fatal("expected network.ConnError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + _, err = s.Read(b) + errCh <- err + + _, err = s.Write(b) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + _, err = s.Read(buf) + checkError(err, 42, false) + + _, err = s.Write(buf) + checkError(err, 42, false) + + err = <-errCh + checkError(err, 42, true) + + err = <-errCh + checkError(err, 42, true) + }) + } +} + +// TestConnectionErrorCode tests correctness for resetting stream with errors +func TestConnectionErrorCode(t *testing.T) { + for _, tc := range transportsToTest { + t.Run(tc.Name, func(t *testing.T) { + if tc.Name == "WebTransport" || tc.Name == "WebRTC" { + t.Skipf("skipping: %s, not implemented", tc.Name) + return + } + server := tc.HostGenerator(t, TransportTestCaseOpts{}) + client := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true}) + defer server.Close() + defer client.Close() + + checkError := func(err error, code network.ConnErrorCode, remote bool) { + t.Helper() + if err == nil { + t.Fatal("expected non nil error") + } + ce := &network.ConnError{} + if errors.As(err, &ce) { + require.Equal(t, code, ce.ErrorCode) + require.Equal(t, remote, ce.Remote) + return + } + t.Fatal("expected network.ConnError, got:", err) + } + + errCh := make(chan error) + server.SetStreamHandler("/test", func(s network.Stream) { + defer s.Reset() + b := make([]byte, 10) + n, err := s.Read(b) + if !assert.NoError(t, err) { + return + } + _, err = s.Write(b[:n]) + if !assert.NoError(t, err) { + return + } + + _, err = s.Read(b) + if !assert.Error(t, err) { + return + } + _, err = s.Conn().NewStream(context.Background()) + errCh <- err + }) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + client.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL) + s, err := client.NewStream(ctx, server.ID(), "/test") + require.NoError(t, err) + + _, err = s.Write([]byte("hello")) + require.NoError(t, err) + + buf := make([]byte, 10) + n, err := s.Read(buf) + require.NoError(t, err) + require.Equal(t, []byte("hello"), buf[:n]) + + err = s.Conn().CloseWithError(42) + require.NoError(t, err) + + str, err := s.Conn().NewStream(context.Background()) + require.Nil(t, str) + checkError(err, 42, false) + + err = <-errCh + checkError(err, 42, true) + + }) + } +} diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index a2da81eb34..16512b4fbe 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -2,6 +2,7 @@ package libp2pquic import ( "context" + "fmt" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -34,6 +35,13 @@ func (c *conn) Close() error { return c.closeWithError(0, "") } +// CloseWithError closes the connection +// It must be called even if the peer closed the connection in order for +// garbage collection to properly work in this package. +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.closeWithError(quic.ApplicationErrorCode(errCode), "") +} + func (c *conn) closeWithError(errCode quic.ApplicationErrorCode, errString string) error { c.transport.removeConn(c.quicConn) err := c.quicConn.CloseWithError(errCode, errString) @@ -53,13 +61,20 @@ func (c *conn) allowWindowIncrease(size uint64) bool { // OpenStream creates a new stream. func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { qstr, err := c.quicConn.OpenStreamSync(ctx) - return &stream{Stream: qstr}, err + if err != nil { + return nil, parseStreamError(err) + } + return &stream{Stream: qstr}, nil } // AcceptStream accepts a stream opened by the other side. func (c *conn) AcceptStream() (network.MuxedStream, error) { qstr, err := c.quicConn.AcceptStream(context.Background()) - return &stream{Stream: qstr}, err + if err != nil { + fmt.Println("got error", err) + return nil, parseStreamError(err) + } + return &stream{Stream: qstr}, nil } // LocalPeer returns our peer ID diff --git a/p2p/transport/quic/conn_test.go b/p2p/transport/quic/conn_test.go index d3e27a7e16..7baed97efc 100644 --- a/p2p/transport/quic/conn_test.go +++ b/p2p/transport/quic/conn_test.go @@ -270,6 +270,9 @@ func TestStreams(t *testing.T) { t.Run(tc.Name, func(t *testing.T) { testStreams(t, tc) }) + t.Run(tc.Name, func(t *testing.T) { + testStreamsErrorCode(t, tc) + }) } } @@ -305,6 +308,45 @@ func testStreams(t *testing.T, tc *connTestCase) { require.Equal(t, data, []byte("foobar")) } +func testStreamsErrorCode(t *testing.T, tc *connTestCase) { + serverID, serverKey := createPeer(t) + _, clientKey := createPeer(t) + + serverTransport, err := NewTransport(serverKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer serverTransport.(io.Closer).Close() + ln := runServer(t, serverTransport, "/ip4/127.0.0.1/udp/0/quic-v1") + defer ln.Close() + + clientTransport, err := NewTransport(clientKey, newConnManager(t, tc.Options...), nil, nil, nil) + require.NoError(t, err) + defer clientTransport.(io.Closer).Close() + conn, err := clientTransport.Dial(context.Background(), ln.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + serverConn, err := ln.Accept() + require.NoError(t, err) + defer serverConn.Close() + + str, err := conn.OpenStream(context.Background()) + require.NoError(t, err) + err = str.(network.ResetWithErrorer).ResetWithError(42) + require.NoError(t, err) + + sstr, err := serverConn.AcceptStream() + require.NoError(t, err) + _, err = io.ReadAll(sstr) + require.Error(t, err) + se := &network.StreamError{} + if errors.As(err, &se) { + require.Equal(t, se.ErrorCode, network.StreamErrorCode(42)) + require.True(t, se.Remote) + } else { + t.Fatalf("expected error to be of network.StreamError type, got %T, %v", err, err) + } + +} + func TestHandshakeFailPeerIDMismatch(t *testing.T) { for _, tc := range connTestCases { t.Run(tc.Name, func(t *testing.T) { diff --git a/p2p/transport/quic/listener_test.go b/p2p/transport/quic/listener_test.go index dbd6d810e4..53d6001d35 100644 --- a/p2p/transport/quic/listener_test.go +++ b/p2p/transport/quic/listener_test.go @@ -159,10 +159,11 @@ func TestCleanupConnWhenBlocked(t *testing.T) { s.SetReadDeadline(time.Now().Add(10 * time.Second)) b := [1]byte{} _, err = s.Read(b[:]) - if err != nil && errors.As(err, &quicErr) { + connError := &network.ConnError{} + if err != nil && errors.As(err, &connError) { // We hit our expected application error return } - t.Fatalf("expected application error, got %v", err) + t.Fatalf("expected network.ConnError, got %v", err) } diff --git a/p2p/transport/quic/stream.go b/p2p/transport/quic/stream.go index 56f12dade2..921e17c76c 100644 --- a/p2p/transport/quic/stream.go +++ b/p2p/transport/quic/stream.go @@ -2,6 +2,7 @@ package libp2pquic import ( "errors" + "math" "github.com/libp2p/go-libp2p/core/network" @@ -18,20 +19,43 @@ type stream struct { var _ network.MuxedStream = &stream{} +func parseStreamError(err error) error { + if err == nil { + return err + } + se := &quic.StreamError{} + if errors.As(err, &se) { + code := se.ErrorCode + if code > math.MaxUint32 { + // TODO(sukunrt): do we need this? + code = reset + } + err = &network.StreamError{ + ErrorCode: network.StreamErrorCode(code), + Remote: se.Remote, + TransportError: se, + } + } + ae := &quic.ApplicationError{} + if errors.As(err, &ae) { + code := ae.ErrorCode + err = &network.ConnError{ + ErrorCode: network.ConnErrorCode(code), + Remote: ae.Remote, + TransportError: ae, + } + } + return err +} + func (s *stream) Read(b []byte) (n int, err error) { n, err = s.Stream.Read(b) - if err != nil && errors.Is(err, &quic.StreamError{}) { - err = network.ErrReset - } - return n, err + return n, parseStreamError(err) } func (s *stream) Write(b []byte) (n int, err error) { n, err = s.Stream.Write(b) - if err != nil && errors.Is(err, &quic.StreamError{}) { - err = network.ErrReset - } - return n, err + return n, parseStreamError(err) } func (s *stream) Reset() error { @@ -40,6 +64,12 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + s.Stream.CancelRead(quic.StreamErrorCode(errCode)) + s.Stream.CancelWrite(quic.StreamErrorCode(errCode)) + return nil +} + func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close() diff --git a/p2p/transport/tcp/metrics_unix_test.go b/p2p/transport/tcp/metrics_unix_test.go index 0a09526206..fa07217585 100644 --- a/p2p/transport/tcp/metrics_unix_test.go +++ b/p2p/transport/tcp/metrics_unix_test.go @@ -1,4 +1,4 @@ -// go:build: unix +//go:build: unix package tcp diff --git a/p2p/transport/webrtc/connection.go b/p2p/transport/webrtc/connection.go index 2fba37a970..57c0df7e94 100644 --- a/p2p/transport/webrtc/connection.go +++ b/p2p/transport/webrtc/connection.go @@ -132,6 +132,10 @@ func (c *connection) Close() error { return nil } +func (c *connection) CloseWithError(errCode network.ConnErrorCode) error { + return c.Close() +} + // closeWithError is used to Close the connection when the underlying DTLS connection fails func (c *connection) closeWithError(err error) { c.closeOnce.Do(func() { diff --git a/p2p/transport/webrtc/pb/message.pb.go b/p2p/transport/webrtc/pb/message.pb.go index d7d4d583af..6e7b54f2b1 100644 --- a/p2p/transport/webrtc/pb/message.pb.go +++ b/p2p/transport/webrtc/pb/message.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.35.1 -// protoc v5.28.2 +// protoc-gen-go v1.35.2 +// protoc v5.28.3 // source: p2p/transport/webrtc/pb/message.proto package pb @@ -95,8 +95,9 @@ type Message struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"` - Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` + Flag *Message_Flag `protobuf:"varint,1,opt,name=flag,enum=Message_Flag" json:"flag,omitempty"` + Message []byte `protobuf:"bytes,2,opt,name=message" json:"message,omitempty"` + ErrorCode *uint32 `protobuf:"varint,3,opt,name=errorCode" json:"errorCode,omitempty"` } func (x *Message) Reset() { @@ -143,24 +144,32 @@ func (x *Message) GetMessage() []byte { return nil } +func (x *Message) GetErrorCode() uint32 { + if x != nil && x.ErrorCode != nil { + return *x.ErrorCode + } + return 0 +} + var File_p2p_transport_webrtc_pb_message_proto protoreflect.FileDescriptor var file_p2p_transport_webrtc_pb_message_proto_rawDesc = []byte{ 0x0a, 0x25, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, 0x2f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x81, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x9f, 0x01, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0d, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x2e, 0x46, 0x6c, 0x61, 0x67, 0x52, 0x04, 0x66, 0x6c, 0x61, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x22, 0x39, 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, - 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, - 0x47, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, - 0x0a, 0x07, 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, - 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, - 0x2f, 0x67, 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, - 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, - 0x70, 0x62, + 0x12, 0x1c, 0x0a, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0d, 0x52, 0x09, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x22, 0x39, + 0x0a, 0x04, 0x46, 0x6c, 0x61, 0x67, 0x12, 0x07, 0x0a, 0x03, 0x46, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x10, 0x0a, 0x0c, 0x53, 0x54, 0x4f, 0x50, 0x5f, 0x53, 0x45, 0x4e, 0x44, 0x49, 0x4e, 0x47, 0x10, + 0x01, 0x12, 0x09, 0x0a, 0x05, 0x52, 0x45, 0x53, 0x45, 0x54, 0x10, 0x02, 0x12, 0x0b, 0x0a, 0x07, + 0x46, 0x49, 0x4e, 0x5f, 0x41, 0x43, 0x4b, 0x10, 0x03, 0x42, 0x35, 0x5a, 0x33, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x67, + 0x6f, 0x2d, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2f, 0x70, 0x32, 0x70, 0x2f, 0x74, 0x72, 0x61, + 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2f, 0x70, 0x62, } var ( diff --git a/p2p/transport/webrtc/pb/message.proto b/p2p/transport/webrtc/pb/message.proto index aab885b0da..2401f7c4d2 100644 --- a/p2p/transport/webrtc/pb/message.proto +++ b/p2p/transport/webrtc/pb/message.proto @@ -21,4 +21,6 @@ message Message { optional Flag flag=1; optional bytes message = 2; + + optional uint32 errorCode = 3; } diff --git a/p2p/transport/webrtc/stream.go b/p2p/transport/webrtc/stream.go index 56f869f5e1..54c937c73c 100644 --- a/p2p/transport/webrtc/stream.go +++ b/p2p/transport/webrtc/stream.go @@ -69,8 +69,9 @@ type stream struct { // readerMx ensures that only a single goroutine reads from the reader. Read is not threadsafe // But we may need to read from reader for control messages from a different goroutine. - readerMx sync.Mutex - reader pbio.Reader + readerMx sync.Mutex + reader pbio.Reader + readError error // this buffer is limited up to a single message. Reason we need it // is because a reader might read a message midway, and so we need a @@ -82,6 +83,7 @@ type stream struct { writeStateChanged chan struct{} sendState sendState writeDeadline time.Time + writeError error controlMessageReaderOnce sync.Once // controlMessageReaderEndTime is the end time for reading FIN_ACK from the control @@ -146,6 +148,10 @@ func (s *stream) Close() error { } func (s *stream) Reset() error { + return s.ResetWithError(0) +} + +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { s.mx.Lock() isClosed := s.closeForShutdownErr != nil s.mx.Unlock() @@ -154,8 +160,8 @@ func (s *stream) Reset() error { } defer s.cleanup() - cancelWriteErr := s.cancelWrite() - closeReadErr := s.CloseRead() + cancelWriteErr := s.cancelWrite(errCode) + closeReadErr := s.closeRead(errCode, false) s.setDataChannelReadDeadline(time.Now().Add(-1 * time.Hour)) return errors.Join(closeReadErr, cancelWriteErr) } @@ -175,19 +181,20 @@ func (s *stream) SetDeadline(t time.Time) error { return s.SetWriteDeadline(t) } -// processIncomingFlag process the flag on an incoming message +// processIncomingFlag processes the flag(FIN/RST/etc) on msg. // It needs to be called while the mutex is locked. -func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { - if flag == nil { +func (s *stream) processIncomingFlag(msg *pb.Message) { + if msg.Flag == nil { return } - switch *flag { + switch msg.GetFlag() { case pb.Message_STOP_SENDING: // We must process STOP_SENDING after sending a FIN(sendStateDataSent). Remote peer // may not send a FIN_ACK once it has sent a STOP_SENDING if s.sendState == sendStateSending || s.sendState == sendStateDataSent { s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} } s.notifyWriteStateChanged() case pb.Message_FIN_ACK: @@ -206,6 +213,11 @@ func (s *stream) processIncomingFlag(flag *pb.Message_Flag) { case pb.Message_RESET: if s.receiveState == receiveStateReceiving { s.receiveState = receiveStateReset + s.readError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} + } + if s.sendState == sendStateSending || s.sendState == sendStateDataSent { + s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: true, ErrorCode: network.StreamErrorCode(msg.GetErrorCode())} } s.spawnControlMessageReader() } @@ -235,7 +247,7 @@ func (s *stream) spawnControlMessageReader() { s.readerMx.Unlock() if s.nextMessage != nil { - s.processIncomingFlag(s.nextMessage.Flag) + s.processIncomingFlag(s.nextMessage) s.nextMessage = nil } var msg pb.Message @@ -266,7 +278,7 @@ func (s *stream) spawnControlMessageReader() { } return } - s.processIncomingFlag(msg.Flag) + s.processIncomingFlag(&msg) } }() }) diff --git a/p2p/transport/webrtc/stream_read.go b/p2p/transport/webrtc/stream_read.go index 80d99ea91c..826eec3049 100644 --- a/p2p/transport/webrtc/stream_read.go +++ b/p2p/transport/webrtc/stream_read.go @@ -22,7 +22,7 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return 0, io.EOF case receiveStateReset: - return 0, network.ErrReset + return 0, s.readError } if len(b) == 0 { @@ -52,10 +52,11 @@ func (s *stream) Read(b []byte) (int, error) { // datachannel. For these implementations a stream reset will be observed as an // abrupt closing of the datachannel. s.receiveState = receiveStateReset - return 0, network.ErrReset + s.readError = &network.StreamError{Remote: true} + return 0, s.readError } if s.receiveState == receiveStateReset { - return 0, network.ErrReset + return 0, s.readError } if s.receiveState == receiveStateDataRead { return 0, io.EOF @@ -73,7 +74,7 @@ func (s *stream) Read(b []byte) (int, error) { } // process flags on the message after reading all the data - s.processIncomingFlag(s.nextMessage.Flag) + s.processIncomingFlag(s.nextMessage) s.nextMessage = nil if s.closeForShutdownErr != nil { return read, s.closeForShutdownErr @@ -82,7 +83,7 @@ func (s *stream) Read(b []byte) (int, error) { case receiveStateDataRead: return read, io.EOF case receiveStateReset: - return read, network.ErrReset + return read, s.readError } } } @@ -101,12 +102,17 @@ func (s *stream) setDataChannelReadDeadline(t time.Time) error { } func (s *stream) CloseRead() error { + return s.closeRead(0, false) +} + +func (s *stream) closeRead(errCode network.StreamErrorCode, remote bool) error { s.mx.Lock() defer s.mx.Unlock() var err error if s.receiveState == receiveStateReceiving && s.closeForShutdownErr == nil { err = s.writer.WriteMsg(&pb.Message{Flag: pb.Message_STOP_SENDING.Enum()}) s.receiveState = receiveStateReset + s.readError = &network.StreamError{Remote: remote, ErrorCode: errCode} } s.spawnControlMessageReader() return err diff --git a/p2p/transport/webrtc/stream_write.go b/p2p/transport/webrtc/stream_write.go index 534a8d8e60..01fddac331 100644 --- a/p2p/transport/webrtc/stream_write.go +++ b/p2p/transport/webrtc/stream_write.go @@ -24,7 +24,7 @@ func (s *stream) Write(b []byte) (int, error) { } switch s.sendState { case sendStateReset: - return 0, network.ErrReset + return 0, s.writeError case sendStateDataSent, sendStateDataReceived: return 0, errWriteAfterClose } @@ -48,7 +48,7 @@ func (s *stream) Write(b []byte) (int, error) { } switch s.sendState { case sendStateReset: - return n, network.ErrReset + return n, s.writeError case sendStateDataSent, sendStateDataReceived: return n, errWriteAfterClose } @@ -119,7 +119,7 @@ func (s *stream) availableSendSpace() int { return availableSpace } -func (s *stream) cancelWrite() error { +func (s *stream) cancelWrite(errCode network.StreamErrorCode) error { s.mx.Lock() defer s.mx.Unlock() @@ -129,10 +129,12 @@ func (s *stream) cancelWrite() error { return nil } s.sendState = sendStateReset + s.writeError = &network.StreamError{Remote: false, ErrorCode: errCode} // Remove reference to this stream from data channel s.dataChannel.OnBufferedAmountLow(nil) s.notifyWriteStateChanged() - return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum()}) + code := uint32(errCode) + return s.writer.WriteMsg(&pb.Message{Flag: pb.Message_RESET.Enum(), ErrorCode: &code}) } func (s *stream) CloseWrite() error { diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index ce51611703..eca2262c58 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -203,3 +203,11 @@ func (c *capableConn) ConnState() network.ConnectionState { cs.Transport = "websocket" return cs } + +// CloseWithError implements network.CloseWithErrorer +func (c *capableConn) CloseWithError(errCode network.ConnErrorCode) error { + if ce, ok := c.CapableConn.(network.CloseWithErrorer); ok { + return ce.CloseWithError(errCode) + } + return c.Close() +} diff --git a/p2p/transport/webtransport/conn.go b/p2p/transport/webtransport/conn.go index d914398e0e..3618548d14 100644 --- a/p2p/transport/webtransport/conn.go +++ b/p2p/transport/webtransport/conn.go @@ -78,6 +78,10 @@ func (c *conn) Close() error { return err } +func (c *conn) CloseWithError(errCode network.ConnErrorCode) error { + return c.Close() +} + func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil } func (c *conn) Scope() network.ConnScope { return c.scope } func (c *conn) Transport() tpt.Transport { return c.transport } diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index 0849fc9f38..583708edc2 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -56,6 +56,10 @@ func (s *stream) Reset() error { return nil } +func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { + panic("not implemented") +} + func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close() diff --git a/test-plans/go.mod b/test-plans/go.mod index b2eee27810..13f7c05dcc 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -46,7 +46,7 @@ require ( github.com/libp2p/go-nat v0.2.0 // indirect github.com/libp2p/go-netroute v0.2.1 // indirect github.com/libp2p/go-reuseport v0.4.0 // indirect - github.com/libp2p/go-yamux/v4 v4.0.1 // indirect + github.com/libp2p/go-yamux/v4 v4.0.2-0.20241120100319-39abe7ed206a // indirect github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/miekg/dns v1.1.62 // indirect diff --git a/test-plans/go.sum b/test-plans/go.sum index cbb839c369..00d9c56a44 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -150,8 +150,8 @@ github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9t github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/libp2p/go-reuseport v0.4.0 h1:nR5KU7hD0WxXCJbmw7r2rhRYruNRl2koHw8fQscQm2s= github.com/libp2p/go-reuseport v0.4.0/go.mod h1:ZtI03j/wO5hZVDFo2jKywN6bYKWLOy8Se6DrI2E1cLU= -github.com/libp2p/go-yamux/v4 v4.0.1 h1:FfDR4S1wj6Bw2Pqbc8Uz7pCxeRBPbwsBbEdfwiCypkQ= -github.com/libp2p/go-yamux/v4 v4.0.1/go.mod h1:NWjl8ZTLOGlozrXSOZ/HlfG++39iKNnM5wwmtQP1YB4= +github.com/libp2p/go-yamux/v4 v4.0.2-0.20241120100319-39abe7ed206a h1:zc7jPWFFQibZbACDyQdEAWg7yG/fjx5Jmg6djtpjKog= +github.com/libp2p/go-yamux/v4 v4.0.2-0.20241120100319-39abe7ed206a/go.mod h1:PGP+3py2ZWDKABvqstBZtMnixEHNC7U/odnGylzur5o= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/tcp v0.0.0-20210406111302-dfbc87cc63fd h1:br0buuQ854V8u83wA0rVZ8ttrq5CpaPZdvrK0LP2lOk= From 58f5b89570b42fa2825e6d43e6f8bf0b626e1361 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 20 Nov 2024 19:27:51 +0530 Subject: [PATCH 2/5] add support for multistream --- core/network/mux.go | 12 ++++++++++++ p2p/host/basic/basic_host.go | 6 +++--- p2p/host/basic/basic_host_test.go | 32 +++++++++++++++++++++++++++++++ p2p/transport/quic/conn.go | 2 -- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/core/network/mux.go b/core/network/mux.go index 34d4035429..2e0eb951ee 100644 --- a/core/network/mux.go +++ b/core/network/mux.go @@ -39,6 +39,18 @@ func (s *StreamError) Unwrap() error { return s.TransportError } +const ( + StreamNoError StreamErrorCode = 0 + StreamProtocolNegotiationFailed StreamErrorCode = 1001 + StreamResourceLimitExceeded StreamErrorCode = 1002 + StreamRateLimited StreamErrorCode = 1003 + StreamProtocolViolation StreamErrorCode = 1004 + StreamSupplanted StreamErrorCode = 1005 + StreamGarbageCollected StreamErrorCode = 1006 + StreamShutdown StreamErrorCode = 1007 + StreamGated StreamErrorCode = 1008 +) + // MuxedStream is a bidirectional io pipe within a connection. type MuxedStream interface { io.Reader diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 820411bd27..c9144ff08d 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -464,7 +464,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { } else { log.Debugf("protocol mux failed: %s (took %s, id:%s, remote peer:%s, remote addr:%v)", err, took, s.ID(), s.Conn().RemotePeer(), s.Conn().RemoteMultiaddr()) } - s.Reset() + s.ResetWithError(network.StreamProtocolNegotiationFailed) return } @@ -478,7 +478,7 @@ func (h *BasicHost) newStreamHandler(s network.Stream) { if err := s.SetProtocol(protoID); err != nil { log.Debugf("error setting stream protocol: %s", err) - s.Reset() + s.ResetWithError(network.StreamResourceLimitExceeded) return } @@ -761,7 +761,7 @@ func (h *BasicHost) NewStream(ctx context.Context, p peer.ID, pids ...protocol.I return nil, fmt.Errorf("failed to negotiate protocol: %w", err) } case <-ctx.Done(): - s.Reset() + s.ResetWithError(network.StreamProtocolNegotiationFailed) // wait for `SelectOneOf` to error out because of resetting the stream. <-errCh return nil, fmt.Errorf("failed to negotiate protocol: %w", ctx.Err()) diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index 2a7a772976..dac280731b 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -995,3 +995,35 @@ func TestHostTimeoutNewStream(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "context deadline exceeded") } + +func TestMultistreamFailure(t *testing.T) { + h1, err := NewHost(swarmt.GenSwarm(t), nil) + require.NoError(t, err) + h1.Start() + defer h1.Close() + + h2, err := NewHost(swarmt.GenSwarm(t), nil) + require.NoError(t, err) + h2.Start() + defer h2.Close() + + h2.Peerstore().AddProtocols(h1.ID(), "/test") + + err = h2.Connect(context.Background(), h1.Peerstore().PeerInfo(h1.ID())) + require.NoError(t, err) + h2.Peerstore().AddProtocols(h1.ID(), "/test") + s, err := h2.NewStream(context.Background(), h1.ID(), "/test") + require.NoError(t, err) + // Special string to make the other side fail multistream and reset + buf := make([]byte, 1024) + for i := 0; i < len(buf); i++ { + buf[i] = 0xff + } + _, err = s.Write(buf) + require.NoError(t, err) + _, err = s.Read(buf) + var se *network.StreamError + require.ErrorAs(t, err, &se) + require.True(t, se.Remote) + require.Equal(t, network.StreamProtocolNegotiationFailed, se.ErrorCode) +} diff --git a/p2p/transport/quic/conn.go b/p2p/transport/quic/conn.go index 16512b4fbe..8b381d8eda 100644 --- a/p2p/transport/quic/conn.go +++ b/p2p/transport/quic/conn.go @@ -2,7 +2,6 @@ package libp2pquic import ( "context" - "fmt" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -71,7 +70,6 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { func (c *conn) AcceptStream() (network.MuxedStream, error) { qstr, err := c.quicConn.AcceptStream(context.Background()) if err != nil { - fmt.Println("got error", err) return nil, parseStreamError(err) } return &stream{Stream: qstr}, nil From e5acd2822e554dd42dd93b90fed9a82d824586d7 Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 21 Nov 2024 12:02:53 +0530 Subject: [PATCH 3/5] add support for conn manager --- core/network/conn.go | 12 ++++++++++++ p2p/net/connmgr/connmgr.go | 5 +++-- p2p/net/connmgr/connmgr_test.go | 8 ++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/core/network/conn.go b/core/network/conn.go index 64aee09ffb..914d3ba878 100644 --- a/core/network/conn.go +++ b/core/network/conn.go @@ -35,6 +35,18 @@ func (c *ConnError) Unwrap() error { return c.TransportError } +const ( + ConnNoError ConnErrorCode = 0 + ConnProtocolNegotiationFailed ConnErrorCode = 1001 + ConnResourceLimitExceeded ConnErrorCode = 1002 + ConnRateLimited ConnErrorCode = 1003 + ConnProtocolViolation ConnErrorCode = 1004 + ConnSupplanted ConnErrorCode = 1005 + ConnGarbageCollected ConnErrorCode = 1006 + ConnShutdown ConnErrorCode = 1007 + ConnGated ConnErrorCode = 1008 +) + // Conn is a connection to a remote peer. It multiplexes streams. // Usually there is no need to use a Conn directly, but it may // be useful to get information about the peer on the other side: diff --git a/p2p/net/connmgr/connmgr.go b/p2p/net/connmgr/connmgr.go index 5033538e3b..c2cd307259 100644 --- a/p2p/net/connmgr/connmgr.go +++ b/p2p/net/connmgr/connmgr.go @@ -175,7 +175,8 @@ func (cm *BasicConnMgr) memoryEmergency() { // Trim connections without paying attention to the silence period. for _, c := range cm.getConnsToCloseEmergency(target) { log.Infow("low on memory. closing conn", "peer", c.RemotePeer()) - c.Close() + + c.CloseWithError(network.ConnGarbageCollected) } // finally, update the last trim time. @@ -388,7 +389,7 @@ func (cm *BasicConnMgr) trim() { // do the actual trim. for _, c := range cm.getConnsToClose() { log.Debugw("closing conn", "peer", c.RemotePeer()) - c.Close() + c.CloseWithError(network.ConnGarbageCollected) } } diff --git a/p2p/net/connmgr/connmgr_test.go b/p2p/net/connmgr/connmgr_test.go index 5955265f9b..b0beecf4e2 100644 --- a/p2p/net/connmgr/connmgr_test.go +++ b/p2p/net/connmgr/connmgr_test.go @@ -33,6 +33,14 @@ func (c *tconn) Close() error { return nil } +func (c *tconn) CloseWithError(code network.ConnErrorCode) error { + atomic.StoreUint32(&c.closed, 1) + if c.disconnectNotify != nil { + c.disconnectNotify(nil, c) + } + return nil +} + func (c *tconn) isClosed() bool { return atomic.LoadUint32(&c.closed) == 1 } From 1ad628f19b9115912446a22963d2e78c04885f6d Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 21 Nov 2024 12:40:12 +0530 Subject: [PATCH 4/5] send error codes on transport conn and stream failures --- p2p/net/mock/mock_stream.go | 14 +++++++++++++- p2p/net/swarm/swarm.go | 8 ++++++-- p2p/net/swarm/swarm_conn.go | 6 +++++- p2p/net/upgrader/listener.go | 6 +++++- p2p/transport/quic/listener.go | 5 ++--- p2p/transport/quic/transport.go | 4 +--- p2p/transport/quic/virtuallistener.go | 3 ++- p2p/transport/quicreuse/listener.go | 3 ++- p2p/transport/webtransport/stream.go | 4 ---- 9 files changed, 36 insertions(+), 17 deletions(-) diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 3ba29ddd80..d4381af096 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -145,7 +145,19 @@ func (s *stream) Reset() error { } func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { - panic("not implemented") + // Cancel any pending reads/writes with an error. + // TODO: Should these be the other way round(remote=true)? + s.write.CloseWithError(&network.StreamError{Remote: false, ErrorCode: errCode}) + s.read.CloseWithError(&network.StreamError{Remote: false, ErrorCode: errCode}) + + select { + case s.reset <- struct{}{}: + default: + } + <-s.closed + + // No meaningful error case here. + return nil } func (s *stream) teardown() { diff --git a/p2p/net/swarm/swarm.go b/p2p/net/swarm/swarm.go index 12e83d38ef..a0ccb5091f 100644 --- a/p2p/net/swarm/swarm.go +++ b/p2p/net/swarm/swarm.go @@ -385,8 +385,12 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn, // If we do this in the Upgrader, we will not be able to do this. if s.gater != nil { if allow, _ := s.gater.InterceptUpgraded(c); !allow { - // TODO Send disconnect with reason here - err := tc.Close() + var err error + if tcc, ok := tc.(network.CloseWithErrorer); ok { + err = tcc.CloseWithError(network.ConnGated) + } else { + err = tc.Close() + } if err != nil { log.Warnf("failed to close connection with peer %s and addr %s; err: %s", p, addr, err) } diff --git a/p2p/net/swarm/swarm_conn.go b/p2p/net/swarm/swarm_conn.go index b7cc46fb71..80911d0ab5 100644 --- a/p2p/net/swarm/swarm_conn.go +++ b/p2p/net/swarm/swarm_conn.go @@ -138,7 +138,11 @@ func (c *Conn) start() { } scope, err := c.swarm.ResourceManager().OpenStream(c.RemotePeer(), network.DirInbound) if err != nil { - ts.Reset() + if tse, ok := ts.(network.ResetWithErrorer); ok { + tse.ResetWithError(network.StreamResourceLimitExceeded) + } else { + ts.Reset() + } continue } c.swarm.refs.Add(1) diff --git a/p2p/net/upgrader/listener.go b/p2p/net/upgrader/listener.go index c2e81d2e93..f2230e5619 100644 --- a/p2p/net/upgrader/listener.go +++ b/p2p/net/upgrader/listener.go @@ -162,7 +162,11 @@ func (l *listener) handleIncoming() { // if we stop accepting connections for some reason, // we'll eventually close all the open ones // instead of hanging onto them. - conn.Close() + if cc, ok := conn.(network.CloseWithErrorer); ok { + cc.CloseWithError(network.ConnRateLimited) + } else { + conn.Close() + } } }() } diff --git a/p2p/transport/quic/listener.go b/p2p/transport/quic/listener.go index f90bdf53f0..30868e49eb 100644 --- a/p2p/transport/quic/listener.go +++ b/p2p/transport/quic/listener.go @@ -11,7 +11,6 @@ import ( tpt "github.com/libp2p/go-libp2p/core/transport" p2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" - ma "github.com/multiformats/go-multiaddr" "github.com/quic-go/quic-go" ) @@ -54,12 +53,12 @@ func (l *listener) Accept() (tpt.CapableConn, error) { c, err := l.wrapConn(qconn) if err != nil { log.Debugf("failed to setup connection: %s", err) - qconn.CloseWithError(1, "") + qconn.CloseWithError(quic.ApplicationErrorCode(network.ConnResourceLimitExceeded), "") continue } l.transport.addConn(qconn, c) if l.transport.gater != nil && !(l.transport.gater.InterceptAccept(c) && l.transport.gater.InterceptSecured(network.DirInbound, c.remotePeerID, c)) { - c.closeWithError(errorCodeConnectionGating, "connection gated") + c.closeWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated") continue } diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 4d3d9e551d..62d31a8d2a 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -34,8 +34,6 @@ var ErrHolePunching = errors.New("hole punching attempted; no active dial") var HolePunchTimeout = 5 * time.Second -const errorCodeConnectionGating = 0x47415445 // GATE in ASCII - // The Transport implements the tpt.Transport interface for QUIC connections. type transport struct { privKey ic.PrivKey @@ -169,7 +167,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee remoteMultiaddr: raddr, } if t.gater != nil && !t.gater.InterceptSecured(network.DirOutbound, p, c) { - pconn.CloseWithError(errorCodeConnectionGating, "connection gated") + pconn.CloseWithError(quic.ApplicationErrorCode(network.ConnGated), "connection gated") return nil, fmt.Errorf("secured connection gated") } t.addConn(pconn, c) diff --git a/p2p/transport/quic/virtuallistener.go b/p2p/transport/quic/virtuallistener.go index 7927225567..ceee530b7d 100644 --- a/p2p/transport/quic/virtuallistener.go +++ b/p2p/transport/quic/virtuallistener.go @@ -3,6 +3,7 @@ package libp2pquic import ( "sync" + "github.com/libp2p/go-libp2p/core/network" tpt "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-libp2p/p2p/transport/quicreuse" @@ -142,8 +143,8 @@ func (r *acceptLoopRunner) innerAccept(l *listener, expectedVersion quic.Version select { case ch <- acceptVal{conn: conn}: default: + conn.(network.CloseWithErrorer).CloseWithError(network.ConnRateLimited) // accept queue filled up, drop the connection - conn.Close() log.Warn("Accept queue filled. Dropping connection.") } diff --git a/p2p/transport/quicreuse/listener.go b/p2p/transport/quicreuse/listener.go index 4ee20042d3..31daf59df4 100644 --- a/p2p/transport/quicreuse/listener.go +++ b/p2p/transport/quicreuse/listener.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" "github.com/quic-go/quic-go" @@ -212,7 +213,7 @@ func (l *listener) Close() error { close(l.queue) // drain the queue for conn := range l.queue { - conn.CloseWithError(1, "closing") + conn.CloseWithError(quic.ApplicationErrorCode(network.ConnShutdown), "closing") } }) return nil diff --git a/p2p/transport/webtransport/stream.go b/p2p/transport/webtransport/stream.go index 583708edc2..0849fc9f38 100644 --- a/p2p/transport/webtransport/stream.go +++ b/p2p/transport/webtransport/stream.go @@ -56,10 +56,6 @@ func (s *stream) Reset() error { return nil } -func (s *stream) ResetWithError(errCode network.StreamErrorCode) error { - panic("not implemented") -} - func (s *stream) Close() error { s.Stream.CancelRead(reset) return s.Stream.Close() From 0e1eb185ddb47515d1bf951ad0e94791d6c0f38a Mon Sep 17 00:00:00 2001 From: sukun Date: Thu, 21 Nov 2024 19:24:01 +0530 Subject: [PATCH 5/5] close rcmgr scope for transportConn --- p2p/net/upgrader/conn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/p2p/net/upgrader/conn.go b/p2p/net/upgrader/conn.go index 18e1e6a931..cdeab3b07d 100644 --- a/p2p/net/upgrader/conn.go +++ b/p2p/net/upgrader/conn.go @@ -65,6 +65,7 @@ func (t *transportConn) ConnState() network.ConnectionState { } func (t *transportConn) CloseWithError(errCode network.ConnErrorCode) error { + defer t.scope.Done() if ce, ok := t.MuxedConn.(network.CloseWithErrorer); ok { return ce.CloseWithError(errCode) }