Skip to content

Commit

Permalink
be smarter about buffering multilayered multistream
Browse files Browse the repository at this point in the history
  • Loading branch information
whyrusleeping committed Aug 10, 2016
1 parent 9cad71a commit d78b705
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
53 changes: 29 additions & 24 deletions lazy.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
package multistream

import (
"bufio"
"fmt"
"io"
"sync"
)

type Multistream interface {
io.ReadWriteCloser
Protocol() string
}

func NewMSSelect(c io.ReadWriteCloser, proto string) Multistream {
return NewMultistream(NewMultistream(c, ProtocolID), proto)
return &lazyConn{
protos: []string{ProtocolID, proto},
con: c,
}
}

func NewMultistream(c io.ReadWriteCloser, proto string) Multistream {
return &lazyConn{
proto: proto,
con: c,
protos: []string{proto},
con: c,
}
}

Expand All @@ -35,12 +38,8 @@ type lazyConn struct {
whsync bool
werr error

proto string
con io.ReadWriteCloser
}

func (l *lazyConn) Protocol() string {
return l.proto
protos []string
con io.ReadWriteCloser
}

func (l *lazyConn) Read(b []byte) (int, error) {
Expand Down Expand Up @@ -71,16 +70,18 @@ func (l *lazyConn) readHandshake() error {
}
l.rhsync = true

// read protocol
tok, err := ReadNextToken(l.con)
if err != nil {
l.rerr = err
return err
}
for _, proto := range l.protos {
// read protocol
tok, err := ReadNextToken(l.con)
if err != nil {
l.rerr = err
return err
}

if tok != l.proto {
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, l.proto)
return l.rerr
if tok != proto {
l.rerr = fmt.Errorf("protocol mismatch in lazy handshake ( %s != %s )", tok, proto)
return l.rerr
}
}

return nil
Expand All @@ -96,13 +97,17 @@ func (l *lazyConn) writeHandshake() error {

l.whsync = true

err := delimWriteBuffered(l.con, []byte(l.proto))
if err != nil {
l.werr = err
return err
buf := bufio.NewWriter(l.con)
for _, proto := range l.protos {
err := delimWrite(buf, []byte(proto))
if err != nil {
l.werr = err
return err
}
}

return nil
l.werr = buf.Flush()
return l.werr
}

func (l *lazyConn) Write(b []byte) (int, error) {
Expand Down
16 changes: 16 additions & 0 deletions multistream_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package multistream

import (
"bytes"
"crypto/rand"
"io"
"net"
Expand Down Expand Up @@ -349,3 +350,18 @@ func verifyPipe(t *testing.T, a, b io.ReadWriter) {
t.Fatal("somehow read wrong message")
}
}

func TestTooLargeMessage(t *testing.T) {
buf := new(bytes.Buffer)
mes := make([]byte, 100*1024)

err := delimWrite(buf, mes)
if err != nil {
t.Fatal(err)
}

_, err = ReadNextToken(buf)
if err == nil {
t.Fatal("should have failed to read message larger than 64k")
}
}

0 comments on commit d78b705

Please sign in to comment.