diff options
Diffstat (limited to 'vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go')
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go | 135 |
1 files changed, 135 insertions, 0 deletions
diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go new file mode 100644 index 00000000..26f7f263 --- /dev/null +++ b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go @@ -0,0 +1,135 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package socket + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "fmt" + "io" + "sync/atomic" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +type NoiseHandshake struct { + hash []byte + salt []byte + key cipher.AEAD + counter uint32 +} + +func NewNoiseHandshake() *NoiseHandshake { + return &NoiseHandshake{} +} + +func newCipher(key []byte) (cipher.AEAD, error) { + aesCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesGCM, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, err + } + return aesGCM, nil +} + +func sha256Slice(data []byte) []byte { + hash := sha256.Sum256(data) + return hash[:] +} + +func (nh *NoiseHandshake) Start(pattern string, header []byte) { + data := []byte(pattern) + if len(data) == 32 { + nh.hash = data + } else { + nh.hash = sha256Slice(data) + } + nh.salt = nh.hash + var err error + nh.key, err = newCipher(nh.hash) + if err != nil { + panic(err) + } + nh.Authenticate(header) +} + +func (nh *NoiseHandshake) Authenticate(data []byte) { + nh.hash = sha256Slice(append(nh.hash, data...)) +} + +func (nh *NoiseHandshake) postIncrementCounter() uint32 { + count := atomic.AddUint32(&nh.counter, 1) + return count - 1 +} + +func (nh *NoiseHandshake) Encrypt(plaintext []byte) []byte { + ciphertext := nh.key.Seal(nil, generateIV(nh.postIncrementCounter()), plaintext, nh.hash) + nh.Authenticate(ciphertext) + return ciphertext +} + +func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err error) { + plaintext, err = nh.key.Open(nil, generateIV(nh.postIncrementCounter()), ciphertext, nh.hash) + if err == nil { + nh.Authenticate(ciphertext) + } + return +} + +func (nh *NoiseHandshake) Finish(fs *FrameSocket, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) { + if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil { + return nil, fmt.Errorf("failed to extract final keys: %w", err) + } else if writeKey, err := newCipher(write); err != nil { + return nil, fmt.Errorf("failed to create final write cipher: %w", err) + } else if readKey, err := newCipher(read); err != nil { + return nil, fmt.Errorf("failed to create final read cipher: %w", err) + } else if ns, err := newNoiseSocket(fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil { + return nil, fmt.Errorf("failed to create noise socket: %w", err) + } else { + return ns, nil + } +} + +func (nh *NoiseHandshake) MixSharedSecretIntoKey(priv, pub [32]byte) error { + secret, err := curve25519.X25519(priv[:], pub[:]) + if err != nil { + return fmt.Errorf("failed to do x25519 scalar multiplication: %w", err) + } + return nh.MixIntoKey(secret) +} + +func (nh *NoiseHandshake) MixIntoKey(data []byte) error { + nh.counter = 0 + write, read, err := nh.extractAndExpand(nh.salt, data) + if err != nil { + return fmt.Errorf("failed to extract keys for mixing: %w", err) + } + nh.salt = write + nh.key, err = newCipher(read) + if err != nil { + return fmt.Errorf("failed to create new cipher while mixing keys: %w", err) + } + return nil +} + +func (nh *NoiseHandshake) extractAndExpand(salt, data []byte) (write []byte, read []byte, err error) { + h := hkdf.New(sha256.New, data, salt, nil) + write = make([]byte, 32) + read = make([]byte, 32) + + if _, err = io.ReadFull(h, write); err != nil { + err = fmt.Errorf("failed to read write key: %w", err) + } else if _, err = io.ReadFull(h, read); err != nil { + err = fmt.Errorf("failed to read read key: %w", err) + } + return +} |