diff --git a/client_integration_test.go b/client_integration_test.go index 933897b4..8ea340ec 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -7,6 +7,7 @@ import ( "bytes" "crypto/sha1" "errors" + "fmt" "io" "io/ioutil" "math/rand" @@ -164,7 +165,7 @@ func netPipe(t testing.TB) (io.ReadWriteCloser, io.ReadWriteCloser) { return c1, r.Conn } -func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) { +func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration, opts ...ClientOption) (*Client, *exec.Cmd) { c1, c2 := netPipe(t) options := []ServerOption{WithDebug(os.Stderr)} @@ -183,7 +184,7 @@ func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, wr = newDelayedWriter(wr, delay) } - client, err := NewClientPipe(c2, wr) + client, err := NewClientPipe(c2, wr, opts...) if err != nil { t.Fatal(err) } @@ -194,13 +195,13 @@ func testClientGoSvr(t testing.TB, readonly bool, delay time.Duration) (*Client, // testClient returns a *Client connected to a locally running sftp-server // the *exec.Cmd returned must be defer Wait'd. -func testClient(t testing.TB, readonly bool, delay time.Duration) (*Client, *exec.Cmd) { +func testClient(t testing.TB, readonly bool, delay time.Duration, opts ...ClientOption) (*Client, *exec.Cmd) { if !*testIntegration { t.Skip("skipping integration test") } if *testServerImpl { - return testClientGoSvr(t, readonly, delay) + return testClientGoSvr(t, readonly, delay, opts...) } cmd := exec.Command(*testSftp, "-e", "-R", "-l", debuglevel) // log to stderr, read only @@ -228,7 +229,7 @@ func testClient(t testing.TB, readonly bool, delay time.Duration) (*Client, *exe t.Skipf("could not start sftp-server process: %v", err) } - sftp, err := NewClientPipe(pr, pw) + sftp, err := NewClientPipe(pr, pw, opts...) if err != nil { t.Fatal(err) } @@ -1489,6 +1490,59 @@ func TestClientReadFrom(t *testing.T) { } } +// A wrongSizeReader reads zeros and has an arbitrary Size. +type wrongSizeReader struct{ actual, reported int } + +func (r *wrongSizeReader) Read(p []byte) (n int, err error) { + if len(p) >= r.actual { + n, r.actual = r.actual, 0 + return n, io.EOF + } + r.actual -= len(p) + return len(p), nil +} + +func (r *wrongSizeReader) Size() int { return r.reported } + +// Test File.ReadFrom's handling of a Reader's Size: +// it should be used as a heuristic for determining concurrency only. +func TestClientReadFromSizeMismatch(t *testing.T) { + const ( + packetSize = 1024 + filesize = 4 * packetSize + ) + + sftp, cmd := testClient(t, READWRITE, NODELAY, MaxPacketChecked(packetSize), UseConcurrentWrites(true)) + defer cmd.Wait() + defer sftp.Close() + + d, err := ioutil.TempDir("", "sftptest-readfrom-size-mismatch") + if err != nil { + t.Fatal("cannot create temp dir:", err) + } + defer os.RemoveAll(d) + + for i, reported := range []int{-1, filesize - 100, filesize, filesize + 100} { + r := &wrongSizeReader{filesize, reported} + + f := path.Join(d, fmt.Sprintf("writeTest%d", i)) + w, err := sftp.Create(f) + if err != nil { + t.Fatal("unexpected error:", err) + } + defer w.Close() + + n, err := w.ReadFrom(r) + assert.EqualValues(t, filesize, n) + + fi, err := os.Stat(f) + if err != nil { + t.Fatal("unexpected error:", err) + } + assert.EqualValues(t, filesize, fi.Size()) + } +} + // Issue #145 in github // Deadlock in ReadFrom when network drops after 1 good packet. // Deadlock would occur anytime desiredInFlight-inFlight==2 and 2 errors