diff --git a/core/stream.go b/core/stream.go index 5c773cd2..9cd4c31a 100644 --- a/core/stream.go +++ b/core/stream.go @@ -1,6 +1,15 @@ package core -import "net" +import ( + "bytes" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/shadowsocks/go-shadowsocks2/socks" +) type listener struct { net.Listener @@ -21,3 +30,72 @@ func Dial(network, address string, ciph StreamConnCipher) (net.Conn, error) { c, err := net.Dial(network, address) return ciph.StreamConn(c), err } + +// Connect sends the shadowsocks standard header to underlying ciphered connection +// in the next Write/ReadFrom call. +func Connect(c net.Conn, addr socks.Addr) net.Conn { + return newSSConn(c, addr, 5 * time.Millisecond) +} + +type ssconn struct { + net.Conn + addr socks.Addr + done uint32 + m sync.Mutex + t *time.Timer +} + +func newSSConn(c net.Conn, addr socks.Addr, delay time.Duration) *ssconn { + sc := &ssconn{Conn: c, addr: addr} + if delay > 0 { + sc.t = time.AfterFunc(delay, func() { + sc.Write([]byte{}) + }) + } + return sc +} + +func (c *ssconn) Write(b []byte) (int, error) { + n, err := c.ReadFrom(bytes.NewBuffer(b)) + return int(n), err +} + +func (c *ssconn) ReadFrom(r io.Reader) (int64, error) { + if atomic.LoadUint32(&c.done) == 0 { + c.m.Lock() + defer c.m.Unlock() + if c.done == 0 { + defer atomic.StoreUint32(&c.done, 1) + if c.t != nil { + c.t.Stop() + c.t = nil + } + ra := readerWithAddr{Reader: r, b: c.addr} + n, err := io.Copy(c.Conn, &ra) + n -= int64(len(c.addr)) + if n < 0 { n = 0 } + c.addr = nil + return n, err + } + } + return io.Copy(c.Conn, r) +} + +func (c *ssconn) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, c.Conn) +} + +type readerWithAddr struct { + io.Reader + b []byte +} + +func (r *readerWithAddr) Read(b []byte) (n int, err error) { + nc := copy(b, r.b) + r.b = r.b[nc:] + if len(b) == nc { + return nc, nil + } + nr, err := r.Reader.Read(b[nc:]) + return nc + nr, err +} diff --git a/shadowaead/stream.go b/shadowaead/stream.go index 5f499a21..054f9419 100644 --- a/shadowaead/stream.go +++ b/shadowaead/stream.go @@ -16,17 +16,19 @@ type writer struct { cipher.AEAD nonce []byte buf []byte + salt []byte } // NewWriter wraps an io.Writer with AEAD encryption. -func NewWriter(w io.Writer, aead cipher.AEAD) io.Writer { return newWriter(w, aead) } +func NewWriter(w io.Writer, aead cipher.AEAD, salt []byte) io.Writer { return newWriter(w, aead, salt) } -func newWriter(w io.Writer, aead cipher.AEAD) *writer { +func newWriter(w io.Writer, aead cipher.AEAD, salt []byte) *writer { return &writer{ Writer: w, AEAD: aead, buf: make([]byte, 2+aead.Overhead()+payloadSizeMask+aead.Overhead()), nonce: make([]byte, aead.NonceSize()), + salt: salt, } } @@ -36,6 +38,23 @@ func (w *writer) Write(b []byte) (int, error) { return int(n), err } +// Write salt before encrypted buffer to io.Writer. +func (w *writer) write(b []byte) (int, error) { + if len(w.salt) == 0 { + return w.Writer.Write(b) + } + buf := make([]byte, len(w.salt) + len(b)) + copy(buf[:len(w.salt)], w.salt) + copy(buf[len(w.salt):], b) + nw, err := w.Writer.Write(buf) + if nw < len(w.salt) { + w.salt = w.salt[nw:] + return 0, err + } + w.salt = nil + return nw - len(w.salt), err +} + // ReadFrom reads from the given io.Reader until EOF or error, encrypts and // writes to the embedded io.Writer. Returns number of bytes read from r and // any error encountered. @@ -56,7 +75,7 @@ func (w *writer) ReadFrom(r io.Reader) (n int64, err error) { w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil) increment(w.nonce) - _, ew := w.Writer.Write(buf) + _, ew := w.write(buf) if ew != nil { err = ew break @@ -240,11 +259,7 @@ func (c *streamConn) initWriter() error { if err != nil { return err } - _, err = c.Conn.Write(salt) - if err != nil { - return err - } - c.w = newWriter(c.Conn, aead) + c.w = newWriter(c.Conn, aead, salt) return nil } diff --git a/shadowstream/stream.go b/shadowstream/stream.go index eb4d9679..5d1d09bb 100644 --- a/shadowstream/stream.go +++ b/shadowstream/stream.go @@ -14,11 +14,12 @@ type writer struct { io.Writer cipher.Stream buf []byte + iv []byte } // NewWriter wraps an io.Writer with stream cipher encryption. -func NewWriter(w io.Writer, s cipher.Stream) io.Writer { - return &writer{Writer: w, Stream: s, buf: make([]byte, bufSize)} +func NewWriter(w io.Writer, s cipher.Stream, iv []byte) io.Writer { + return &writer{Writer: w, Stream: s, buf: make([]byte, bufSize), iv: iv} } func (w *writer) ReadFrom(r io.Reader) (n int64, err error) { @@ -29,7 +30,7 @@ func (w *writer) ReadFrom(r io.Reader) (n int64, err error) { n += int64(nr) buf = buf[:nr] w.XORKeyStream(buf, buf) - _, ew := w.Writer.Write(buf) + _, ew := w.write(buf) if ew != nil { err = ew return @@ -50,6 +51,23 @@ func (w *writer) Write(b []byte) (int, error) { return int(n), err } +// Write IV before encrypted buffer to io.Writer. +func (w *writer) write(b []byte) (int, error) { + if len(w.iv) == 0 { + return w.Writer.Write(b) + } + buf := make([]byte, len(w.iv) + len(b)) + copy(buf[:len(w.iv)], w.iv) + copy(buf[len(w.iv):], b) + nw, err := w.Writer.Write(buf) + if nw < len(w.iv) { + w.iv = w.iv[nw:] + return 0, err + } + w.iv = nil + return nw - len(w.iv), err +} + type reader struct { io.Reader cipher.Stream @@ -144,10 +162,7 @@ func (c *conn) initWriter() error { if _, err := io.ReadFull(rand.Reader, iv); err != nil { return err } - if _, err := c.Conn.Write(iv); err != nil { - return err - } - c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf} + c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf, iv: iv} } return nil } diff --git a/tcp.go b/tcp.go index 243b2704..b410c83f 100644 --- a/tcp.go +++ b/tcp.go @@ -5,6 +5,7 @@ import ( "net" "time" + "github.com/shadowsocks/go-shadowsocks2/core" "github.com/shadowsocks/go-shadowsocks2/socks" ) @@ -72,11 +73,8 @@ func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func( defer rc.Close() rc.(*net.TCPConn).SetKeepAlive(true) rc = shadow(rc) - - if _, err = rc.Write(tgt); err != nil { - logf("failed to send target address: %v", err) - return - } + // Connect to target + rc = core.Connect(rc, tgt) logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt) _, _, err = relay(rc, c)