summaryrefslogtreecommitdiffstats
path: root/vendor/go.mau.fi/libsignal/state/record/SessionState.go
blob: d0f61d5c72f96f7176019ee50bbeca1c483ebabe (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
package record

import (
	"go.mau.fi/libsignal/ecc"
	"go.mau.fi/libsignal/kdf"
	"go.mau.fi/libsignal/keys/chain"
	"go.mau.fi/libsignal/keys/identity"
	"go.mau.fi/libsignal/keys/message"
	"go.mau.fi/libsignal/keys/root"
	"go.mau.fi/libsignal/keys/session"
	"go.mau.fi/libsignal/logger"
	"go.mau.fi/libsignal/util/errorhelper"
	"go.mau.fi/libsignal/util/optional"
)

const maxMessageKeys int = 2000
const maxReceiverChains int = 5

// StateSerializer is an interface for serializing and deserializing
// a Signal State into bytes. An implementation of this interface should be
// used to encode/decode the object into JSON, Protobuffers, etc.
type StateSerializer interface {
	Serialize(state *StateStructure) []byte
	Deserialize(serialized []byte) (*StateStructure, error)
}

// NewStateFromBytes will return a Signal State from the given
// bytes using the given serializer.
func NewStateFromBytes(serialized []byte, serializer StateSerializer) (*State, error) {
	// Use the given serializer to decode the signal message.
	stateStructure, err := serializer.Deserialize(serialized)
	if err != nil {
		return nil, err
	}

	return NewStateFromStructure(stateStructure, serializer)
}

// NewState returns a new session state.
func NewState(serializer StateSerializer) *State {
	return &State{serializer: serializer}
}

// NewStateFromStructure will return a new session state with the
// given state structure.
func NewStateFromStructure(structure *StateStructure, serializer StateSerializer) (*State, error) {
	// Keep a list of errors, so they can be handled once.
	errors := errorhelper.NewMultiError()

	// Convert our ecc keys from bytes into object form.
	localIdentityPublic, err := ecc.DecodePoint(structure.LocalIdentityPublic, 0)
	errors.Add(err)
	remoteIdentityPublic, err := ecc.DecodePoint(structure.RemoteIdentityPublic, 0)
	errors.Add(err)
	senderBaseKey, err := ecc.DecodePoint(structure.SenderBaseKey, 0)
	errors.Add(err)
	var pendingPreKey *PendingPreKey
	if structure.PendingPreKey != nil {
		pendingPreKey, err = NewPendingPreKeyFromStruct(structure.PendingPreKey)
		errors.Add(err)
	}
	senderChain, err := NewChainFromStructure(structure.SenderChain)
	errors.Add(err)

	// Build our receiver chains from structure.
	receiverChains := make([]*Chain, len(structure.ReceiverChains))
	for i := range structure.ReceiverChains {
		receiverChains[i], err = NewChainFromStructure(structure.ReceiverChains[i])
		errors.Add(err)
	}

	// Handle any errors. The first error will always be returned if there are multiple.
	if errors.HasErrors() {
		return nil, errors
	}

	// Build our state object.
	state := &State{
		localIdentityPublic:  identity.NewKey(localIdentityPublic),
		localRegistrationID:  structure.LocalRegistrationID,
		needsRefresh:         structure.NeedsRefresh,
		pendingKeyExchange:   NewPendingKeyExchangeFromStruct(structure.PendingKeyExchange),
		pendingPreKey:        pendingPreKey,
		previousCounter:      structure.PreviousCounter,
		receiverChains:       receiverChains,
		remoteIdentityPublic: identity.NewKey(remoteIdentityPublic),
		remoteRegistrationID: structure.RemoteRegistrationID,
		rootKey:              root.NewKey(kdf.DeriveSecrets, structure.RootKey),
		senderBaseKey:        senderBaseKey,
		senderChain:          senderChain,
		serializer:           serializer,
		sessionVersion:       structure.SessionVersion,
	}

	return state, nil
}

// StateStructure is the structure of a session state. Fields are public
// to be used for serialization and deserialization.
type StateStructure struct {
	LocalIdentityPublic  []byte
	LocalRegistrationID  uint32
	NeedsRefresh         bool
	PendingKeyExchange   *PendingKeyExchangeStructure
	PendingPreKey        *PendingPreKeyStructure
	PreviousCounter      uint32
	ReceiverChains       []*ChainStructure
	RemoteIdentityPublic []byte
	RemoteRegistrationID uint32
	RootKey              []byte
	SenderBaseKey        []byte
	SenderChain          *ChainStructure
	SessionVersion       int
}

// State is a session state that contains the structure for
// all sessions. Session states are contained inside session records.
// The session state is implemented as a struct rather than protobuffers
// to allow other serialization methods.
type State struct {
	localIdentityPublic  *identity.Key
	localRegistrationID  uint32
	needsRefresh         bool
	pendingKeyExchange   *PendingKeyExchange
	pendingPreKey        *PendingPreKey
	previousCounter      uint32
	receiverChains       []*Chain
	remoteIdentityPublic *identity.Key
	remoteRegistrationID uint32
	rootKey              *root.Key
	senderBaseKey        ecc.ECPublicKeyable
	senderChain          *Chain
	serializer           StateSerializer
	sessionVersion       int
}

// SenderBaseKey returns the sender's base key in bytes.
func (s *State) SenderBaseKey() []byte {
	if s.senderBaseKey == nil {
		return nil
	}
	return s.senderBaseKey.Serialize()
}

// SetSenderBaseKey sets the sender's base key with the given bytes.
func (s *State) SetSenderBaseKey(senderBaseKey []byte) {
	s.senderBaseKey, _ = ecc.DecodePoint(senderBaseKey, 0)
}

// Version returns the session's version.
func (s *State) Version() int {
	return s.sessionVersion
}

// SetVersion sets the session state's version number.
func (s *State) SetVersion(version int) {
	s.sessionVersion = version
}

// RemoteIdentityKey returns the identity key of the remote user.
func (s *State) RemoteIdentityKey() *identity.Key {
	return s.remoteIdentityPublic
}

// SetRemoteIdentityKey sets this session's identity key for the remote
// user.
func (s *State) SetRemoteIdentityKey(identityKey *identity.Key) {
	s.remoteIdentityPublic = identityKey
}

// LocalIdentityKey returns the session's identity key for the local
// user.
func (s *State) LocalIdentityKey() *identity.Key {
	return s.localIdentityPublic
}

// SetLocalIdentityKey sets the session's identity key for the local
// user.
func (s *State) SetLocalIdentityKey(identityKey *identity.Key) {
	s.localIdentityPublic = identityKey
}

// PreviousCounter returns the counter of the previous message.
func (s *State) PreviousCounter() uint32 {
	return s.previousCounter
}

// SetPreviousCounter sets the counter for the previous message.
func (s *State) SetPreviousCounter(previousCounter uint32) {
	s.previousCounter = previousCounter
}

// RootKey returns the root key for the session.
func (s *State) RootKey() session.RootKeyable {
	return s.rootKey
}

// SetRootKey sets the root key for the session.
func (s *State) SetRootKey(rootKey session.RootKeyable) {
	s.rootKey = rootKey.(*root.Key)
}

// SenderRatchetKey returns the public ratchet key of the sender.
func (s *State) SenderRatchetKey() ecc.ECPublicKeyable {
	return s.senderChain.senderRatchetKeyPair.PublicKey()
}

// SenderRatchetKeyPair returns the public/private ratchet key pair
// of the sender.
func (s *State) SenderRatchetKeyPair() *ecc.ECKeyPair {
	return s.senderChain.senderRatchetKeyPair
}

// HasReceiverChain will check to see if the session state has
// the given ephemeral key.
func (s *State) HasReceiverChain(senderEphemeral ecc.ECPublicKeyable) bool {
	return s.receiverChain(senderEphemeral) != nil
}

// HasSenderChain will check to see if the session state has a
// sender chain.
func (s *State) HasSenderChain() bool {
	return s.senderChain != nil
}

// receiverChain will loop through the session state's receiver chains
// and compare the given ephemeral key. If it is found, then the chain
// and index will be returned as a pair.
func (s *State) receiverChain(senderEphemeral ecc.ECPublicKeyable) *ReceiverChainPair {
	receiverChains := s.receiverChains

	for i, receiverChain := range receiverChains {
		chainSenderRatchetKey, err := ecc.DecodePoint(receiverChain.senderRatchetKeyPair.PublicKey().Serialize(), 0)
		if err != nil {
			logger.Error("Error getting receiverchain: ", err)
		}

		// If the chainSenderRatchetKey equals our senderEphemeral key, return it.
		if chainSenderRatchetKey.PublicKey() == senderEphemeral.PublicKey() {
			return NewReceiverChainPair(receiverChain, i)
		}
	}

	return nil
}

// ReceiverChainKey will use the given ephemeral key to generate a new
// chain key.
func (s *State) ReceiverChainKey(senderEphemeral ecc.ECPublicKeyable) *chain.Key {
	receiverChainAndIndex := s.receiverChain(senderEphemeral)
	receiverChain := receiverChainAndIndex.ReceiverChain

	if receiverChainAndIndex == nil || receiverChain == nil {
		return nil
	}

	return chain.NewKey(
		kdf.DeriveSecrets,
		receiverChain.chainKey.Key(),
		receiverChain.chainKey.Index(),
	)
}

// AddReceiverChain will add the given ratchet key and chain key to the session
// state.
func (s *State) AddReceiverChain(senderRatchetKey ecc.ECPublicKeyable, chainKey session.ChainKeyable) {
	// Create a keypair structure with our sender ratchet key.
	senderKey := ecc.NewECKeyPair(senderRatchetKey, nil)

	// Create a Chain state object that will hold our sender key, chain key, and
	// message keys.
	chain := NewChain(senderKey, chainKey.(*chain.Key), []*message.Keys{})

	// Add the Chain state to our list of receiver chain states.
	s.receiverChains = append(s.receiverChains, chain)

	// If our list of receiver chains is too big, delete the oldest entry.
	if len(s.receiverChains) > maxReceiverChains {
		i := 0
		s.receiverChains = append(s.receiverChains[:i], s.receiverChains[i+1:]...)
	}
}

// SetSenderChain will set the given ratchet key pair and chain key for this session
// state.
func (s *State) SetSenderChain(senderRatchetKeyPair *ecc.ECKeyPair, chainKey session.ChainKeyable) {
	// Create a Chain state object that will hold our sender key, chain key, and
	// message keys.
	chain := NewChain(senderRatchetKeyPair, chainKey.(*chain.Key), []*message.Keys{})

	// Set the sender chain.
	s.senderChain = chain
}

// SenderChainKey will return the chain key of the session state.
func (s *State) SenderChainKey() session.ChainKeyable {
	chainKey := s.senderChain.chainKey
	return chain.NewKey(kdf.DeriveSecrets, chainKey.Key(), chainKey.Index())
}

// SetSenderChainKey will set the chain key in the chain state for this session to
// the given chain key.
func (s *State) SetSenderChainKey(nextChainKey session.ChainKeyable) {
	senderChain := s.senderChain
	senderChain.SetChainKey(nextChainKey.(*chain.Key))
}

// HasMessageKeys returns true if we have message keys associated with the given
// sender key and counter.
func (s *State) HasMessageKeys(senderEphemeral ecc.ECPublicKeyable, counter uint32) bool {
	// Get our chain state that has our chain key.
	chainAndIndex := s.receiverChain(senderEphemeral)
	receiverChain := chainAndIndex.ReceiverChain

	// If the chain is empty, we don't have any message keys.
	if receiverChain == nil {
		return false
	}

	// Get our message keys from our receiver chain.
	messageKeyList := receiverChain.MessageKeys()

	// Loop through our message keys and compare its index with the
	// given counter.
	for _, messageKey := range messageKeyList {
		if messageKey.Index() == counter {
			return true
		}
	}

	return false
}

// RemoveMessageKeys removes the message key with the given sender key and
// counter. It will return the removed message key.
func (s *State) RemoveMessageKeys(senderEphemeral ecc.ECPublicKeyable, counter uint32) *message.Keys {
	// Get our chain state that has our chain key.
	chainAndIndex := s.receiverChain(senderEphemeral)
	chainKey := chainAndIndex.ReceiverChain

	// If the chain is empty, we don't have any message keys.
	if chainKey == nil {
		return nil
	}

	// Get our message keys from our receiver chain.
	messageKeyList := chainKey.MessageKeys()

	// Loop through our message keys and compare its index with the
	// given counter. When we find a match, remove it from our list.
	var rmIndex int
	for i, messageKey := range messageKeyList {
		if messageKey.Index() == counter {
			rmIndex = i
			break
		}
	}

	// Retrive the message key
	messageKey := chainKey.messageKeys[rmIndex]

	// Delete the message key from the given position.
	chainKey.messageKeys = append(chainKey.messageKeys[:rmIndex], chainKey.messageKeys[rmIndex+1:]...)

	return message.NewKeys(
		messageKey.CipherKey(),
		messageKey.MacKey(),
		messageKey.Iv(),
		messageKey.Index(),
	)
}

// SetMessageKeys will update the chain associated with the given sender key with
// the given message keys.
func (s *State) SetMessageKeys(senderEphemeral ecc.ECPublicKeyable, messageKeys *message.Keys) {
	chainAndIndex := s.receiverChain(senderEphemeral)
	chainState := chainAndIndex.ReceiverChain

	// Add the message keys to our chain state.
	chainState.AddMessageKeys(
		message.NewKeys(
			messageKeys.CipherKey(),
			messageKeys.MacKey(),
			messageKeys.Iv(),
			messageKeys.Index(),
		),
	)

	if len(chainState.MessageKeys()) > maxMessageKeys {
		chainState.PopFirstMessageKeys()
	}
}

// SetReceiverChainKey sets the session's receiver chain key with the given chain key
// associated with the given senderEphemeral key.
func (s *State) SetReceiverChainKey(senderEphemeral ecc.ECPublicKeyable, chainKey session.ChainKeyable) {
	chainAndIndex := s.receiverChain(senderEphemeral)
	chainState := chainAndIndex.ReceiverChain
	chainState.SetChainKey(chainKey.(*chain.Key))
}

// SetPendingKeyExchange will set the session's pending key exchange state to the given
// sequence and key pairs.
func (s *State) SetPendingKeyExchange(sequence uint32, ourBaseKey, ourRatchetKey *ecc.ECKeyPair,
	ourIdentityKey *identity.KeyPair) {

	s.pendingKeyExchange = NewPendingKeyExchange(
		sequence,
		ourBaseKey,
		ourRatchetKey,
		ourIdentityKey,
	)
}

// PendingKeyExchangeSequence will return the session's pending key exchange sequence
// number.
func (s *State) PendingKeyExchangeSequence() uint32 {
	return s.pendingKeyExchange.sequence
}

// PendingKeyExchangeBaseKeyPair will return the session's pending key exchange base keypair.
func (s *State) PendingKeyExchangeBaseKeyPair() *ecc.ECKeyPair {
	return s.pendingKeyExchange.localBaseKeyPair
}

// PendingKeyExchangeRatchetKeyPair will return the session's pending key exchange ratchet
// keypair.
func (s *State) PendingKeyExchangeRatchetKeyPair() *ecc.ECKeyPair {
	return s.pendingKeyExchange.localRatchetKeyPair
}

// PendingKeyExchangeIdentityKeyPair will return the session's pending key exchange identity
// keypair.
func (s *State) PendingKeyExchangeIdentityKeyPair() *identity.KeyPair {
	return s.pendingKeyExchange.localIdentityKeyPair
}

// HasPendingKeyExchange will return true if there is a valid pending key exchange waiting.
func (s *State) HasPendingKeyExchange() bool {
	return s.pendingKeyExchange != nil
}

// SetUnacknowledgedPreKeyMessage will return unacknowledged pre key message with the
// given key ids and base key.
func (s *State) SetUnacknowledgedPreKeyMessage(preKeyID *optional.Uint32, signedPreKeyID uint32, baseKey ecc.ECPublicKeyable) {
	s.pendingPreKey = NewPendingPreKey(
		preKeyID,
		signedPreKeyID,
		baseKey,
	)
}

// HasUnacknowledgedPreKeyMessage will return true if this session has an unacknowledged
// pre key message.
func (s *State) HasUnacknowledgedPreKeyMessage() bool {
	return s.pendingPreKey != nil
}

// UnackPreKeyMessageItems will return the session's unacknowledged pre key messages.
func (s *State) UnackPreKeyMessageItems() (*UnackPreKeyMessageItems, error) {
	preKeyID := s.pendingPreKey.preKeyID
	signedPreKeyID := s.pendingPreKey.signedPreKeyID
	baseKey, err := ecc.DecodePoint(s.pendingPreKey.baseKey.Serialize(), 0)
	if err != nil {
		return nil, err
	}
	return NewUnackPreKeyMessageItems(preKeyID, signedPreKeyID, baseKey), nil
}

// ClearUnackPreKeyMessage will clear the session's pending pre key.
func (s *State) ClearUnackPreKeyMessage() {
	s.pendingPreKey = nil
}

// SetRemoteRegistrationID sets the remote user's registration id.
func (s *State) SetRemoteRegistrationID(registrationID uint32) {
	s.remoteRegistrationID = registrationID
}

// RemoteRegistrationID returns the remote user's registration id.
func (s *State) RemoteRegistrationID() uint32 {
	return s.remoteRegistrationID
}

// SetLocalRegistrationID sets the local user's registration id.
func (s *State) SetLocalRegistrationID(registrationID uint32) {
	s.localRegistrationID = registrationID
}

// LocalRegistrationID returns the local user's registration id.
func (s *State) LocalRegistrationID() uint32 {
	return s.localRegistrationID
}

// Serialize will return the state as bytes using the given serializer.
func (s *State) Serialize() []byte {
	return s.serializer.Serialize(s.structure())
}

// structure will return a serializable structure of the
// the given state so it can be persistently stored.
func (s *State) structure() *StateStructure {
	// Convert our receiver chains into a serializeable structure
	receiverChains := make([]*ChainStructure, len(s.receiverChains))
	for i := range s.receiverChains {
		receiverChains[i] = s.receiverChains[i].structure()
	}

	// Convert our pending key exchange into a serializeable structure
	var pendingKeyExchange *PendingKeyExchangeStructure
	if s.pendingKeyExchange != nil {
		pendingKeyExchange = s.pendingKeyExchange.structure()
	}

	// Build and return our state structure.
	return &StateStructure{
		LocalIdentityPublic:  s.localIdentityPublic.Serialize(),
		LocalRegistrationID:  s.localRegistrationID,
		NeedsRefresh:         s.needsRefresh,
		PendingKeyExchange:   pendingKeyExchange,
		PendingPreKey:        s.pendingPreKey.structure(),
		PreviousCounter:      s.previousCounter,
		ReceiverChains:       receiverChains,
		RemoteIdentityPublic: s.remoteIdentityPublic.Serialize(),
		RemoteRegistrationID: s.remoteRegistrationID,
		RootKey:              s.rootKey.Bytes(),
		SenderBaseKey:        s.senderBaseKey.Serialize(),
		SenderChain:          s.senderChain.structure(),
		SessionVersion:       s.sessionVersion,
	}
}