Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for closing write ends of streams #84

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ var (
// ErrStreamClosed is returned when using a closed stream
ErrStreamClosed = fmt.Errorf("stream closed")

// ErrWriteClosed is returned when using a closed write end of a stream
ErrWriteClosed = fmt.Errorf("write end of stream closed")

// ErrUnexpectedFlag is set when we get an unexpected flag
ErrUnexpectedFlag = fmt.Errorf("unexpected flag")

Expand Down Expand Up @@ -93,6 +96,11 @@ const (

// RST is used to hard close a given stream.
flagRST

// flagCloseWrite is sent to notify the remote end
// that no more data will be written to the stream.
// May be sent with a data payload.
flagCloseWrite
)

const (
Expand Down
31 changes: 31 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1351,3 +1351,34 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) {

wg.Wait()
}

func TestCloseWrite(t *testing.T) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()

stream2, err := server.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream2.Close()

if _, err := stream.Write([]byte("test")); err != nil {
t.Fatal(err)
} else if err := stream.CloseWrite(); err != nil {
t.Fatal(err)
}

data, err := ioutil.ReadAll(stream2)
if err != nil {
t.Fatal(err)
} else if !bytes.Equal(data, []byte("test")) {
t.Fatalf("got data %q, want %q", data, "test")
}
}
77 changes: 77 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ const (
streamReset
)

type streamFlags uint16

const (
writeCloseFlag streamFlags = 1 << iota
writeCloseFlagSent
readCloseFlag
)

// Stream is used to represent a logical stream
// within a session.
type Stream struct {
Expand All @@ -31,6 +39,7 @@ type Stream struct {
session *Session

state streamState
flags streamFlags
stateLock sync.Mutex

recvBuf *bytes.Buffer
Expand Down Expand Up @@ -104,6 +113,15 @@ START:
s.stateLock.Unlock()
return 0, ErrConnectionReset
}
if (s.flags & readCloseFlag) != 0 {
s.recvLock.Lock()
if s.recvBuf == nil || s.recvBuf.Len() == 0 {
s.recvLock.Unlock()
s.stateLock.Unlock()
return 0, io.EOF
}
s.recvLock.Unlock()
}
s.stateLock.Unlock()

// If there is no data available, block
Expand Down Expand Up @@ -174,6 +192,10 @@ START:
s.stateLock.Unlock()
return 0, ErrConnectionReset
}
if (s.flags & writeCloseFlag) != 0 {
s.stateLock.Unlock()
return 0, ErrWriteClosed
}
s.stateLock.Unlock()

// If there is no data available, block
Expand Down Expand Up @@ -231,6 +253,10 @@ func (s *Stream) sendFlags() uint16 {
flags |= flagACK
s.state = streamEstablished
}
if (s.flags & writeCloseFlag & ^writeCloseFlagSent) != 0 {
flags |= flagCloseWrite
s.flags |= writeCloseFlagSent
}
return flags
}

Expand Down Expand Up @@ -321,6 +347,53 @@ SEND_CLOSE:
return nil
}

// CloseWrite is used to close this side's write end of the stream.
func (s *Stream) CloseWrite() error {
s.stateLock.Lock()
s.flags |= writeCloseFlag
switch s.state {
// Opened means we need to signal a close
case streamSYNSent:
fallthrough
case streamSYNReceived:
fallthrough
case streamEstablished:
goto SEND_CLOSE

case streamLocalClose:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean that seeing streamLocalClose means we do nothing and continue after the switch but Go makes it hard to here if that's the case or if it's a programming error. Could you add a comment to the bottom saying something like "ok, handle normally" or something? Thanks!

case streamRemoteClose:
goto SEND_CLOSE
case streamClosed:
case streamReset:
default:
panic("unhandled state")
}
s.stateLock.Unlock()
return nil
SEND_CLOSE:
s.stateLock.Unlock()
s.sendCloseWrite()
s.notifyWaiting()
return nil
}

// sendCloseWrite is used to send a write close notice
func (s *Stream) sendCloseWrite() error {
s.controlHdrLock.Lock()
defer s.controlHdrLock.Unlock()

flags := s.sendFlags()
if (flags & flagCloseWrite) == 0 {
// We have already sent it; no need to do so again
return nil
}
s.controlHdr.encode(typeWindowUpdate, flags, s.id, 0)
if err := s.session.waitForSendErr(s.controlHdr, nil, s.controlErr); err != nil {
return err
}
return nil
}

// forceClose is used for when the session is exiting
func (s *Stream) forceClose() {
s.stateLock.Lock()
Expand Down Expand Up @@ -348,6 +421,10 @@ func (s *Stream) processFlags(flags uint16) error {
}
s.session.establishStream(s.id)
}
if (flags & flagCloseWrite) == flagCloseWrite {
s.flags |= readCloseFlag
s.notifyWaiting()
}
if flags&flagFIN == flagFIN {
switch s.state {
case streamSYNSent:
Expand Down