summaryrefslogtreecommitdiffstats
path: root/vendor/go.mau.fi/libsignal/session
diff options
context:
space:
mode:
authorWim <wim@42.be>2022-01-31 00:27:37 +0100
committerWim <wim@42.be>2022-03-20 14:57:48 +0100
commite3cafeaf9292f67459ff1d186f68283bfaedf2ae (patch)
treeb69c39620aa91dba695b3b935c6651c0fb37ce75 /vendor/go.mau.fi/libsignal/session
parente7b193788a56ee7cdb02a87a9db0ad6724ef66d5 (diff)
downloadmatterbridge-msglm-e3cafeaf9292f67459ff1d186f68283bfaedf2ae.tar.gz
matterbridge-msglm-e3cafeaf9292f67459ff1d186f68283bfaedf2ae.tar.bz2
matterbridge-msglm-e3cafeaf9292f67459ff1d186f68283bfaedf2ae.zip
Add dependencies/vendor (whatsapp)
Diffstat (limited to 'vendor/go.mau.fi/libsignal/session')
-rw-r--r--vendor/go.mau.fi/libsignal/session/Session.go272
-rw-r--r--vendor/go.mau.fi/libsignal/session/SessionCipher.go366
2 files changed, 638 insertions, 0 deletions
diff --git a/vendor/go.mau.fi/libsignal/session/Session.go b/vendor/go.mau.fi/libsignal/session/Session.go
new file mode 100644
index 00000000..aafac760
--- /dev/null
+++ b/vendor/go.mau.fi/libsignal/session/Session.go
@@ -0,0 +1,272 @@
+// Package session provides the methods necessary to build sessions
+package session
+
+import (
+ "fmt"
+
+ "go.mau.fi/libsignal/ecc"
+ "go.mau.fi/libsignal/keys/prekey"
+ "go.mau.fi/libsignal/logger"
+ "go.mau.fi/libsignal/protocol"
+ "go.mau.fi/libsignal/ratchet"
+ "go.mau.fi/libsignal/serialize"
+ "go.mau.fi/libsignal/signalerror"
+ "go.mau.fi/libsignal/state/record"
+ "go.mau.fi/libsignal/state/store"
+ "go.mau.fi/libsignal/util/medium"
+ "go.mau.fi/libsignal/util/optional"
+)
+
+// NewBuilder constructs a session builder.
+func NewBuilder(sessionStore store.Session, preKeyStore store.PreKey,
+ signedStore store.SignedPreKey, identityStore store.IdentityKey,
+ remoteAddress *protocol.SignalAddress, serializer *serialize.Serializer) *Builder {
+
+ builder := Builder{
+ sessionStore: sessionStore,
+ preKeyStore: preKeyStore,
+ signedPreKeyStore: signedStore,
+ identityKeyStore: identityStore,
+ remoteAddress: remoteAddress,
+ serializer: serializer,
+ }
+
+ return &builder
+}
+
+// NewBuilderFromSignal Store constructs a session builder using a
+// SignalProtocol Store.
+func NewBuilderFromSignal(signalStore store.SignalProtocol,
+ remoteAddress *protocol.SignalAddress, serializer *serialize.Serializer) *Builder {
+
+ builder := Builder{
+ sessionStore: signalStore,
+ preKeyStore: signalStore,
+ signedPreKeyStore: signalStore,
+ identityKeyStore: signalStore,
+ remoteAddress: remoteAddress,
+ serializer: serializer,
+ }
+
+ return &builder
+}
+
+// Builder is responsible for setting up encrypted sessions.
+// Once a session has been established, SessionCipher can be
+// used to encrypt/decrypt messages in that session.
+//
+// Sessions are built from one of three different vectors:
+// * PreKeyBundle retrieved from a server.
+// * PreKeySignalMessage received from a client.
+// * KeyExchangeMessage sent to or received from a client.
+//
+// Sessions are constructed per recipientId + deviceId tuple.
+// Remote logical users are identified by their recipientId,
+// and each logical recipientId can have multiple physical
+// devices.
+type Builder struct {
+ sessionStore store.Session
+ preKeyStore store.PreKey
+ signedPreKeyStore store.SignedPreKey
+ identityKeyStore store.IdentityKey
+ remoteAddress *protocol.SignalAddress
+ serializer *serialize.Serializer
+}
+
+// Process builds a new session from a session record and pre
+// key signal message.
+func (b *Builder) Process(sessionRecord *record.Session, message *protocol.PreKeySignalMessage) (unsignedPreKeyID *optional.Uint32, err error) {
+
+ // Check to see if the keys are trusted.
+ theirIdentityKey := message.IdentityKey()
+ if !(b.identityKeyStore.IsTrustedIdentity(b.remoteAddress, theirIdentityKey)) {
+ return nil, signalerror.ErrUntrustedIdentity
+ }
+
+ // Use version 3 of the signal/axolotl protocol.
+ unsignedPreKeyID, err = b.processV3(sessionRecord, message)
+ if err != nil {
+ return nil, err
+ }
+
+ // Save the identity key to our identity store.
+ b.identityKeyStore.SaveIdentity(b.remoteAddress, theirIdentityKey)
+
+ // Return the unsignedPreKeyID
+ return unsignedPreKeyID, nil
+}
+
+// ProcessV3 builds a new session from a session record and pre key
+// signal message. After a session is constructed in this way, the embedded
+// SignalMessage can be decrypted.
+func (b *Builder) processV3(sessionRecord *record.Session,
+ message *protocol.PreKeySignalMessage) (unsignedPreKeyID *optional.Uint32, err error) {
+
+ logger.Debug("Processing message with PreKeyID: ", message.PreKeyID())
+ // Check to see if we've already set up a session for this V3 message.
+ sessionExists := sessionRecord.HasSessionState(
+ message.MessageVersion(),
+ message.BaseKey().Serialize(),
+ )
+ if sessionExists {
+ logger.Debug("We've already setup a session for this V3 message, letting bundled message fall through...")
+ return optional.NewEmptyUint32(), nil
+ }
+
+ // Load our signed prekey from our signed prekey store.
+ ourSignedPreKeyRecord := b.signedPreKeyStore.LoadSignedPreKey(message.SignedPreKeyID())
+ if ourSignedPreKeyRecord == nil {
+ return nil, fmt.Errorf("%w with ID %d", signalerror.ErrNoSignedPreKey, message.SignedPreKeyID())
+ }
+ ourSignedPreKey := ourSignedPreKeyRecord.KeyPair()
+
+ // Build the parameters of the session.
+ parameters := ratchet.NewEmptyReceiverParameters()
+ parameters.SetTheirBaseKey(message.BaseKey())
+ parameters.SetTheirIdentityKey(message.IdentityKey())
+ parameters.SetOurIdentityKeyPair(b.identityKeyStore.GetIdentityKeyPair())
+ parameters.SetOurSignedPreKey(ourSignedPreKey)
+ parameters.SetOurRatchetKey(ourSignedPreKey)
+
+ // Set our one time pre key with the one from our prekey store
+ // if the message contains a valid pre key id
+ if !message.PreKeyID().IsEmpty {
+ oneTimePreKey := b.preKeyStore.LoadPreKey(message.PreKeyID().Value)
+ if oneTimePreKey == nil {
+ return nil, fmt.Errorf("%w with ID %d", signalerror.ErrNoOneTimeKeyFound, message.PreKeyID().Value)
+ }
+ parameters.SetOurOneTimePreKey(oneTimePreKey.KeyPair())
+ } else {
+ parameters.SetOurOneTimePreKey(nil)
+ }
+
+ // If this is a fresh record, archive our current state.
+ if !sessionRecord.IsFresh() {
+ sessionRecord.ArchiveCurrentState()
+ }
+
+ ///////// Initialize our session /////////
+ sessionState := sessionRecord.SessionState()
+ derivedKeys, sessionErr := ratchet.CalculateReceiverSession(parameters)
+ if sessionErr != nil {
+ return nil, sessionErr
+ }
+ sessionState.SetVersion(protocol.CurrentVersion)
+ sessionState.SetRemoteIdentityKey(parameters.TheirIdentityKey())
+ sessionState.SetLocalIdentityKey(parameters.OurIdentityKeyPair().PublicKey())
+ sessionState.SetSenderChain(parameters.OurRatchetKey(), derivedKeys.ChainKey)
+ sessionState.SetRootKey(derivedKeys.RootKey)
+
+ // Set the session's registration ids and base key
+ sessionState.SetLocalRegistrationID(b.identityKeyStore.GetLocalRegistrationId())
+ sessionState.SetRemoteRegistrationID(message.RegistrationID())
+ sessionState.SetSenderBaseKey(message.BaseKey().Serialize())
+
+ // Remove the PreKey from our store and return the message prekey id if it is valid.
+ if message.PreKeyID() != nil && message.PreKeyID().Value != medium.MaxValue {
+ return message.PreKeyID(), nil
+ }
+ return nil, nil
+}
+
+// ProcessBundle builds a new session from a PreKeyBundle retrieved
+// from a server.
+func (b *Builder) ProcessBundle(preKey *prekey.Bundle) error {
+ // Check to see if the keys are trusted.
+ if !(b.identityKeyStore.IsTrustedIdentity(b.remoteAddress, preKey.IdentityKey())) {
+ return signalerror.ErrUntrustedIdentity
+ }
+
+ // Check to see if the bundle has a signed pre key.
+ if preKey.SignedPreKey() == nil {
+ return signalerror.ErrNoSignedPreKey
+ }
+
+ // Verify the signature of the pre key
+ preKeyPublic := preKey.IdentityKey().PublicKey()
+ preKeyBytes := preKey.SignedPreKey().Serialize()
+ preKeySignature := preKey.SignedPreKeySignature()
+ if !ecc.VerifySignature(preKeyPublic, preKeyBytes, preKeySignature) {
+ return signalerror.ErrInvalidSignature
+ }
+
+ // Load our session and generate keys.
+ sessionRecord := b.sessionStore.LoadSession(b.remoteAddress)
+ ourBaseKey, err := ecc.GenerateKeyPair()
+ if err != nil {
+ return err
+ }
+ theirSignedPreKey := preKey.SignedPreKey()
+ theirOneTimePreKey := preKey.PreKey()
+ theirOneTimePreKeyID := preKey.PreKeyID()
+
+ // Build the parameters of the session
+ parameters := ratchet.NewEmptySenderParameters()
+ parameters.SetOurBaseKey(ourBaseKey)
+ parameters.SetOurIdentityKey(b.identityKeyStore.GetIdentityKeyPair())
+ parameters.SetTheirIdentityKey(preKey.IdentityKey())
+ parameters.SetTheirSignedPreKey(theirSignedPreKey)
+ parameters.SetTheirRatchetKey(theirSignedPreKey)
+ parameters.SetTheirOneTimePreKey(theirOneTimePreKey)
+
+ // If this is a fresh record, archive our current state.
+ if !sessionRecord.IsFresh() {
+ sessionRecord.ArchiveCurrentState()
+ }
+
+ ///////// Initialize our session /////////
+ sessionState := sessionRecord.SessionState()
+ derivedKeys, sessionErr := ratchet.CalculateSenderSession(parameters)
+ if sessionErr != nil {
+ return sessionErr
+ }
+ // Generate an ephemeral "ratchet" key that will be advertised to
+ // the receiving user.
+ sendingRatchetKey, keyErr := ecc.GenerateKeyPair()
+ if keyErr != nil {
+ return keyErr
+ }
+ sendingChain, chainErr := derivedKeys.RootKey.CreateChain(
+ parameters.TheirRatchetKey(),
+ sendingRatchetKey,
+ )
+ if chainErr != nil {
+ return chainErr
+ }
+
+ // Calculate the sender session.
+ sessionState.SetVersion(protocol.CurrentVersion)
+ sessionState.SetRemoteIdentityKey(parameters.TheirIdentityKey())
+ sessionState.SetLocalIdentityKey(parameters.OurIdentityKey().PublicKey())
+ sessionState.AddReceiverChain(parameters.TheirRatchetKey(), derivedKeys.ChainKey.Current())
+ sessionState.SetSenderChain(sendingRatchetKey, sendingChain.ChainKey)
+ sessionState.SetRootKey(sendingChain.RootKey)
+
+ // Update our session record with the unackowledged prekey message
+ sessionState.SetUnacknowledgedPreKeyMessage(
+ theirOneTimePreKeyID,
+ preKey.SignedPreKeyID(),
+ ourBaseKey.PublicKey(),
+ )
+
+ // Set the local registration ID based on the registration id in our identity key store.
+ sessionState.SetLocalRegistrationID(
+ b.identityKeyStore.GetLocalRegistrationId(),
+ )
+
+ // Set the remote registration ID based on the given prekey bundle registrationID.
+ sessionState.SetRemoteRegistrationID(
+ preKey.RegistrationID(),
+ )
+
+ // Set the sender base key in our session record state.
+ sessionState.SetSenderBaseKey(
+ ourBaseKey.PublicKey().Serialize(),
+ )
+
+ // Store the session in our session store and save the identity in our identity store.
+ b.sessionStore.StoreSession(b.remoteAddress, sessionRecord)
+ b.identityKeyStore.SaveIdentity(b.remoteAddress, preKey.IdentityKey())
+
+ return nil
+}
diff --git a/vendor/go.mau.fi/libsignal/session/SessionCipher.go b/vendor/go.mau.fi/libsignal/session/SessionCipher.go
new file mode 100644
index 00000000..a70812b9
--- /dev/null
+++ b/vendor/go.mau.fi/libsignal/session/SessionCipher.go
@@ -0,0 +1,366 @@
+package session
+
+import (
+ "fmt"
+
+ "go.mau.fi/libsignal/cipher"
+ "go.mau.fi/libsignal/ecc"
+ "go.mau.fi/libsignal/keys/chain"
+ "go.mau.fi/libsignal/keys/message"
+ "go.mau.fi/libsignal/logger"
+ "go.mau.fi/libsignal/protocol"
+ "go.mau.fi/libsignal/signalerror"
+ "go.mau.fi/libsignal/state/record"
+ "go.mau.fi/libsignal/state/store"
+ "go.mau.fi/libsignal/util/bytehelper"
+)
+
+const maxFutureMessages = 2000
+
+// NewCipher constructs a session cipher for encrypt/decrypt operations on a
+// session. In order to use the session cipher, a session must have already
+// been created and stored using session.Builder.
+func NewCipher(builder *Builder, remoteAddress *protocol.SignalAddress) *Cipher {
+ cipher := &Cipher{
+ sessionStore: builder.sessionStore,
+ preKeyMessageSerializer: builder.serializer.PreKeySignalMessage,
+ signalMessageSerializer: builder.serializer.SignalMessage,
+ preKeyStore: builder.preKeyStore,
+ remoteAddress: remoteAddress,
+ builder: builder,
+ identityKeyStore: builder.identityKeyStore,
+ }
+
+ return cipher
+}
+
+func NewCipherFromSession(remoteAddress *protocol.SignalAddress,
+ sessionStore store.Session, preKeyStore store.PreKey, identityKeyStore store.IdentityKey,
+ preKeyMessageSerializer protocol.PreKeySignalMessageSerializer,
+ signalMessageSerializer protocol.SignalMessageSerializer) *Cipher {
+ cipher := &Cipher{
+ sessionStore: sessionStore,
+ preKeyMessageSerializer: preKeyMessageSerializer,
+ signalMessageSerializer: signalMessageSerializer,
+ preKeyStore: preKeyStore,
+ remoteAddress: remoteAddress,
+ identityKeyStore: identityKeyStore,
+ }
+
+ return cipher
+}
+
+// Cipher is the main entry point for Signal Protocol encrypt/decrypt operations.
+// Once a session has been established with session.Builder, this can be used for
+// all encrypt/decrypt operations within that session.
+type Cipher struct {
+ sessionStore store.Session
+ preKeyMessageSerializer protocol.PreKeySignalMessageSerializer
+ signalMessageSerializer protocol.SignalMessageSerializer
+ preKeyStore store.PreKey
+ remoteAddress *protocol.SignalAddress
+ builder *Builder
+ identityKeyStore store.IdentityKey
+}
+
+// Encrypt will take the given message in bytes and return an object that follows
+// the CiphertextMessage interface.
+func (d *Cipher) Encrypt(plaintext []byte) (protocol.CiphertextMessage, error) {
+ sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
+ sessionState := sessionRecord.SessionState()
+ chainKey := sessionState.SenderChainKey()
+ messageKeys := chainKey.MessageKeys()
+ senderEphemeral := sessionState.SenderRatchetKey()
+ previousCounter := sessionState.PreviousCounter()
+ sessionVersion := sessionState.Version()
+
+ ciphertextBody, err := encrypt(messageKeys, plaintext)
+ logger.Debug("Got ciphertextBody: ", ciphertextBody)
+ if err != nil {
+ return nil, err
+ }
+
+ var ciphertextMessage protocol.CiphertextMessage
+ ciphertextMessage, err = protocol.NewSignalMessage(
+ sessionVersion,
+ chainKey.Index(),
+ previousCounter,
+ messageKeys.MacKey(),
+ senderEphemeral,
+ ciphertextBody,
+ sessionState.LocalIdentityKey(),
+ sessionState.RemoteIdentityKey(),
+ d.signalMessageSerializer,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ // If we haven't established a session with the recipient yet,
+ // send our message as a PreKeySignalMessage.
+ if sessionState.HasUnacknowledgedPreKeyMessage() {
+ items, err := sessionState.UnackPreKeyMessageItems()
+ if err != nil {
+ return nil, err
+ }
+ localRegistrationID := sessionState.LocalRegistrationID()
+
+ ciphertextMessage, err = protocol.NewPreKeySignalMessage(
+ sessionVersion,
+ localRegistrationID,
+ items.PreKeyID(),
+ items.SignedPreKeyID(),
+ items.BaseKey(),
+ sessionState.LocalIdentityKey(),
+ ciphertextMessage.(*protocol.SignalMessage),
+ d.preKeyMessageSerializer,
+ d.signalMessageSerializer,
+ )
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ sessionState.SetSenderChainKey(chainKey.NextKey())
+ if !d.identityKeyStore.IsTrustedIdentity(d.remoteAddress, sessionState.RemoteIdentityKey()) {
+ // return err
+ }
+ d.identityKeyStore.SaveIdentity(d.remoteAddress, sessionState.RemoteIdentityKey())
+ d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)
+ return ciphertextMessage, nil
+}
+
+// Decrypt decrypts the given message using an existing session that
+// is stored in the session store.
+func (d *Cipher) Decrypt(ciphertextMessage *protocol.SignalMessage) ([]byte, error) {
+ plaintext, _, err := d.DecryptAndGetKey(ciphertextMessage)
+
+ return plaintext, err
+}
+
+// DecryptAndGetKey decrypts the given message using an existing session that
+// is stored in the session store and returns the message keys used for encryption.
+func (d *Cipher) DecryptAndGetKey(ciphertextMessage *protocol.SignalMessage) ([]byte, *message.Keys, error) {
+ if !d.sessionStore.ContainsSession(d.remoteAddress) {
+ return nil, nil, fmt.Errorf("%w %s", signalerror.ErrNoSessionForUser, d.remoteAddress.String())
+ }
+
+ // Load the session record from our session store and decrypt the message.
+ sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
+ plaintext, messageKeys, err := d.DecryptWithRecord(sessionRecord, ciphertextMessage)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ if !d.identityKeyStore.IsTrustedIdentity(d.remoteAddress, sessionRecord.SessionState().RemoteIdentityKey()) {
+ // return err
+ }
+ d.identityKeyStore.SaveIdentity(d.remoteAddress, sessionRecord.SessionState().RemoteIdentityKey())
+
+ // Store the session record in our session store.
+ d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)
+ return plaintext, messageKeys, nil
+}
+
+func (d *Cipher) DecryptMessage(ciphertextMessage *protocol.PreKeySignalMessage) ([]byte, error) {
+ plaintext, _, err := d.DecryptMessageReturnKey(ciphertextMessage)
+ return plaintext, err
+}
+
+func (d *Cipher) DecryptMessageReturnKey(ciphertextMessage *protocol.PreKeySignalMessage) ([]byte, *message.Keys, error) {
+ // Load or create session record for this session.
+ sessionRecord := d.sessionStore.LoadSession(d.remoteAddress)
+ unsignedPreKeyID, err := d.builder.Process(sessionRecord, ciphertextMessage)
+ if err != nil {
+ return nil, nil, err
+ }
+ plaintext, keys, err := d.DecryptWithRecord(sessionRecord, ciphertextMessage.WhisperMessage())
+ if err != nil {
+ return nil, nil, err
+ }
+ // Store the session record in our session store.
+ d.sessionStore.StoreSession(d.remoteAddress, sessionRecord)
+ if !unsignedPreKeyID.IsEmpty {
+ d.preKeyStore.RemovePreKey(unsignedPreKeyID.Value)
+ }
+ return plaintext, keys, nil
+}
+
+// DecryptWithKey will decrypt the given message using the given symmetric key. This
+// can be used when decrypting messages at a later time if the message key was saved.
+func (d *Cipher) DecryptWithKey(ciphertextMessage *protocol.SignalMessage, key *message.Keys) ([]byte, error) {
+ logger.Debug("Decrypting ciphertext body: ", ciphertextMessage.Body())
+ plaintext, err := decrypt(key, ciphertextMessage.Body())
+ if err != nil {
+ logger.Error("Unable to get plain text from ciphertext: ", err)
+ return nil, err
+ }
+
+ return plaintext, nil
+}
+
+// DecryptWithRecord decrypts the given message using the given session record.
+func (d *Cipher) DecryptWithRecord(sessionRecord *record.Session, ciphertext *protocol.SignalMessage) ([]byte, *message.Keys, error) {
+ logger.Debug("Decrypting ciphertext with record: ", sessionRecord)
+ previousStates := sessionRecord.PreviousSessionStates()
+ sessionState := sessionRecord.SessionState()
+
+ // Try and decrypt the message with the current session state.
+ plaintext, messageKeys, err := d.DecryptWithState(sessionState, ciphertext)
+
+ // If we received an error using the current session state, loop
+ // through all previous states.
+ if err != nil {
+ logger.Warning(err)
+ for i, state := range previousStates {
+ // Try decrypting the message with previous states
+ plaintext, messageKeys, err = d.DecryptWithState(state, ciphertext)
+ if err != nil {
+ continue
+ }
+
+ // If successful, remove and promote the state.
+ previousStates = append(previousStates[:i], previousStates[i+1:]...)
+ sessionRecord.PromoteState(state)
+
+ return plaintext, messageKeys, nil
+ }
+
+ return nil, nil, signalerror.ErrNoValidSessions
+ }
+
+ // If decryption was successful, set the session state and return the plain text.
+ sessionRecord.SetState(sessionState)
+
+ return plaintext, messageKeys, nil
+}
+
+// DecryptWithState decrypts the given message with the given session state.
+func (d *Cipher) DecryptWithState(sessionState *record.State, ciphertextMessage *protocol.SignalMessage) ([]byte, *message.Keys, error) {
+ logger.Debug("Decrypting ciphertext with session state: ", sessionState)
+ if !sessionState.HasSenderChain() {
+ logger.Error("Unable to decrypt message with state: ", signalerror.ErrUninitializedSession)
+ return nil, nil, signalerror.ErrUninitializedSession
+ }
+
+ if ciphertextMessage.MessageVersion() != sessionState.Version() {
+ logger.Error("Unable to decrypt message with state: ", signalerror.ErrWrongMessageVersion)
+ return nil, nil, signalerror.ErrWrongMessageVersion
+ }
+
+ messageVersion := ciphertextMessage.MessageVersion()
+ theirEphemeral := ciphertextMessage.SenderRatchetKey()
+ counter := ciphertextMessage.Counter()
+ chainKey, chainCreateErr := getOrCreateChainKey(sessionState, theirEphemeral)
+ if chainCreateErr != nil {
+ logger.Error("Unable to get or create chain key: ", chainCreateErr)
+ return nil, nil, fmt.Errorf("failed to get or create chain key: %w", chainCreateErr)
+ }
+
+ messageKeys, keysCreateErr := getOrCreateMessageKeys(sessionState, theirEphemeral, chainKey, counter)
+ if keysCreateErr != nil {
+ logger.Error("Unable to get or create message keys: ", keysCreateErr)
+ return nil, nil, fmt.Errorf("failed to get or create message keys: %w", keysCreateErr)
+ }
+
+ err := ciphertextMessage.VerifyMac(messageVersion, sessionState.RemoteIdentityKey(), sessionState.LocalIdentityKey(), messageKeys.MacKey())
+ if err != nil {
+ logger.Error("Unable to verify ciphertext mac: ", err)
+ return nil, nil, fmt.Errorf("failed to verify ciphertext MAC: %w", err)
+ }
+
+ plaintext, err := d.DecryptWithKey(ciphertextMessage, messageKeys)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ sessionState.ClearUnackPreKeyMessage()
+
+ return plaintext, messageKeys, nil
+}
+
+func getOrCreateMessageKeys(sessionState *record.State, theirEphemeral ecc.ECPublicKeyable,
+ chainKey *chain.Key, counter uint32) (*message.Keys, error) {
+
+ if chainKey.Index() > counter {
+ if sessionState.HasMessageKeys(theirEphemeral, counter) {
+ return sessionState.RemoveMessageKeys(theirEphemeral, counter), nil
+ }
+ return nil, fmt.Errorf("%w (index: %d, count: %d)", signalerror.ErrOldCounter, chainKey.Index(), counter)
+ }
+
+ if counter-chainKey.Index() > maxFutureMessages {
+ return nil, signalerror.ErrTooFarIntoFuture
+ }
+
+ for chainKey.Index() < counter {
+ messageKeys := chainKey.MessageKeys()
+ sessionState.SetMessageKeys(theirEphemeral, messageKeys)
+ chainKey = chainKey.NextKey()
+ }
+
+ sessionState.SetReceiverChainKey(theirEphemeral, chainKey.NextKey())
+ return chainKey.MessageKeys(), nil
+}
+
+// getOrCreateChainKey will either return the existing chain key or
+// create a new one with the given session state and ephemeral key.
+func getOrCreateChainKey(sessionState *record.State, theirEphemeral ecc.ECPublicKeyable) (*chain.Key, error) {
+
+ // If our session state already has a receiver chain, use their
+ // ephemeral key in the existing chain.
+ if sessionState.HasReceiverChain(theirEphemeral) {
+ return sessionState.ReceiverChainKey(theirEphemeral), nil
+ }
+
+ // If we don't have a chain key, create one with ephemeral keys.
+ rootKey := sessionState.RootKey()
+ ourEphemeral := sessionState.SenderRatchetKeyPair()
+ receiverChain, rErr := rootKey.CreateChain(theirEphemeral, ourEphemeral)
+ if rErr != nil {
+ return nil, rErr
+ }
+
+ // Generate a new ephemeral key pair.
+ ourNewEphemeral, gErr := ecc.GenerateKeyPair()
+ if gErr != nil {
+ return nil, gErr
+ }
+
+ // Create a new chain using our new ephemeral key.
+ senderChain, cErr := receiverChain.RootKey.CreateChain(theirEphemeral, ourNewEphemeral)
+ if cErr != nil {
+ return nil, cErr
+ }
+
+ // Set our session state parameters.
+ sessionState.SetRootKey(senderChain.RootKey)
+ sessionState.AddReceiverChain(theirEphemeral, receiverChain.ChainKey)
+ previousCounter := max(sessionState.SenderChainKey().Index()-1, 0)
+ sessionState.SetPreviousCounter(previousCounter)
+ sessionState.SetSenderChain(ourNewEphemeral, senderChain.ChainKey)
+
+ return receiverChain.ChainKey.(*chain.Key), nil
+}
+
+// decrypt will use the given message keys and ciphertext and return
+// the plaintext bytes.
+func decrypt(keys *message.Keys, body []byte) ([]byte, error) {
+ logger.Debug("Using cipherKey: ", keys.CipherKey())
+ return cipher.DecryptCbc(keys.Iv(), keys.CipherKey(), bytehelper.CopySlice(body))
+}
+
+// encrypt will use the given cipher, message keys, and plaintext bytes
+// and return ciphertext bytes.
+func encrypt(messageKeys *message.Keys, plaintext []byte) ([]byte, error) {
+ logger.Debug("Using cipherKey: ", messageKeys.CipherKey())
+ return cipher.EncryptCbc(messageKeys.Iv(), messageKeys.CipherKey(), plaintext)
+}
+
+// Max is a uint32 implementation of math.Max
+func max(x, y uint32) uint32 {
+ if x > y {
+ return x
+ }
+ return y
+}