This repository has been archived by the owner on Mar 5, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathconn.go
164 lines (132 loc) · 3.57 KB
/
conn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package dshardorchestrator
import (
"encoding/binary"
"fmt"
"github.com/pkg/errors"
"io"
"net"
"strconv"
"sync"
"sync/atomic"
)
// Conn represents a connection from either node to the orchestrator or the other way around
// it implements common logic across both sides
type Conn struct {
logger Logger
netConn net.Conn
sendmu sync.Mutex
ID atomic.Value
// called on incoming messages
MessageHandler func(*Message)
// called when the connection is closed
ConnClosedHanlder func()
}
// ConnFromNetCon wraos a Conn around a net.Conn
func ConnFromNetCon(conn net.Conn, logger Logger) *Conn {
c := &Conn{
netConn: conn,
logger: logger,
}
c.ID.Store("unknown-" + strconv.FormatInt(getNewID(), 10))
return c
}
func (c *Conn) Close() {
c.netConn.Close()
}
// Listen starts listening for events on the connection
func (c *Conn) Listen() {
c.Log(LogInfo, nil, "started listening for events...")
var err error
defer func() {
if err != nil {
c.Log(LogError, err, "an error occured while handling a connection")
}
c.netConn.Close()
c.Log(LogInfo, nil, "connection closed")
if c.ConnClosedHanlder != nil {
c.ConnClosedHanlder()
}
}()
idBuf := make([]byte, 4)
lenBuf := make([]byte, 4)
for {
// Read the event id
_, err = c.netConn.Read(idBuf)
if err != nil {
c.Log(LogError, err, "failed reading event id")
return
}
// Read the body length
_, err = c.netConn.Read(lenBuf)
if err != nil {
c.Log(LogError, err, "failed reading event length")
return
}
id := EventType(binary.LittleEndian.Uint32(idBuf))
l := binary.LittleEndian.Uint32(lenBuf)
c.Log(LogDebug, err, fmt.Sprintf("inc message evt: %s, payload lenght: %d", id.String(), l))
body := make([]byte, int(l))
if l > 0 {
// Read the body, if there was one
_, err = io.ReadFull(c.netConn, body)
if err != nil {
c.Log(LogError, err, "failed reading message body")
return
}
}
msg := &Message{
EvtID: id,
}
if id < 100 {
decoded, err := DecodePayload(id, body)
if err != nil {
c.Log(LogError, err, "failed decoding message payload")
}
msg.DecodedBody = decoded
} else {
msg.RawBody = body
}
c.MessageHandler(msg)
}
}
// Send sends the specified message over the connection, marshaling the data using json
// this locks the writer
func (c *Conn) Send(evtID EventType, data interface{}) error {
encoded, err := EncodeMessage(evtID, data)
if err != nil {
return errors.WithMessage(err, "EncodeEvent")
}
c.sendmu.Lock()
defer c.sendmu.Unlock()
c.Log(LogDebug, nil, fmt.Sprintf("sending evt %s, len: %d", evtID.String(), len(encoded)))
return c.SendNoLock(encoded)
}
// Same as Send but logs the error (usefull for launching send in new goroutines)
func (c *Conn) SendLogErr(evtID EventType, data interface{}) {
err := c.Send(evtID, data)
if err != nil {
c.Log(LogError, err, "failed sending message")
}
}
// SendNoLock sends the specified message over the connection, marshaling the data using json
// This does no locking and the caller is responsible for making sure its not called in multiple goroutines at the same time
func (c *Conn) SendNoLock(data []byte) error {
_, err := c.netConn.Write(data)
return errors.WithMessage(err, "netConn.Write")
}
// GetID is a simpler helper for retrieving the connection id
func (c *Conn) GetID() string {
return c.ID.Load().(string)
}
func (c *Conn) Log(level LogLevel, err error, msg string) {
if err != nil {
msg = msg + ": " + err.Error()
}
id := c.GetID()
msg = id + ": " + msg
if c.logger == nil {
StdLogInstance.Log(level, msg)
} else {
c.logger.Log(level, msg)
}
}