diff options
Diffstat (limited to 'vendor/go.mau.fi/whatsmeow/socket')
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go | 23 | ||||
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/noisesocket.go | 4 |
2 files changed, 10 insertions, 17 deletions
diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go index 26f7f263..3add4705 100644 --- a/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go +++ b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go @@ -7,7 +7,6 @@ package socket import ( - "crypto/aes" "crypto/cipher" "crypto/sha256" "fmt" @@ -16,6 +15,8 @@ import ( "golang.org/x/crypto/curve25519" "golang.org/x/crypto/hkdf" + + "go.mau.fi/whatsmeow/util/gcmutil" ) type NoiseHandshake struct { @@ -29,18 +30,6 @@ func NewNoiseHandshake() *NoiseHandshake { return &NoiseHandshake{} } -func newCipher(key []byte) (cipher.AEAD, error) { - aesCipher, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - aesGCM, err := cipher.NewGCM(aesCipher) - if err != nil { - return nil, err - } - return aesGCM, nil -} - func sha256Slice(data []byte) []byte { hash := sha256.Sum256(data) return hash[:] @@ -55,7 +44,7 @@ func (nh *NoiseHandshake) Start(pattern string, header []byte) { } nh.salt = nh.hash var err error - nh.key, err = newCipher(nh.hash) + nh.key, err = gcmutil.Prepare(nh.hash) if err != nil { panic(err) } @@ -88,9 +77,9 @@ func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err erro func (nh *NoiseHandshake) Finish(fs *FrameSocket, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) { if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil { return nil, fmt.Errorf("failed to extract final keys: %w", err) - } else if writeKey, err := newCipher(write); err != nil { + } else if writeKey, err := gcmutil.Prepare(write); err != nil { return nil, fmt.Errorf("failed to create final write cipher: %w", err) - } else if readKey, err := newCipher(read); err != nil { + } else if readKey, err := gcmutil.Prepare(read); err != nil { return nil, fmt.Errorf("failed to create final read cipher: %w", err) } else if ns, err := newNoiseSocket(fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil { return nil, fmt.Errorf("failed to create noise socket: %w", err) @@ -114,7 +103,7 @@ func (nh *NoiseHandshake) MixIntoKey(data []byte) error { return fmt.Errorf("failed to extract keys for mixing: %w", err) } nh.salt = write - nh.key, err = newCipher(read) + nh.key, err = gcmutil.Prepare(read) if err != nil { return fmt.Errorf("failed to create new cipher while mixing keys: %w", err) } diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go index 23bb44e8..85973d72 100644 --- a/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go +++ b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go @@ -47,6 +47,10 @@ func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler } func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) { + if ctx == nil { + // ctx being nil implies the connection already closed somehow + return + } ctxDone := ctx.Done() for { select { |