diff options
Diffstat (limited to 'vendor/go.mau.fi/whatsmeow/socket')
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/constants.go | 40 | ||||
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/framesocket.go | 228 | ||||
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go | 135 | ||||
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/noisesocket.go | 104 |
4 files changed, 507 insertions, 0 deletions
diff --git a/vendor/go.mau.fi/whatsmeow/socket/constants.go b/vendor/go.mau.fi/whatsmeow/socket/constants.go new file mode 100644 index 00000000..88e9a90b --- /dev/null +++ b/vendor/go.mau.fi/whatsmeow/socket/constants.go @@ -0,0 +1,40 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package socket implements a subset of the Noise protocol framework on top of websockets as used by WhatsApp. +// +// There shouldn't be any need to manually interact with this package. +// The Client struct in the top-level whatsmeow package handles everything. +package socket + +import "errors" + +const ( + // Origin is the Origin header for all WhatsApp websocket connections + Origin = "https://web.whatsapp.com" + // URL is the websocket URL for the new multidevice protocol + URL = "wss://web.whatsapp.com/ws/chat" +) + +const ( + NoiseStartPattern = "Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + + WADictVersion = 2 + WAMagicValue = 5 +) + +var WAConnHeader = []byte{'W', 'A', WAMagicValue, WADictVersion} + +const ( + FrameMaxSize = 2 << 23 + FrameLengthSize = 3 +) + +var ( + ErrFrameTooLarge = errors.New("frame too large") + ErrSocketClosed = errors.New("frame socket is closed") + ErrSocketAlreadyOpen = errors.New("frame socket is already open") +) diff --git a/vendor/go.mau.fi/whatsmeow/socket/framesocket.go b/vendor/go.mau.fi/whatsmeow/socket/framesocket.go new file mode 100644 index 00000000..2bcb21b5 --- /dev/null +++ b/vendor/go.mau.fi/whatsmeow/socket/framesocket.go @@ -0,0 +1,228 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package socket + +import ( + "context" + "errors" + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + + waLog "go.mau.fi/whatsmeow/util/log" +) + +type FrameSocket struct { + conn *websocket.Conn + ctx context.Context + cancel func() + log waLog.Logger + lock sync.Mutex + + Frames chan []byte + OnDisconnect func(remote bool) + WriteTimeout time.Duration + + Header []byte + + incomingLength int + receivedLength int + incoming []byte + partialHeader []byte +} + +func NewFrameSocket(log waLog.Logger, header []byte) *FrameSocket { + return &FrameSocket{ + conn: nil, + log: log, + Header: header, + Frames: make(chan []byte), + } +} + +func (fs *FrameSocket) IsConnected() bool { + return fs.conn != nil +} + +func (fs *FrameSocket) Context() context.Context { + return fs.ctx +} + +func (fs *FrameSocket) Close(code int) { + fs.lock.Lock() + defer fs.lock.Unlock() + + if fs.conn == nil { + return + } + + if code > 0 { + message := websocket.FormatCloseMessage(code, "") + err := fs.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second)) + if err != nil { + fs.log.Warnf("Error sending close message: %v", err) + } + } + + fs.cancel() + err := fs.conn.Close() + if err != nil { + fs.log.Errorf("Error closing websocket: %v", err) + } + fs.conn = nil + fs.ctx = nil + fs.cancel = nil + if fs.OnDisconnect != nil { + go fs.OnDisconnect(code == 0) + } +} + +func (fs *FrameSocket) Connect() error { + fs.lock.Lock() + defer fs.lock.Unlock() + + if fs.conn != nil { + return ErrSocketAlreadyOpen + } + ctx, cancel := context.WithCancel(context.Background()) + dialer := websocket.Dialer{} + + headers := http.Header{"Origin": []string{Origin}} + fs.log.Debugf("Dialing %s", URL) + conn, _, err := dialer.Dial(URL, headers) + if err != nil { + cancel() + return fmt.Errorf("couldn't dial whatsapp web websocket: %w", err) + } + + fs.ctx, fs.cancel = ctx, cancel + fs.conn = conn + conn.SetCloseHandler(func(code int, text string) error { + fs.log.Debugf("Server closed websocket with status %d/%s", code, text) + cancel() + // from default CloseHandler + message := websocket.FormatCloseMessage(code, "") + _ = conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second)) + return nil + }) + + go fs.readPump(conn, ctx) + return nil +} + +func (fs *FrameSocket) SendFrame(data []byte) error { + conn := fs.conn + if conn == nil { + return ErrSocketClosed + } + dataLength := len(data) + if dataLength >= FrameMaxSize { + return fmt.Errorf("%w (got %d bytes, max %d bytes)", ErrFrameTooLarge, len(data), FrameMaxSize) + } + + headerLength := len(fs.Header) + // Whole frame is header + 3 bytes for length + data + wholeFrame := make([]byte, headerLength+FrameLengthSize+dataLength) + + // Copy the header if it's there + if fs.Header != nil { + copy(wholeFrame[:headerLength], fs.Header) + // We only want to send the header once + fs.Header = nil + } + + // Encode length of frame + wholeFrame[headerLength] = byte(dataLength >> 16) + wholeFrame[headerLength+1] = byte(dataLength >> 8) + wholeFrame[headerLength+2] = byte(dataLength) + + // Copy actual frame data + copy(wholeFrame[headerLength+FrameLengthSize:], data) + + if fs.WriteTimeout > 0 { + err := conn.SetWriteDeadline(time.Now().Add(fs.WriteTimeout)) + if err != nil { + fs.log.Warnf("Failed to set write deadline: %v", err) + } + } + return conn.WriteMessage(websocket.BinaryMessage, wholeFrame) +} + +func (fs *FrameSocket) frameComplete() { + data := fs.incoming + fs.incoming = nil + fs.partialHeader = nil + fs.incomingLength = 0 + fs.receivedLength = 0 + fs.Frames <- data +} + +func (fs *FrameSocket) processData(msg []byte) { + for len(msg) > 0 { + // This probably doesn't happen a lot (if at all), so the code is unoptimized + if fs.partialHeader != nil { + msg = append(fs.partialHeader, msg...) + fs.partialHeader = nil + } + if fs.incoming == nil { + if len(msg) >= FrameLengthSize { + length := (int(msg[0]) << 16) + (int(msg[1]) << 8) + int(msg[2]) + fs.incomingLength = length + fs.receivedLength = len(msg) + msg = msg[FrameLengthSize:] + if len(msg) >= length { + fs.incoming = msg[:length] + msg = msg[length:] + fs.frameComplete() + } else { + fs.incoming = make([]byte, length) + copy(fs.incoming, msg) + msg = nil + } + } else { + fs.log.Warnf("Received partial header (report if this happens often)") + fs.partialHeader = msg + msg = nil + } + } else { + if len(fs.incoming)+len(msg) >= fs.incomingLength { + copy(fs.incoming[fs.receivedLength:], msg[:fs.incomingLength-fs.receivedLength]) + msg = msg[fs.incomingLength-fs.receivedLength:] + fs.frameComplete() + } else { + copy(fs.incoming[fs.receivedLength:], msg) + fs.receivedLength += len(msg) + msg = nil + } + } + } +} + +func (fs *FrameSocket) readPump(conn *websocket.Conn, ctx context.Context) { + fs.log.Debugf("Frame websocket read pump starting %p", fs) + defer func() { + fs.log.Debugf("Frame websocket read pump exiting %p", fs) + go fs.Close(0) + }() + for { + msgType, data, err := conn.ReadMessage() + if err != nil { + // Ignore the error if the context has been closed + if !errors.Is(ctx.Err(), context.Canceled) { + fs.log.Errorf("Error reading from websocket: %v", err) + } + return + } else if msgType != websocket.BinaryMessage { + fs.log.Warnf("Got unexpected websocket message type %d", msgType) + continue + } + fs.processData(data) + } +} diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go new file mode 100644 index 00000000..26f7f263 --- /dev/null +++ b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go @@ -0,0 +1,135 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package socket + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha256" + "fmt" + "io" + "sync/atomic" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/hkdf" +) + +type NoiseHandshake struct { + hash []byte + salt []byte + key cipher.AEAD + counter uint32 +} + +func NewNoiseHandshake() *NoiseHandshake { + return &NoiseHandshake{} +} + +func newCipher(key []byte) (cipher.AEAD, error) { + aesCipher, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + aesGCM, err := cipher.NewGCM(aesCipher) + if err != nil { + return nil, err + } + return aesGCM, nil +} + +func sha256Slice(data []byte) []byte { + hash := sha256.Sum256(data) + return hash[:] +} + +func (nh *NoiseHandshake) Start(pattern string, header []byte) { + data := []byte(pattern) + if len(data) == 32 { + nh.hash = data + } else { + nh.hash = sha256Slice(data) + } + nh.salt = nh.hash + var err error + nh.key, err = newCipher(nh.hash) + if err != nil { + panic(err) + } + nh.Authenticate(header) +} + +func (nh *NoiseHandshake) Authenticate(data []byte) { + nh.hash = sha256Slice(append(nh.hash, data...)) +} + +func (nh *NoiseHandshake) postIncrementCounter() uint32 { + count := atomic.AddUint32(&nh.counter, 1) + return count - 1 +} + +func (nh *NoiseHandshake) Encrypt(plaintext []byte) []byte { + ciphertext := nh.key.Seal(nil, generateIV(nh.postIncrementCounter()), plaintext, nh.hash) + nh.Authenticate(ciphertext) + return ciphertext +} + +func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err error) { + plaintext, err = nh.key.Open(nil, generateIV(nh.postIncrementCounter()), ciphertext, nh.hash) + if err == nil { + nh.Authenticate(ciphertext) + } + return +} + +func (nh *NoiseHandshake) Finish(fs *FrameSocket, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) { + if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil { + return nil, fmt.Errorf("failed to extract final keys: %w", err) + } else if writeKey, err := newCipher(write); err != nil { + return nil, fmt.Errorf("failed to create final write cipher: %w", err) + } else if readKey, err := newCipher(read); err != nil { + return nil, fmt.Errorf("failed to create final read cipher: %w", err) + } else if ns, err := newNoiseSocket(fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil { + return nil, fmt.Errorf("failed to create noise socket: %w", err) + } else { + return ns, nil + } +} + +func (nh *NoiseHandshake) MixSharedSecretIntoKey(priv, pub [32]byte) error { + secret, err := curve25519.X25519(priv[:], pub[:]) + if err != nil { + return fmt.Errorf("failed to do x25519 scalar multiplication: %w", err) + } + return nh.MixIntoKey(secret) +} + +func (nh *NoiseHandshake) MixIntoKey(data []byte) error { + nh.counter = 0 + write, read, err := nh.extractAndExpand(nh.salt, data) + if err != nil { + return fmt.Errorf("failed to extract keys for mixing: %w", err) + } + nh.salt = write + nh.key, err = newCipher(read) + if err != nil { + return fmt.Errorf("failed to create new cipher while mixing keys: %w", err) + } + return nil +} + +func (nh *NoiseHandshake) extractAndExpand(salt, data []byte) (write []byte, read []byte, err error) { + h := hkdf.New(sha256.New, data, salt, nil) + write = make([]byte, 32) + read = make([]byte, 32) + + if _, err = io.ReadFull(h, write); err != nil { + err = fmt.Errorf("failed to read write key: %w", err) + } else if _, err = io.ReadFull(h, read); err != nil { + err = fmt.Errorf("failed to read read key: %w", err) + } + return +} diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go new file mode 100644 index 00000000..23bb44e8 --- /dev/null +++ b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go @@ -0,0 +1,104 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package socket + +import ( + "context" + "crypto/cipher" + "encoding/binary" + "sync" + "sync/atomic" + + "github.com/gorilla/websocket" +) + +type NoiseSocket struct { + fs *FrameSocket + onFrame FrameHandler + writeKey cipher.AEAD + readKey cipher.AEAD + writeCounter uint32 + readCounter uint32 + writeLock sync.Mutex + destroyed uint32 + stopConsumer chan struct{} +} + +type DisconnectHandler func(socket *NoiseSocket, remote bool) +type FrameHandler func([]byte) + +func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) { + ns := &NoiseSocket{ + fs: fs, + writeKey: writeKey, + readKey: readKey, + onFrame: frameHandler, + stopConsumer: make(chan struct{}), + } + fs.OnDisconnect = func(remote bool) { + disconnectHandler(ns, remote) + } + go ns.consumeFrames(fs.ctx, fs.Frames) + return ns, nil +} + +func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) { + ctxDone := ctx.Done() + for { + select { + case frame := <-frames: + ns.receiveEncryptedFrame(frame) + case <-ctxDone: + return + case <-ns.stopConsumer: + return + } + } +} + +func generateIV(count uint32) []byte { + iv := make([]byte, 12) + binary.BigEndian.PutUint32(iv[8:], count) + return iv +} + +func (ns *NoiseSocket) Context() context.Context { + return ns.fs.Context() +} + +func (ns *NoiseSocket) Stop(disconnect bool) { + if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) { + close(ns.stopConsumer) + ns.fs.OnDisconnect = nil + if disconnect { + ns.fs.Close(websocket.CloseNormalClosure) + } + } +} + +func (ns *NoiseSocket) SendFrame(plaintext []byte) error { + ns.writeLock.Lock() + ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil) + ns.writeCounter++ + err := ns.fs.SendFrame(ciphertext) + ns.writeLock.Unlock() + return err +} + +func (ns *NoiseSocket) receiveEncryptedFrame(ciphertext []byte) { + count := atomic.AddUint32(&ns.readCounter, 1) - 1 + plaintext, err := ns.readKey.Open(nil, generateIV(count), ciphertext, nil) + if err != nil { + ns.fs.log.Warnf("Failed to decrypt frame: %v", err) + return + } + ns.onFrame(plaintext) +} + +func (ns *NoiseSocket) IsConnected() bool { + return ns.fs.IsConnected() +} |