summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/labstack/echo/v4/middleware/decompress.go
blob: c046359a2086f6443d19ae3de57c59b6a9bf6217 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package middleware

import (
	"bytes"
	"compress/gzip"
	"io"
	"io/ioutil"
	"net/http"
	"sync"

	"github.com/labstack/echo/v4"
)

type (
	// DecompressConfig defines the config for Decompress middleware.
	DecompressConfig struct {
		// Skipper defines a function to skip middleware.
		Skipper Skipper

		// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
		GzipDecompressPool Decompressor
	}
)

//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"

// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
	gzipDecompressPool() sync.Pool
}

var (
	//DefaultDecompressConfig defines the config for decompress middleware
	DefaultDecompressConfig = DecompressConfig{
		Skipper:            DefaultSkipper,
		GzipDecompressPool: &DefaultGzipDecompressPool{},
	}
)

// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}

func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
	return sync.Pool{
		New: func() interface{} {
			// create with an empty reader (but with GZIP header)
			w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
			if err != nil {
				return err
			}

			b := new(bytes.Buffer)
			w.Reset(b)
			w.Flush()
			w.Close()

			r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
			if err != nil {
				return err
			}
			return r
		},
	}
}

//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
	return DecompressWithConfig(DefaultDecompressConfig)
}

//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
	// Defaults
	if config.Skipper == nil {
		config.Skipper = DefaultGzipConfig.Skipper
	}
	if config.GzipDecompressPool == nil {
		config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
	}

	return func(next echo.HandlerFunc) echo.HandlerFunc {
		pool := config.GzipDecompressPool.gzipDecompressPool()
		return func(c echo.Context) error {
			if config.Skipper(c) {
				return next(c)
			}
			switch c.Request().Header.Get(echo.HeaderContentEncoding) {
			case GZIPEncoding:
				b := c.Request().Body

				i := pool.Get()
				gr, ok := i.(*gzip.Reader)
				if !ok {
					return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
				}

				if err := gr.Reset(b); err != nil {
					pool.Put(gr)
					if err == io.EOF { //ignore if body is empty
						return next(c)
					}
					return err
				}
				var buf bytes.Buffer
				io.Copy(&buf, gr)

				gr.Close()
				pool.Put(gr)

				b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here

				r := ioutil.NopCloser(&buf)
				c.Request().Body = r
			}
			return next(c)
		}
	}
}