1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package socket
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"fmt"
"io"
"sync/atomic"
"golang.org/x/crypto/curve25519"
"golang.org/x/crypto/hkdf"
)
type NoiseHandshake struct {
hash []byte
salt []byte
key cipher.AEAD
counter uint32
}
func NewNoiseHandshake() *NoiseHandshake {
return &NoiseHandshake{}
}
func newCipher(key []byte) (cipher.AEAD, error) {
aesCipher, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aesGCM, err := cipher.NewGCM(aesCipher)
if err != nil {
return nil, err
}
return aesGCM, nil
}
func sha256Slice(data []byte) []byte {
hash := sha256.Sum256(data)
return hash[:]
}
func (nh *NoiseHandshake) Start(pattern string, header []byte) {
data := []byte(pattern)
if len(data) == 32 {
nh.hash = data
} else {
nh.hash = sha256Slice(data)
}
nh.salt = nh.hash
var err error
nh.key, err = newCipher(nh.hash)
if err != nil {
panic(err)
}
nh.Authenticate(header)
}
func (nh *NoiseHandshake) Authenticate(data []byte) {
nh.hash = sha256Slice(append(nh.hash, data...))
}
func (nh *NoiseHandshake) postIncrementCounter() uint32 {
count := atomic.AddUint32(&nh.counter, 1)
return count - 1
}
func (nh *NoiseHandshake) Encrypt(plaintext []byte) []byte {
ciphertext := nh.key.Seal(nil, generateIV(nh.postIncrementCounter()), plaintext, nh.hash)
nh.Authenticate(ciphertext)
return ciphertext
}
func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err error) {
plaintext, err = nh.key.Open(nil, generateIV(nh.postIncrementCounter()), ciphertext, nh.hash)
if err == nil {
nh.Authenticate(ciphertext)
}
return
}
func (nh *NoiseHandshake) Finish(fs *FrameSocket, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) {
if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil {
return nil, fmt.Errorf("failed to extract final keys: %w", err)
} else if writeKey, err := newCipher(write); err != nil {
return nil, fmt.Errorf("failed to create final write cipher: %w", err)
} else if readKey, err := newCipher(read); err != nil {
return nil, fmt.Errorf("failed to create final read cipher: %w", err)
} else if ns, err := newNoiseSocket(fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil {
return nil, fmt.Errorf("failed to create noise socket: %w", err)
} else {
return ns, nil
}
}
func (nh *NoiseHandshake) MixSharedSecretIntoKey(priv, pub [32]byte) error {
secret, err := curve25519.X25519(priv[:], pub[:])
if err != nil {
return fmt.Errorf("failed to do x25519 scalar multiplication: %w", err)
}
return nh.MixIntoKey(secret)
}
func (nh *NoiseHandshake) MixIntoKey(data []byte) error {
nh.counter = 0
write, read, err := nh.extractAndExpand(nh.salt, data)
if err != nil {
return fmt.Errorf("failed to extract keys for mixing: %w", err)
}
nh.salt = write
nh.key, err = newCipher(read)
if err != nil {
return fmt.Errorf("failed to create new cipher while mixing keys: %w", err)
}
return nil
}
func (nh *NoiseHandshake) extractAndExpand(salt, data []byte) (write []byte, read []byte, err error) {
h := hkdf.New(sha256.New, data, salt, nil)
write = make([]byte, 32)
read = make([]byte, 32)
if _, err = io.ReadFull(h, write); err != nil {
err = fmt.Errorf("failed to read write key: %w", err)
} else if _, err = io.ReadFull(h, read); err != nil {
err = fmt.Errorf("failed to read read key: %w", err)
}
return
}
|