diff options
Diffstat (limited to 'vendor/go.mau.fi/whatsmeow/socket/noisesocket.go')
-rw-r--r-- | vendor/go.mau.fi/whatsmeow/socket/noisesocket.go | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go new file mode 100644 index 00000000..23bb44e8 --- /dev/null +++ b/vendor/go.mau.fi/whatsmeow/socket/noisesocket.go @@ -0,0 +1,104 @@ +// Copyright (c) 2021 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package socket + +import ( + "context" + "crypto/cipher" + "encoding/binary" + "sync" + "sync/atomic" + + "github.com/gorilla/websocket" +) + +type NoiseSocket struct { + fs *FrameSocket + onFrame FrameHandler + writeKey cipher.AEAD + readKey cipher.AEAD + writeCounter uint32 + readCounter uint32 + writeLock sync.Mutex + destroyed uint32 + stopConsumer chan struct{} +} + +type DisconnectHandler func(socket *NoiseSocket, remote bool) +type FrameHandler func([]byte) + +func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) { + ns := &NoiseSocket{ + fs: fs, + writeKey: writeKey, + readKey: readKey, + onFrame: frameHandler, + stopConsumer: make(chan struct{}), + } + fs.OnDisconnect = func(remote bool) { + disconnectHandler(ns, remote) + } + go ns.consumeFrames(fs.ctx, fs.Frames) + return ns, nil +} + +func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) { + ctxDone := ctx.Done() + for { + select { + case frame := <-frames: + ns.receiveEncryptedFrame(frame) + case <-ctxDone: + return + case <-ns.stopConsumer: + return + } + } +} + +func generateIV(count uint32) []byte { + iv := make([]byte, 12) + binary.BigEndian.PutUint32(iv[8:], count) + return iv +} + +func (ns *NoiseSocket) Context() context.Context { + return ns.fs.Context() +} + +func (ns *NoiseSocket) Stop(disconnect bool) { + if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) { + close(ns.stopConsumer) + ns.fs.OnDisconnect = nil + if disconnect { + ns.fs.Close(websocket.CloseNormalClosure) + } + } +} + +func (ns *NoiseSocket) SendFrame(plaintext []byte) error { + ns.writeLock.Lock() + ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil) + ns.writeCounter++ + err := ns.fs.SendFrame(ciphertext) + ns.writeLock.Unlock() + return err +} + +func (ns *NoiseSocket) receiveEncryptedFrame(ciphertext []byte) { + count := atomic.AddUint32(&ns.readCounter, 1) - 1 + plaintext, err := ns.readKey.Open(nil, generateIV(count), ciphertext, nil) + if err != nil { + ns.fs.log.Warnf("Failed to decrypt frame: %v", err) + return + } + ns.onFrame(plaintext) +} + +func (ns *NoiseSocket) IsConnected() bool { + return ns.fs.IsConnected() +} |