diff options
Diffstat (limited to 'vendor/github.com/labstack/echo/v4/middleware')
16 files changed, 780 insertions, 65 deletions
diff --git a/vendor/github.com/labstack/echo/v4/middleware/basic_auth.go b/vendor/github.com/labstack/echo/v4/middleware/basic_auth.go index 76ba2420..8cf1ed9f 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/basic_auth.go +++ b/vendor/github.com/labstack/echo/v4/middleware/basic_auth.go @@ -73,7 +73,7 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { auth := c.Request().Header.Get(echo.HeaderAuthorization) l := len(basic) - if len(auth) > l+1 && strings.ToLower(auth[:l]) == basic { + if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { b, err := base64.StdEncoding.DecodeString(auth[l+1:]) if err != nil { return err diff --git a/vendor/github.com/labstack/echo/v4/middleware/compress.go b/vendor/github.com/labstack/echo/v4/middleware/compress.go index dd97d983..6ae19745 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/compress.go +++ b/vendor/github.com/labstack/echo/v4/middleware/compress.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "strings" + "sync" "github.com/labstack/echo/v4" ) @@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { config.Level = DefaultGzipConfig.Level } + pool := gzipCompressPool(config) + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { @@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding) if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) { res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806 - rw := res.Writer - w, err := gzip.NewWriterLevel(rw, config.Level) - if err != nil { - return err + i := pool.Get() + w, ok := i.(*gzip.Writer) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) } + rw := res.Writer + w.Reset(rw) defer func() { if res.Size == 0 { if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme { @@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { w.Reset(ioutil.Discard) } w.Close() + pool.Put(w) }() grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw} res.Writer = grw @@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { } return http.ErrNotSupported } + +func gzipCompressPool(config GzipConfig) sync.Pool { + return sync.Pool{ + New: func() interface{} { + w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level) + if err != nil { + return err + } + return w + }, + } +} diff --git a/vendor/github.com/labstack/echo/v4/middleware/cors.go b/vendor/github.com/labstack/echo/v4/middleware/cors.go index 5dfe31f9..d6ef8964 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/cors.go +++ b/vendor/github.com/labstack/echo/v4/middleware/cors.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "regexp" "strconv" "strings" @@ -18,6 +19,13 @@ type ( // Optional. Default value []string{"*"}. AllowOrigins []string `yaml:"allow_origins"` + // AllowOriginFunc is a custom function to validate the origin. It takes the + // origin as an argument and returns true if allowed or false otherwise. If + // an error is returned, it is returned by the handler. If this option is + // set, AllowOrigins is ignored. + // Optional. + AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"` + // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // Optional. Default value DefaultCORSConfig.AllowMethods. @@ -76,6 +84,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { config.AllowMethods = DefaultCORSConfig.AllowMethods } + allowOriginPatterns := []string{} + for _, origin := range config.AllowOrigins { + pattern := regexp.QuoteMeta(origin) + pattern = strings.Replace(pattern, "\\*", ".*", -1) + pattern = strings.Replace(pattern, "\\?", ".", -1) + pattern = "^" + pattern + "$" + allowOriginPatterns = append(allowOriginPatterns, pattern) + } + allowMethods := strings.Join(config.AllowMethods, ",") allowHeaders := strings.Join(config.AllowHeaders, ",") exposeHeaders := strings.Join(config.ExposeHeaders, ",") @@ -92,25 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { origin := req.Header.Get(echo.HeaderOrigin) allowOrigin := "" - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials { - allowOrigin = origin - break + preflight := req.Method == http.MethodOptions + res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + + // No Origin provided + if origin == "" { + if !preflight { + return next(c) } - if o == "*" || o == origin { - allowOrigin = o - break + return c.NoContent(http.StatusNoContent) + } + + if config.AllowOriginFunc != nil { + allowed, err := config.AllowOriginFunc(origin) + if err != nil { + return err } - if matchSubdomain(origin, o) { + if allowed { allowOrigin = origin - break + } + } else { + // Check allowed origins + for _, o := range config.AllowOrigins { + if o == "*" && config.AllowCredentials { + allowOrigin = origin + break + } + if o == "*" || o == origin { + allowOrigin = o + break + } + if matchSubdomain(origin, o) { + allowOrigin = origin + break + } + } + + // Check allowed origin patterns + for _, re := range allowOriginPatterns { + if allowOrigin == "" { + didx := strings.Index(origin, "://") + if didx == -1 { + continue + } + domAuth := origin[didx+3:] + // to avoid regex cost by invalid long domain + if len(domAuth) > 253 { + break + } + + if match, _ := regexp.MatchString(re, origin); match { + allowOrigin = origin + break + } + } } } + // Origin not allowed + if allowOrigin == "" { + if !preflight { + return next(c) + } + return c.NoContent(http.StatusNoContent) + } + // Simple request - if req.Method != http.MethodOptions { - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) + if !preflight { res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") @@ -122,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } // Preflight request - res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) diff --git a/vendor/github.com/labstack/echo/v4/middleware/csrf.go b/vendor/github.com/labstack/echo/v4/middleware/csrf.go index 09a66bb6..60f809a0 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/csrf.go +++ b/vendor/github.com/labstack/echo/v4/middleware/csrf.go @@ -57,6 +57,10 @@ type ( // Indicates if CSRF cookie is HTTP only. // Optional. Default value false. CookieHTTPOnly bool `yaml:"cookie_http_only"` + + // Indicates SameSite mode of the CSRF cookie. + // Optional. Default value SameSiteDefaultMode. + CookieSameSite http.SameSite `yaml:"cookie_same_site"` } // csrfTokenExtractor defines a function that takes `echo.Context` and returns @@ -67,12 +71,13 @@ type ( var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ - Skipper: DefaultSkipper, - TokenLength: 32, - TokenLookup: "header:" + echo.HeaderXCSRFToken, - ContextKey: "csrf", - CookieName: "_csrf", - CookieMaxAge: 86400, + Skipper: DefaultSkipper, + TokenLength: 32, + TokenLookup: "header:" + echo.HeaderXCSRFToken, + ContextKey: "csrf", + CookieName: "_csrf", + CookieMaxAge: 86400, + CookieSameSite: http.SameSiteDefaultMode, } ) @@ -105,6 +110,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } + if config.CookieSameSite == SameSiteNoneMode { + config.CookieSecure = true + } // Initialize parts := strings.Split(config.TokenLookup, ":") @@ -157,6 +165,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieDomain != "" { cookie.Domain = config.CookieDomain } + if config.CookieSameSite != http.SameSiteDefaultMode { + cookie.SameSite = config.CookieSameSite + } cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) cookie.Secure = config.CookieSecure cookie.HttpOnly = config.CookieHTTPOnly diff --git a/vendor/github.com/labstack/echo/v4/middleware/csrf_samesite.go b/vendor/github.com/labstack/echo/v4/middleware/csrf_samesite.go new file mode 100644 index 00000000..9a27dc43 --- /dev/null +++ b/vendor/github.com/labstack/echo/v4/middleware/csrf_samesite.go @@ -0,0 +1,12 @@ +// +build go1.13 + +package middleware + +import ( + "net/http" +) + +const ( + // SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524) + SameSiteNoneMode http.SameSite = http.SameSiteNoneMode +) diff --git a/vendor/github.com/labstack/echo/v4/middleware/csrf_samesite_1.12.go b/vendor/github.com/labstack/echo/v4/middleware/csrf_samesite_1.12.go new file mode 100644 index 00000000..22076dd6 --- /dev/null +++ b/vendor/github.com/labstack/echo/v4/middleware/csrf_samesite_1.12.go @@ -0,0 +1,12 @@ +// +build !go1.13 + +package middleware + +import ( + "net/http" +) + +const ( + // SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524) + SameSiteNoneMode http.SameSite = 4 +) diff --git a/vendor/github.com/labstack/echo/v4/middleware/decompress.go b/vendor/github.com/labstack/echo/v4/middleware/decompress.go new file mode 100644 index 00000000..c046359a --- /dev/null +++ b/vendor/github.com/labstack/echo/v4/middleware/decompress.go @@ -0,0 +1,120 @@ +package middleware + +import ( + "bytes" + "compress/gzip" + "io" + "io/ioutil" + "net/http" + "sync" + + "github.com/labstack/echo/v4" +) + +type ( + // DecompressConfig defines the config for Decompress middleware. + DecompressConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers + GzipDecompressPool Decompressor + } +) + +//GZIPEncoding content-encoding header if set to "gzip", decompress body contents. +const GZIPEncoding string = "gzip" + +// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers +type Decompressor interface { + gzipDecompressPool() sync.Pool +} + +var ( + //DefaultDecompressConfig defines the config for decompress middleware + DefaultDecompressConfig = DecompressConfig{ + Skipper: DefaultSkipper, + GzipDecompressPool: &DefaultGzipDecompressPool{}, + } +) + +// DefaultGzipDecompressPool is the default implementation of Decompressor interface +type DefaultGzipDecompressPool struct { +} + +func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { + return sync.Pool{ + New: func() interface{} { + // create with an empty reader (but with GZIP header) + w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed) + if err != nil { + return err + } + + b := new(bytes.Buffer) + w.Reset(b) + w.Flush() + w.Close() + + r, err := gzip.NewReader(bytes.NewReader(b.Bytes())) + if err != nil { + return err + } + return r + }, + } +} + +//Decompress decompresses request body based if content encoding type is set to "gzip" with default config +func Decompress() echo.MiddlewareFunc { + return DecompressWithConfig(DefaultDecompressConfig) +} + +//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultGzipConfig.Skipper + } + if config.GzipDecompressPool == nil { + config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + pool := config.GzipDecompressPool.gzipDecompressPool() + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + switch c.Request().Header.Get(echo.HeaderContentEncoding) { + case GZIPEncoding: + b := c.Request().Body + + i := pool.Get() + gr, ok := i.(*gzip.Reader) + if !ok { + return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + } + + if err := gr.Reset(b); err != nil { + pool.Put(gr) + if err == io.EOF { //ignore if body is empty + return next(c) + } + return err + } + var buf bytes.Buffer + io.Copy(&buf, gr) + + gr.Close() + pool.Put(gr) + + b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here + + r := ioutil.NopCloser(&buf) + c.Request().Body = r + } + return next(c) + } + } +} diff --git a/vendor/github.com/labstack/echo/v4/middleware/jwt.go b/vendor/github.com/labstack/echo/v4/middleware/jwt.go index 3c7c4868..da00ea56 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/jwt.go +++ b/vendor/github.com/labstack/echo/v4/middleware/jwt.go @@ -57,6 +57,7 @@ type ( // - "query:<name>" // - "param:<name>" // - "cookie:<name>" + // - "form:<name>" TokenLookup string // AuthScheme to be used in the Authorization header. @@ -86,6 +87,7 @@ const ( // Errors var ( ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt") + ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") ) var ( @@ -166,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { extractor = jwtFromParam(parts[1]) case "cookie": extractor = jwtFromCookie(parts[1]) + case "form": + extractor = jwtFromForm(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -213,8 +217,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc { return config.ErrorHandlerWithContext(err, c) } return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "invalid or expired jwt", + Code: ErrJWTInvalid.Code, + Message: ErrJWTInvalid.Message, Internal: err, } } @@ -265,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor { return cookie.Value, nil } } + +// jwtFromForm returns a `jwtExtractor` that extracts token from the form field. +func jwtFromForm(name string) jwtExtractor { + return func(c echo.Context) (string, error) { + field := c.FormValue(name) + if field == "" { + return "", ErrJWTMissing + } + return field, nil + } +} diff --git a/vendor/github.com/labstack/echo/v4/middleware/middleware.go b/vendor/github.com/labstack/echo/v4/middleware/middleware.go index d0b7153c..6bdb0eb7 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/middleware.go +++ b/vendor/github.com/labstack/echo/v4/middleware/middleware.go @@ -1,6 +1,8 @@ package middleware import ( + "net/http" + "net/url" "regexp" "strconv" "strings" @@ -32,6 +34,47 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return strings.NewReplacer(replace...) } +func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string { + // Initialize + rulesRegex := map[*regexp.Regexp]string{} + for k, v := range rewrite { + k = regexp.QuoteMeta(k) + k = strings.Replace(k, `\*`, "(.*?)", -1) + if strings.HasPrefix(k, `\^`) { + k = strings.Replace(k, `\^`, "^", -1) + } + k = k + "$" + rulesRegex[regexp.MustCompile(k)] = v + } + return rulesRegex +} + +func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) { + for k, v := range rewriteRegex { + rawPath := req.URL.RawPath + if rawPath != "" { + // RawPath is only set when there has been escaping done. In that case Path must be deduced from rewritten RawPath + // because encoded Path could match rules that RawPath did not + if replacer := captureTokens(k, rawPath); replacer != nil { + rawPath = replacer.Replace(v) + + req.URL.RawPath = rawPath + req.URL.Path, _ = url.PathUnescape(rawPath) + + return // rewrite only once + } + + continue + } + + if replacer := captureTokens(k, req.URL.Path); replacer != nil { + req.URL.Path = replacer.Replace(v) + + return // rewrite only once + } + } +} + // DefaultSkipper returns false which processes the middleware. func DefaultSkipper(echo.Context) bool { return false diff --git a/vendor/github.com/labstack/echo/v4/middleware/proxy.go b/vendor/github.com/labstack/echo/v4/middleware/proxy.go index a9b91f6c..63eec5a2 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/proxy.go +++ b/vendor/github.com/labstack/echo/v4/middleware/proxy.go @@ -8,7 +8,6 @@ import ( "net/http" "net/url" "regexp" - "strings" "sync" "sync/atomic" "time" @@ -37,6 +36,13 @@ type ( // "/users/*/orders/*": "/user/$1/order/$2", Rewrite map[string]string + // RegexRewrite defines rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRewrite map[*regexp.Regexp]string + // Context key to store selected ProxyTarget into context. // Optional. Default value "target". ContextKey string @@ -47,8 +53,6 @@ type ( // ModifyResponse defines function to modify response from ProxyTarget. ModifyResponse func(*http.Response) error - - rewriteRegex map[*regexp.Regexp]string } // ProxyTarget defines the upstream target. @@ -206,12 +210,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { if config.Balancer == nil { panic("echo: proxy middleware requires balancer") } - config.rewriteRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rewrite { - k = strings.Replace(k, "*", "(\\S*)", -1) - config.rewriteRegex[regexp.MustCompile(k)] = v + if config.Rewrite != nil { + if config.RegexRewrite == nil { + config.RegexRewrite = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rewrite) { + config.RegexRewrite[k] = v + } } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -225,13 +231,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { tgt := config.Balancer.Next(c) c.Set(config.ContextKey, tgt) - // Rewrite - for k, v := range config.rewriteRegex { - replacer := captureTokens(k, echo.GetPath(req)) - if replacer != nil { - req.URL.Path = replacer.Replace(v) - } - } + // Set rewrite path and raw path + rewritePath(config.RegexRewrite, req) // Fix header // Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream. diff --git a/vendor/github.com/labstack/echo/v4/middleware/proxy_1_11.go b/vendor/github.com/labstack/echo/v4/middleware/proxy_1_11.go index a4392781..17d142d8 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/proxy_1_11.go +++ b/vendor/github.com/labstack/echo/v4/middleware/proxy_1_11.go @@ -3,13 +3,22 @@ package middleware import ( + "context" "fmt" "net/http" "net/http/httputil" + "strings" "github.com/labstack/echo/v4" ) +// StatusCodeContextCanceled is a custom HTTP status code for situations +// where a client unexpectedly closed the connection to the server. +// As there is no standard error code for "client closed connection", but +// various well-known HTTP clients and server implement this HTTP code we use +// 499 too instead of the more problematic 5xx, which does not allow to detect this situation +const StatusCodeContextCanceled = 499 + func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { @@ -17,7 +26,20 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle if tgt.Name != "" { desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String()) } - c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))) + // If the client canceled the request (usually by closing the connection), we can report a + // client error (4xx) instead of a server error (5xx) to correctly identify the situation. + // The Go standard library (at of late 2020) wraps the exported, standard + // context.Canceled error with unexported garbage value requiring a substring check, see + // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 + if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { + httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) + httpError.Internal = err + c.Set("_error", httpError) + } else { + httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)) + httpError.Internal = err + c.Set("_error", httpError) + } } proxy.Transport = config.Transport proxy.ModifyResponse = config.ModifyResponse diff --git a/vendor/github.com/labstack/echo/v4/middleware/rate_limiter.go b/vendor/github.com/labstack/echo/v4/middleware/rate_limiter.go new file mode 100644 index 00000000..46a310d9 --- /dev/null +++ b/vendor/github.com/labstack/echo/v4/middleware/rate_limiter.go @@ -0,0 +1,266 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" + "golang.org/x/time/rate" +) + +type ( + // RateLimiterStore is the interface to be implemented by custom stores. + RateLimiterStore interface { + // Stores for the rate limiter have to implement the Allow method + Allow(identifier string) (bool, error) + } +) + +type ( + // RateLimiterConfig defines the configuration for the rate limiter + RateLimiterConfig struct { + Skipper Skipper + BeforeFunc BeforeFunc + // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + IdentifierExtractor Extractor + // Store defines a store for the rate limiter + Store RateLimiterStore + // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error + ErrorHandler func(context echo.Context, err error) error + // DenyHandler provides a handler to be called when RateLimiter denies access + DenyHandler func(context echo.Context, identifier string, err error) error + } + // Extractor is used to extract data from echo.Context + Extractor func(context echo.Context) (string, error) +) + +// errors +var ( + // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded + ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") + // ErrExtractorError denotes an error raised when extractor function is unsuccessful + ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") +) + +// DefaultRateLimiterConfig defines default values for RateLimiterConfig +var DefaultRateLimiterConfig = RateLimiterConfig{ + Skipper: DefaultSkipper, + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return &echo.HTTPError{ + Code: ErrExtractorError.Code, + Message: ErrExtractorError.Message, + Internal: err, + } + }, + DenyHandler: func(context echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: ErrRateLimitExceeded.Code, + Message: ErrRateLimitExceeded.Message, + Internal: err, + } + }, +} + +/* +RateLimiter returns a rate limiting middleware + + e := echo.New() + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, RateLimiter(limiterStore)) +*/ +func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { + config := DefaultRateLimiterConfig + config.Store = store + + return RateLimiterWithConfig(config) +} + +/* +RateLimiterWithConfig returns a rate limiting middleware + + e := echo.New() + + config := middleware.RateLimiterConfig{ + Skipper: DefaultSkipper, + Store: middleware.NewRateLimiterMemoryStore( + middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} + ) + IdentifierExtractor: func(ctx echo.Context) (string, error) { + id := ctx.RealIP() + return id, nil + }, + ErrorHandler: func(context echo.Context, err error) error { + return context.JSON(http.StatusTooManyRequests, nil) + }, + DenyHandler: func(context echo.Context, identifier string) error { + return context.JSON(http.StatusForbidden, nil) + }, + } + + e.GET("/rate-limited", func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }, middleware.RateLimiterWithConfig(config)) +*/ +func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + if config.Skipper == nil { + config.Skipper = DefaultRateLimiterConfig.Skipper + } + if config.IdentifierExtractor == nil { + config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor + } + if config.ErrorHandler == nil { + config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler + } + if config.DenyHandler == nil { + config.DenyHandler = DefaultRateLimiterConfig.DenyHandler + } + if config.Store == nil { + panic("Store configuration must be provided") + } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) { + return next(c) + } + if config.BeforeFunc != nil { + config.BeforeFunc(c) + } + + identifier, err := config.IdentifierExtractor(c) + if err != nil { + c.Error(config.ErrorHandler(c, err)) + return nil + } + + if allow, err := config.Store.Allow(identifier); !allow { + c.Error(config.DenyHandler(c, identifier, err)) + return nil + } + return next(c) + } + } +} + +type ( + // RateLimiterMemoryStore is the built-in store implementation for RateLimiter + RateLimiterMemoryStore struct { + visitors map[string]*Visitor + mutex sync.Mutex + rate rate.Limit + burst int + expiresIn time.Duration + lastCleanup time.Time + } + // Visitor signifies a unique user's limiter details + Visitor struct { + *rate.Limiter + lastSeen time.Time + } +) + +/* +NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with +the provided rate (as req/s). Burst and ExpiresIn will be set to default values. + +Example (with 20 requests/sec): + + limiterStore := middleware.NewRateLimiterMemoryStore(20) + +*/ +func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { + return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ + Rate: rate, + }) +} + +/* +NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore +with the provided configuration. Rate must be provided. Burst will be set to the value of +the configured rate if not provided or set to 0. + +The build-in memory store is usually capable for modest loads. For higher loads other +store implementations should be considered. + +Characteristics: +* Concurrency above 100 parallel requests may causes measurable lock contention +* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map +* A high number of requests from a single IP address may cause lock contention + +Example: + + limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( + middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes}, + ) +*/ +func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { + store = &RateLimiterMemoryStore{} + + store.rate = config.Rate + store.burst = config.Burst + store.expiresIn = config.ExpiresIn + if config.ExpiresIn == 0 { + store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn + } + if config.Burst == 0 { + store.burst = int(config.Rate) + } + store.visitors = make(map[string]*Visitor) + store.lastCleanup = now() + return +} + +// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore +type RateLimiterMemoryStoreConfig struct { + Rate rate.Limit // Rate of requests allowed to pass as req/s + Burst int // Burst additionally allows a number of requests to pass when rate limit is reached + ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up +} + +// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore +var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ + ExpiresIn: 3 * time.Minute, +} + +// Allow implements RateLimiterStore.Allow +func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { + store.mutex.Lock() + limiter, exists := store.visitors[identifier] + if !exists { + limiter = new(Visitor) + limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + store.visitors[identifier] = limiter + } + limiter.lastSeen = now() + if now().Sub(store.lastCleanup) > store.expiresIn { + store.cleanupStaleVisitors() + } + store.mutex.Unlock() + return limiter.AllowN(now(), 1), nil +} + +/* +cleanupStaleVisitors helps manage the size of the visitors map by removing stale records +of users who haven't visited again after the configured expiry time has elapsed +*/ +func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { + for id, visitor := range store.visitors { + if now().Sub(visitor.lastSeen) > store.expiresIn { + delete(store.visitors, id) + } + } + store.lastCleanup = now() +} + +/* +actual time method which is mocked in test file +*/ +var now = time.Now diff --git a/vendor/github.com/labstack/echo/v4/middleware/rewrite.go b/vendor/github.com/labstack/echo/v4/middleware/rewrite.go index d1387af0..c05d5d84 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/rewrite.go +++ b/vendor/github.com/labstack/echo/v4/middleware/rewrite.go @@ -2,7 +2,6 @@ package middleware import ( "regexp" - "strings" "github.com/labstack/echo/v4" ) @@ -23,7 +22,12 @@ type ( // Required. Rules map[string]string `yaml:"rules"` - rulesRegex map[*regexp.Regexp]string + // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures + // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. + // Example: + // "^/old/[0.9]+/": "/new", + // "^/api/.+?/(.*)": "/v2/$1", + RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"` } ) @@ -47,20 +51,19 @@ func Rewrite(rules map[string]string) echo.MiddlewareFunc { // See: `Rewrite()`. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { // Defaults - if config.Rules == nil { - panic("echo: rewrite middleware requires url path rewrite rules") + if config.Rules == nil && config.RegexRules == nil { + panic("echo: rewrite middleware requires url path rewrite rules or regex rules") } + if config.Skipper == nil { config.Skipper = DefaultBodyDumpConfig.Skipper } - config.rulesRegex = map[*regexp.Regexp]string{} - // Initialize - for k, v := range config.Rules { - k = regexp.QuoteMeta(k) - k = strings.Replace(k, `\*`, "(.*)", -1) - k = k + "$" - config.rulesRegex[regexp.MustCompile(k)] = v + if config.RegexRules == nil { + config.RegexRules = make(map[*regexp.Regexp]string) + } + for k, v := range rewriteRulesRegex(config.Rules) { + config.RegexRules[k] = v } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -70,15 +73,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } req := c.Request() - - // Rewrite - for k, v := range config.rulesRegex { - replacer := captureTokens(k, req.URL.Path) - if replacer != nil { - req.URL.Path = replacer.Replace(v) - break - } - } + // Set rewrite path and raw path + rewritePath(config.RegexRules, req) return next(c) } } diff --git a/vendor/github.com/labstack/echo/v4/middleware/slash.go b/vendor/github.com/labstack/echo/v4/middleware/slash.go index 0492b334..4188675b 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/slash.go +++ b/vendor/github.com/labstack/echo/v4/middleware/slash.go @@ -60,7 +60,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc // Redirect if config.RedirectCode != 0 { - return c.Redirect(config.RedirectCode, uri) + return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward @@ -108,7 +108,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu // Redirect if config.RedirectCode != 0 { - return c.Redirect(config.RedirectCode, uri) + return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward @@ -119,3 +119,12 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu } } } + +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) + } + return uri +} diff --git a/vendor/github.com/labstack/echo/v4/middleware/static.go b/vendor/github.com/labstack/echo/v4/middleware/static.go index bc2087a7..ae79cb5f 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/static.go +++ b/vendor/github.com/labstack/echo/v4/middleware/static.go @@ -36,6 +36,12 @@ type ( // Enable directory browsing. // Optional. Default value false. Browse bool `yaml:"browse"` + + // Enable ignoring of the base of the URL path. + // Example: when assigning a static middleware to a non root path group, + // the filesystem path is not doubled + // Optional. Default value false. + IgnoreBase bool `yaml:"ignoreBase"` } ) @@ -161,7 +167,16 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc { if err != nil { return } - name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security + name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security + + if config.IgnoreBase { + routePath := path.Base(strings.TrimRight(c.Path(), "/*")) + baseURLPath := path.Base(p) + if baseURLPath == routePath { + i := strings.LastIndex(name, routePath) + name = name[:i] + strings.Replace(name[i:], routePath, "", 1) + } + } fi, err := os.Stat(name) if err != nil { diff --git a/vendor/github.com/labstack/echo/v4/middleware/timeout.go b/vendor/github.com/labstack/echo/v4/middleware/timeout.go new file mode 100644 index 00000000..68f464e4 --- /dev/null +++ b/vendor/github.com/labstack/echo/v4/middleware/timeout.go @@ -0,0 +1,111 @@ +// +build go1.13 + +package middleware + +import ( + "context" + "github.com/labstack/echo/v4" + "net/http" + "time" +) + +type ( + // TimeoutConfig defines the config for Timeout middleware. + TimeoutConfig struct { + // Skipper defines a function to skip middleware. + Skipper Skipper + + // ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code + // It can be used to define a custom timeout error message + ErrorMessage string + + // OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after + // request timeouted and we already had sent the error code (503) and message response to the client. + // NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer + // will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()` + OnTimeoutRouteErrorHandler func(err error, c echo.Context) + + // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + // NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds) + // the result of timeout does not seem to be reliable - could respond timeout, could respond handler output + // difference over 500microseconds (0.5millisecond) response seems to be reliable + Timeout time.Duration + } +) + +var ( + // DefaultTimeoutConfig is the default Timeout middleware config. + DefaultTimeoutConfig = TimeoutConfig{ + Skipper: DefaultSkipper, + Timeout: 0, + ErrorMessage: "", + } +) + +// Timeout returns a middleware which recovers from panics anywhere in the chain +// and handles the control to the centralized HTTPErrorHandler. +func Timeout() echo.MiddlewareFunc { + return TimeoutWithConfig(DefaultTimeoutConfig) +} + +// TimeoutWithConfig returns a Timeout middleware with config. +// See: `Timeout()`. +func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc { + // Defaults + if config.Skipper == nil { + config.Skipper = DefaultTimeoutConfig.Skipper + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + if config.Skipper(c) || config.Timeout == 0 { + return next(c) + } + + handlerWrapper := echoHandlerFuncWrapper{ + ctx: c, + handler: next, + errChan: make(chan error, 1), + errHandler: config.OnTimeoutRouteErrorHandler, + } + handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage) + handler.ServeHTTP(c.Response().Writer, c.Request()) + + select { + case err := <-handlerWrapper.errChan: + return err + default: + return nil + } + } + } +} + +type echoHandlerFuncWrapper struct { + ctx echo.Context + handler echo.HandlerFunc + errHandler func(err error, c echo.Context) + errChan chan error +} + +func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + // replace writer with TimeoutHandler custom one. This will guarantee that + // `writes by h to its ResponseWriter will return ErrHandlerTimeout.` + originalWriter := t.ctx.Response().Writer + t.ctx.Response().Writer = rw + + err := t.handler(t.ctx) + if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded { + if err != nil && t.errHandler != nil { + t.errHandler(err, t.ctx) + } + return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers + } + // we restore original writer only for cases we did not timeout. On timeout we have already sent response to client + // and should not anymore send additional headers/data + // so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body + t.ctx.Response().Writer = originalWriter + if err != nil { + t.errChan <- err + } +} |