1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
|
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package socket
import (
"context"
"crypto/cipher"
"encoding/binary"
"sync"
"sync/atomic"
"github.com/gorilla/websocket"
)
type NoiseSocket struct {
fs *FrameSocket
onFrame FrameHandler
writeKey cipher.AEAD
readKey cipher.AEAD
writeCounter uint32
readCounter uint32
writeLock sync.Mutex
destroyed uint32
stopConsumer chan struct{}
}
type DisconnectHandler func(socket *NoiseSocket, remote bool)
type FrameHandler func([]byte)
func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) {
ns := &NoiseSocket{
fs: fs,
writeKey: writeKey,
readKey: readKey,
onFrame: frameHandler,
stopConsumer: make(chan struct{}),
}
fs.OnDisconnect = func(remote bool) {
disconnectHandler(ns, remote)
}
go ns.consumeFrames(fs.ctx, fs.Frames)
return ns, nil
}
func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) {
if ctx == nil {
// ctx being nil implies the connection already closed somehow
return
}
ctxDone := ctx.Done()
for {
select {
case frame := <-frames:
ns.receiveEncryptedFrame(frame)
case <-ctxDone:
return
case <-ns.stopConsumer:
return
}
}
}
func generateIV(count uint32) []byte {
iv := make([]byte, 12)
binary.BigEndian.PutUint32(iv[8:], count)
return iv
}
func (ns *NoiseSocket) Context() context.Context {
return ns.fs.Context()
}
func (ns *NoiseSocket) Stop(disconnect bool) {
if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) {
close(ns.stopConsumer)
ns.fs.OnDisconnect = nil
if disconnect {
ns.fs.Close(websocket.CloseNormalClosure)
}
}
}
func (ns *NoiseSocket) SendFrame(plaintext []byte) error {
ns.writeLock.Lock()
ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil)
ns.writeCounter++
err := ns.fs.SendFrame(ciphertext)
ns.writeLock.Unlock()
return err
}
func (ns *NoiseSocket) receiveEncryptedFrame(ciphertext []byte) {
count := atomic.AddUint32(&ns.readCounter, 1) - 1
plaintext, err := ns.readKey.Open(nil, generateIV(count), ciphertext, nil)
if err != nil {
ns.fs.log.Warnf("Failed to decrypt frame: %v", err)
return
}
ns.onFrame(plaintext)
}
func (ns *NoiseSocket) IsConnected() bool {
return ns.fs.IsConnected()
}
|