summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/pkg/sftp/packet-manager.go
blob: bf822e672a12d06940a64d0201c0373811725a5a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package sftp

import (
	"encoding"
	"sort"
	"sync"
)

// The goal of the packetManager is to keep the outgoing packets in the same
// order as the incoming. This is due to some sftp clients requiring this
// behavior (eg. winscp).

type packetSender interface {
	sendPacket(encoding.BinaryMarshaler) error
}

type packetManager struct {
	requests  chan requestPacket
	responses chan responsePacket
	fini      chan struct{}
	incoming  requestPacketIDs
	outgoing  responsePackets
	sender    packetSender // connection object
	working   *sync.WaitGroup
}

func newPktMgr(sender packetSender) *packetManager {
	s := &packetManager{
		requests:  make(chan requestPacket, SftpServerWorkerCount),
		responses: make(chan responsePacket, SftpServerWorkerCount),
		fini:      make(chan struct{}),
		incoming:  make([]uint32, 0, SftpServerWorkerCount),
		outgoing:  make([]responsePacket, 0, SftpServerWorkerCount),
		sender:    sender,
		working:   &sync.WaitGroup{},
	}
	go s.controller()
	return s
}

type responsePackets []responsePacket

func (r responsePackets) Sort() {
	sort.Slice(r, func(i, j int) bool {
		return r[i].id() < r[j].id()
	})
}

type requestPacketIDs []uint32

func (r requestPacketIDs) Sort() {
	sort.Slice(r, func(i, j int) bool {
		return r[i] < r[j]
	})
}

// register incoming packets to be handled
// send id of 0 for packets without id
func (s *packetManager) incomingPacket(pkt requestPacket) {
	s.working.Add(1)
	s.requests <- pkt // buffer == SftpServerWorkerCount
}

// register outgoing packets as being ready
func (s *packetManager) readyPacket(pkt responsePacket) {
	s.responses <- pkt
	s.working.Done()
}

// shut down packetManager controller
func (s *packetManager) close() {
	// pause until current packets are processed
	s.working.Wait()
	close(s.fini)
}

// Passed a worker function, returns a channel for incoming packets.
// The goal is to process packets in the order they are received as is
// requires by section 7 of the RFC, while maximizing throughput of file
// transfers.
func (s *packetManager) workerChan(runWorker func(requestChan)) requestChan {

	rwChan := make(chan requestPacket, SftpServerWorkerCount)
	for i := 0; i < SftpServerWorkerCount; i++ {
		runWorker(rwChan)
	}

	cmdChan := make(chan requestPacket)
	runWorker(cmdChan)

	pktChan := make(chan requestPacket, SftpServerWorkerCount)
	go func() {
		// start with cmdChan
		curChan := cmdChan
		for pkt := range pktChan {
			// on file open packet, switch to rwChan
			switch pkt.(type) {
			case *sshFxpOpenPacket:
				curChan = rwChan
			// on file close packet, switch back to cmdChan
			// after waiting for any reads/writes to finish
			case *sshFxpClosePacket:
				// wait for rwChan to finish
				s.working.Wait()
				// stop using rwChan
				curChan = cmdChan
			}
			s.incomingPacket(pkt)
			curChan <- pkt
		}
		close(rwChan)
		close(cmdChan)
		s.close()
	}()

	return pktChan
}

// process packets
func (s *packetManager) controller() {
	for {
		select {
		case pkt := <-s.requests:
			debug("incoming id: %v", pkt.id())
			s.incoming = append(s.incoming, pkt.id())
			if len(s.incoming) > 1 {
				s.incoming.Sort()
			}
		case pkt := <-s.responses:
			debug("outgoing pkt: %v", pkt.id())
			s.outgoing = append(s.outgoing, pkt)
			if len(s.outgoing) > 1 {
				s.outgoing.Sort()
			}
		case <-s.fini:
			return
		}
		s.maybeSendPackets()
	}
}

// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
	for {
		if len(s.outgoing) == 0 || len(s.incoming) == 0 {
			debug("break! -- outgoing: %v; incoming: %v",
				len(s.outgoing), len(s.incoming))
			break
		}
		out := s.outgoing[0]
		in := s.incoming[0]
		// 		debug("incoming: %v", s.incoming)
		// 		debug("outgoing: %v", outfilter(s.outgoing))
		if in == out.id() {
			s.sender.sendPacket(out)
			// pop off heads
			copy(s.incoming, s.incoming[1:])            // shift left
			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
			copy(s.outgoing, s.outgoing[1:])            // shift left
			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
		} else {
			break
		}
	}
}

//func outfilter(o []responsePacket) []uint32 {
//	res := make([]uint32, 0, len(o))
//	for _, v := range o {
//		res = append(res, v.id())
//	}
//	return res
//}