Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write salt/iv, addr together with content #160

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion core/stream.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
package core

import "net"
import (
"bytes"
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/shadowsocks/go-shadowsocks2/socks"
)

type listener struct {
net.Listener
Expand All @@ -21,3 +30,72 @@ func Dial(network, address string, ciph StreamConnCipher) (net.Conn, error) {
c, err := net.Dial(network, address)
return ciph.StreamConn(c), err
}

// Connect sends the shadowsocks standard header to underlying ciphered connection
// in the next Write/ReadFrom call.
func Connect(c net.Conn, addr socks.Addr) net.Conn {
return newSSConn(c, addr, 5 * time.Millisecond)
}

type ssconn struct {
net.Conn
addr socks.Addr
done uint32
m sync.Mutex
t *time.Timer
}

func newSSConn(c net.Conn, addr socks.Addr, delay time.Duration) *ssconn {
sc := &ssconn{Conn: c, addr: addr}
if delay > 0 {
sc.t = time.AfterFunc(delay, func() {
sc.Write([]byte{})
})
}
return sc
}

func (c *ssconn) Write(b []byte) (int, error) {
n, err := c.ReadFrom(bytes.NewBuffer(b))
return int(n), err
}

func (c *ssconn) ReadFrom(r io.Reader) (int64, error) {
if atomic.LoadUint32(&c.done) == 0 {
c.m.Lock()
defer c.m.Unlock()
if c.done == 0 {
defer atomic.StoreUint32(&c.done, 1)
if c.t != nil {
c.t.Stop()
c.t = nil
}
ra := readerWithAddr{Reader: r, b: c.addr}
n, err := io.Copy(c.Conn, &ra)
n -= int64(len(c.addr))
if n < 0 { n = 0 }
c.addr = nil
return n, err
}
}
return io.Copy(c.Conn, r)
}

func (c *ssconn) WriteTo(w io.Writer) (int64, error) {
return io.Copy(w, c.Conn)
}

type readerWithAddr struct {
io.Reader
b []byte
}

func (r *readerWithAddr) Read(b []byte) (n int, err error) {
nc := copy(b, r.b)
r.b = r.b[nc:]
if len(b) == nc {
return nc, nil
}
nr, err := r.Reader.Read(b[nc:])
return nc + nr, err
}
31 changes: 23 additions & 8 deletions shadowaead/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ type writer struct {
cipher.AEAD
nonce []byte
buf []byte
salt []byte
}

// NewWriter wraps an io.Writer with AEAD encryption.
func NewWriter(w io.Writer, aead cipher.AEAD) io.Writer { return newWriter(w, aead) }
func NewWriter(w io.Writer, aead cipher.AEAD, salt []byte) io.Writer { return newWriter(w, aead, salt) }

func newWriter(w io.Writer, aead cipher.AEAD) *writer {
func newWriter(w io.Writer, aead cipher.AEAD, salt []byte) *writer {
return &writer{
Writer: w,
AEAD: aead,
buf: make([]byte, 2+aead.Overhead()+payloadSizeMask+aead.Overhead()),
nonce: make([]byte, aead.NonceSize()),
salt: salt,
}
}

Expand All @@ -36,6 +38,23 @@ func (w *writer) Write(b []byte) (int, error) {
return int(n), err
}

// Write salt before encrypted buffer to io.Writer.
func (w *writer) write(b []byte) (int, error) {
if len(w.salt) == 0 {
return w.Writer.Write(b)
}
buf := make([]byte, len(w.salt) + len(b))
copy(buf[:len(w.salt)], w.salt)
copy(buf[len(w.salt):], b)
nw, err := w.Writer.Write(buf)
if nw < len(w.salt) {
w.salt = w.salt[nw:]
return 0, err
}
w.salt = nil
return nw - len(w.salt), err
}

// ReadFrom reads from the given io.Reader until EOF or error, encrypts and
// writes to the embedded io.Writer. Returns number of bytes read from r and
// any error encountered.
Expand All @@ -56,7 +75,7 @@ func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil)
increment(w.nonce)

_, ew := w.Writer.Write(buf)
_, ew := w.write(buf)
if ew != nil {
err = ew
break
Expand Down Expand Up @@ -240,11 +259,7 @@ func (c *streamConn) initWriter() error {
if err != nil {
return err
}
_, err = c.Conn.Write(salt)
if err != nil {
return err
}
c.w = newWriter(c.Conn, aead)
c.w = newWriter(c.Conn, aead, salt)
return nil
}

Expand Down
29 changes: 22 additions & 7 deletions shadowstream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@ type writer struct {
io.Writer
cipher.Stream
buf []byte
iv []byte
}

// NewWriter wraps an io.Writer with stream cipher encryption.
func NewWriter(w io.Writer, s cipher.Stream) io.Writer {
return &writer{Writer: w, Stream: s, buf: make([]byte, bufSize)}
func NewWriter(w io.Writer, s cipher.Stream, iv []byte) io.Writer {
return &writer{Writer: w, Stream: s, buf: make([]byte, bufSize), iv: iv}
}

func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
Expand All @@ -29,7 +30,7 @@ func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
n += int64(nr)
buf = buf[:nr]
w.XORKeyStream(buf, buf)
_, ew := w.Writer.Write(buf)
_, ew := w.write(buf)
if ew != nil {
err = ew
return
Expand All @@ -50,6 +51,23 @@ func (w *writer) Write(b []byte) (int, error) {
return int(n), err
}

// Write IV before encrypted buffer to io.Writer.
func (w *writer) write(b []byte) (int, error) {
if len(w.iv) == 0 {
return w.Writer.Write(b)
}
buf := make([]byte, len(w.iv) + len(b))
copy(buf[:len(w.iv)], w.iv)
copy(buf[len(w.iv):], b)
nw, err := w.Writer.Write(buf)
if nw < len(w.iv) {
w.iv = w.iv[nw:]
return 0, err
}
w.iv = nil
return nw - len(w.iv), err
}

type reader struct {
io.Reader
cipher.Stream
Expand Down Expand Up @@ -144,10 +162,7 @@ func (c *conn) initWriter() error {
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return err
}
if _, err := c.Conn.Write(iv); err != nil {
return err
}
c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf}
c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf, iv: iv}
}
return nil
}
Expand Down
8 changes: 3 additions & 5 deletions tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"time"

"github.com/shadowsocks/go-shadowsocks2/core"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

Expand Down Expand Up @@ -72,11 +73,8 @@ func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func(
defer rc.Close()
rc.(*net.TCPConn).SetKeepAlive(true)
rc = shadow(rc)

if _, err = rc.Write(tgt); err != nil {
logf("failed to send target address: %v", err)
return
}
// Connect to target
rc = core.Connect(rc, tgt)

logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt)
_, _, err = relay(rc, c)
Expand Down