summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/facebookgo/httpdown/httpdown.go
blob: 34c5dea9ff89f51a59b866dae6f5af1b5bca02f3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
// Package httpdown provides http.ConnState enabled graceful termination of
// http.Server.
package httpdown

import (
	"crypto/tls"
	"fmt"
	"net"
	"net/http"
	"os"
	"os/signal"
	"sync"
	"syscall"
	"time"

	"github.com/facebookgo/clock"
	"github.com/facebookgo/stats"
)

const (
	defaultStopTimeout = time.Minute
	defaultKillTimeout = time.Minute
)

// A Server allows encapsulates the process of accepting new connections and
// serving them, and gracefully shutting down the listener without dropping
// active connections.
type Server interface {
	// Wait waits for the serving loop to finish. This will happen when Stop is
	// called, at which point it returns no error, or if there is an error in the
	// serving loop. You must call Wait after calling Serve or ListenAndServe.
	Wait() error

	// Stop stops the listener. It will block until all connections have been
	// closed.
	Stop() error
}

// HTTP defines the configuration for serving a http.Server. Multiple calls to
// Serve or ListenAndServe can be made on the same HTTP instance. The default
// timeouts of 1 minute each result in a maximum of 2 minutes before a Stop()
// returns.
type HTTP struct {
	// StopTimeout is the duration before we begin force closing connections.
	// Defaults to 1 minute.
	StopTimeout time.Duration

	// KillTimeout is the duration before which we completely give up and abort
	// even though we still have connected clients. This is useful when a large
	// number of client connections exist and closing them can take a long time.
	// Note, this is in addition to the StopTimeout. Defaults to 1 minute.
	KillTimeout time.Duration

	// Stats is optional. If provided, it will be used to record various metrics.
	Stats stats.Client

	// Clock allows for testing timing related functionality. Do not specify this
	// in production code.
	Clock clock.Clock
}

// Serve provides the low-level API which is useful if you're creating your own
// net.Listener.
func (h HTTP) Serve(s *http.Server, l net.Listener) Server {
	stopTimeout := h.StopTimeout
	if stopTimeout == 0 {
		stopTimeout = defaultStopTimeout
	}
	killTimeout := h.KillTimeout
	if killTimeout == 0 {
		killTimeout = defaultKillTimeout
	}
	klock := h.Clock
	if klock == nil {
		klock = clock.New()
	}

	ss := &server{
		stopTimeout:  stopTimeout,
		killTimeout:  killTimeout,
		stats:        h.Stats,
		clock:        klock,
		oldConnState: s.ConnState,
		listener:     l,
		server:       s,
		serveDone:    make(chan struct{}),
		serveErr:     make(chan error, 1),
		new:          make(chan net.Conn),
		active:       make(chan net.Conn),
		idle:         make(chan net.Conn),
		closed:       make(chan net.Conn),
		stop:         make(chan chan struct{}),
		kill:         make(chan chan struct{}),
	}
	s.ConnState = ss.connState
	go ss.manage()
	go ss.serve()
	return ss
}

// ListenAndServe returns a Server for the given http.Server. It is equivalent
// to ListenAndServe from the standard library, but returns immediately.
// Requests will be accepted in a background goroutine. If the http.Server has
// a non-nil TLSConfig, a TLS enabled listener will be setup.
func (h HTTP) ListenAndServe(s *http.Server) (Server, error) {
	addr := s.Addr
	if addr == "" {
		if s.TLSConfig == nil {
			addr = ":http"
		} else {
			addr = ":https"
		}
	}
	l, err := net.Listen("tcp", addr)
	if err != nil {
		stats.BumpSum(h.Stats, "listen.error", 1)
		return nil, err
	}
	if s.TLSConfig != nil {
		l = tls.NewListener(l, s.TLSConfig)
	}
	return h.Serve(s, l), nil
}

// server manages the serving process and allows for gracefully stopping it.
type server struct {
	stopTimeout time.Duration
	killTimeout time.Duration
	stats       stats.Client
	clock       clock.Clock

	oldConnState func(net.Conn, http.ConnState)
	server       *http.Server
	serveDone    chan struct{}
	serveErr     chan error
	listener     net.Listener

	new    chan net.Conn
	active chan net.Conn
	idle   chan net.Conn
	closed chan net.Conn
	stop   chan chan struct{}
	kill   chan chan struct{}

	stopOnce sync.Once
	stopErr  error
}

func (s *server) connState(c net.Conn, cs http.ConnState) {
	if s.oldConnState != nil {
		s.oldConnState(c, cs)
	}

	switch cs {
	case http.StateNew:
		s.new <- c
	case http.StateActive:
		s.active <- c
	case http.StateIdle:
		s.idle <- c
	case http.StateHijacked, http.StateClosed:
		s.closed <- c
	}
}

func (s *server) manage() {
	defer func() {
		close(s.new)
		close(s.active)
		close(s.idle)
		close(s.closed)
		close(s.stop)
		close(s.kill)
	}()

	var stopDone chan struct{}

	conns := map[net.Conn]http.ConnState{}
	var countNew, countActive, countIdle float64

	// decConn decrements the count associated with the current state of the
	// given connection.
	decConn := func(c net.Conn) {
		switch conns[c] {
		default:
			panic(fmt.Errorf("unknown existing connection: %s", c))
		case http.StateNew:
			countNew--
		case http.StateActive:
			countActive--
		case http.StateIdle:
			countIdle--
		}
	}

	// setup a ticker to report various values every minute. if we don't have a
	// Stats implementation provided, we Stop it so it never ticks.
	statsTicker := s.clock.Ticker(time.Minute)
	if s.stats == nil {
		statsTicker.Stop()
	}

	for {
		select {
		case <-statsTicker.C:
			// we'll only get here when s.stats is not nil
			s.stats.BumpAvg("http-state.new", countNew)
			s.stats.BumpAvg("http-state.active", countActive)
			s.stats.BumpAvg("http-state.idle", countIdle)
			s.stats.BumpAvg("http-state.total", countNew+countActive+countIdle)
		case c := <-s.new:
			conns[c] = http.StateNew
			countNew++
		case c := <-s.active:
			decConn(c)
			countActive++

			conns[c] = http.StateActive
		case c := <-s.idle:
			decConn(c)
			countIdle++

			conns[c] = http.StateIdle

			// if we're already stopping, close it
			if stopDone != nil {
				c.Close()
			}
		case c := <-s.closed:
			stats.BumpSum(s.stats, "conn.closed", 1)
			decConn(c)
			delete(conns, c)

			// if we're waiting to stop and are all empty, we just closed the last
			// connection and we're done.
			if stopDone != nil && len(conns) == 0 {
				close(stopDone)
				return
			}
		case stopDone = <-s.stop:
			// if we're already all empty, we're already done
			if len(conns) == 0 {
				close(stopDone)
				return
			}

			// close current idle connections right away
			for c, cs := range conns {
				if cs == http.StateIdle {
					c.Close()
				}
			}

			// continue the loop and wait for all the ConnState updates which will
			// eventually close(stopDone) and return from this goroutine.

		case killDone := <-s.kill:
			// force close all connections
			stats.BumpSum(s.stats, "kill.conn.count", float64(len(conns)))
			for c := range conns {
				c.Close()
			}

			// don't block the kill.
			close(killDone)

			// continue the loop and we wait for all the ConnState updates and will
			// return from this goroutine when we're all done. otherwise we'll try to
			// send those ConnState updates on closed channels.

		}
	}
}

func (s *server) serve() {
	stats.BumpSum(s.stats, "serve", 1)
	s.serveErr <- s.server.Serve(s.listener)
	close(s.serveDone)
	close(s.serveErr)
}

func (s *server) Wait() error {
	if err := <-s.serveErr; !isUseOfClosedError(err) {
		return err
	}
	return nil
}

func (s *server) Stop() error {
	s.stopOnce.Do(func() {
		defer stats.BumpTime(s.stats, "stop.time").End()
		stats.BumpSum(s.stats, "stop", 1)

		// first disable keep-alive for new connections
		s.server.SetKeepAlivesEnabled(false)

		// then close the listener so new connections can't connect come thru
		closeErr := s.listener.Close()
		<-s.serveDone

		// then trigger the background goroutine to stop and wait for it
		stopDone := make(chan struct{})
		s.stop <- stopDone

		// wait for stop
		select {
		case <-stopDone:
		case <-s.clock.After(s.stopTimeout):
			defer stats.BumpTime(s.stats, "kill.time").End()
			stats.BumpSum(s.stats, "kill", 1)

			// stop timed out, wait for kill
			killDone := make(chan struct{})
			s.kill <- killDone
			select {
			case <-killDone:
			case <-s.clock.After(s.killTimeout):
				// kill timed out, give up
				stats.BumpSum(s.stats, "kill.timeout", 1)
			}
		}

		if closeErr != nil && !isUseOfClosedError(closeErr) {
			stats.BumpSum(s.stats, "listener.close.error", 1)
			s.stopErr = closeErr
		}
	})
	return s.stopErr
}

func isUseOfClosedError(err error) bool {
	if err == nil {
		return false
	}
	if opErr, ok := err.(*net.OpError); ok {
		err = opErr.Err
	}
	return err.Error() == "use of closed network connection"
}

// ListenAndServe is a convenience function to serve and wait for a SIGTERM
// or SIGINT before shutting down.
func ListenAndServe(s *http.Server, hd *HTTP) error {
	if hd == nil {
		hd = &HTTP{}
	}
	hs, err := hd.ListenAndServe(s)
	if err != nil {
		return err
	}

	waiterr := make(chan error, 1)
	go func() {
		defer close(waiterr)
		waiterr <- hs.Wait()
	}()

	signals := make(chan os.Signal, 10)
	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)

	select {
	case err := <-waiterr:
		if err != nil {
			return err
		}
	case <-signals:
		signal.Stop(signals)
		if err := hs.Stop(); err != nil {
			return err
		}
		if err := <-waiterr; err != nil {
			return err
		}
	}
	return nil
}