summaryrefslogtreecommitdiffstats
path: root/vendor/go.mau.fi/whatsmeow/socket/noisehandshake.go
blob: 26f7f2631ebe18dd614c1f61729dd4dee09597ab (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package socket

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/sha256"
	"fmt"
	"io"
	"sync/atomic"

	"golang.org/x/crypto/curve25519"
	"golang.org/x/crypto/hkdf"
)

type NoiseHandshake struct {
	hash    []byte
	salt    []byte
	key     cipher.AEAD
	counter uint32
}

func NewNoiseHandshake() *NoiseHandshake {
	return &NoiseHandshake{}
}

func newCipher(key []byte) (cipher.AEAD, error) {
	aesCipher, err := aes.NewCipher(key)
	if err != nil {
		return nil, err
	}
	aesGCM, err := cipher.NewGCM(aesCipher)
	if err != nil {
		return nil, err
	}
	return aesGCM, nil
}

func sha256Slice(data []byte) []byte {
	hash := sha256.Sum256(data)
	return hash[:]
}

func (nh *NoiseHandshake) Start(pattern string, header []byte) {
	data := []byte(pattern)
	if len(data) == 32 {
		nh.hash = data
	} else {
		nh.hash = sha256Slice(data)
	}
	nh.salt = nh.hash
	var err error
	nh.key, err = newCipher(nh.hash)
	if err != nil {
		panic(err)
	}
	nh.Authenticate(header)
}

func (nh *NoiseHandshake) Authenticate(data []byte) {
	nh.hash = sha256Slice(append(nh.hash, data...))
}

func (nh *NoiseHandshake) postIncrementCounter() uint32 {
	count := atomic.AddUint32(&nh.counter, 1)
	return count - 1
}

func (nh *NoiseHandshake) Encrypt(plaintext []byte) []byte {
	ciphertext := nh.key.Seal(nil, generateIV(nh.postIncrementCounter()), plaintext, nh.hash)
	nh.Authenticate(ciphertext)
	return ciphertext
}

func (nh *NoiseHandshake) Decrypt(ciphertext []byte) (plaintext []byte, err error) {
	plaintext, err = nh.key.Open(nil, generateIV(nh.postIncrementCounter()), ciphertext, nh.hash)
	if err == nil {
		nh.Authenticate(ciphertext)
	}
	return
}

func (nh *NoiseHandshake) Finish(fs *FrameSocket, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) {
	if write, read, err := nh.extractAndExpand(nh.salt, nil); err != nil {
		return nil, fmt.Errorf("failed to extract final keys: %w", err)
	} else if writeKey, err := newCipher(write); err != nil {
		return nil, fmt.Errorf("failed to create final write cipher: %w", err)
	} else if readKey, err := newCipher(read); err != nil {
		return nil, fmt.Errorf("failed to create final read cipher: %w", err)
	} else if ns, err := newNoiseSocket(fs, writeKey, readKey, frameHandler, disconnectHandler); err != nil {
		return nil, fmt.Errorf("failed to create noise socket: %w", err)
	} else {
		return ns, nil
	}
}

func (nh *NoiseHandshake) MixSharedSecretIntoKey(priv, pub [32]byte) error {
	secret, err := curve25519.X25519(priv[:], pub[:])
	if err != nil {
		return fmt.Errorf("failed to do x25519 scalar multiplication: %w", err)
	}
	return nh.MixIntoKey(secret)
}

func (nh *NoiseHandshake) MixIntoKey(data []byte) error {
	nh.counter = 0
	write, read, err := nh.extractAndExpand(nh.salt, data)
	if err != nil {
		return fmt.Errorf("failed to extract keys for mixing: %w", err)
	}
	nh.salt = write
	nh.key, err = newCipher(read)
	if err != nil {
		return fmt.Errorf("failed to create new cipher while mixing keys: %w", err)
	}
	return nil
}

func (nh *NoiseHandshake) extractAndExpand(salt, data []byte) (write []byte, read []byte, err error) {
	h := hkdf.New(sha256.New, data, salt, nil)
	write = make([]byte, 32)
	read = make([]byte, 32)

	if _, err = io.ReadFull(h, write); err != nil {
		err = fmt.Errorf("failed to read write key: %w", err)
	} else if _, err = io.ReadFull(h, read); err != nil {
		err = fmt.Errorf("failed to read read key: %w", err)
	}
	return
}