Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RequestServer.Serve bugs found looking at PR-361 #363

Merged
merged 5 commits into from
Jul 31, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions request-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,19 @@ func (rs *RequestServer) closeRequest(handle string) error {
// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }

// Serve requests for user session
func (rs *RequestServer) Serve() error {
defer func() {
if rs.pktMgr.alloc != nil {
rs.pktMgr.alloc.Free()
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ctx, ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)

func (rs *RequestServer) serveLoop(pktChan chan<- orderedRequest) error {
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte

for {
pktType, pktBytes, err = rs.serverConn.recvPacket(rs.pktMgr.getNextOrderID())
if err != nil {
// we don't care about releasing allocated pages here, the server will quit and the allocator freed
break
return err
}

pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
switch errors.Cause(err) {
Expand All @@ -145,33 +127,48 @@ func (rs *RequestServer) Serve() error {
default:
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
break
return err
}
}

pktChan <- rs.pktMgr.newOrderedRequest(pkt)
}
}

// Serve requests for user session
func (rs *RequestServer) Serve() error {
defer func() {
if rs.pktMgr.alloc != nil {
rs.pktMgr.alloc.Free()
}
}()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ctx, ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)

err := rs.serveLoop(pktChan)

close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit

rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()

// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
if err != nil {
req.state.RLock()
writer := req.state.writerAt
reader := req.state.readerAt
req.state.RUnlock()
if t, ok := writer.(TransferError); ok {
debug("notify error: %v to writer: %v\n", err, writer)
t.TransferError(err)
}
if t, ok := reader.(TransferError); ok {
debug("notify error: %v to reader: %v\n", err, reader)
t.TransferError(err)
}
}
req.transferError(err)

delete(rs.openRequests, handle)
req.close()
}
Expand Down
40 changes: 35 additions & 5 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,49 @@ func (r *Request) close() error {
r.cancelCtx()
}
}()

r.state.RLock()
rd := r.state.readerAt
wr := r.state.writerAt
r.state.RUnlock()

var err error

if c, ok := rd.(io.Closer); ok {
return c.Close()
if err2 := c.Close(); err == nil {
// update error if it is still nil
err = err2
}
}

if c, ok := wr.(io.Closer); ok {
if err2 := c.Close(); err == nil {
// update error if it is still nil
err = err2
}
}
puellanivis marked this conversation as resolved.
Show resolved Hide resolved

return err
}

// Close reader/writer if possible
func (r *Request) transferError(err error) {
if err == nil {
return
}

r.state.RLock()
wt := r.state.writerAt
rd := r.state.readerAt
wr := r.state.writerAt
r.state.RUnlock()
if c, ok := wt.(io.Closer); ok {
return c.Close()

if t, ok := rd.(TransferError); ok {
t.TransferError(err)
}

if t, ok := wr.(TransferError); ok {
t.TransferError(err)
}
return nil
}

// called from worker to handle packet/request
Expand Down