summaryrefslogtreecommitdiffstats
path: root/vendor/go.mau.fi/whatsmeow/socket
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/go.mau.fi/whatsmeow/socket')
-rw-r--r--vendor/go.mau.fi/whatsmeow/socket/constants.go40
-rw-r--r--vendor/go.mau.fi/whatsmeow/socket/framesocket.go228
-rw-r--r--vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go135
-rw-r--r--vendor/go.mau.fi/whatsmeow/socket/noisesocket.go104
4 files changed, 507 insertions, 0 deletions
diff --git a/vendor/go.mau.fi/whatsmeow/socket/constants.go b/vendor/go.mau.fi/whatsmeow/socket/constants.go
new file mode 100644
index 00000000..88e9a90b
--- /dev/null
+++ b/vendor/go.mau.fi/whatsmeow/socket/constants.go
@@ -0,0 +1,40 @@
+// Copyright (c) 2021 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+// Package socket implements a subset of the Noise protocol framework on top of websockets as used by WhatsApp.
+//
+// There shouldn't be any need to manually interact with this package.
+// The Client struct in the top-level whatsmeow package handles everything.
+package socket
+
+import "errors"
+
+const (
+ // Origin is the Origin header for all WhatsApp websocket connections
+ Origin = "https://web.whatsapp.com"
+ // URL is the websocket URL for the new multidevice protocol
+ URL = "wss://web.whatsapp.com/ws/chat"
+)
+
+const (
+ NoiseStartPattern = "Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
+
+ WADictVersion = 2
+ WAMagicValue = 5
+)
+
+var WAConnHeader = []byte{'W', 'A', WAMagicValue, WADictVersion}
+
+const (
+ FrameMaxSize = 2 << 23
+ FrameLengthSize = 3
+)
+
+var (
+ ErrFrameTooLarge = errors.New("frame too large")
+ ErrSocketClosed = errors.New("frame socket is closed")
+ ErrSocketAlreadyOpen = errors.New("frame socket is already open")
+)
diff --git a/vendor/go.mau.fi/whatsmeow/socket/framesocket.go b/vendor/go.mau.fi/whatsmeow/socket/framesocket.go
new file mode 100644
index 00000000..2bcb21b5
--- /dev/null
+++ b/vendor/go.mau.fi/whatsmeow/socket/framesocket.go
@@ -0,0 +1,228 @@
+// Copyright (c) 2021 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package socket
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/gorilla/websocket"
+
+ waLog "go.mau.fi/whatsmeow/util/log"
+)
+
+type FrameSocket struct {
+ conn *websocket.Conn
+ ctx context.Context
+ cancel func()
+ log waLog.Logger
+ lock sync.Mutex
+
+ Frames chan []byte
+ OnDisconnect func(remote bool)
+ WriteTimeout time.Duration
+
+ Header []byte
+
+ incomingLength int
+ receivedLength int
+ incoming []byte
+ partialHeader []byte
+}
+
+func NewFrameSocket(log waLog.Logger, header []byte) *FrameSocket {
+ return &FrameSocket{
+ conn: nil,
+ log: log,
+ Header: header,
+ Frames: make(chan []byte),
+ }
+}
+
+func (fs *FrameSocket) IsConnected() bool {
+ return fs.conn != nil
+}
+
+func (fs *FrameSocket) Context() context.Context {
+ return fs.ctx
+}
+
+func (fs *FrameSocket) Close(code int) {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+
+ if fs.conn == nil {
+ return
+ }
+
+ if code > 0 {
+ message := websocket.FormatCloseMessage(code, "")
+ err := fs.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
+ if err != nil {
+ fs.log.Warnf("Error sending close message: %v", err)
+ }
+ }
+
+ fs.cancel()
+ err := fs.conn.Close()
+ if err != nil {
+ fs.log.Errorf("Error closing websocket: %v", err)
+ }
+ fs.conn = nil
+ fs.ctx = nil
+ fs.cancel = nil
+ if fs.OnDisconnect != nil {
+ go fs.OnDisconnect(code == 0)
+ }
+}
+
+func (fs *FrameSocket) Connect() error {
+ fs.lock.Lock()
+ defer fs.lock.Unlock()
+
+ if fs.conn != nil {
+ return ErrSocketAlreadyOpen
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ dialer := websocket.Dialer{}
+
+ headers := http.Header{"Origin": []string{Origin}}
+ fs.log.Debugf("Dialing %s", URL)
+ conn, _, err := dialer.Dial(URL, headers)
+ if err != nil {
+ cancel()
+ return fmt.Errorf("couldn't dial whatsapp web websocket: %w", err)
+ }
+
+ fs.ctx, fs.cancel = ctx, cancel
+ fs.conn = conn
+ conn.SetCloseHandler(func(code int, text string) error {
+ fs.log.Debugf("Server closed websocket with status %d/%s", code, text)
+ cancel()
+ // from default CloseHandler
+ message := websocket.FormatCloseMessage(code, "")
+ _ = conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
+ return nil
+ })
+
+ go fs.readPump(conn, ctx)
+ return nil
+}
+
+func (fs *FrameSocket) SendFrame(data []byte) error {
+ conn := fs.conn
+ if conn == nil {
+ return ErrSocketClosed
+ }
+ dataLength := len(data)
+ if dataLength >= FrameMaxSize {
+ return fmt.Errorf("%w (got %d bytes, max %d bytes)", ErrFrameTooLarge, len(data), FrameMaxSize)
+ }
+
+ headerLength := len(fs.Header)
+ // Whole frame is header + 3 bytes for length + data
+ wholeFrame := make([]byte, headerLength+FrameLengthSize+dataLength)
+
+ // Copy the header if it's there
+ if fs.Header != nil {
+ copy(wholeFrame[:headerLength], fs.Header)
+ // We only want to send the header once
+ fs.Header = nil
+ }
+
+ // Encode length of frame
+ wholeFrame[headerLength] = byte(dataLength >> 16)
+ wholeFrame[headerLength+1] = byte(dataLength >> 8)
+ wholeFrame[headerLength+2] = byte(dataLength)
+
+ // Copy actual frame data
+ copy(wholeFrame[headerLength+FrameLengthSize:], data)
+
+ if fs.WriteTimeout > 0 {
+ err := conn.SetWriteDeadline(time.Now().Add(fs.WriteTimeout))
+ if err != nil {
+ fs.log.Warnf("Failed to set write deadline: %v", err)
+ }
+ }
+ return conn.WriteMessage(websocket.BinaryMessage, wholeFrame)
+}
+
+func (fs *FrameSocket) frameComplete() {
+ data := fs.incoming
+ fs.incoming = nil
+ fs.partialHeader = nil
+ fs.incomingLength = 0
+ fs.receivedLength = 0
+ fs.Frames <- data
+}
+
+func (fs *FrameSocket) processData(msg []byte) {
+ for len(msg) > 0 {
+ // This probably doesn't happen a lot (if at all), so the code is unoptimized
+ if fs.partialHeader != nil {
+ msg = append(fs.partialHeader, msg...)
+ fs.partialHeader = nil
+ }
+ if fs.incoming == nil {
+ if len(msg) >= FrameLengthSize {
+ length := (int(msg[0]) << 16) + (int(msg[1]) << 8) + int(msg[2])
+ fs.incomingLength = length
+ fs.receivedLength = len(msg)
+ msg = msg[FrameLengthSize:]
+ if len(msg) >= length {
+ fs.incoming = msg[:length]
+ msg = msg[length:]
+ fs.frameComplete()
+ } else {
+ fs.incoming = make([]byte, length)
+ copy(fs.incoming, msg)
+ msg = nil
+ }
+ } else {
+ fs.log.Warnf("Received partial header (report if this happens often)")
+ fs.partialHeader = msg
+ msg = nil
+ }
+ } else {
+ if len(fs.incoming)+len(msg) >= fs.incomingLength {
+ copy(fs.incoming[fs.receivedLength:], msg[:fs.incomingLength-fs.receivedLength])
+ msg = msg[fs.incomingLength-fs.receivedLength:]
+ fs.frameComplete()
+ } else {
+ copy(fs.incoming[fs.receivedLength:], msg)
+ fs.receivedLength += len(msg)
+ msg = nil
+ }
+ }
+ }
+}
+
+func (fs *FrameSocket) readPump(conn *websocket.Conn, ctx context.Context) {
+ fs.log.Debugf("Frame websocket read pump starting %p", fs)
+ defer func() {
+ fs.log.Debugf("Frame websocket read pump exiting %p", fs)
+ go fs.Close(0)
+ }()
+ for {
+ msgType, data, err := conn.ReadMessage()
+ if err != nil {
+ // Ignore the error if the context has been closed
+ if !errors.Is(ctx.Err(), context.Canceled) {
+ fs.log.Errorf("Error reading from websocket: %v", err)
+ }
+ return
+ } else if msgType != websocket.BinaryMessage {
+ fs.log.Warnf("Got unexpected websocket message type %d", msgType)
+ continue
+ }
+ fs.processData(data)
+ }
+}
diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go
new file mode 100644
index 00000000..26f7f263
--- /dev/null
+++ b/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go
@@ -0,0 +1,135 @@
+// Copyright (c) 2021 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package socket
+
+import (
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/sha256"
+ "fmt"
+ "io"
+ "sync/atomic"
+
+ "golang.org/x/crypto/curve25519"
+ "golang.org/x/crypto/hkdf"
+)
+
+type NoiseHandshake struct {
+ hash []byte
+ salt []byte
+ key cipher.AEAD
+ counter uint32
+}
+
+func NewNoiseHandshake() *NoiseHandshake {
+ return &NoiseHandshake{}
+}
+
+func newCipher(key []byte) (cipher.AEAD, error) {
+ aesCipher, err := aes.NewCipher(key)
+ if err != nil {
+ return nil, err
+ }
+ aesGCM, err := cipher.NewGCM(aesCipher)
+ if err != nil {
+ return nil, err
+ }
+ return aesGCM, nil
+}
+
+func sha256Slice(data []byte) []byte {
+ hash := sha256.Sum256(data)
+ return hash[:]
+}
+
+func (nh *NoiseHandshake) Start(pattern string, header []byte) {
+ data := []byte(pattern)
+ if len(data) == 32 {
+ nh.hash = data
+ } else {
+ nh.hash = sha256Slice(data)
+ }
+ nh.salt = nh.hash
+ var err error
+ nh.key, err = newCipher(nh.hash)
+ if err != nil {
+ panic(err)
+ }
+ nh.Authenticate(header)
+}
+
+func (nh *NoiseHandshake) Authenticate(data []byte) {
+ nh.hash = sha256Slice(append(nh.hash, data...))
+}
+
+func (nh *NoiseHandshake) postIncrementCounter() uint32 {
+ count := atomic.AddUint32(&nh.counter, 1)
+ return count - 1
+}
+
+func (nh *NoiseHandshake) Encrypt(plaintext []byte) []byte {
+ ciphertext := nh.key.Seal(nil, generateIV(nh.postIncrementCounter()), plaintext, nh.hash)
+ nh.Authenticate(ciphertext)
+ return ciphertext
+}
+
+func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err error) {
+ plaintext, err = nh.key.Open(nil, generateIV(nh.postIncrementCounter()), ciphertext, nh.hash)
+ if err == nil {
+ nh.Authenticate(ciphertext)
+ }
+ return
+}
+
+func (nh *NoiseHandshake) Finish(fs *FrameSocket, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) {
+ if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil {
+ return nil, fmt.Errorf("failed to extract final keys: %w", err)
+ } else if writeKey, err := newCipher(write); err != nil {
+ return nil, fmt.Errorf("failed to create final write cipher: %w", err)
+ } else if readKey, err := newCipher(read); err != nil {
+ return nil, fmt.Errorf("failed to create final read cipher: %w", err)
+ } else if ns, err := newNoiseSocket(fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil {
+ return nil, fmt.Errorf("failed to create noise socket: %w", err)
+ } else {
+ return ns, nil
+ }
+}
+
+func (nh *NoiseHandshake) MixSharedSecretIntoKey(priv, pub [32]byte) error {
+ secret, err := curve25519.X25519(priv[:], pub[:])
+ if err != nil {
+ return fmt.Errorf("failed to do x25519 scalar multiplication: %w", err)
+ }
+ return nh.MixIntoKey(secret)
+}
+
+func (nh *NoiseHandshake) MixIntoKey(data []byte) error {
+ nh.counter = 0
+ write, read, err := nh.extractAndExpand(nh.salt, data)
+ if err != nil {
+ return fmt.Errorf("failed to extract keys for mixing: %w", err)
+ }
+ nh.salt = write
+ nh.key, err = newCipher(read)
+ if err != nil {
+ return fmt.Errorf("failed to create new cipher while mixing keys: %w", err)
+ }
+ return nil
+}
+
+func (nh *NoiseHandshake) extractAndExpand(salt, data []byte) (write []byte, read []byte, err error) {
+ h := hkdf.New(sha256.New, data, salt, nil)
+ write = make([]byte, 32)
+ read = make([]byte, 32)
+
+ if _, err = io.ReadFull(h, write); err != nil {
+ err = fmt.Errorf("failed to read write key: %w", err)
+ } else if _, err = io.ReadFull(h, read); err != nil {
+ err = fmt.Errorf("failed to read read key: %w", err)
+ }
+ return
+}
diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go
new file mode 100644
index 00000000..23bb44e8
--- /dev/null
+++ b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go
@@ -0,0 +1,104 @@
+// Copyright (c) 2021 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package socket
+
+import (
+ "context"
+ "crypto/cipher"
+ "encoding/binary"
+ "sync"
+ "sync/atomic"
+
+ "github.com/gorilla/websocket"
+)
+
+type NoiseSocket struct {
+ fs *FrameSocket
+ onFrame FrameHandler
+ writeKey cipher.AEAD
+ readKey cipher.AEAD
+ writeCounter uint32
+ readCounter uint32
+ writeLock sync.Mutex
+ destroyed uint32
+ stopConsumer chan struct{}
+}
+
+type DisconnectHandler func(socket *NoiseSocket, remote bool)
+type FrameHandler func([]byte)
+
+func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) {
+ ns := &NoiseSocket{
+ fs: fs,
+ writeKey: writeKey,
+ readKey: readKey,
+ onFrame: frameHandler,
+ stopConsumer: make(chan struct{}),
+ }
+ fs.OnDisconnect = func(remote bool) {
+ disconnectHandler(ns, remote)
+ }
+ go ns.consumeFrames(fs.ctx, fs.Frames)
+ return ns, nil
+}
+
+func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) {
+ ctxDone := ctx.Done()
+ for {
+ select {
+ case frame := <-frames:
+ ns.receiveEncryptedFrame(frame)
+ case <-ctxDone:
+ return
+ case <-ns.stopConsumer:
+ return
+ }
+ }
+}
+
+func generateIV(count uint32) []byte {
+ iv := make([]byte, 12)
+ binary.BigEndian.PutUint32(iv[8:], count)
+ return iv
+}
+
+func (ns *NoiseSocket) Context() context.Context {
+ return ns.fs.Context()
+}
+
+func (ns *NoiseSocket) Stop(disconnect bool) {
+ if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) {
+ close(ns.stopConsumer)
+ ns.fs.OnDisconnect = nil
+ if disconnect {
+ ns.fs.Close(websocket.CloseNormalClosure)
+ }
+ }
+}
+
+func (ns *NoiseSocket) SendFrame(plaintext []byte) error {
+ ns.writeLock.Lock()
+ ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil)
+ ns.writeCounter++
+ err := ns.fs.SendFrame(ciphertext)
+ ns.writeLock.Unlock()
+ return err
+}
+
+func (ns *NoiseSocket) receiveEncryptedFrame(ciphertext []byte) {
+ count := atomic.AddUint32(&ns.readCounter, 1) - 1
+ plaintext, err := ns.readKey.Open(nil, generateIV(count), ciphertext, nil)
+ if err != nil {
+ ns.fs.log.Warnf("Failed to decrypt frame: %v", err)
+ return
+ }
+ ns.onFrame(plaintext)
+}
+
+func (ns *NoiseSocket) IsConnected() bool {
+ return ns.fs.IsConnected()
+}