From 8be14f480de60eacd57f57baed78419dd455d38a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 2 Sep 2022 11:49:45 +0300 Subject: [PATCH] noise: implement an API to send and receive early data --- p2p/security/noise/crypto_test.go | 2 +- p2p/security/noise/handshake.go | 41 +++---- p2p/security/noise/session.go | 20 ++-- p2p/security/noise/session_transport.go | 23 +++- p2p/security/noise/transport.go | 4 +- p2p/security/noise/transport_test.go | 135 +++++++++++++++++++++--- 6 files changed, 178 insertions(+), 47 deletions(-) diff --git a/p2p/security/noise/crypto_test.go b/p2p/security/noise/crypto_test.go index 35837a5a6c..9b7d390829 100644 --- a/p2p/security/noise/crypto_test.go +++ b/p2p/security/noise/crypto_test.go @@ -93,7 +93,7 @@ func TestCryptoFailsIfHandshakeIncomplete(t *testing.T) { init, resp := net.Pipe() _ = resp.Close() - session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, true) + session, _ := newSecureSession(initTransport, context.TODO(), init, "remote-peer", nil, nil, true) _, err := session.encrypt(nil, []byte("hi")) if err == nil { t.Error("expected encryption error when handshake incomplete") diff --git a/p2p/security/noise/handshake.go b/p2p/security/noise/handshake.go index b71d7ef89b..bb68196c03 100644 --- a/p2p/security/noise/handshake.go +++ b/p2p/security/noise/handshake.go @@ -10,17 +10,16 @@ import ( "runtime/debug" "time" - "github.com/minio/sha256-simd" "golang.org/x/crypto/chacha20poly1305" - "github.com/libp2p/go-libp2p/p2p/security/noise/pb" - "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/p2p/security/noise/pb" "github.com/flynn/noise" "github.com/gogo/protobuf/proto" pool "github.com/libp2p/go-buffer-pool" + "github.com/minio/sha256-simd" ) // payloadSigPrefix is prepended to our Noise static key before signing with @@ -89,10 +88,12 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if s.initiator { // stage 0 // - // do not send the payload just yet, as it would be plaintext; not secret. // Handshake Msg Len = len(DH ephemeral key) - err = s.sendHandshakeMessage(hs, nil, hbuf) - if err != nil { + var ed []byte + if s.earlyDataHandler != nil { + ed = s.earlyDataHandler.Send(ctx, s.insecureConn, s.remoteID) + } + if err := s.sendHandshakeMessage(hs, ed, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } @@ -101,29 +102,34 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) - if err != nil { + if err := s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()); err != nil { return err } // stage 2 // // Handshake Msg Len = len(DHT static key) + MAC(static key is encrypted) + len(Payload) + MAC(payload is encrypted) - err = s.sendHandshakeMessage(hs, payload, hbuf) - if err != nil { + if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } + return nil } else { // stage 0 // - // We don't expect any payload on the first message. - if _, err := s.readHandshakeMessage(hs); err != nil { + initialPayload, err := s.readHandshakeMessage(hs) + if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } + if s.earlyDataHandler != nil { + if err := s.earlyDataHandler.Received(ctx, s.insecureConn, initialPayload); err != nil { + return err + } + } else if len(initialPayload) > 0 { + return fmt.Errorf("received unexpected early data (%d bytes)", len(initialPayload)) + } // stage 1 // // Handshake Msg Len = len(DH ephemeral key) + len(DHT static key) + MAC(static key is encrypted) + len(Payload) + // MAC(payload is encrypted) - err = s.sendHandshakeMessage(hs, payload, hbuf) - if err != nil { + if err := s.sendHandshakeMessage(hs, payload, hbuf); err != nil { return fmt.Errorf("error sending handshake message: %w", err) } @@ -132,13 +138,8 @@ func (s *secureSession) runHandshake(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("error reading handshake message: %w", err) } - err = s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) - if err != nil { - return err - } + return s.handleRemoteHandshakePayload(plaintext, hs.PeerStatic()) } - - return nil } // setCipherStates sets the initial cipher states that will be used to protect diff --git a/p2p/security/noise/session.go b/p2p/security/noise/session.go index a563e58b5d..5e3d0956cf 100644 --- a/p2p/security/noise/session.go +++ b/p2p/security/noise/session.go @@ -36,20 +36,22 @@ type secureSession struct { dec *noise.CipherState // noise prologue - prologue []byte + prologue []byte + earlyDataHandler EarlyDataHandler } // newSecureSession creates a Noise session over the given insecureConn Conn, using // the libp2p identity keypair from the given Transport. -func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, initiator bool) (*secureSession, error) { +func newSecureSession(tpt *Transport, ctx context.Context, insecure net.Conn, remote peer.ID, prologue []byte, edh EarlyDataHandler, initiator bool) (*secureSession, error) { s := &secureSession{ - insecureConn: insecure, - insecureReader: bufio.NewReader(insecure), - initiator: initiator, - localID: tpt.localID, - localKey: tpt.privateKey, - remoteID: remote, - prologue: prologue, + insecureConn: insecure, + insecureReader: bufio.NewReader(insecure), + initiator: initiator, + localID: tpt.localID, + localKey: tpt.privateKey, + remoteID: remote, + prologue: prologue, + earlyDataHandler: edh, } // the go-routine we create to run the handshake will diff --git a/p2p/security/noise/session_transport.go b/p2p/security/noise/session_transport.go index 5aa24ee595..973f6facf2 100644 --- a/p2p/security/noise/session_transport.go +++ b/p2p/security/noise/session_transport.go @@ -22,6 +22,22 @@ func Prologue(prologue []byte) SessionOption { } } +// EarlyDataHandler allows attaching an (unencrypted) application payload to the first handshake message. +// While unencrypted, the integrity of this early data is retroactively authenticated on completion of the handshake. +type EarlyDataHandler interface { + // Send is called for the client before sending the first handshake message. + Send(context.Context, net.Conn, peer.ID) []byte + // Received is called for the server when the first handshake message from the client is received. + Received(context.Context, net.Conn, []byte) error +} + +func EarlyData(h EarlyDataHandler) SessionOption { + return func(s *SessionTransport) error { + s.earlyDataHandler = h + return nil + } +} + var _ sec.SecureTransport = &SessionTransport{} // SessionTransport can be used @@ -29,13 +45,14 @@ var _ sec.SecureTransport = &SessionTransport{} type SessionTransport struct { t *Transport // options - prologue []byte + prologue []byte + earlyDataHandler EarlyDataHandler } // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, false) + c, err := newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -47,5 +64,5 @@ func (i *SessionTransport) SecureInbound(ctx context.Context, insecure net.Conn, // SecureOutbound runs the Noise handshake as the initiator. func (i *SessionTransport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(i.t, ctx, insecure, p, i.prologue, true) + return newSecureSession(i.t, ctx, insecure, p, i.prologue, i.earlyDataHandler, true) } diff --git a/p2p/security/noise/transport.go b/p2p/security/noise/transport.go index fd33e24917..bd66d0fdd1 100644 --- a/p2p/security/noise/transport.go +++ b/p2p/security/noise/transport.go @@ -41,7 +41,7 @@ func New(privkey crypto.PrivKey) (*Transport, error) { // SecureInbound runs the Noise handshake as the responder. // If p is empty, connections from any peer are accepted. func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - c, err := newSecureSession(t, ctx, insecure, p, nil, false) + c, err := newSecureSession(t, ctx, insecure, p, nil, nil, false) if err != nil { addr, maErr := manet.FromNetAddr(insecure.RemoteAddr()) if maErr == nil { @@ -53,7 +53,7 @@ func (t *Transport) SecureInbound(ctx context.Context, insecure net.Conn, p peer // SecureOutbound runs the Noise handshake as the initiator. func (t *Transport) SecureOutbound(ctx context.Context, insecure net.Conn, p peer.ID) (sec.SecureConn, error) { - return newSecureSession(t, ctx, insecure, p, nil, true) + return newSecureSession(t, ctx, insecure, p, nil, nil, true) } func (t *Transport) WithSessionOptions(opts ...SessionOption) (sec.SecureTransport, error) { diff --git a/p2p/security/noise/transport_test.go b/p2p/security/noise/transport_test.go index 5269be2945..108efec816 100644 --- a/p2p/security/noise/transport_test.go +++ b/p2p/security/noise/transport_test.go @@ -78,10 +78,10 @@ func connect(t *testing.T, initTransport, respTransport *Transport) (*secureSess done := make(chan struct{}) go func() { defer close(done) - initConn, initErr = initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + initConn, initErr = initTransport.SecureOutbound(context.Background(), init, respTransport.localID) }() - respConn, respErr := respTransport.SecureInbound(context.TODO(), resp, "") + respConn, respErr := respTransport.SecureInbound(context.Background(), resp, "") <-done if initErr != nil { @@ -171,7 +171,7 @@ func TestPeerIDMatch(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) assert.NoError(t, err) assert.Equal(t, conn.RemotePeer(), respTransport.localID) b := make([]byte, 6) @@ -180,7 +180,7 @@ func TestPeerIDMatch(t *testing.T) { assert.Equal(t, b, []byte("foobar")) }() - conn, err := respTransport.SecureInbound(context.TODO(), resp, initTransport.localID) + conn, err := respTransport.SecureInbound(context.Background(), resp, initTransport.localID) require.NoError(t, err) require.Equal(t, conn.RemotePeer(), initTransport.localID) _, err = conn.Write([]byte("foobar")) @@ -194,11 +194,11 @@ func TestPeerIDMismatchOutboundFailsHandshake(t *testing.T) { errChan := make(chan error) go func() { - _, err := initTransport.SecureOutbound(context.TODO(), init, "a-random-peer-id") + _, err := initTransport.SecureOutbound(context.Background(), init, "a-random-peer-id") errChan <- err }() - _, err := respTransport.SecureInbound(context.TODO(), resp, "") + _, err := respTransport.SecureInbound(context.Background(), resp, "") require.Error(t, err) initErr := <-errChan @@ -214,13 +214,13 @@ func TestPeerIDMismatchInboundFailsHandshake(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - conn, err := initTransport.SecureOutbound(context.TODO(), init, respTransport.localID) + conn, err := initTransport.SecureOutbound(context.Background(), init, respTransport.localID) assert.NoError(t, err) _, err = conn.Read([]byte{0}) assert.Error(t, err) }() - _, err := respTransport.SecureInbound(context.TODO(), resp, "a-random-peer-id") + _, err := respTransport.SecureInbound(context.Background(), resp, "a-random-peer-id") require.Error(t, err, "expected responder to fail with peer ID mismatch error") <-done } @@ -387,7 +387,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID) + conn, err := tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) require.NoError(t, err) defer conn.Close() }() @@ -395,7 +395,7 @@ func TestPrologueMatches(t *testing.T) { tpt, err := respTransport. WithSessionOptions(Prologue(commonPrologue)) require.NoError(t, err) - conn, err := tpt.SecureInbound(context.TODO(), respConn, "") + conn, err := tpt.SecureInbound(context.Background(), respConn, "") require.NoError(t, err) defer conn.Close() <-done @@ -415,14 +415,125 @@ func TestPrologueDoesNotMatchFailsHandshake(t *testing.T) { tpt, err := initTransport. WithSessionOptions(Prologue(initPrologue)) require.NoError(t, err) - _, err = tpt.SecureOutbound(context.TODO(), initConn, respTransport.localID) + _, err = tpt.SecureOutbound(context.Background(), initConn, respTransport.localID) require.Error(t, err) }() tpt, err := respTransport.WithSessionOptions(Prologue(respPrologue)) require.NoError(t, err) - _, err = tpt.SecureInbound(context.TODO(), respConn, "") + _, err = tpt.SecureInbound(context.Background(), respConn, "") require.Error(t, err) <-done } + +type earlyDataHandler struct { + send func(context.Context, net.Conn, peer.ID) []byte + received func(context.Context, net.Conn, []byte) error +} + +func (e *earlyDataHandler) Send(ctx context.Context, conn net.Conn, id peer.ID) []byte { + if e.send == nil { + return nil + } + return e.send(ctx, conn, id) +} + +func (e *earlyDataHandler) Received(ctx context.Context, conn net.Conn, data []byte) error { + if e.received == nil { + return nil + } + return e.received(ctx, conn, data) +} + +func TestEarlyDataAccepted(t *testing.T) { + var receivedEarlyData []byte + serverEDH := &earlyDataHandler{ + received: func(_ context.Context, _ net.Conn, data []byte) error { + receivedEarlyData = data + return nil + }, + } + clientEDH := &earlyDataHandler{ + send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, + } + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) + require.NoError(t, err) + tpt := newTestTransport(t, crypto.Ed25519, 2048) + respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) + require.NoError(t, err) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + conn, err := initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + require.NoError(t, err) + defer conn.Close() + + require.Equal(t, []byte("foobar"), receivedEarlyData) +} + +func TestEarlyDataRejected(t *testing.T) { + serverEDH := &earlyDataHandler{ + received: func(_ context.Context, _ net.Conn, data []byte) error { return errors.New("nope") }, + } + clientEDH := &earlyDataHandler{ + send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, + } + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) + require.NoError(t, err) + tpt := newTestTransport(t, crypto.Ed25519, 2048) + respTransport, err := tpt.WithSessionOptions(EarlyData(serverEDH)) + require.NoError(t, err) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + _, err = initTransport.SecureOutbound(context.Background(), respConn, tpt.localID) + require.Error(t, err) + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout") + case err := <-errChan: + require.EqualError(t, err, "nope") + } +} + +func TestEarlyDataRejectedWithNoHandler(t *testing.T) { + clientEDH := &earlyDataHandler{ + send: func(ctx context.Context, conn net.Conn, id peer.ID) []byte { return []byte("foobar") }, + } + initTransport, err := newTestTransport(t, crypto.Ed25519, 2048).WithSessionOptions(EarlyData(clientEDH)) + require.NoError(t, err) + respTransport := newTestTransport(t, crypto.Ed25519, 2048) + + initConn, respConn := newConnPair(t) + + errChan := make(chan error) + go func() { + _, err := respTransport.SecureInbound(context.Background(), initConn, "") + errChan <- err + }() + + _, err = initTransport.SecureOutbound(context.Background(), respConn, respTransport.localID) + require.Error(t, err) + + select { + case <-time.After(500 * time.Millisecond): + t.Fatal("timeout") + case err := <-errChan: + require.Error(t, err) + require.Contains(t, err.Error(), "received unexpected early data") + } +}