From 3f969fcd59affc53b98049c985bd1f92f57515a5 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 14 Mar 2020 19:42:19 +0100 Subject: [PATCH 1/5] add optional AllocationModeOptimized after processing a packet we keep in memory the allocated slices and we reuse them for new packets. Slices are allocated in: - recvPacket - when we receive a sshFxpReadPacket (downloads) The allocated slices have a fixed size = maxMsgLength. Allocated slices are referenced to the request order id and are marked for reuse after a request is served in maybeSendPackets. The allocator is added to the packetManager struct and it is cleaned at the end of the Serve() function. This allocation mode is optional and disabled by default --- .travis.yml | 6 ++ Makefile | 6 ++ allocator.go | 99 +++++++++++++++++++++++++++ allocator_test.go | 136 +++++++++++++++++++++++++++++++++++++ client.go | 2 +- conn.go | 8 ++- packet-manager.go | 17 +++++ packet.go | 28 +++++--- packet_test.go | 2 +- request-server.go | 18 +++-- request-server_test.go | 32 +++++++++ request.go | 18 ++--- request_test.go | 22 +++--- server.go | 44 ++++++++++-- server_integration_test.go | 23 +++++++ server_test.go | 12 ++++ 16 files changed, 426 insertions(+), 47 deletions(-) create mode 100644 allocator.go create mode 100644 allocator_test.go diff --git a/.travis.yml b/.travis.yml index 51dead07..36b761bb 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,6 +35,12 @@ script: - go test -integration -v ./... - go test -testserver -v ./... - go test -integration -testserver -v ./... + - go test -integration -optimized-allocator -v ./... + - go test -testserver -optimized-allocator -v ./... + - go test -integration -testserver -optimized-allocator -v ./... - go test -race -integration -v ./... - go test -race -testserver -v ./... - go test -race -integration -testserver -v ./... + - go test -race -integration -optimized-allocator -v ./... + - go test -race -testserver -optimized-allocator -v ./... + - go test -race -integration -optimized-allocator -testserver -v ./... diff --git a/Makefile b/Makefile index 781fe1f5..e7106f04 100644 --- a/Makefile +++ b/Makefile @@ -2,10 +2,16 @@ integration: go test -integration -v go test -testserver -v go test -integration -testserver -v + go test -integration -optimized-allocator -v ./... + go test -testserver -optimized-allocator -v ./... + go test -integration -testserver -optimized-allocator -v ./... integration_w_race: go test -race -integration -v go test -race -testserver -v go test -race -integration -testserver -v + go test -race -integration -optimized-allocator -v ./... + go test -race -testserver -optimized-allocator -v ./... + go test -race -integration -optimized-allocator -testserver -v ./... diff --git a/allocator.go b/allocator.go new file mode 100644 index 00000000..af8f8be9 --- /dev/null +++ b/allocator.go @@ -0,0 +1,99 @@ +package sftp + +import ( + "sync" +) + +type allocator struct { + available [][]byte + // map key is the request order + used map[uint32][][]byte + sync.Mutex +} + +func newAllocator() *allocator { + return &allocator{ + available: nil, + used: make(map[uint32][][]byte), + } +} + +// GetPage returns a previously allocated and unused []byte or create a new one. +// The slice have a fixed size = maxMsgLength, this value is suitable for both +// receiving new packets and reading the files to serve +func (a *allocator) GetPage(requestOrderID uint32) []byte { + a.Lock() + defer a.Unlock() + + var result []byte + + // get an available page and remove it from the available ones + if len(a.available) > 0 { + truncLength := len(a.available) - 1 + result = a.available[truncLength] + + a.available[truncLength] = nil // clear out the internal pointer + a.available = a.available[:truncLength] // truncate the slice + } + + // no preallocated slice found, just allocate a new one + if result == nil { + result = make([]byte, maxMsgLength) + } + + // put result in used pages + a.used[requestOrderID] = append(a.used[requestOrderID], result) + + return result +} + +// ReleasePages marks unused all pages in use for the given requestID +func (a *allocator) ReleasePages(requestOrderID uint32) { + a.Lock() + defer a.Unlock() + + if used, ok := a.used[requestOrderID]; ok && len(used) > 0 { + a.available = append(a.available, used...) + // this is probably useless + a.used[requestOrderID] = nil + } + delete(a.used, requestOrderID) +} + +// Free removes all the used and free pages. +// Call this method when the allocator is not needed anymore +func (a *allocator) Free() { + a.Lock() + defer a.Unlock() + + a.available = nil + a.used = make(map[uint32][][]byte) +} + +func (a *allocator) countUsedPages() int { + a.Lock() + defer a.Unlock() + + num := 0 + for _, p := range a.used { + num += len(p) + } + return num +} + +func (a *allocator) countAvailablePages() int { + a.Lock() + defer a.Unlock() + + return len(a.available) +} + +func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool { + a.Lock() + defer a.Unlock() + + if _, ok := a.used[requestOrderID]; ok { + return true + } + return false +} diff --git a/allocator_test.go b/allocator_test.go new file mode 100644 index 00000000..9fc03276 --- /dev/null +++ b/allocator_test.go @@ -0,0 +1,136 @@ +package sftp + +import ( + "strconv" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +// I like the full flow of the test here, but probably I will be asked to split it in separate test cases +func TestAllocator(t *testing.T) { + allocator := newAllocator() + // get a page for request order id 1 + page := allocator.GetPage(1) + page[1] = uint8(1) + assert.Equal(t, maxMsgLength, len(page)) + assert.Equal(t, 1, allocator.countUsedPages()) + // get another page for request order id 1, we now have 2 used pages + page = allocator.GetPage(1) + page[0] = uint8(2) + assert.Equal(t, 2, allocator.countUsedPages()) + // get another page for request order id 1, we now have 3 used pages + page = allocator.GetPage(1) + page[2] = uint8(3) + assert.Equal(t, 3, allocator.countUsedPages()) + // release the page for request order id 1, we now have 3 available pages + allocator.ReleasePages(1) + assert.NotContains(t, allocator.used, 1) + assert.Equal(t, 3, allocator.countAvailablePages()) + // get a page for request order id 2 + // we get the latest released page, let's verify that by checking the previously written values + // so we are sure we are reusing a previously allocated page + page = allocator.GetPage(2) + assert.Equal(t, uint8(3), page[2]) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 1, allocator.countUsedPages()) + page = allocator.GetPage(2) + assert.Equal(t, uint8(2), page[0]) + assert.Equal(t, 1, allocator.countAvailablePages()) + assert.Equal(t, 2, allocator.countUsedPages()) + page = allocator.GetPage(2) + assert.Equal(t, uint8(1), page[1]) + // we now have 3 used pages for request order id 2 and no available pages + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // release some request order id with no allocated pages, should have no effect + allocator.ReleasePages(1) + allocator.ReleasePages(3) + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // now get some pages for another request order id + allocator.GetPage(3) + // we now must have 3 used pages for request order id 2 and 1 used page for request order id 3 + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 4, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.True(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // get another page for request order id 3 + allocator.GetPage(3) + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 5, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.True(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + // now release the pages for request order id 3 + allocator.ReleasePages(3) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be used") + assert.False(t, allocator.isRequestOrderIDUsed(1), "page with request order id 1 must be not used") + assert.False(t, allocator.isRequestOrderIDUsed(3), "page with request order id 3 must be not used") + // again check we are reusing previously allocated pages. + // We have written nothing to the 2 last requested page so release them and get the third one + allocator.ReleasePages(2) + assert.Equal(t, 5, allocator.countAvailablePages()) + assert.Equal(t, 0, allocator.countUsedPages()) + assert.False(t, allocator.isRequestOrderIDUsed(2), "page with request order id 2 must be not used") + allocator.GetPage(4) + allocator.GetPage(4) + page = allocator.GetPage(4) + assert.Equal(t, uint8(3), page[2]) + assert.Equal(t, 2, allocator.countAvailablePages()) + assert.Equal(t, 3, allocator.countUsedPages()) + assert.True(t, allocator.isRequestOrderIDUsed(4), "page with request order id 4 must be used") + // free the allocator + allocator.Free() + assert.Equal(t, 0, allocator.countAvailablePages()) + assert.Equal(t, 0, allocator.countUsedPages()) +} + +func BenchmarkAllocatorSerial(b *testing.B) { + allocator := newAllocator() + for i := 0; i < b.N; i++ { + benchAllocator(allocator, uint32(i)) + } +} + +func BenchmarkAllocatorParallel(b *testing.B) { + var counter uint32 + allocator := newAllocator() + for i := 1; i <= 8; i *= 2 { + b.Run(strconv.Itoa(i), func(b *testing.B) { + b.SetParallelism(i) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + benchAllocator(allocator, atomic.AddUint32(&counter, 1)) + } + }) + }) + } +} + +func benchAllocator(allocator *allocator, requestOrderID uint32) { + // simulates the page requested in recvPacket + allocator.GetPage(requestOrderID) + // simulates the page requested in fileget for downloads + allocator.GetPage(requestOrderID) + // release the allocated pages + allocator.ReleasePages(requestOrderID) +} + +// useful for debug +func printAllocatorContents(allocator *allocator) { + for o, u := range allocator.used { + debug("used order id: %v, values: %+v", o, u) + } + for _, v := range allocator.available { + debug("available, values: %+v", v) + } +} diff --git a/client.go b/client.go index 0d09d2a2..4b5eb4d9 100644 --- a/client.go +++ b/client.go @@ -214,7 +214,7 @@ func (c *Client) nextID() uint32 { } func (c *Client) recvVersion() error { - typ, data, err := c.recvPacket() + typ, data, err := c.recvPacket(nil, 0) if err != nil { return err } diff --git a/conn.go b/conn.go index cfbc1d21..1fbce8c1 100644 --- a/conn.go +++ b/conn.go @@ -18,8 +18,10 @@ type conn struct { sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error } -func (c *conn) recvPacket() (uint8, []byte, error) { - return recvPacket(c) +// the allocator and the orderID are used in server mode if AllocationModeOptimized is enabled. +// For the client just pass nil and 0 +func (c *conn) recvPacket(alloc *allocator, orderID uint32) (uint8, []byte, error) { + return recvPacket(c, alloc, orderID) } func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { @@ -76,7 +78,7 @@ func (c *clientConn) recv() error { c.conn.Close() }() for { - typ, data, err := c.recvPacket() + typ, data, err := c.recvPacket(nil, 0) if err != nil { return err } diff --git a/packet-manager.go b/packet-manager.go index 45d29956..f10935e3 100644 --- a/packet-manager.go +++ b/packet-manager.go @@ -18,6 +18,8 @@ type packetManager struct { sender packetSender // connection object working *sync.WaitGroup packetCount uint32 + // it is not nil if AllocationModeOptimized is enabled + alloc *allocator } type packetSender interface { @@ -34,6 +36,9 @@ func newPktMgr(sender packetSender) *packetManager { sender: sender, working: &sync.WaitGroup{}, } + if enabledAllocationMode == AllocationModeOptimized { + s.alloc = newAllocator() + } go s.controller() return s } @@ -44,6 +49,14 @@ func (s *packetManager) newOrderID() uint32 { return s.packetCount } +// returns the next orderID without incrementing it. +// This is used before receiving a new packet in AllocationModeOptimized to associate +// the slice allocated for the received packet with the orderID that will be used to mark +// the allocated slices for reuse once the request is served +func (s *packetManager) getNextOrderID() uint32 { + return s.packetCount + 1 +} + type orderedRequest struct { requestPacket orderid uint32 @@ -174,6 +187,10 @@ func (s *packetManager) maybeSendPackets() { if in.orderID() == out.orderID() { debug("Sending packet: %v", out.id()) s.sender.sendPacket(out.(encoding.BinaryMarshaler)) + if s.alloc != nil { + // mark for reuse the slices allocated for this request + s.alloc.ReleasePages(in.orderID()) + } // pop off heads copy(s.incoming, s.incoming[1:]) // shift left s.incoming[len(s.incoming)-1] = nil // clear last diff --git a/packet.go b/packet.go index 7f55e542..68407184 100644 --- a/packet.go +++ b/packet.go @@ -139,9 +139,14 @@ func sendPacket(w io.Writer, m encoding.BinaryMarshaler) error { return nil } -func recvPacket(r io.Reader) (uint8, []byte, error) { - var b = []byte{0, 0, 0, 0} - if _, err := io.ReadFull(r, b); err != nil { +func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, error) { + var b []byte + if alloc != nil { + b = alloc.GetPage(orderID) + } else { + b = make([]byte, 4) + } + if _, err := io.ReadFull(r, b[:4]); err != nil { return 0, nil, err } l, _ := unmarshalUint32(b) @@ -149,17 +154,19 @@ func recvPacket(r io.Reader) (uint8, []byte, error) { debug("recv packet %d bytes too long", l) return 0, nil, errLongPacket } - b = make([]byte, l) - if _, err := io.ReadFull(r, b); err != nil { + if alloc == nil { + b = make([]byte, l) + } + if _, err := io.ReadFull(r, b[0:l]); err != nil { debug("recv packet %d bytes: err %v", l, err) return 0, nil, err } if debugDumpRxPacketBytes { - debug("recv packet: %s %d bytes %x", fxp(b[0]), l, b[1:]) + debug("recv packet: %s %d bytes %x", fxp(b[0]), l, b[1:l]) } else if debugDumpRxPacket { debug("recv packet: %s %d bytes", fxp(b[0]), l) } - return b[0], b[1:], nil + return b[0], b[1:l], nil } type extensionPair struct { @@ -584,10 +591,13 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { return nil } -func (p *sshFxpReadPacket) getDataSlice() []byte { +func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { dataLen := clamp(p.Len, maxTxPacket) // we allocate a slice with a bigger capacity so we avoid a new allocation in sshFxpDataPacket.MarshalBinary - // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket + // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket. + if alloc != nil { + return alloc.GetPage(orderID)[:dataLen] + } return make([]byte, dataLen, dataLen+9+4) } diff --git a/packet_test.go b/packet_test.go index 8378fb63..8b16be6e 100644 --- a/packet_test.go +++ b/packet_test.go @@ -206,7 +206,7 @@ var recvPacketTests = []struct { func TestRecvPacket(t *testing.T) { for _, tt := range recvPacketTests { r := bytes.NewReader(tt.b) - got, rest, _ := recvPacket(r) + got, rest, _ := recvPacket(r, nil, 0) if got != tt.want || !bytes.Equal(rest, tt.rest) { t.Errorf("recvPacket(%#v): want %v, %#v, got %v, %#v", tt.b, tt.want, tt.rest, got, rest) } diff --git a/request-server.go b/request-server.go index d15f3a11..e415a9e7 100644 --- a/request-server.go +++ b/request-server.go @@ -107,11 +107,11 @@ func (rs *RequestServer) Serve() error { var pktType uint8 var pktBytes []byte for { - pktType, pktBytes, err = rs.recvPacket() + pktType, pktBytes, err = rs.recvPacket(rs.pktMgr.alloc, rs.pktMgr.getNextOrderID()) if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed break } - pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes}) if err != nil { switch errors.Cause(err) { @@ -150,6 +150,9 @@ func (rs *RequestServer) Serve() error { delete(rs.openRequests, handle) req.close() } + if rs.pktMgr.alloc != nil { + rs.pktMgr.alloc.Free() + } return err } @@ -158,6 +161,7 @@ func (rs *RequestServer) packetWorker( ctx context.Context, pktChan chan orderedRequest, ) error { for pkt := range pktChan { + orderID := pkt.orderID() if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok { if epkt.SpecificPacket != nil { pkt.requestPacket = epkt.SpecificPacket @@ -188,30 +192,30 @@ func (rs *RequestServer) packetWorker( rpkt = statusFromError(pkt, syscall.EBADF) } else { request = NewRequest("Stat", request.Filepath) - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) } case *sshFxpExtendedPacketPosixRename: request := NewRequest("Rename", pkt.Oldpath) request.Target = pkt.Newpath - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) case hasHandle: handle := pkt.getHandle() request, ok := rs.getRequest(handle) if !ok { rpkt = statusFromError(pkt, syscall.EBADF) } else { - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) } case hasPath: request := requestFromPacket(ctx, pkt) - rpkt = request.call(rs.Handlers, pkt) + rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID) request.close() default: rpkt = statusFromError(pkt, ErrSSHFxOpUnsupported) } rs.pktMgr.readyPacket( - rs.pktMgr.newOrderedResponse(rpkt, pkt.orderID())) + rs.pktMgr.newOrderedResponse(rpkt, orderID)) } return nil } diff --git a/request-server_test.go b/request-server_test.go index 6f06bd22..ecea6a88 100644 --- a/request-server_test.go +++ b/request-server_test.go @@ -60,6 +60,15 @@ func clientRequestServerPair(t *testing.T) *csPair { return &csPair{client, server} } +func checkRequestServerAllocator(t *testing.T, p *csPair) { + if p.svr.pktMgr.alloc == nil { + return + } + checkAllocatorBeforeServerClose(t, p.svr.pktMgr.alloc) + p.Close() + checkAllocatorAfterServerClose(t, p.svr.pktMgr.alloc) +} + // after adding logging, maybe check log to make sure packet handling // was split over more than one worker func TestRequestSplitWrite(t *testing.T) { @@ -74,6 +83,7 @@ func TestRequestSplitWrite(t *testing.T) { r := p.testHandler() f, _ := r.fetch("/foo") assert.Equal(t, contents, string(f.content)) + checkRequestServerAllocator(t, p) } func TestRequestCache(t *testing.T) { @@ -101,6 +111,7 @@ func TestRequestCache(t *testing.T) { assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled") p.svr.closeRequest(bh) assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) } func TestRequestCacheState(t *testing.T) { @@ -114,6 +125,7 @@ func TestRequestCacheState(t *testing.T) { err = p.cli.Remove("/foo") assert.Nil(t, err) assert.Len(t, p.svr.openRequests, 0) + checkRequestServerAllocator(t, p) } func putTestFile(cli *Client, path, content string) (int, error) { @@ -136,6 +148,7 @@ func TestRequestWrite(t *testing.T) { assert.Nil(t, err) assert.False(t, f.isdir) assert.Equal(t, f.content, []byte("hello")) + checkRequestServerAllocator(t, p) } func TestRequestWriteEmpty(t *testing.T) { @@ -156,6 +169,7 @@ func TestRequestWriteEmpty(t *testing.T) { assert.Error(t, err) r.returnErr(nil) assert.Equal(t, 0, n) + checkRequestServerAllocator(t, p) } func TestRequestFilename(t *testing.T) { @@ -169,6 +183,7 @@ func TestRequestFilename(t *testing.T) { assert.Equal(t, f.Name(), "foo") _, err = r.fetch("/bar") assert.Error(t, err) + checkRequestServerAllocator(t, p) } func TestRequestJustRead(t *testing.T) { @@ -186,6 +201,7 @@ func TestRequestJustRead(t *testing.T) { } assert.Equal(t, 5, n) assert.Equal(t, "hello", string(contents[0:5])) + checkRequestServerAllocator(t, p) } func TestRequestOpenFail(t *testing.T) { @@ -194,6 +210,7 @@ func TestRequestOpenFail(t *testing.T) { rf, err := p.cli.Open("/foo") assert.Exactly(t, os.ErrNotExist, err) assert.Nil(t, rf) + checkRequestServerAllocator(t, p) } func TestRequestCreate(t *testing.T) { @@ -203,6 +220,7 @@ func TestRequestCreate(t *testing.T) { assert.Nil(t, err) err = fh.Close() assert.Nil(t, err) + checkRequestServerAllocator(t, p) } func TestRequestMkdir(t *testing.T) { @@ -214,6 +232,7 @@ func TestRequestMkdir(t *testing.T) { f, err := r.fetch("/foo") assert.Nil(t, err) assert.True(t, f.isdir) + checkRequestServerAllocator(t, p) } func TestRequestRemove(t *testing.T) { @@ -228,6 +247,7 @@ func TestRequestRemove(t *testing.T) { assert.Nil(t, err) _, err = r.fetch("/foo") assert.Equal(t, err, os.ErrNotExist) + checkRequestServerAllocator(t, p) } func TestRequestRename(t *testing.T) { @@ -256,6 +276,7 @@ func TestRequestRename(t *testing.T) { assert.Equal(t, "baz", f.Name()) _, err = r.fetch("/bar") assert.Equal(t, os.ErrNotExist, err) + checkRequestServerAllocator(t, p) } func TestRequestRenameFail(t *testing.T) { @@ -267,6 +288,7 @@ func TestRequestRenameFail(t *testing.T) { assert.Nil(t, err) err = p.cli.Rename("/foo", "/bar") assert.IsType(t, &StatusError{}, err) + checkRequestServerAllocator(t, p) } func TestRequestStat(t *testing.T) { @@ -280,6 +302,7 @@ func TestRequestStat(t *testing.T) { assert.Equal(t, fi.Mode(), os.FileMode(0644)) assert.NoError(t, testOsSys(fi.Sys())) assert.NoError(t, err) + checkRequestServerAllocator(t, p) } // NOTE: Setstat is a noop in the request server tests, but we want to test @@ -298,6 +321,7 @@ func TestRequestSetstat(t *testing.T) { assert.Equal(t, fi.Size(), int64(5)) assert.Equal(t, fi.Mode(), os.FileMode(0644)) assert.NoError(t, testOsSys(fi.Sys())) + checkRequestServerAllocator(t, p) } func TestRequestFstat(t *testing.T) { @@ -314,6 +338,7 @@ func TestRequestFstat(t *testing.T) { assert.Equal(t, fi.Mode(), os.FileMode(0644)) assert.NoError(t, testOsSys(fi.Sys())) } + checkRequestServerAllocator(t, p) } func TestRequestStatFail(t *testing.T) { @@ -322,6 +347,7 @@ func TestRequestStatFail(t *testing.T) { fi, err := p.cli.Stat("/foo") assert.Nil(t, fi) assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) } func TestRequestLink(t *testing.T) { @@ -335,6 +361,7 @@ func TestRequestLink(t *testing.T) { fi, err := r.fetch("/bar") assert.Nil(t, err) assert.True(t, int(fi.Size()) == len("hello")) + checkRequestServerAllocator(t, p) } func TestRequestLinkFail(t *testing.T) { @@ -343,6 +370,7 @@ func TestRequestLinkFail(t *testing.T) { err := p.cli.Link("/foo", "/bar") t.Log(err) assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) } func TestRequestSymlink(t *testing.T) { @@ -356,6 +384,7 @@ func TestRequestSymlink(t *testing.T) { fi, err := r.fetch("/bar") assert.Nil(t, err) assert.True(t, fi.Mode()&os.ModeSymlink == os.ModeSymlink) + checkRequestServerAllocator(t, p) } func TestRequestSymlinkFail(t *testing.T) { @@ -363,6 +392,7 @@ func TestRequestSymlinkFail(t *testing.T) { defer p.Close() err := p.cli.Symlink("/foo", "/bar") assert.True(t, os.IsNotExist(err)) + checkRequestServerAllocator(t, p) } func TestRequestReadlink(t *testing.T) { @@ -375,6 +405,7 @@ func TestRequestReadlink(t *testing.T) { rl, err := p.cli.ReadLink("/bar") assert.Nil(t, err) assert.Equal(t, "foo", rl) + checkRequestServerAllocator(t, p) } func TestRequestReaddir(t *testing.T) { @@ -398,6 +429,7 @@ func TestRequestReaddir(t *testing.T) { assert.Len(t, di, 100) names := []string{di[18].Name(), di[81].Name()} assert.Equal(t, []string{"foo_18", "foo_81"}, names) + checkRequestServerAllocator(t, p) } func TestCleanPath(t *testing.T) { diff --git a/request.go b/request.go index c81bb784..772628bf 100644 --- a/request.go +++ b/request.go @@ -154,12 +154,12 @@ func (r *Request) close() error { } // called from worker to handle packet/request -func (r *Request) call(handlers Handlers, pkt requestPacket) responsePacket { +func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { switch r.Method { case "Get": - return fileget(handlers.FileGet, r, pkt) + return fileget(handlers.FileGet, r, pkt, alloc, orderID) case "Put": - return fileput(handlers.FilePut, r, pkt) + return fileput(handlers.FilePut, r, pkt, alloc, orderID) case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove": return filecmd(handlers.FileCmd, r, pkt) case "List": @@ -206,7 +206,7 @@ func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket { } // wrap FileReader handler -func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket { +func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { //fmt.Println("fileget", r) r.state.RLock() reader := r.state.readerAt @@ -215,7 +215,7 @@ func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket { return statusFromError(pkt, errors.New("unexpected read packet")) } - data, offset, _ := packetData(pkt) + data, offset, _ := packetData(pkt, alloc, orderID) n, err := reader.ReadAt(data, offset) // only return EOF erro if no data left to read if err != nil && (err != io.EOF || n == 0) { @@ -229,7 +229,7 @@ func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket { } // wrap FileWriter handler -func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket { +func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket { //fmt.Println("fileput", r) r.state.RLock() writer := r.state.writerAt @@ -238,18 +238,18 @@ func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket { return statusFromError(pkt, errors.New("unexpected write packet")) } - data, offset, _ := packetData(pkt) + data, offset, _ := packetData(pkt, alloc, orderID) _, err := writer.WriteAt(data, offset) return statusFromError(pkt, err) } // file data for additional read/write packets -func packetData(p requestPacket) (data []byte, offset int64, length uint32) { +func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) { switch p := p.(type) { case *sshFxpReadPacket: length = p.Len offset = int64(p.Offset) - data = p.getDataSlice() + data = p.getDataSlice(alloc, orderID) case *sshFxpWritePacket: data = p.Data length = p.Length diff --git a/request_test.go b/request_test.go index ea14168e..9f1ed661 100644 --- a/request_test.go +++ b/request_test.go @@ -152,7 +152,7 @@ func TestRequestGet(t *testing.T) { for i, txt := range []string{"file-", "data."} { pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a", Offset: uint64(i * 5), Len: 5} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) dpkt := rpkt.(*sshFxpDataPacket) assert.Equal(t, dpkt.id(), uint32(i)) assert.Equal(t, string(dpkt.Data), txt) @@ -165,7 +165,7 @@ func TestRequestCustomError(t *testing.T) { pkt := fakePacket{myid: 1} cmdErr := errors.New("stat not supported") handlers.returnError(cmdErr) - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr)) } @@ -176,11 +176,11 @@ func TestRequestPut(t *testing.T) { request.state.writerAt, _ = handlers.FilePut.Filewrite(request) pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5, Data: []byte("file-")} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) checkOkStatus(t, rpkt) pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5, Data: []byte("data.")} - rpkt = request.call(handlers, pkt) + rpkt = request.call(handlers, pkt, nil, 0) checkOkStatus(t, rpkt) assert.Equal(t, "file-data.", handlers.getOutString()) } @@ -189,11 +189,11 @@ func TestRequestCmdr(t *testing.T) { handlers := newTestHandlers() request := testRequest("Mkdir") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) checkOkStatus(t, rpkt) handlers.returnError(errTest) - rpkt = request.call(handlers, pkt) + rpkt = request.call(handlers, pkt, nil, 0) assert.Equal(t, rpkt, statusFromError(rpkt, errTest)) } @@ -201,7 +201,7 @@ func TestRequestInfoStat(t *testing.T) { handlers := newTestHandlers() request := testRequest("Stat") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) spkt, ok := rpkt.(*sshFxpStatResponse) assert.True(t, ok) assert.Equal(t, spkt.info.Name(), "request_test.go") @@ -218,13 +218,13 @@ func TestRequestInfoList(t *testing.T) { assert.Equal(t, hpkt.Handle, "1") } pkt = fakePacket{myid: 2} - request.call(handlers, pkt) + request.call(handlers, pkt, nil, 0) } func TestRequestInfoReadlink(t *testing.T) { handlers := newTestHandlers() request := testRequest("Readlink") pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) npkt, ok := rpkt.(*sshFxpNamePacket) if assert.True(t, ok) { assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0]) @@ -237,7 +237,7 @@ func TestOpendirHandleReuse(t *testing.T) { request := testRequest("Stat") request.handle = "1" pkt := fakePacket{myid: 1} - rpkt := request.call(handlers, pkt) + rpkt := request.call(handlers, pkt, nil, 0) assert.IsType(t, &sshFxpStatResponse{}, rpkt) request.Method = "List" @@ -247,6 +247,6 @@ func TestOpendirHandleReuse(t *testing.T) { hpkt := rpkt.(*sshFxpHandlePacket) assert.Equal(t, hpkt.Handle, "1") } - rpkt = request.call(handlers, pkt) + rpkt = request.call(handlers, pkt, nil, 0) assert.IsType(t, &sshFxpNamePacket{}, rpkt) } diff --git a/server.go b/server.go index 7802e9b6..06c08f1a 100644 --- a/server.go +++ b/server.go @@ -22,6 +22,34 @@ const ( SftpServerWorkerCount = 8 ) +// AllocationMode defines allocation modes +type AllocationMode uint8 + +const ( + // AllocationModeStandard a new slice is allocated for each new SFTP packet + AllocationModeStandard AllocationMode = iota + // AllocationModeOptimized after processing a packet we keep in memory the allocated slices + // and we reuse them for new packets. This mode is experimental + AllocationModeOptimized +) + +// enabledAllocationMode defines the allocation mode. +var enabledAllocationMode = AllocationModeStandard + +// SetEnabledAllocationMode sets the allocation mode. +// Allowed values are: +// - AllocationModeStandard, default +// - AllocationModeOptimized +// This method must be called before creating the Server or the RequestServer +func SetEnabledAllocationMode(mode AllocationMode) { + enabledAllocationMode = mode +} + +// GetEnabledAllocationMode returns the configured AllocationMode +func GetEnabledAllocationMode() AllocationMode { + return enabledAllocationMode +} + // Server is an SSH File Transfer Protocol (sftp) server. // This is intended to provide the sftp subsystem to an ssh server daemon. // This implementation currently supports most of sftp server protocol version 3, @@ -138,9 +166,8 @@ func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { // If server is operating read-only and a write operation is requested, // return permission denied if !readonly && svr.readOnly { - svr.sendPacket(orderedResponse{ - responsePacket: statusFromError(pkt, syscall.EPERM), - orderid: pkt.orderID()}) + svr.pktMgr.readyPacket( + svr.pktMgr.newOrderedResponse(statusFromError(pkt, syscall.EPERM), pkt.orderID())) continue } @@ -153,6 +180,7 @@ func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { func handlePacket(s *Server, p orderedRequest) error { var rpkt responsePacket + orderID := p.orderID() switch p := p.requestPacket.(type) { case *sshFxInitPacket: rpkt = sshFxVersionPacket{ @@ -256,7 +284,7 @@ func handlePacket(s *Server, p orderedRequest) error { f, ok := s.getHandle(p.Handle) if ok { err = nil - data := p.getDataSlice() + data := p.getDataSlice(s.pktMgr.alloc, orderID) n, _err := f.ReadAt(data, int64(p.Offset)) if _err != nil && (_err != io.EOF || n == 0) { err = _err @@ -291,7 +319,7 @@ func handlePacket(s *Server, p orderedRequest) error { return errors.Errorf("unexpected packet type %T", p) } - s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, p.orderID())) + s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, orderID)) return nil } @@ -315,8 +343,9 @@ func (svr *Server) Serve() error { var pktType uint8 var pktBytes []byte for { - pktType, pktBytes, err = svr.recvPacket() + pktType, pktBytes, err = svr.recvPacket(svr.pktMgr.alloc, svr.pktMgr.getNextOrderID()) if err != nil { + // we don't care about releasing allocated pages here, the server will quit and the allocator freed break } @@ -347,6 +376,9 @@ func (svr *Server) Serve() error { fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Name()) file.Close() } + if svr.pktMgr.alloc != nil { + svr.pktMgr.alloc.Free() + } return err // error from recvPacket } diff --git a/server_integration_test.go b/server_integration_test.go index f15a6e09..a24257a5 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -29,6 +29,7 @@ import ( "time" "github.com/kr/fs" + "github.com/stretchr/testify/assert" "golang.org/x/crypto/ssh" ) @@ -52,6 +53,9 @@ func TestMain(m *testing.M) { } testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary") flag.Parse() + if *testOptimizedAllocator { + SetEnabledAllocationMode(AllocationModeOptimized) + } os.Exit(m.Run()) } @@ -64,6 +68,7 @@ func skipIfWindows(t testing.TB) { var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") +var testOptimizedAllocator = flag.Bool("optimized-allocator", false, "perform tests using AllocationModeOptimized instead of AllocationModeStandard") var testSftp *string var testSftpClientBin *string @@ -468,6 +473,24 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i return stdout.String(), err } +func checkAllocatorBeforeServerClose(t *testing.T, alloc *allocator) { + if alloc != nil { + // before closing the server we are, generally, waiting for new packets in recvPacket and we have a page allocated. + // Sometime the sendPacket returns some milliseconds after the client receives the response, and so we have 2 + // allocated pages here, so wait some milliseconds. To avoid crashes we must be sure to not release the pages + // too soon. + assert.Eventually(t, func() bool { return alloc.countUsedPages() <= 1 }, 100*time.Millisecond, 10*time.Millisecond) + } +} + +func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) { + if alloc != nil { + // wait for the server cleanup + assert.Eventually(t, func() bool { return alloc.countUsedPages() == 0 }, 100*time.Millisecond, 10*time.Millisecond) + assert.Eventually(t, func() bool { return alloc.countAvailablePages() == 0 }, 100*time.Millisecond, 10*time.Millisecond) + } +} + func TestServerCompareSubsystems(t *testing.T) { listenerGo, hostGo, portGo := testServer(t, GolangSFTP, READONLY) listenerOp, hostOp, portOp := testServer(t, OpenSSHSFTP, READONLY) diff --git a/server_test.go b/server_test.go index 5191415f..59e8bb4a 100644 --- a/server_test.go +++ b/server_test.go @@ -198,6 +198,15 @@ func (p sshFxpTestBadExtendedPacket) MarshalBinary() ([]byte, error) { return b, nil } +func checkServerAllocator(t *testing.T, server *Server) { + if server.pktMgr.alloc == nil { + return + } + checkAllocatorBeforeServerClose(t, server.pktMgr.alloc) + server.Close() + checkAllocatorAfterServerClose(t, server.pktMgr.alloc) +} + // test that errors are sent back when we request an invalid extended packet operation // this validates the following rfc draft is followed https://tools.ietf.org/html/draft-ietf-secsh-filexfer-extensions-00 func TestInvalidExtendedPacket(t *testing.T) { @@ -222,6 +231,7 @@ func TestInvalidExtendedPacket(t *testing.T) { if statusErr.Code != sshFxOPUnsupported { t.Errorf("statusErr.Code => %d, wanted %d", statusErr.Code, sshFxOPUnsupported) } + checkServerAllocator(t, server) } // test that server handles concurrent requests correctly @@ -251,6 +261,7 @@ func TestConcurrentRequests(t *testing.T) { }() } wg.Wait() + checkServerAllocator(t, server) } // Test error conversion @@ -327,4 +338,5 @@ func TestOpenStatRace(t *testing.T) { testreply(id1, ch) testreply(id2, ch) os.Remove(tmppath) + checkServerAllocator(t, server) } From 7168541b0e868098e6e50c969d3084c6758ecaa3 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 14 Mar 2020 19:53:41 +0100 Subject: [PATCH 2/5] travis: add go 1.14 and remove 1.12 --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 36b761bb..e701bf9d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,8 +4,8 @@ go_import_path: github.com/pkg/sftp # current and previous stable releases, plus tip # remember to exclude previous and tip for macs below go: - - 1.12.x - 1.13.x + - 1.14.x - tip os: @@ -15,7 +15,7 @@ os: matrix: exclude: - os: osx - go: 1.12.x + go: 1.13.x - os: osx go: tip From 1f178f9671d23dfd7e44ad4480938287542424bd Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Wed, 18 Mar 2020 09:36:07 +0100 Subject: [PATCH 3/5] the allocator can now be enabled per request Other minor changes as per review comments --- .travis.yml | 12 ++++----- Makefile | 24 ++++++++--------- allocator.go | 19 ++++++-------- allocator_test.go | 1 - client.go | 2 +- conn.go | 12 +++++---- packet-manager.go | 7 ++--- packet.go | 24 +++++++++-------- request-server.go | 46 +++++++++++++++++++++++++++----- request-server_test.go | 7 ++++- server.go | 54 +++++++++++++++----------------------- server_integration_test.go | 25 +++++++++++++----- server_test.go | 6 ++++- 13 files changed, 139 insertions(+), 100 deletions(-) diff --git a/.travis.yml b/.travis.yml index e701bf9d..e9490286 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,12 +35,12 @@ script: - go test -integration -v ./... - go test -testserver -v ./... - go test -integration -testserver -v ./... - - go test -integration -optimized-allocator -v ./... - - go test -testserver -optimized-allocator -v ./... - - go test -integration -testserver -optimized-allocator -v ./... + - go test -integration -allocator -v ./... + - go test -testserver -allocator -v ./... + - go test -integration -testserver -allocator -v ./... - go test -race -integration -v ./... - go test -race -testserver -v ./... - go test -race -integration -testserver -v ./... - - go test -race -integration -optimized-allocator -v ./... - - go test -race -testserver -optimized-allocator -v ./... - - go test -race -integration -optimized-allocator -testserver -v ./... + - go test -race -integration -allocator -v ./... + - go test -race -testserver -allocator -v ./... + - go test -race -integration -allocator -testserver -v ./... diff --git a/Makefile b/Makefile index e7106f04..0afad584 100644 --- a/Makefile +++ b/Makefile @@ -1,17 +1,17 @@ integration: - go test -integration -v - go test -testserver -v - go test -integration -testserver -v - go test -integration -optimized-allocator -v ./... - go test -testserver -optimized-allocator -v ./... - go test -integration -testserver -optimized-allocator -v ./... + go test -integration -v ./... + go test -testserver -v ./... + go test -integration -testserver -v ./... + go test -integration -allocator -v ./... + go test -testserver -allocator -v ./... + go test -integration -testserver -allocator -v ./... integration_w_race: - go test -race -integration -v - go test -race -testserver -v - go test -race -integration -testserver -v - go test -race -integration -optimized-allocator -v ./... - go test -race -testserver -optimized-allocator -v ./... - go test -race -integration -optimized-allocator -testserver -v ./... + go test -race -integration -v ./... + go test -race -testserver -v ./... + go test -race -integration -testserver -v ./... + go test -race -integration -allocator -v ./... + go test -race -testserver -allocator -v ./... + go test -race -integration -allocator -testserver -v ./... diff --git a/allocator.go b/allocator.go index af8f8be9..3e67e543 100644 --- a/allocator.go +++ b/allocator.go @@ -5,15 +5,16 @@ import ( ) type allocator struct { + sync.Mutex available [][]byte // map key is the request order used map[uint32][][]byte - sync.Mutex } func newAllocator() *allocator { return &allocator{ - available: nil, + // micro optimization: initialize available pages with an initial capacity + available: make([][]byte, 0, SftpServerWorkerCount*2), used: make(map[uint32][][]byte), } } @@ -27,7 +28,7 @@ func (a *allocator) GetPage(requestOrderID uint32) []byte { var result []byte - // get an available page and remove it from the available ones + // get an available page and remove it from the available ones. if len(a.available) > 0 { truncLength := len(a.available) - 1 result = a.available[truncLength] @@ -52,15 +53,13 @@ func (a *allocator) ReleasePages(requestOrderID uint32) { a.Lock() defer a.Unlock() - if used, ok := a.used[requestOrderID]; ok && len(used) > 0 { + if used := a.used[requestOrderID]; len(used) > 0 { a.available = append(a.available, used...) - // this is probably useless - a.used[requestOrderID] = nil } delete(a.used, requestOrderID) } -// Free removes all the used and free pages. +// Free removes all the used and available pages. // Call this method when the allocator is not needed anymore func (a *allocator) Free() { a.Lock() @@ -92,8 +91,6 @@ func (a *allocator) isRequestOrderIDUsed(requestOrderID uint32) bool { a.Lock() defer a.Unlock() - if _, ok := a.used[requestOrderID]; ok { - return true - } - return false + _, ok := a.used[requestOrderID] + return ok } diff --git a/allocator_test.go b/allocator_test.go index 9fc03276..74f4da1a 100644 --- a/allocator_test.go +++ b/allocator_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/assert" ) -// I like the full flow of the test here, but probably I will be asked to split it in separate test cases func TestAllocator(t *testing.T) { allocator := newAllocator() // get a page for request order id 1 diff --git a/client.go b/client.go index 4b5eb4d9..f5905cf9 100644 --- a/client.go +++ b/client.go @@ -214,7 +214,7 @@ func (c *Client) nextID() uint32 { } func (c *Client) recvVersion() error { - typ, data, err := c.recvPacket(nil, 0) + typ, data, err := c.recvPacket(0) if err != nil { return err } diff --git a/conn.go b/conn.go index 1fbce8c1..0d8de601 100644 --- a/conn.go +++ b/conn.go @@ -13,15 +13,17 @@ import ( type conn struct { io.Reader io.WriteCloser + // this is the same allocator used in packet manager + alloc *allocator sync.Mutex // used to serialise writes to sendPacket // sendPacketTest is needed to replicate packet issues in testing sendPacketTest func(w io.Writer, m encoding.BinaryMarshaler) error } -// the allocator and the orderID are used in server mode if AllocationModeOptimized is enabled. -// For the client just pass nil and 0 -func (c *conn) recvPacket(alloc *allocator, orderID uint32) (uint8, []byte, error) { - return recvPacket(c, alloc, orderID) +// the orderID is used in server mode if the allocator is enabled. +// For the client mode just pass 0 +func (c *conn) recvPacket(orderID uint32) (uint8, []byte, error) { + return recvPacket(c, c.alloc, orderID) } func (c *conn) sendPacket(m encoding.BinaryMarshaler) error { @@ -78,7 +80,7 @@ func (c *clientConn) recv() error { c.conn.Close() }() for { - typ, data, err := c.recvPacket(nil, 0) + typ, data, err := c.recvPacket(0) if err != nil { return err } diff --git a/packet-manager.go b/packet-manager.go index f10935e3..c870c378 100644 --- a/packet-manager.go +++ b/packet-manager.go @@ -18,7 +18,7 @@ type packetManager struct { sender packetSender // connection object working *sync.WaitGroup packetCount uint32 - // it is not nil if AllocationModeOptimized is enabled + // it is not nil if the allocator is enabled alloc *allocator } @@ -36,9 +36,6 @@ func newPktMgr(sender packetSender) *packetManager { sender: sender, working: &sync.WaitGroup{}, } - if enabledAllocationMode == AllocationModeOptimized { - s.alloc = newAllocator() - } go s.controller() return s } @@ -50,7 +47,7 @@ func (s *packetManager) newOrderID() uint32 { } // returns the next orderID without incrementing it. -// This is used before receiving a new packet in AllocationModeOptimized to associate +// This is used before receiving a new packet, with the allocator enabled, to associate // the slice allocated for the received packet with the orderID that will be used to mark // the allocated slices for reuse once the request is served func (s *packetManager) getNextOrderID() uint32 { diff --git a/packet.go b/packet.go index 68407184..dba34d2b 100644 --- a/packet.go +++ b/packet.go @@ -149,24 +149,24 @@ func recvPacket(r io.Reader, alloc *allocator, orderID uint32) (uint8, []byte, e if _, err := io.ReadFull(r, b[:4]); err != nil { return 0, nil, err } - l, _ := unmarshalUint32(b) - if l > maxMsgLength { - debug("recv packet %d bytes too long", l) + length, _ := unmarshalUint32(b) + if length > maxMsgLength { + debug("recv packet %d bytes too long", length) return 0, nil, errLongPacket } if alloc == nil { - b = make([]byte, l) + b = make([]byte, length) } - if _, err := io.ReadFull(r, b[0:l]); err != nil { - debug("recv packet %d bytes: err %v", l, err) + if _, err := io.ReadFull(r, b[:length]); err != nil { + debug("recv packet %d bytes: err %v", length, err) return 0, nil, err } if debugDumpRxPacketBytes { - debug("recv packet: %s %d bytes %x", fxp(b[0]), l, b[1:l]) + debug("recv packet: %s %d bytes %x", fxp(b[0]), length, b[1:length]) } else if debugDumpRxPacket { - debug("recv packet: %s %d bytes", fxp(b[0]), l) + debug("recv packet: %s %d bytes", fxp(b[0]), length) } - return b[0], b[1:l], nil + return b[0], b[1:length], nil } type extensionPair struct { @@ -593,11 +593,13 @@ func (p *sshFxpReadPacket) UnmarshalBinary(b []byte) error { func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte { dataLen := clamp(p.Len, maxTxPacket) - // we allocate a slice with a bigger capacity so we avoid a new allocation in sshFxpDataPacket.MarshalBinary - // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket. if alloc != nil { + // GetPage returns a slice with capacity = maxMsgLength this is enough to avoid new allocations in + // sshFxpDataPacket.MarshalBinary and sendPacket return alloc.GetPage(orderID)[:dataLen] } + // we allocate a slice with a bigger capacity so we avoid a new allocation in sshFxpDataPacket.MarshalBinary + // and in sendPacket, we need 9 bytes in MarshalBinary and 4 bytes in sendPacket. return make([]byte, dataLen, dataLen+9+4) } diff --git a/request-server.go b/request-server.go index e415a9e7..41050cbf 100644 --- a/request-server.go +++ b/request-server.go @@ -33,20 +33,52 @@ type RequestServer struct { } // NewRequestServer creates/allocates/returns new RequestServer. -// Normally there there will be one server per user-session. +// Normally there will be one server per user-session. +// +// Deprecated: please use NewRequestServerWithOptions func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer { + rs, _ := NewRequestServerWithOptions(rwc, h) + return rs +} + +// A RequestServerOption is a function which applies configuration to a RequestServer. +type RequestServerOption func(*RequestServer) error + +// WithRSAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithRSAllocator() RequestServerOption { + return func(rs *RequestServer) error { + alloc := newAllocator() + rs.pktMgr.alloc = alloc + rs.conn.alloc = alloc + return nil + } +} + +// NewRequestServerWithOptions creates/allocates/returns new RequestServer adding the specified options +// If options is nil or empty this is equivalent to NewRequestServer +func NewRequestServerWithOptions(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) (*RequestServer, error) { svrConn := &serverConn{ conn: conn{ Reader: rwc, WriteCloser: rwc, }, } - return &RequestServer{ + rs := &RequestServer{ serverConn: svrConn, Handlers: h, pktMgr: newPktMgr(svrConn), openRequests: make(map[string]*Request), } + + for _, o := range options { + if err := o(rs); err != nil { + return nil, err + } + } + return rs, nil } // New Open packet/Request @@ -88,6 +120,11 @@ func (rs *RequestServer) Close() error { return rs.conn.Close() } // Serve requests for user session func (rs *RequestServer) Serve() error { + defer func() { + if rs.pktMgr.alloc != nil { + rs.pktMgr.alloc.Free() + } + }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() var wg sync.WaitGroup @@ -107,7 +144,7 @@ func (rs *RequestServer) Serve() error { var pktType uint8 var pktBytes []byte for { - pktType, pktBytes, err = rs.recvPacket(rs.pktMgr.alloc, rs.pktMgr.getNextOrderID()) + pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID()) if err != nil { // we don't care about releasing allocated pages here, the server will quit and the allocator freed break @@ -150,9 +187,6 @@ func (rs *RequestServer) Serve() error { delete(rs.openRequests, handle) req.close() } - if rs.pktMgr.alloc != nil { - rs.pktMgr.alloc.Free() - } return err } diff --git a/request-server_test.go b/request-server_test.go index ecea6a88..b85b4f2b 100644 --- a/request-server_test.go +++ b/request-server_test.go @@ -46,7 +46,12 @@ func clientRequestServerPair(t *testing.T) *csPair { fd, err := l.Accept() assert.Nil(t, err) handlers := InMemHandler() - server = NewRequestServer(fd, handlers) + if *testAllocator { + options := []RequestServerOption{WithRSAllocator()} + server, _ = NewRequestServerWithOptions(fd, handlers, options...) + } else { + server = NewRequestServer(fd, handlers) + } server.Serve() }() <-ready diff --git a/server.go b/server.go index 06c08f1a..013350cc 100644 --- a/server.go +++ b/server.go @@ -22,34 +22,6 @@ const ( SftpServerWorkerCount = 8 ) -// AllocationMode defines allocation modes -type AllocationMode uint8 - -const ( - // AllocationModeStandard a new slice is allocated for each new SFTP packet - AllocationModeStandard AllocationMode = iota - // AllocationModeOptimized after processing a packet we keep in memory the allocated slices - // and we reuse them for new packets. This mode is experimental - AllocationModeOptimized -) - -// enabledAllocationMode defines the allocation mode. -var enabledAllocationMode = AllocationModeStandard - -// SetEnabledAllocationMode sets the allocation mode. -// Allowed values are: -// - AllocationModeStandard, default -// - AllocationModeOptimized -// This method must be called before creating the Server or the RequestServer -func SetEnabledAllocationMode(mode AllocationMode) { - enabledAllocationMode = mode -} - -// GetEnabledAllocationMode returns the configured AllocationMode -func GetEnabledAllocationMode() AllocationMode { - return enabledAllocationMode -} - // Server is an SSH File Transfer Protocol (sftp) server. // This is intended to provide the sftp subsystem to an ssh server daemon. // This implementation currently supports most of sftp server protocol version 3, @@ -144,6 +116,19 @@ func ReadOnly() ServerOption { } } +// WithAllocator enable the allocator. +// After processing a packet we keep in memory the allocated slices +// and we reuse them for new packets. +// The allocator is experimental +func WithAllocator() ServerOption { + return func(s *Server) error { + alloc := newAllocator() + s.pktMgr.alloc = alloc + s.conn.alloc = alloc + return nil + } +} + type rxPacket struct { pktType fxp pktBytes []byte @@ -167,7 +152,8 @@ func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error { // return permission denied if !readonly && svr.readOnly { svr.pktMgr.readyPacket( - svr.pktMgr.newOrderedResponse(statusFromError(pkt, syscall.EPERM), pkt.orderID())) + svr.pktMgr.newOrderedResponse(statusFromError(pkt, syscall.EPERM), pkt.orderID()), + ) continue } @@ -326,6 +312,11 @@ func handlePacket(s *Server, p orderedRequest) error { // Serve serves SFTP connections until the streams stop or the SFTP subsystem // is stopped. func (svr *Server) Serve() error { + defer func() { + if svr.pktMgr.alloc != nil { + svr.pktMgr.alloc.Free() + } + }() var wg sync.WaitGroup runWorker := func(ch chan orderedRequest) { wg.Add(1) @@ -343,7 +334,7 @@ func (svr *Server) Serve() error { var pktType uint8 var pktBytes []byte for { - pktType, pktBytes, err = svr.recvPacket(svr.pktMgr.alloc, svr.pktMgr.getNextOrderID()) + pktType, pktBytes, err = svr.serverConn.recvPacket(svr.pktMgr.getNextOrderID()) if err != nil { // we don't care about releasing allocated pages here, the server will quit and the allocator freed break @@ -376,9 +367,6 @@ func (svr *Server) Serve() error { fmt.Fprintf(svr.debugStream, "sftp server file with handle %q left open: %v\n", handle, file.Name()) file.Close() } - if svr.pktMgr.alloc != nil { - svr.pktMgr.alloc.Free() - } return err // error from recvPacket } diff --git a/server_integration_test.go b/server_integration_test.go index a24257a5..0ad87e02 100644 --- a/server_integration_test.go +++ b/server_integration_test.go @@ -53,9 +53,6 @@ func TestMain(m *testing.M) { } testSftp = flag.String("sftp", sftpServer, "location of the sftp server binary") flag.Parse() - if *testOptimizedAllocator { - SetEnabledAllocationMode(AllocationModeOptimized) - } os.Exit(m.Run()) } @@ -68,7 +65,7 @@ func skipIfWindows(t testing.TB) { var testServerImpl = flag.Bool("testserver", false, "perform integration tests against sftp package server instance") var testIntegration = flag.Bool("integration", false, "perform integration tests against sftp server process") -var testOptimizedAllocator = flag.Bool("optimized-allocator", false, "perform tests using AllocationModeOptimized instead of AllocationModeStandard") +var testAllocator = flag.Bool("allocator", false, "perform tests using the allocator") var testSftp *string var testSftpClientBin *string @@ -473,21 +470,35 @@ func runSftpClient(t *testing.T, script string, path string, host string, port i return stdout.String(), err } +// assert.Eventually seems to have a data rate on macOS with go 1.14 so replace it with this simpler function +func waitForCondition(t *testing.T, condition func() bool) { + start := time.Now() + tick := 10 * time.Millisecond + waitFor := 100 * time.Millisecond + for !condition() { + time.Sleep(tick) + if time.Since(start) > waitFor { + break + } + } + assert.True(t, condition()) +} + func checkAllocatorBeforeServerClose(t *testing.T, alloc *allocator) { if alloc != nil { // before closing the server we are, generally, waiting for new packets in recvPacket and we have a page allocated. // Sometime the sendPacket returns some milliseconds after the client receives the response, and so we have 2 // allocated pages here, so wait some milliseconds. To avoid crashes we must be sure to not release the pages // too soon. - assert.Eventually(t, func() bool { return alloc.countUsedPages() <= 1 }, 100*time.Millisecond, 10*time.Millisecond) + waitForCondition(t, func() bool { return alloc.countUsedPages() <= 1 }) } } func checkAllocatorAfterServerClose(t *testing.T, alloc *allocator) { if alloc != nil { // wait for the server cleanup - assert.Eventually(t, func() bool { return alloc.countUsedPages() == 0 }, 100*time.Millisecond, 10*time.Millisecond) - assert.Eventually(t, func() bool { return alloc.countAvailablePages() == 0 }, 100*time.Millisecond, 10*time.Millisecond) + waitForCondition(t, func() bool { return alloc.countUsedPages() == 0 }) + waitForCondition(t, func() bool { return alloc.countAvailablePages() == 0 }) } } diff --git a/server_test.go b/server_test.go index 59e8bb4a..6995af95 100644 --- a/server_test.go +++ b/server_test.go @@ -162,10 +162,14 @@ func runLsTestHelper(t *testing.T, result, expectedType, path string) { func clientServerPair(t *testing.T) (*Client, *Server) { cr, sw := io.Pipe() sr, cw := io.Pipe() + var options []ServerOption + if *testAllocator { + options = append(options, WithAllocator()) + } server, err := NewServer(struct { io.Reader io.WriteCloser - }{sr, sw}) + }{sr, sw}, options...) if err != nil { t.Fatal(err) } From 2fc68482d27f8c25e1a37f827d1220667c11e8ba Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Thu, 19 Mar 2020 13:20:22 +0100 Subject: [PATCH 4/5] remove NewRequestServerWithOptions we can keep compatibility removing the error return value from RequestServerOption --- request-server.go | 26 +++++++------------------- request-server_test.go | 7 +++---- 2 files changed, 10 insertions(+), 23 deletions(-) diff --git a/request-server.go b/request-server.go index 41050cbf..9e51b012 100644 --- a/request-server.go +++ b/request-server.go @@ -32,34 +32,24 @@ type RequestServer struct { handleCount int } -// NewRequestServer creates/allocates/returns new RequestServer. -// Normally there will be one server per user-session. -// -// Deprecated: please use NewRequestServerWithOptions -func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer { - rs, _ := NewRequestServerWithOptions(rwc, h) - return rs -} - // A RequestServerOption is a function which applies configuration to a RequestServer. -type RequestServerOption func(*RequestServer) error +type RequestServerOption func(*RequestServer) // WithRSAllocator enable the allocator. // After processing a packet we keep in memory the allocated slices // and we reuse them for new packets. // The allocator is experimental func WithRSAllocator() RequestServerOption { - return func(rs *RequestServer) error { + return func(rs *RequestServer) { alloc := newAllocator() rs.pktMgr.alloc = alloc rs.conn.alloc = alloc - return nil } } -// NewRequestServerWithOptions creates/allocates/returns new RequestServer adding the specified options -// If options is nil or empty this is equivalent to NewRequestServer -func NewRequestServerWithOptions(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) (*RequestServer, error) { +// NewRequestServer creates/allocates/returns new RequestServer. +// Normally there will be one server per user-session. +func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer { svrConn := &serverConn{ conn: conn{ Reader: rwc, @@ -74,11 +64,9 @@ func NewRequestServerWithOptions(rwc io.ReadWriteCloser, h Handlers, options ... } for _, o := range options { - if err := o(rs); err != nil { - return nil, err - } + o(rs) } - return rs, nil + return rs } // New Open packet/Request diff --git a/request-server_test.go b/request-server_test.go index b85b4f2b..703b7d92 100644 --- a/request-server_test.go +++ b/request-server_test.go @@ -46,12 +46,11 @@ func clientRequestServerPair(t *testing.T) *csPair { fd, err := l.Accept() assert.Nil(t, err) handlers := InMemHandler() + var options []RequestServerOption if *testAllocator { - options := []RequestServerOption{WithRSAllocator()} - server, _ = NewRequestServerWithOptions(fd, handlers, options...) - } else { - server = NewRequestServer(fd, handlers) + options = append(options, WithRSAllocator()) } + server = NewRequestServer(fd, handlers, options...) server.Serve() }() <-ready From 118ca5720446d2b8512e9cbbb8944a8f4c3a4085 Mon Sep 17 00:00:00 2001 From: Nicola Murino Date: Sat, 6 Jun 2020 19:26:12 +0200 Subject: [PATCH 5/5] cleanPath: use path.IsAbs after converting ToSlash we need a POSIX path filepath.IsAbs can give unexpected results on Windows --- request-server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/request-server.go b/request-server.go index 9e51b012..cb357e3b 100644 --- a/request-server.go +++ b/request-server.go @@ -258,7 +258,7 @@ func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket { // Makes sure we have a clean POSIX (/) absolute path to work with func cleanPath(p string) string { p = filepath.ToSlash(p) - if !filepath.IsAbs(p) { + if !path.IsAbs(p) { p = "/" + p } return path.Clean(p)