summaryrefslogblamecommitdiffstats
path: root/vendor/github.com/pkg/sftp/packet-manager.go
blob: bf822e672a12d06940a64d0201c0373811725a5a (plain) (tree)











































































































































































                                                                                  
package sftp

import (
	"encoding"
	"sort"
	"sync"
)

// The goal of the packetManager is to keep the outgoing packets in the same
// order as the incoming. This is due to some sftp clients requiring this
// behavior (eg. winscp).

type packetSender interface {
	sendPacket(encoding.BinaryMarshaler) error
}

type packetManager struct {
	requests  chan requestPacket
	responses chan responsePacket
	fini      chan struct{}
	incoming  requestPacketIDs
	outgoing  responsePackets
	sender    packetSender // connection object
	working   *sync.WaitGroup
}

func newPktMgr(sender packetSender) *packetManager {
	s := &packetManager{
		requests:  make(chan requestPacket, SftpServerWorkerCount),
		responses: make(chan responsePacket, SftpServerWorkerCount),
		fini:      make(chan struct{}),
		incoming:  make([]uint32, 0, SftpServerWorkerCount),
		outgoing:  make([]responsePacket, 0, SftpServerWorkerCount),
		sender:    sender,
		working:   &sync.WaitGroup{},
	}
	go s.controller()
	return s
}

type responsePackets []responsePacket

func (r responsePackets) Sort() {
	sort.Slice(r, func(i, j int) bool {
		return r[i].id() < r[j].id()
	})
}

type requestPacketIDs []uint32

func (r requestPacketIDs) Sort() {
	sort.Slice(r, func(i, j int) bool {
		return r[i] < r[j]
	})
}

// register incoming packets to be handled
// send id of 0 for packets without id
func (s *packetManager) incomingPacket(pkt requestPacket) {
	s.working.Add(1)
	s.requests <- pkt // buffer == SftpServerWorkerCount
}

// register outgoing packets as being ready
func (s *packetManager) readyPacket(pkt responsePacket) {
	s.responses <- pkt
	s.working.Done()
}

// shut down packetManager controller
func (s *packetManager) close() {
	// pause until current packets are processed
	s.working.Wait()
	close(s.fini)
}

// Passed a worker function, returns a channel for incoming packets.
// The goal is to process packets in the order they are received as is
// requires by section 7 of the RFC, while maximizing throughput of file
// transfers.
func (s *packetManager) workerChan(runWorker func(requestChan)) requestChan {

	rwChan := make(chan requestPacket, SftpServerWorkerCount)
	for i := 0; i < SftpServerWorkerCount; i++ {
		runWorker(rwChan)
	}

	cmdChan := make(chan requestPacket)
	runWorker(cmdChan)

	pktChan := make(chan requestPacket, SftpServerWorkerCount)
	go func() {
		// start with cmdChan
		curChan := cmdChan
		for pkt := range pktChan {
			// on file open packet, switch to rwChan
			switch pkt.(type) {
			case *sshFxpOpenPacket:
				curChan = rwChan
			// on file close packet, switch back to cmdChan
			// after waiting for any reads/writes to finish
			case *sshFxpClosePacket:
				// wait for rwChan to finish
				s.working.Wait()
				// stop using rwChan
				curChan = cmdChan
			}
			s.incomingPacket(pkt)
			curChan <- pkt
		}
		close(rwChan)
		close(cmdChan)
		s.close()
	}()

	return pktChan
}

// process packets
func (s *packetManager) controller() {
	for {
		select {
		case pkt := <-s.requests:
			debug("incoming id: %v", pkt.id())
			s.incoming = append(s.incoming, pkt.id())
			if len(s.incoming) > 1 {
				s.incoming.Sort()
			}
		case pkt := <-s.responses:
			debug("outgoing pkt: %v", pkt.id())
			s.outgoing = append(s.outgoing, pkt)
			if len(s.outgoing) > 1 {
				s.outgoing.Sort()
			}
		case <-s.fini:
			return
		}
		s.maybeSendPackets()
	}
}

// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
	for {
		if len(s.outgoing) == 0 || len(s.incoming) == 0 {
			debug("break! -- outgoing: %v; incoming: %v",
				len(s.outgoing), len(s.incoming))
			break
		}
		out := s.outgoing[0]
		in := s.incoming[0]
		// 		debug("incoming: %v", s.incoming)
		// 		debug("outgoing: %v", outfilter(s.outgoing))
		if in == out.id() {
			s.sender.sendPacket(out)
			// pop off heads
			copy(s.incoming, s.incoming[1:])            // shift left
			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
			copy(s.outgoing, s.outgoing[1:])            // shift left
			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
		} else {
			break
		}
	}
}

//func outfilter(o []responsePacket) []uint32 {
//	res := make([]uint32, 0, len(o))
//	for _, v := range o {
//		res = append(res, v.id())
//	}
//	return res
//}