diff options
Diffstat (limited to 'vendor/github.com/labstack/echo/v4/middleware/csrf.go')
-rw-r--r-- | vendor/github.com/labstack/echo/v4/middleware/csrf.go | 110 |
1 files changed, 47 insertions, 63 deletions
diff --git a/vendor/github.com/labstack/echo/v4/middleware/csrf.go b/vendor/github.com/labstack/echo/v4/middleware/csrf.go index 7804997d..61299f5c 100644 --- a/vendor/github.com/labstack/echo/v4/middleware/csrf.go +++ b/vendor/github.com/labstack/echo/v4/middleware/csrf.go @@ -2,9 +2,7 @@ package middleware import ( "crypto/subtle" - "errors" "net/http" - "strings" "time" "github.com/labstack/echo/v4" @@ -21,13 +19,15 @@ type ( TokenLength uint8 `yaml:"token_length"` // Optional. Default value 32. - // TokenLookup is a string in the form of "<source>:<key>" that is used + // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used // to extract token from the request. // Optional. Default value "header:X-CSRF-Token". // Possible values: - // - "header:<name>" - // - "form:<name>" + // - "header:<name>" or "header:<name>:<cut-prefix>" // - "query:<name>" + // - "form:<name>" + // Multiple sources example: + // - "header:X-CSRF-Token,query:csrf" TokenLookup string `yaml:"token_lookup"` // Context key to store generated CSRF token into context. @@ -62,12 +62,11 @@ type ( // Optional. Default value SameSiteDefaultMode. CookieSameSite http.SameSite `yaml:"cookie_same_site"` } - - // csrfTokenExtractor defines a function that takes `echo.Context` and returns - // either a token or an error. - csrfTokenExtractor func(echo.Context) (string, error) ) +// ErrCSRFInvalid is returned when CSRF check fails +var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ @@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { config.CookieSecure = true } - // Initialize - parts := strings.Split(config.TokenLookup, ":") - extractor := csrfTokenFromHeader(parts[1]) - switch parts[0] { - case "form": - extractor = csrfTokenFromForm(parts[1]) - case "query": - extractor = csrfTokenFromQuery(parts[1]) + extractors, err := createExtractors(config.TokenLookup, "") + if err != nil { + panic(err) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - req := c.Request() - k, err := c.Cookie(config.CookieName) token := "" - - // Generate token - if err != nil { - token = random.String(config.TokenLength) + if k, err := c.Cookie(config.CookieName); err != nil { + token = random.String(config.TokenLength) // Generate token } else { - // Reuse token - token = k.Value + token = k.Value // Reuse token } - switch req.Method { + switch c.Request().Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 - clientToken, err := extractor(c) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + var lastExtractorErr error + var lastTokenErr error + outer: + for _, extractor := range extractors { + clientTokens, err := extractor(c) + if err != nil { + lastExtractorErr = err + continue + } + + for _, clientToken := range clientTokens { + if validateCSRFToken(token, clientToken) { + lastTokenErr = nil + lastExtractorErr = nil + break outer + } + lastTokenErr = ErrCSRFInvalid + } } - if !validateCSRFToken(token, clientToken) { - return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") + if lastTokenErr != nil { + return lastTokenErr + } else if lastExtractorErr != nil { + // ugly part to preserve backwards compatible errors. someone could rely on them + if lastExtractorErr == errQueryExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") + } else if lastExtractorErr == errFormExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") + } else if lastExtractorErr == errHeaderExtractorValueMissing { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") + } else { + lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) + } + return lastExtractorErr } } @@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { } } -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided request header. -func csrfTokenFromHeader(header string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - return c.Request().Header.Get(header), nil - } -} - -// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the -// provided form parameter. -func csrfTokenFromForm(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.FormValue(param) - if token == "" { - return "", errors.New("missing csrf token in the form parameter") - } - return token, nil - } -} - -// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the -// provided query parameter. -func csrfTokenFromQuery(param string) csrfTokenExtractor { - return func(c echo.Context) (string, error) { - token := c.QueryParam(param) - if token == "" { - return "", errors.New("missing csrf token in the query string") - } - return token, nil - } -} - func validateCSRFToken(token, clientToken string) bool { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } |