summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/gorilla/websocket/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gorilla/websocket/client.go')
-rw-r--r--vendor/github.com/gorilla/websocket/client.go243
1 files changed, 123 insertions, 120 deletions
diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go
index 43a87c75..2e32fd50 100644
--- a/vendor/github.com/gorilla/websocket/client.go
+++ b/vendor/github.com/gorilla/websocket/client.go
@@ -5,15 +5,15 @@
package websocket
import (
- "bufio"
"bytes"
+ "context"
"crypto/tls"
- "encoding/base64"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
+ "net/http/httptrace"
"net/url"
"strings"
"time"
@@ -53,6 +53,10 @@ type Dialer struct {
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
+ // NetDialContext specifies the dial function for creating TCP connections. If
+ // NetDialContext is nil, net.DialContext is used.
+ NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
+
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@@ -71,6 +75,17 @@ type Dialer struct {
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
+ // WriteBufferPool is a pool of buffers for write operations. If the value
+ // is not set, then write buffers are allocated to the connection for the
+ // lifetime of the connection.
+ //
+ // A pool is most useful when the application has a modest volume of writes
+ // across a large number of connections.
+ //
+ // Applications should use a single pool for each unique value of
+ // WriteBufferSize.
+ WriteBufferPool BufferPool
+
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
@@ -86,52 +101,13 @@ type Dialer struct {
Jar http.CookieJar
}
-var errMalformedURL = errors.New("malformed ws or wss URL")
-
-// parseURL parses the URL.
-//
-// This function is a replacement for the standard library url.Parse function.
-// In Go 1.4 and earlier, url.Parse loses information from the path.
-func parseURL(s string) (*url.URL, error) {
- // From the RFC:
- //
- // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
- // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
- var u url.URL
- switch {
- case strings.HasPrefix(s, "ws://"):
- u.Scheme = "ws"
- s = s[len("ws://"):]
- case strings.HasPrefix(s, "wss://"):
- u.Scheme = "wss"
- s = s[len("wss://"):]
- default:
- return nil, errMalformedURL
- }
-
- if i := strings.Index(s, "?"); i >= 0 {
- u.RawQuery = s[i+1:]
- s = s[:i]
- }
-
- if i := strings.Index(s, "/"); i >= 0 {
- u.Opaque = s[i:]
- s = s[:i]
- } else {
- u.Opaque = "/"
- }
-
- u.Host = s
-
- if strings.Contains(u.Host, "@") {
- // Don't bother parsing user information because user information is
- // not allowed in websocket URIs.
- return nil, errMalformedURL
- }
-
- return &u, nil
+// Dial creates a new client connection by calling DialContext with a background context.
+func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
+ return d.DialContext(context.Background(), urlStr, requestHeader)
}
+var errMalformedURL = errors.New("malformed ws or wss URL")
+
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
@@ -150,26 +126,29 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
return hostPort, hostNoPort
}
-// DefaultDialer is a dialer with all fields set to the default zero values.
+// DefaultDialer is a dialer with all fields set to the default values.
var DefaultDialer = &Dialer{
- Proxy: http.ProxyFromEnvironment,
+ Proxy: http.ProxyFromEnvironment,
+ HandshakeTimeout: 45 * time.Second,
}
-// Dial creates a new client connection. Use requestHeader to specify the
+// nilDialer is dialer to use when receiver is nil.
+var nilDialer = *DefaultDialer
+
+// DialContext creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
+// The context will be used in the request and in the Dialer
+//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
-func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
-
+func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
- d = &Dialer{
- Proxy: http.ProxyFromEnvironment,
- }
+ d = &nilDialer
}
challengeKey, err := generateChallengeKey()
@@ -177,7 +156,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, nil, err
}
- u, err := parseURL(urlStr)
+ u, err := url.Parse(urlStr)
if err != nil {
return nil, nil, err
}
@@ -205,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
Header: make(http.Header),
Host: u.Host,
}
+ req = req.WithContext(ctx)
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
@@ -237,45 +217,83 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
+ case k == "Sec-Websocket-Protocol":
+ req.Header["Sec-WebSocket-Protocol"] = vs
default:
req.Header[k] = vs
}
}
if d.EnableCompression {
- req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
+ req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
}
- hostPort, hostNoPort := hostPortNoPort(u)
-
- var proxyURL *url.URL
- // Check wether the proxy method has been configured
- if d.Proxy != nil {
- proxyURL, err = d.Proxy(req)
- }
- if err != nil {
- return nil, nil, err
+ if d.HandshakeTimeout != 0 {
+ var cancel func()
+ ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
+ defer cancel()
}
- var targetHostPort string
- if proxyURL != nil {
- targetHostPort, _ = hostPortNoPort(proxyURL)
+ // Get network dial function.
+ var netDial func(network, add string) (net.Conn, error)
+
+ if d.NetDialContext != nil {
+ netDial = func(network, addr string) (net.Conn, error) {
+ return d.NetDialContext(ctx, network, addr)
+ }
+ } else if d.NetDial != nil {
+ netDial = d.NetDial
} else {
- targetHostPort = hostPort
+ netDialer := &net.Dialer{}
+ netDial = func(network, addr string) (net.Conn, error) {
+ return netDialer.DialContext(ctx, network, addr)
+ }
}
- var deadline time.Time
- if d.HandshakeTimeout != 0 {
- deadline = time.Now().Add(d.HandshakeTimeout)
+ // If needed, wrap the dial function to set the connection deadline.
+ if deadline, ok := ctx.Deadline(); ok {
+ forwardDial := netDial
+ netDial = func(network, addr string) (net.Conn, error) {
+ c, err := forwardDial(network, addr)
+ if err != nil {
+ return nil, err
+ }
+ err = c.SetDeadline(deadline)
+ if err != nil {
+ c.Close()
+ return nil, err
+ }
+ return c, nil
+ }
}
- netDial := d.NetDial
- if netDial == nil {
- netDialer := &net.Dialer{Deadline: deadline}
- netDial = netDialer.Dial
+ // If needed, wrap the dial function to connect through a proxy.
+ if d.Proxy != nil {
+ proxyURL, err := d.Proxy(req)
+ if err != nil {
+ return nil, nil, err
+ }
+ if proxyURL != nil {
+ dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
+ if err != nil {
+ return nil, nil, err
+ }
+ netDial = dialer.Dial
+ }
+ }
+
+ hostPort, hostNoPort := hostPortNoPort(u)
+ trace := httptrace.ContextClientTrace(ctx)
+ if trace != nil && trace.GetConn != nil {
+ trace.GetConn(hostPort)
}
- netConn, err := netDial("tcp", targetHostPort)
+ netConn, err := netDial("tcp", hostPort)
+ if trace != nil && trace.GotConn != nil {
+ trace.GotConn(httptrace.GotConnInfo{
+ Conn: netConn,
+ })
+ }
if err != nil {
return nil, nil, err
}
@@ -286,42 +304,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
}()
- if err := netConn.SetDeadline(deadline); err != nil {
- return nil, nil, err
- }
-
- if proxyURL != nil {
- connectHeader := make(http.Header)
- if user := proxyURL.User; user != nil {
- proxyUser := user.Username()
- if proxyPassword, passwordSet := user.Password(); passwordSet {
- credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
- connectHeader.Set("Proxy-Authorization", "Basic "+credential)
- }
- }
- connectReq := &http.Request{
- Method: "CONNECT",
- URL: &url.URL{Opaque: hostPort},
- Host: hostPort,
- Header: connectHeader,
- }
-
- connectReq.Write(netConn)
-
- // Read response.
- // Okay to use and discard buffered reader here, because
- // TLS server will not speak until spoken to.
- br := bufio.NewReader(netConn)
- resp, err := http.ReadResponse(br, connectReq)
- if err != nil {
- return nil, nil, err
- }
- if resp.StatusCode != 200 {
- f := strings.SplitN(resp.Status, " ", 2)
- return nil, nil, errors.New(f[1])
- }
- }
-
if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
@@ -329,22 +311,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
- if err := tlsConn.Handshake(); err != nil {
- return nil, nil, err
+
+ var err error
+ if trace != nil {
+ err = doHandshakeWithTrace(trace, tlsConn, cfg)
+ } else {
+ err = doHandshake(tlsConn, cfg)
}
- if !cfg.InsecureSkipVerify {
- if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
- return nil, nil, err
- }
+
+ if err != nil {
+ return nil, nil, err
}
}
- conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
+ conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
+ if trace != nil && trace.GotFirstResponseByte != nil {
+ if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
+ trace.GotFirstResponseByte()
+ }
+ }
+
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
return nil, nil, err
@@ -390,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
+
+func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
+ if err := tlsConn.Handshake(); err != nil {
+ return err
+ }
+ if !cfg.InsecureSkipVerify {
+ if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
+ return err
+ }
+ }
+ return nil
+}