forked from pion/dtls
-
Notifications
You must be signed in to change notification settings - Fork 0
/
state.go
293 lines (240 loc) · 8.52 KB
/
state.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package dtls
import (
"bytes"
"encoding/gob"
"errors"
"sync/atomic"
"github.com/pion/dtls/v3/pkg/crypto/elliptic"
"github.com/pion/dtls/v3/pkg/crypto/prf"
"github.com/pion/dtls/v3/pkg/crypto/signaturehash"
"github.com/pion/dtls/v3/pkg/protocol/handshake"
"github.com/pion/transport/v3/replaydetector"
)
// State holds the dtls connection state and implements both encoding.BinaryMarshaler and encoding.BinaryUnmarshaler
type State struct {
localEpoch, remoteEpoch atomic.Value
localSequenceNumber []uint64 // uint48
localRandom, remoteRandom handshake.Random
masterSecret []byte
cipherSuite CipherSuite // nil if a cipherSuite hasn't been chosen
CipherSuiteID CipherSuiteID
srtpProtectionProfile atomic.Value // Negotiated SRTPProtectionProfile
remoteSRTPMasterKeyIdentifier []byte
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
// Connection Identifiers must be negotiated afresh on session resumption.
// https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension
// localConnectionID is the locally generated connection ID that is expected
// to be received from the remote endpoint.
// For a server, this is the connection ID sent in ServerHello.
// For a client, this is the connection ID sent in the ClientHello.
localConnectionID atomic.Value
// remoteConnectionID is the connection ID that the remote endpoint
// specifies should be sent.
// For a server, this is the connection ID received in the ClientHello.
// For a client, this is the connection ID received in the ServerHello.
remoteConnectionID []byte
isClient bool
preMasterSecret []byte
extendedMasterSecret bool
namedCurve elliptic.Curve
localKeypair *elliptic.Keypair
cookie []byte
handshakeSendSequence int
handshakeRecvSequence int
serverName string
remoteCertRequestAlgs []signaturehash.Algorithm
remoteRequestedCertificate bool // Did we get a CertificateRequest
localCertificatesVerify []byte // cache CertificateVerify
localVerifyData []byte // cached VerifyData
localKeySignature []byte // cached keySignature
peerCertificatesVerified bool
replayDetector []replaydetector.ReplayDetector
peerSupportedProtocols []string
NegotiatedProtocol string
}
type serializedState struct {
LocalEpoch uint16
RemoteEpoch uint16
LocalRandom [handshake.RandomLength]byte
RemoteRandom [handshake.RandomLength]byte
CipherSuiteID uint16
MasterSecret []byte
SequenceNumber uint64
SRTPProtectionProfile uint16
PeerCertificates [][]byte
IdentityHint []byte
SessionID []byte
LocalConnectionID []byte
RemoteConnectionID []byte
IsClient bool
NegotiatedProtocol string
}
var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:goerr113
func (s *State) clone() (*State, error) {
serialized, err := s.serialize()
if err != nil {
return nil, err
}
state := &State{}
state.deserialize(*serialized)
return state, err
}
func (s *State) serialize() (*serializedState, error) {
if s.cipherSuite == nil {
return nil, errCipherSuiteNotSet
}
cipherSuiteID := uint16(s.cipherSuite.ID())
// Marshal random values
localRnd := s.localRandom.MarshalFixed()
remoteRnd := s.remoteRandom.MarshalFixed()
epoch := s.getLocalEpoch()
return &serializedState{
LocalEpoch: s.getLocalEpoch(),
RemoteEpoch: s.getRemoteEpoch(),
CipherSuiteID: cipherSuiteID,
MasterSecret: s.masterSecret,
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
RemoteRandom: remoteRnd,
SRTPProtectionProfile: uint16(s.getSRTPProtectionProfile()),
PeerCertificates: s.PeerCertificates,
IdentityHint: s.IdentityHint,
SessionID: s.SessionID,
LocalConnectionID: s.getLocalConnectionID(),
RemoteConnectionID: s.remoteConnectionID,
IsClient: s.isClient,
NegotiatedProtocol: s.NegotiatedProtocol,
}, nil
}
func (s *State) deserialize(serialized serializedState) {
// Set epoch values
epoch := serialized.LocalEpoch
s.localEpoch.Store(serialized.LocalEpoch)
s.remoteEpoch.Store(serialized.RemoteEpoch)
for len(s.localSequenceNumber) <= int(epoch) {
s.localSequenceNumber = append(s.localSequenceNumber, uint64(0))
}
// Set random values
localRandom := &handshake.Random{}
localRandom.UnmarshalFixed(serialized.LocalRandom)
s.localRandom = *localRandom
remoteRandom := &handshake.Random{}
remoteRandom.UnmarshalFixed(serialized.RemoteRandom)
s.remoteRandom = *remoteRandom
s.isClient = serialized.IsClient
// Set master secret
s.masterSecret = serialized.MasterSecret
// Set cipher suite
s.CipherSuiteID = CipherSuiteID(serialized.CipherSuiteID)
s.cipherSuite = cipherSuiteForID(s.CipherSuiteID, nil)
atomic.StoreUint64(&s.localSequenceNumber[epoch], serialized.SequenceNumber)
s.setSRTPProtectionProfile(SRTPProtectionProfile(serialized.SRTPProtectionProfile))
// Set remote certificate
s.PeerCertificates = serialized.PeerCertificates
s.IdentityHint = serialized.IdentityHint
// Set local and remote connection IDs
s.setLocalConnectionID(serialized.LocalConnectionID)
s.remoteConnectionID = serialized.RemoteConnectionID
s.SessionID = serialized.SessionID
s.NegotiatedProtocol = serialized.NegotiatedProtocol
}
func (s *State) initCipherSuite() error {
if s.cipherSuite.IsInitialized() {
return nil
}
localRandom := s.localRandom.MarshalFixed()
remoteRandom := s.remoteRandom.MarshalFixed()
var err error
if s.isClient {
err = s.cipherSuite.Init(s.masterSecret, localRandom[:], remoteRandom[:], true)
} else {
err = s.cipherSuite.Init(s.masterSecret, remoteRandom[:], localRandom[:], false)
}
if err != nil {
return err
}
return nil
}
// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
func (s *State) MarshalBinary() ([]byte, error) {
serialized, err := s.serialize()
if err != nil {
return nil, err
}
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(*serialized); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// UnmarshalBinary is a binary.BinaryUnmarshaler.UnmarshalBinary implementation
func (s *State) UnmarshalBinary(data []byte) error {
enc := gob.NewDecoder(bytes.NewBuffer(data))
var serialized serializedState
if err := enc.Decode(&serialized); err != nil {
return err
}
s.deserialize(serialized)
return s.initCipherSuite()
}
// ExportKeyingMaterial returns length bytes of exported key material in a new
// slice as defined in RFC 5705.
// This allows protocols to use DTLS for key establishment, but
// then use some of the keying material for their own purposes
func (s *State) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) {
if s.getLocalEpoch() == 0 {
return nil, errHandshakeInProgress
} else if len(context) != 0 {
return nil, errContextUnsupported
} else if _, ok := invalidKeyingLabels()[label]; ok {
return nil, errReservedExportKeyingMaterial
}
localRandom := s.localRandom.MarshalFixed()
remoteRandom := s.remoteRandom.MarshalFixed()
seed := []byte(label)
if s.isClient {
seed = append(append(seed, localRandom[:]...), remoteRandom[:]...)
} else {
seed = append(append(seed, remoteRandom[:]...), localRandom[:]...)
}
return prf.PHash(s.masterSecret, seed, length, s.cipherSuite.HashFunc())
}
func (s *State) getRemoteEpoch() uint16 {
if remoteEpoch, ok := s.remoteEpoch.Load().(uint16); ok {
return remoteEpoch
}
return 0
}
func (s *State) getLocalEpoch() uint16 {
if localEpoch, ok := s.localEpoch.Load().(uint16); ok {
return localEpoch
}
return 0
}
func (s *State) setSRTPProtectionProfile(profile SRTPProtectionProfile) {
s.srtpProtectionProfile.Store(profile)
}
func (s *State) getSRTPProtectionProfile() SRTPProtectionProfile {
if val, ok := s.srtpProtectionProfile.Load().(SRTPProtectionProfile); ok {
return val
}
return 0
}
func (s *State) getLocalConnectionID() []byte {
if val, ok := s.localConnectionID.Load().([]byte); ok {
return val
}
return nil
}
func (s *State) setLocalConnectionID(v []byte) {
s.localConnectionID.Store(v)
}
// RemoteRandomBytes returns the remote client hello random bytes
func (s *State) RemoteRandomBytes() [handshake.RandomBytesLength]byte {
return s.remoteRandom.RandomBytes
}