Skip to content

Commit

Permalink
fix: prevent 1st Read() after record() to return 0 bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
ydylla committed Sep 9, 2022
1 parent f0934af commit 1e85a1a
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
2 changes: 1 addition & 1 deletion layer4/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (cx *Connection) Read(p []byte) (n int, err error) {
// if there is a buffer we should read from, start
// with that; we only read from the underlying conn
// after the buffer has been "depleted"
if cx.bufReader != nil {
if cx.bufReader != nil && cx.buf.Len() > 0 { // buf len check to prevent first read from returning 0 bytes
n, err = cx.bufReader.Read(p)
if err == io.EOF {
cx.bufReader = nil
Expand Down
86 changes: 86 additions & 0 deletions layer4/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package layer4

import (
"bytes"
"net"
"testing"
)

func TestConnection_RecordAndRewind(t *testing.T) {
in, out := net.Pipe()
defer in.Close()
defer out.Close()

cx := WrapConnection(out, &bytes.Buffer{})
defer cx.Close()

matcherData := []byte("foo")
consumeData := []byte("bar")

buf := make([]byte, len(matcherData))

go func() {
in.Write(matcherData)
in.Write(consumeData)
}()

// 1st matcher

cx.record()

n, err := cx.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(matcherData) {
t.Fatalf("expected to read %d bytes but got %d", len(matcherData), n)
}
if bytes.Compare(matcherData, buf) != 0 {
t.Fatalf("expected %s but received %s", matcherData, buf)
}

cx.rewind()

// 2nd matcher (reads same data)

cx.record()

n, err = cx.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(matcherData) {
t.Fatalf("expected to read %d bytes but got %d", len(matcherData), n)
}
if bytes.Compare(matcherData, buf) != 0 {
t.Fatalf("expected %s but received %s", matcherData, buf)
}

cx.rewind()

// 1st consumer (no record call)

n, err = cx.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(matcherData) {
t.Fatalf("expected to read %d bytes but got %d", len(matcherData), n)
}
if bytes.Compare(matcherData, buf) != 0 {
t.Fatalf("expected %s but received %s", matcherData, buf)
}

// 2nd consumer (reads other data)

n, err = cx.Read(buf)
if err != nil {
t.Fatal(err)
}
if n != len(consumeData) {
t.Fatalf("expected to read %d bytes but got %d", len(consumeData), n)
}
if bytes.Compare(consumeData, buf) != 0 {
t.Fatalf("expected %s but received %s", consumeData, buf)
}
}

0 comments on commit 1e85a1a

Please sign in to comment.