-
Notifications
You must be signed in to change notification settings - Fork 1
/
proto.go
450 lines (379 loc) · 9.94 KB
/
proto.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
/* Snowflake client--server protocol
This package implements SnowflakeConn, a net.Conn for use between a Snowflake
client and server that implements a sequence and reliable Snowflake protocol.
The first 8 bytes sent from the client to the server at the start of every
connection is the session ID. This is meant to be read by the server
and mapped to a long-lived SnowflakeConn.
The protocol sends data in chunks, accompanied by a header:
0 4 8
+---------------+---------------+
| Seq Number | Ack Number |
+-------+-------+---------------+
| Len |
+-------+
With a 4 byte sequence number, a 4 byte acknowledgement number, a
2 byte length.
Each SnowflakeConn is initialized with a call to NewSnowflakeConn() and
an underlying connection is set with the call NewSnowflake(). Since Snowflakes
are ephemeral, a new snowflake can be set at any time.
This net.Conn is reliable, so any bytes sent as a call to SnowflakeConn's Write
method will be buffered until they are acknowledged by the other end. If a new
snowflake is provided, buffered bytes will be resent through the new connection
and remain buffered until they are acknowledged.
When a SnowflakeConn reads in bytes, it automatically sends an empty
acknowledgement packet to the other end of the connection with the Ack number
updated to reflect the most recently received data. Only when an endpoint
receives a packet with an updated acknowledgement number will it remove that
data from the stored buffer.
*/
package proto
import (
"bytes"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"strings"
"sync"
"time"
)
const snowflakeHeaderLen = 18
const maxLength = 65535
const sessionIDLength = 8
var snowflakeTimeout = 10 * time.Second
type snowflakeHeader struct {
seq uint32
ack uint32
length uint16 //length of the accompanying data (excluding header length)
}
func (h *snowflakeHeader) Parse(b []byte) error {
h.seq = binary.BigEndian.Uint32(b[0:4])
h.ack = binary.BigEndian.Uint32(b[4:8])
h.length = binary.BigEndian.Uint16(b[8:10])
return nil
}
// Converts a header to bytes
func (h *snowflakeHeader) marshal() []byte {
b := make([]byte, snowflakeHeaderLen, snowflakeHeaderLen)
binary.BigEndian.PutUint32(b[0:4], h.seq)
binary.BigEndian.PutUint32(b[4:8], h.ack)
binary.BigEndian.PutUint16(b[8:10], h.length)
return b
}
// Parses a Snowflake header from bytes received on the
// webRTC connection
func (s *SnowflakeConn) readHeader(h *snowflakeHeader) error {
var err error
b := make([]byte, snowflakeHeaderLen, snowflakeHeaderLen)
s.readLock.Lock()
if s.conn != nil {
_, err = io.ReadFull(s.conn, b)
}
s.readLock.Unlock()
if err != nil {
return err
}
if err := h.Parse(b); err != nil {
return err
}
return nil
}
type snowflakeTimer struct {
t *time.Timer
seq uint32
}
func newSnowflakeTimer(seq uint32, s *SnowflakeConn) *snowflakeTimer {
timer := &snowflakeTimer{seq: seq}
timer.t = time.AfterFunc(snowflakeTimeout, func() {
s.seqLock.Lock()
if s.acked < timer.seq {
log.Println("Closing WebRTC connection, timed out waiting for ACK")
s.Close()
}
s.seqLock.Unlock()
})
return timer
}
// updates the timer by resetting the duration and
// the sequence number to be acknowledged
func (timer *snowflakeTimer) update(seq uint32) error {
if !timer.t.Stop() {
// timer has already been fired
return fmt.Errorf("timer has already stopped")
}
timer.seq = seq
timer.t.Reset(snowflakeTimeout)
return nil
}
// SessionAddr implements the net.Addr interface and is set to the snowflake
// sessionID by SnowflakeConn
type SessionAddr []byte
func (addr SessionAddr) Network() string {
return "session"
}
func (addr SessionAddr) String() string {
return strings.TrimRight(base64.StdEncoding.EncodeToString(addr), "=")
}
type SnowflakeConn struct {
seq uint32
ack uint32
SessionID SessionAddr
conn io.ReadWriteCloser
pr *io.PipeReader
seqLock sync.Mutex //lock for the seq and ack numbers
readLock sync.Mutex //lock for the underlying connection
writeLock sync.Mutex //lock for the underlying connection
timerLock sync.Mutex //lock for timers
timer *snowflakeTimer
acked uint32
buf bytes.Buffer
}
func NewSnowflakeConn() *SnowflakeConn {
s := &SnowflakeConn{}
s.genSessionID()
s.timer = newSnowflakeTimer(0, s)
return s
}
func SetLog(w io.Writer) {
log.SetFlags(log.LstdFlags | log.LUTC)
log.SetOutput(w)
}
func (s *SnowflakeConn) genSessionID() error {
buf := make([]byte, sessionIDLength)
_, err := rand.Read(buf)
if err != nil {
return err
}
s.SessionID = buf
return nil
}
//Peak at header from a connection and return SessionAddr
func ReadSessionID(conn io.ReadWriteCloser) (net.Addr, error) {
var addr SessionAddr
addr = make([]byte, sessionIDLength)
_, err := io.ReadFull(conn, addr)
if err != nil {
return nil, err
}
return addr, nil
}
func (s *SnowflakeConn) sendSessionID() (int, error) {
var err error
var n int
s.writeLock.Lock()
if s.conn != nil {
n, err = s.conn.Write(s.SessionID)
}
s.writeLock.Unlock()
return n, err
}
func (s *SnowflakeConn) NewSnowflake(conn io.ReadWriteCloser, isClient bool) error {
if s.conn != nil {
s.conn.Close()
}
s.readLock.Lock()
s.writeLock.Lock()
s.conn = conn
s.writeLock.Unlock()
s.readLock.Unlock()
pr, pw := io.Pipe()
s.pr = pr
go s.readLoop(pw)
// if this is a client connection, send the session ID as the first 8 bytes
if isClient {
n, err := s.sendSessionID()
if err != nil {
return err
}
if n != sessionIDLength {
return fmt.Errorf("failed to write session id")
}
}
// Write out bytes in buffer
if s.buf.Len() > 0 {
s.seqLock.Lock()
s.seq = s.acked
s.seqLock.Unlock()
_, err := s.Write(s.buf.Next(s.buf.Len()))
if err != nil {
return err
}
}
return nil
}
func (s *SnowflakeConn) readBody(header snowflakeHeader, pw *io.PipeWriter) {
var n int64
var err error
s.seqLock.Lock()
if header.seq == s.ack {
s.readLock.Lock()
if s.conn != nil {
n, err = io.CopyN(pw, s.conn, int64(header.length))
}
s.readLock.Unlock()
if err != nil {
log.Printf("Error copying bytes from WebRTC connection to pipe: %s", err.Error())
}
s.ack += uint32(header.length)
} else {
s.readLock.Lock()
if s.conn != nil {
_, err = io.CopyN(ioutil.Discard, s.conn, int64(header.length))
}
s.readLock.Unlock()
if err != nil {
log.Printf("Error discarding bytes from WebRTC connection to pipe: %s", err.Error())
}
}
if int32(header.ack-s.acked) > 0 {
// remove newly acknowledged bytes from buffer
s.buf.Next(int(int32(header.ack - s.acked)))
s.acked = header.ack
}
s.seqLock.Unlock()
if n > 0 {
//send acknowledgement
go s.sendAck()
}
}
func (s *SnowflakeConn) readLoop(pw *io.PipeWriter) {
var err error
for err == nil {
// strip headers and write data into the pipe
var header snowflakeHeader
err = s.readHeader(&header)
if err != nil {
break
}
s.readBody(header, pw)
}
pw.CloseWithError(err)
}
func (s *SnowflakeConn) Read(b []byte) (int, error) {
// read de-headered data from the pipe
return s.pr.Read(b)
}
func (s *SnowflakeConn) sendAck() {
var err error
h := new(snowflakeHeader)
h.length = 0
h.seq = s.seq
s.seqLock.Lock()
h.ack = s.ack
s.seqLock.Unlock()
bytes := h.marshal()
s.writeLock.Lock()
if s.conn != nil {
_, err = s.conn.Write(bytes)
}
s.writeLock.Unlock()
if err != nil {
log.Printf("Error sending acknowledgment packet: %s", err.Error())
}
}
//Writes bytes to the underlying connection but saves them in a buffer first.
//These bytes will remain in the buffer until they are acknowledged by the
// other end of the connection.
// Note: Write will not return an error if the underlying connection has been closed
func (s *SnowflakeConn) Write(b []byte) (n int, err error) {
var err2 error
//need to append a header onto
h := new(snowflakeHeader)
if len(b) > maxLength {
h.length = maxLength
err = io.ErrShortWrite
} else {
h.length = uint16(len(b))
}
h.seq = s.seq
s.seqLock.Lock()
h.ack = s.ack
s.seqLock.Unlock()
bytes := h.marshal()
bytes = append(bytes, b...)
s.seq += uint32(len(b))
//save bytes to buffer until the have been acked
s.seqLock.Lock()
s.buf.Write(b)
s.seqLock.Unlock()
if s.conn == nil {
log.Printf("Buffering %d bytes, no connection yet.", len(b))
return len(b), nil
}
s.writeLock.Lock()
if s.conn != nil {
n, err2 = s.conn.Write(bytes)
}
s.writeLock.Unlock()
if err2 != nil {
log.Printf("Error writing to connection: %s", err.Error())
return len(b), err2
}
s.timerLock.Lock()
s.timer.update(s.seq)
s.timerLock.Unlock()
return len(b), err
}
func (s *SnowflakeConn) Close() error {
var err error
if s.conn != nil {
err = s.conn.Close()
}
s.readLock.Lock()
s.writeLock.Lock()
s.conn = nil
s.writeLock.Unlock()
s.readLock.Unlock()
//terminate all waiting timers
s.timerLock.Lock()
s.timer.t.Stop()
s.timerLock.Unlock()
return err
}
func (s *SnowflakeConn) LocalAddr() net.Addr {
return s.SessionID
}
func (s *SnowflakeConn) RemoteAddr() net.Addr {
return s.SessionID
}
func (s *SnowflakeConn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline not implemented")
}
func (s *SnowflakeConn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline not implemented")
}
func (s *SnowflakeConn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline not implemented")
}
// Functions similarly to io.Copy, except return a bool with value
// true if the call to src.Read caused the error and a value of false
// if the call to dst.Write caused the error
func Proxy(dst io.WriteCloser, src io.ReadCloser) (bool, error) {
buf := make([]byte, 32*1024)
var err error
var readClose bool
for {
nr, er := src.Read(buf)
if er != nil {
err = er
readClose = true
break
}
if nr > 0 {
nw, ew := dst.Write(buf[0:nr])
if ew != nil {
err = ew
break
}
if nw != nr {
err = io.ErrShortWrite
break
}
}
}
return readClose, err
}