package middleware import ( "fmt" "io" "math/rand" "net" "net/http" "net/http/httputil" "net/url" "sync/atomic" "time" "github.com/labstack/echo" ) // TODO: Handle TLS proxy type ( // ProxyConfig defines the config for Proxy middleware. ProxyConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Balancer defines a load balancing technique. // Required. // Possible values: // - RandomBalancer // - RoundRobinBalancer Balancer ProxyBalancer } // ProxyTarget defines the upstream target. ProxyTarget struct { URL *url.URL } // RandomBalancer implements a random load balancing technique. RandomBalancer struct { Targets []*ProxyTarget random *rand.Rand } // RoundRobinBalancer implements a round-robin load balancing technique. RoundRobinBalancer struct { Targets []*ProxyTarget i uint32 } // ProxyBalancer defines an interface to implement a load balancing technique. ProxyBalancer interface { Next() *ProxyTarget } ) var ( // DefaultProxyConfig is the default Proxy middleware config. DefaultProxyConfig = ProxyConfig{ Skipper: DefaultSkipper, } ) func proxyHTTP(t *ProxyTarget) http.Handler { return httputil.NewSingleHostReverseProxy(t.URL) } func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { in, _, err := c.Response().Hijack() if err != nil { c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", t.URL, err)) return } defer in.Close() out, err := net.Dial("tcp", t.URL.Host) if err != nil { he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", t.URL, err)) c.Error(he) return } defer out.Close() // Write header err = r.Write(out) if err != nil { he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request header copy error=%v, url=%s", t.URL, err)) c.Error(he) return } errc := make(chan error, 2) cp := func(dst io.Writer, src io.Reader) { _, err := io.Copy(dst, src) errc <- err } go cp(out, in) go cp(in, out) err = <-errc if err != nil && err != io.EOF { c.Logger().Errorf("proxy raw, copy body error=%v, url=%s", t.URL, err) } }) } // Next randomly returns an upstream target. func (r *RandomBalancer) Next() *ProxyTarget { if r.random == nil { r.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) } return r.Targets[r.random.Intn(len(r.Targets))] } // Next returns an upstream target using round-robin technique. func (r *RoundRobinBalancer) Next() *ProxyTarget { r.i = r.i % uint32(len(r.Targets)) t := r.Targets[r.i] atomic.AddUint32(&r.i, 1) return t } // Proxy returns a Proxy middleware. // // Proxy middleware forwards the request to upstream server using a configured load balancing technique. func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { c := DefaultProxyConfig c.Balancer = balancer return ProxyWithConfig(c) } // ProxyWithConfig returns a Proxy middleware with config. // See: `Proxy()` func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { config.Skipper = DefaultLoggerConfig.Skipper } if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (err error) { if config.Skipper(c) { return next(c) } req := c.Request() res := c.Response() tgt := config.Balancer.Next() // Fix header if req.Header.Get(echo.HeaderXRealIP) == "" { req.Header.Set(echo.HeaderXRealIP, c.RealIP()) } if req.Header.Get(echo.HeaderXForwardedProto) == "" { req.Header.Set(echo.HeaderXForwardedProto, c.Scheme()) } if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy. req.Header.Set(echo.HeaderXForwardedFor, c.RealIP()) } // Proxy switch { case c.IsWebSocket(): proxyRaw(tgt, c).ServeHTTP(res, req) case req.Header.Get(echo.HeaderAccept) == "text/event-stream": default: proxyHTTP(tgt).ServeHTTP(res, req) } return } } }