summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/gorilla/websocket/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/gorilla/websocket/conn.go')
-rw-r--r--vendor/github.com/gorilla/websocket/conn.go63
1 files changed, 46 insertions, 17 deletions
diff --git a/vendor/github.com/gorilla/websocket/conn.go b/vendor/github.com/gorilla/websocket/conn.go
index ca46d2f7..331eebc8 100644
--- a/vendor/github.com/gorilla/websocket/conn.go
+++ b/vendor/github.com/gorilla/websocket/conn.go
@@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
+ "strings"
"sync"
"time"
"unicode/utf8"
@@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil
}
+func (c *Conn) writeBufs(bufs ...[]byte) error {
+ b := net.Buffers(bufs)
+ _, err := b.WriteTo(c.conn)
+ return err
+}
+
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
}
// 2. Read and parse first two bytes of frame header.
+ // To aid debugging, collect and report all errors in the first two bytes
+ // of the header.
+
+ var errors []string
p, err := c.read(2)
if err != nil {
return noFrame, err
}
- final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
+ final := p[0]&finalBit != 0
+ rsv1 := p[0]&rsv1Bit != 0
+ rsv2 := p[0]&rsv2Bit != 0
+ rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
- if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
- c.readDecompress = true
- p[0] &^= rsv1Bit
+ if rsv1 {
+ if c.newDecompressionReader != nil {
+ c.readDecompress = true
+ } else {
+ errors = append(errors, "RSV1 set")
+ }
+ }
+
+ if rsv2 {
+ errors = append(errors, "RSV2 set")
}
- if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
- return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
+ if rsv3 {
+ errors = append(errors, "RSV3 set")
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
- return noFrame, c.handleProtocolError("control frame length > 125")
+ errors = append(errors, "len > 125 for control")
}
if !final {
- return noFrame, c.handleProtocolError("control frame not final")
+ errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
- return noFrame, c.handleProtocolError("message start before final message frame")
+ errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
- return noFrame, c.handleProtocolError("continuation after final message frame")
+ errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
- return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
+ errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
+ }
+
+ if mask != c.isServer {
+ errors = append(errors, "bad MASK")
+ }
+
+ if len(errors) > 0 {
+ return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}
// 3. Read and parse frame length as per
@@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking.
- if mask != c.isServer {
- return noFrame, c.handleProtocolError("incorrect mask flag")
- }
-
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
@@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
- return noFrame, c.handleProtocolError("invalid close code")
+ return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
@@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
}
func (c *Conn) handleProtocolError(message string) error {
- c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
+ data := FormatCloseMessage(CloseProtocolError, message)
+ if len(data) > maxControlFramePayloadSize {
+ data = data[:maxControlFramePayloadSize]
+ }
+ c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}