package middleware import ( "strings" "github.com/labstack/echo/v4" ) type ( // TrailingSlashConfig defines the config for TrailingSlash middleware. TrailingSlashConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // Status code to be used when redirecting the request. // Optional, but when provided the request is redirected using this code. RedirectCode int `yaml:"redirect_code"` } ) var ( // DefaultTrailingSlashConfig is the default TrailingSlash middleware config. DefaultTrailingSlashConfig = TrailingSlashConfig{ Skipper: DefaultSkipper, } ) // AddTrailingSlash returns a root level (before router) middleware which adds a // trailing slash to the request `URL#Path`. // // Usage `Echo#Pre(AddTrailingSlash())` func AddTrailingSlash() echo.MiddlewareFunc { return AddTrailingSlashWithConfig(DefaultTrailingSlashConfig) } // AddTrailingSlashWithConfig returns a AddTrailingSlash middleware with config. // See `AddTrailingSlash()`. func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { config.Skipper = DefaultTrailingSlashConfig.Skipper } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() url := req.URL path := url.Path qs := c.QueryString() if !strings.HasSuffix(path, "/") { path += "/" uri := path if qs != "" { uri += "?" + qs } // Redirect if config.RedirectCode != 0 { return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward req.RequestURI = uri url.Path = path } return next(c) } } } // RemoveTrailingSlash returns a root level (before router) middleware which removes // a trailing slash from the request URI. // // Usage `Echo#Pre(RemoveTrailingSlash())` func RemoveTrailingSlash() echo.MiddlewareFunc { return RemoveTrailingSlashWithConfig(TrailingSlashConfig{}) } // RemoveTrailingSlashWithConfig returns a RemoveTrailingSlash middleware with config. // See `RemoveTrailingSlash()`. func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { config.Skipper = DefaultTrailingSlashConfig.Skipper } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() url := req.URL path := url.Path qs := c.QueryString() l := len(path) - 1 if l > 0 && strings.HasSuffix(path, "/") { path = path[:l] uri := path if qs != "" { uri += "?" + qs } // Redirect if config.RedirectCode != 0 { return c.Redirect(config.RedirectCode, sanitizeURI(uri)) } // Forward req.RequestURI = uri url.Path = path } return next(c) } } } func sanitizeURI(uri string) string { // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { uri = "/" + strings.TrimLeft(uri, `/\`) } return uri }