diff options
Diffstat (limited to 'vendor/github.com/gorilla/websocket/client.go')
-rw-r--r-- | vendor/github.com/gorilla/websocket/client.go | 77 |
1 files changed, 52 insertions, 25 deletions
diff --git a/vendor/github.com/gorilla/websocket/client.go b/vendor/github.com/gorilla/websocket/client.go index 962c06a3..2efd8355 100644 --- a/vendor/github.com/gorilla/websocket/client.go +++ b/vendor/github.com/gorilla/websocket/client.go @@ -48,15 +48,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS } // A Dialer contains options for connecting to WebSocket server. +// +// It is safe to call Dialer's methods concurrently. type Dialer struct { // NetDial specifies the dial function for creating TCP connections. If // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) // NetDialContext specifies the dial function for creating TCP connections. If - // NetDialContext is nil, net.DialContext is used. + // NetDialContext is nil, NetDial is used. NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) + // NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If + // NetDialTLSContext is nil, NetDialContext is used. + // If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and + // TLSClientConfig is ignored. + NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -65,6 +73,8 @@ type Dialer struct { // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. + // If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake + // is done there and TLSClientConfig is ignored. TLSClientConfig *tls.Config // HandshakeTimeout specifies the duration for the handshake to complete. @@ -176,7 +186,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } req := &http.Request{ - Method: "GET", + Method: http.MethodGet, URL: u, Proto: "HTTP/1.1", ProtoMajor: 1, @@ -237,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h // Get network dial function. var netDial func(network, add string) (net.Conn, error) - if d.NetDialContext != nil { - netDial = func(network, addr string) (net.Conn, error) { - return d.NetDialContext(ctx, network, addr) + switch u.Scheme { + case "http": + if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial } - } else if d.NetDial != nil { - netDial = d.NetDial - } else { + case "https": + if d.NetDialTLSContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialTLSContext(ctx, network, addr) + } + } else if d.NetDialContext != nil { + netDial = func(network, addr string) (net.Conn, error) { + return d.NetDialContext(ctx, network, addr) + } + } else if d.NetDial != nil { + netDial = d.NetDial + } + default: + return nil, nil, errMalformedURL + } + + if netDial == nil { netDialer := &net.Dialer{} netDial = func(network, addr string) (net.Conn, error) { return netDialer.DialContext(ctx, network, addr) @@ -304,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } }() - if u.Scheme == "https" { + if u.Scheme == "https" && d.NetDialTLSContext == nil { + // If NetDialTLSContext is set, assume that the TLS handshake has already been done + cfg := cloneTLSConfig(d.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = hostNoPort @@ -312,11 +343,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h tlsConn := tls.Client(netConn, cfg) netConn = tlsConn - var err error - if trace != nil { - err = doHandshakeWithTrace(trace, tlsConn, cfg) - } else { - err = doHandshake(tlsConn, cfg) + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } + err := doHandshake(ctx, tlsConn, cfg) + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) } if err != nil { @@ -348,8 +380,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h } if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + !tokenListContainsValue(resp.Header, "Upgrade", "websocket") || + !tokenListContainsValue(resp.Header, "Connection", "upgrade") || resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { // Before closing the network connection on return from this // function, slurp up some of the response to aid application @@ -382,14 +414,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h return conn, resp, nil } -func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { - if err := tlsConn.Handshake(); err != nil { - return err - } - if !cfg.InsecureSkipVerify { - if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { - return err - } +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} } - return nil + return cfg.Clone() } |