summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/labstack/echo/v4/middleware/csrf.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/labstack/echo/v4/middleware/csrf.go')
-rw-r--r--vendor/github.com/labstack/echo/v4/middleware/csrf.go110
1 files changed, 47 insertions, 63 deletions
diff --git a/vendor/github.com/labstack/echo/v4/middleware/csrf.go b/vendor/github.com/labstack/echo/v4/middleware/csrf.go
index 7804997d..61299f5c 100644
--- a/vendor/github.com/labstack/echo/v4/middleware/csrf.go
+++ b/vendor/github.com/labstack/echo/v4/middleware/csrf.go
@@ -2,9 +2,7 @@ package middleware
import (
"crypto/subtle"
- "errors"
"net/http"
- "strings"
"time"
"github.com/labstack/echo/v4"
@@ -21,13 +19,15 @@ type (
TokenLength uint8 `yaml:"token_length"`
// Optional. Default value 32.
- // TokenLookup is a string in the form of "<source>:<key>" that is used
+ // TokenLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
// to extract token from the request.
// Optional. Default value "header:X-CSRF-Token".
// Possible values:
- // - "header:<name>"
- // - "form:<name>"
+ // - "header:<name>" or "header:<name>:<cut-prefix>"
// - "query:<name>"
+ // - "form:<name>"
+ // Multiple sources example:
+ // - "header:X-CSRF-Token,query:csrf"
TokenLookup string `yaml:"token_lookup"`
// Context key to store generated CSRF token into context.
@@ -62,12 +62,11 @@ type (
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
}
-
- // csrfTokenExtractor defines a function that takes `echo.Context` and returns
- // either a token or an error.
- csrfTokenExtractor func(echo.Context) (string, error)
)
+// ErrCSRFInvalid is returned when CSRF check fails
+var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
+
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
@@ -114,14 +113,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
config.CookieSecure = true
}
- // Initialize
- parts := strings.Split(config.TokenLookup, ":")
- extractor := csrfTokenFromHeader(parts[1])
- switch parts[0] {
- case "form":
- extractor = csrfTokenFromForm(parts[1])
- case "query":
- extractor = csrfTokenFromQuery(parts[1])
+ extractors, err := createExtractors(config.TokenLookup, "")
+ if err != nil {
+ panic(err)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@@ -130,28 +124,50 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
return next(c)
}
- req := c.Request()
- k, err := c.Cookie(config.CookieName)
token := ""
-
- // Generate token
- if err != nil {
- token = random.String(config.TokenLength)
+ if k, err := c.Cookie(config.CookieName); err != nil {
+ token = random.String(config.TokenLength) // Generate token
} else {
- // Reuse token
- token = k.Value
+ token = k.Value // Reuse token
}
- switch req.Method {
+ switch c.Request().Method {
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
default:
// Validate token only for requests which are not defined as 'safe' by RFC7231
- clientToken, err := extractor(c)
- if err != nil {
- return echo.NewHTTPError(http.StatusBadRequest, err.Error())
+ var lastExtractorErr error
+ var lastTokenErr error
+ outer:
+ for _, extractor := range extractors {
+ clientTokens, err := extractor(c)
+ if err != nil {
+ lastExtractorErr = err
+ continue
+ }
+
+ for _, clientToken := range clientTokens {
+ if validateCSRFToken(token, clientToken) {
+ lastTokenErr = nil
+ lastExtractorErr = nil
+ break outer
+ }
+ lastTokenErr = ErrCSRFInvalid
+ }
}
- if !validateCSRFToken(token, clientToken) {
- return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token")
+ if lastTokenErr != nil {
+ return lastTokenErr
+ } else if lastExtractorErr != nil {
+ // ugly part to preserve backwards compatible errors. someone could rely on them
+ if lastExtractorErr == errQueryExtractorValueMissing {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string")
+ } else if lastExtractorErr == errFormExtractorValueMissing {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter")
+ } else if lastExtractorErr == errHeaderExtractorValueMissing {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header")
+ } else {
+ lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error())
+ }
+ return lastExtractorErr
}
}
@@ -184,38 +200,6 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
}
}
-// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
-// provided request header.
-func csrfTokenFromHeader(header string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- return c.Request().Header.Get(header), nil
- }
-}
-
-// csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the
-// provided form parameter.
-func csrfTokenFromForm(param string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- token := c.FormValue(param)
- if token == "" {
- return "", errors.New("missing csrf token in the form parameter")
- }
- return token, nil
- }
-}
-
-// csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the
-// provided query parameter.
-func csrfTokenFromQuery(param string) csrfTokenExtractor {
- return func(c echo.Context) (string, error) {
- token := c.QueryParam(param)
- if token == "" {
- return "", errors.New("missing csrf token in the query string")
- }
- return token, nil
- }
-}
-
func validateCSRFToken(token, clientToken string) bool {
return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1
}