package graceful import ( "crypto/tls" "log" "net" "net/http" "os" "sync" "time" ) // Server wraps an http.Server with graceful connection handling. // It may be used directly in the same way as http.Server, or may // be constructed with the global functions in this package. // // Example: // srv := &graceful.Server{ // Timeout: 5 * time.Second, // Server: &http.Server{Addr: ":1234", Handler: handler}, // } // srv.ListenAndServe() type Server struct { *http.Server // Timeout is the duration to allow outstanding requests to survive // before forcefully terminating them. Timeout time.Duration // Limit the number of outstanding requests ListenLimit int // TCPKeepAlive sets the TCP keep-alive timeouts on accepted // connections. It prunes dead TCP connections ( e.g. closing // laptop mid-download) TCPKeepAlive time.Duration // ConnState specifies an optional callback function that is // called when a client connection changes state. This is a proxy // to the underlying http.Server's ConnState, and the original // must not be set directly. ConnState func(net.Conn, http.ConnState) // BeforeShutdown is an optional callback function that is called // before the listener is closed. Returns true if shutdown is allowed BeforeShutdown func() bool // ShutdownInitiated is an optional callback function that is called // when shutdown is initiated. It can be used to notify the client // side of long lived connections (e.g. websockets) to reconnect. ShutdownInitiated func() // NoSignalHandling prevents graceful from automatically shutting down // on SIGINT and SIGTERM. If set to true, you must shut down the server // manually with Stop(). NoSignalHandling bool // Logger used to notify of errors on startup and on stop. Logger *log.Logger // LogFunc can be assigned with a logging function of your choice, allowing // you to use whatever logging approach you would like LogFunc func(format string, args ...interface{}) // Interrupted is true if the server is handling a SIGINT or SIGTERM // signal and is thus shutting down. Interrupted bool // interrupt signals the listener to stop serving connections, // and the server to shut down. interrupt chan os.Signal // stopLock is used to protect against concurrent calls to Stop stopLock sync.Mutex // stopChan is the channel on which callers may block while waiting for // the server to stop. stopChan chan struct{} // chanLock is used to protect access to the various channel constructors. chanLock sync.RWMutex // connections holds all connections managed by graceful connections map[net.Conn]struct{} // idleConnections holds all idle connections managed by graceful idleConnections map[net.Conn]struct{} } // Run serves the http.Handler with graceful shutdown enabled. // // timeout is the duration to wait until killing active requests and stopping the server. // If timeout is 0, the server never times out. It waits for all active requests to finish. func Run(addr string, timeout time.Duration, n http.Handler) { srv := &Server{ Timeout: timeout, TCPKeepAlive: 3 * time.Minute, Server: &http.Server{Addr: addr, Handler: n}, // Logger: DefaultLogger(), } if err := srv.ListenAndServe(); err != nil { if opErr, ok := err.(*net.OpError); !ok || (ok && opErr.Op != "accept") { srv.logf("%s", err) os.Exit(1) } } } // RunWithErr is an alternative version of Run function which can return error. // // Unlike Run this version will not exit the program if an error is encountered but will // return it instead. func RunWithErr(addr string, timeout time.Duration, n http.Handler) error { srv := &Server{ Timeout: timeout, TCPKeepAlive: 3 * time.Minute, Server: &http.Server{Addr: addr, Handler: n}, Logger: DefaultLogger(), } return srv.ListenAndServe() } // ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled. // // timeout is the duration to wait until killing active requests and stopping the server. // If timeout is 0, the server never times out. It waits for all active requests to finish. func ListenAndServe(server *http.Server, timeout time.Duration) error { srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} return srv.ListenAndServe() } // ListenAndServe is equivalent to http.Server.ListenAndServe with graceful shutdown enabled. func (srv *Server) ListenAndServe() error { // Create the listener so we can control their lifetime addr := srv.Addr if addr == "" { addr = ":http" } conn, err := srv.newTCPListener(addr) if err != nil { return err } return srv.Serve(conn) } // ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled. // // timeout is the duration to wait until killing active requests and stopping the server. // If timeout is 0, the server never times out. It waits for all active requests to finish. func ListenAndServeTLS(server *http.Server, certFile, keyFile string, timeout time.Duration) error { srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} return srv.ListenAndServeTLS(certFile, keyFile) } // ListenTLS is a convenience method that creates an https listener using the // provided cert and key files. Use this method if you need access to the // listener object directly. When ready, pass it to the Serve method. func (srv *Server) ListenTLS(certFile, keyFile string) (net.Listener, error) { // Create the listener ourselves so we can control its lifetime addr := srv.Addr if addr == "" { addr = ":https" } config := &tls.Config{} if srv.TLSConfig != nil { *config = *srv.TLSConfig } var err error if certFile != "" && keyFile != "" { config.Certificates = make([]tls.Certificate, 1) config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } } // Enable http2 enableHTTP2ForTLSConfig(config) conn, err := srv.newTCPListener(addr) if err != nil { return nil, err } srv.TLSConfig = config tlsListener := tls.NewListener(conn, config) return tlsListener, nil } // Enable HTTP2ForTLSConfig explicitly enables http/2 for a TLS Config. This is due to changes in Go 1.7 where // http servers are no longer automatically configured to enable http/2 if the server's TLSConfig is set. // See https://github.com/golang/go/issues/15908 func enableHTTP2ForTLSConfig(t *tls.Config) { if TLSConfigHasHTTP2Enabled(t) { return } t.NextProtos = append(t.NextProtos, "h2") } // TLSConfigHasHTTP2Enabled checks to see if a given TLS Config has http2 enabled. func TLSConfigHasHTTP2Enabled(t *tls.Config) bool { for _, value := range t.NextProtos { if value == "h2" { return true } } return false } // ListenAndServeTLS is equivalent to http.Server.ListenAndServeTLS with graceful shutdown enabled. func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { l, err := srv.ListenTLS(certFile, keyFile) if err != nil { return err } return srv.Serve(l) } // ListenAndServeTLSConfig can be used with an existing TLS config and is equivalent to // http.Server.ListenAndServeTLS with graceful shutdown enabled, func (srv *Server) ListenAndServeTLSConfig(config *tls.Config) error { addr := srv.Addr if addr == "" { addr = ":https" } conn, err := srv.newTCPListener(addr) if err != nil { return err } srv.TLSConfig = config tlsListener := tls.NewListener(conn, config) return srv.Serve(tlsListener) } // Serve is equivalent to http.Server.Serve with graceful shutdown enabled. // // timeout is the duration to wait until killing active requests and stopping the server. // If timeout is 0, the server never times out. It waits for all active requests to finish. func Serve(server *http.Server, l net.Listener, timeout time.Duration) error { srv := &Server{Timeout: timeout, Server: server, Logger: DefaultLogger()} return srv.Serve(l) } // Serve is equivalent to http.Server.Serve with graceful shutdown enabled. func (srv *Server) Serve(listener net.Listener) error { if srv.ListenLimit != 0 { listener = LimitListener(listener, srv.ListenLimit) } // Make our stopchan srv.StopChan() // Track connection state add := make(chan net.Conn) idle := make(chan net.Conn) active := make(chan net.Conn) remove := make(chan net.Conn) srv.Server.ConnState = func(conn net.Conn, state http.ConnState) { switch state { case http.StateNew: add <- conn case http.StateActive: active <- conn case http.StateIdle: idle <- conn case http.StateClosed, http.StateHijacked: remove <- conn } srv.stopLock.Lock() defer srv.stopLock.Unlock() if srv.ConnState != nil { srv.ConnState(conn, state) } } // Manage open connections shutdown := make(chan chan struct{}) kill := make(chan struct{}) go srv.manageConnections(add, idle, active, remove, shutdown, kill) interrupt := srv.interruptChan() // Set up the interrupt handler if !srv.NoSignalHandling { signalNotify(interrupt) } quitting := make(chan struct{}) go srv.handleInterrupt(interrupt, quitting, listener) // Serve with graceful listener. // Execution blocks here until listener.Close() is called, above. err := srv.Server.Serve(listener) if err != nil { // If the underlying listening is closed, Serve returns an error // complaining about listening on a closed socket. This is expected, so // let's ignore the error if we are the ones who explicitly closed the // socket. select { case <-quitting: err = nil default: } } srv.shutdown(shutdown, kill) return err } // Stop instructs the type to halt operations and close // the stop channel when it is finished. // // timeout is grace period for which to wait before shutting // down the server. The timeout value passed here will override the // timeout given when constructing the server, as this is an explicit // command to stop the server. func (srv *Server) Stop(timeout time.Duration) { srv.stopLock.Lock() defer srv.stopLock.Unlock() srv.Timeout = timeout sendSignalInt(srv.interruptChan()) } // StopChan gets the stop channel which will block until // stopping has completed, at which point it is closed. // Callers should never close the stop channel. func (srv *Server) StopChan() <-chan struct{} { srv.chanLock.Lock() defer srv.chanLock.Unlock() if srv.stopChan == nil { srv.stopChan = make(chan struct{}) } return srv.stopChan } // DefaultLogger returns the logger used by Run, RunWithErr, ListenAndServe, ListenAndServeTLS and Serve. // The logger outputs to STDERR by default. func DefaultLogger() *log.Logger { return log.New(os.Stderr, "[graceful] ", 0) } func (srv *Server) manageConnections(add, idle, active, remove chan net.Conn, shutdown chan chan struct{}, kill chan struct{}) { var done chan struct{} srv.connections = map[net.Conn]struct{}{} srv.idleConnections = map[net.Conn]struct{}{} for { select { case conn := <-add: srv.connections[conn] = struct{}{} srv.idleConnections[conn] = struct{}{} // Newly-added connections are considered idle until they become active. case conn := <-idle: srv.idleConnections[conn] = struct{}{} case conn := <-active: delete(srv.idleConnections, conn) case conn := <-remove: delete(srv.connections, conn) delete(srv.idleConnections, conn) if done != nil && len(srv.connections) == 0 { done <- struct{}{} return } case done = <-shutdown: if len(srv.connections) == 0 && len(srv.idleConnections) == 0 { done <- struct{}{} return } // a shutdown request has been received. if we have open idle // connections, we must close all of them now. this prevents idle // connections from holding the server open while waiting for them to // hit their idle timeout. for k := range srv.idleConnections { if err := k.Close(); err != nil { srv.logf("[ERROR] %s", err) } } case <-kill: srv.stopLock.Lock() defer srv.stopLock.Unlock() srv.Server.ConnState = nil for k := range srv.connections { if err := k.Close(); err != nil { srv.logf("[ERROR] %s", err) } } return } } } func (srv *Server) interruptChan() chan os.Signal { srv.chanLock.Lock() defer srv.chanLock.Unlock() if srv.interrupt == nil { srv.interrupt = make(chan os.Signal, 1) } return srv.interrupt } func (srv *Server) handleInterrupt(interrupt chan os.Signal, quitting chan struct{}, listener net.Listener) { for _ = range interrupt { if srv.Interrupted { srv.logf("already shutting down") continue } srv.logf("shutdown initiated") srv.Interrupted = true if srv.BeforeShutdown != nil { if !srv.BeforeShutdown() { srv.Interrupted = false continue } } close(quitting) srv.SetKeepAlivesEnabled(false) if err := listener.Close(); err != nil { srv.logf("[ERROR] %s", err) } if srv.ShutdownInitiated != nil { srv.ShutdownInitiated() } } } func (srv *Server) logf(format string, args ...interface{}) { if srv.LogFunc != nil { srv.LogFunc(format, args...) } else if srv.Logger != nil { srv.Logger.Printf(format, args...) } } func (srv *Server) shutdown(shutdown chan chan struct{}, kill chan struct{}) { // Request done notification done := make(chan struct{}) shutdown <- done srv.stopLock.Lock() defer srv.stopLock.Unlock() if srv.Timeout > 0 { select { case <-done: case <-time.After(srv.Timeout): close(kill) } } else { <-done } // Close the stopChan to wake up any blocked goroutines. srv.chanLock.Lock() if srv.stopChan != nil { close(srv.stopChan) } srv.chanLock.Unlock() } func (srv *Server) newTCPListener(addr string) (net.Listener, error) { conn, err := net.Listen("tcp", addr) if err != nil { return conn, err } if srv.TCPKeepAlive != 0 { conn = keepAliveListener{conn, srv.TCPKeepAlive} } return conn, nil }