diff options
Diffstat (limited to 'vendor/github.com/lrstanley/girc/conn.go')
-rw-r--r-- | vendor/github.com/lrstanley/girc/conn.go | 108 |
1 files changed, 81 insertions, 27 deletions
diff --git a/vendor/github.com/lrstanley/girc/conn.go b/vendor/github.com/lrstanley/girc/conn.go index d9ec6319..441c3e71 100644 --- a/vendor/github.com/lrstanley/girc/conn.go +++ b/vendor/github.com/lrstanley/girc/conn.go @@ -58,7 +58,7 @@ type Dialer interface { } // newConn sets up and returns a new connection to the server. -func newConn(conf Config, dialer Dialer, addr string) (*ircConn, error) { +func newConn(conf Config, dialer Dialer, addr string, sts *strictTransport) (*ircConn, error) { if err := conf.isValid(); err != nil { return nil, err } @@ -83,13 +83,29 @@ func newConn(conf Config, dialer Dialer, addr string) (*ircConn, error) { } if conn, err = dialer.Dial("tcp", addr); err != nil { + if sts.enabled() { + err = &ErrSTSUpgradeFailed{Err: err} + } + + if sts.expired() && !conf.DisableSTSFallback { + sts.lastFailed = time.Now() + sts.reset() + } return nil, err } - if conf.SSL { + if conf.SSL || sts.enabled() { var tlsConn net.Conn tlsConn, err = tlsHandshake(conn, conf.TLSConfig, conf.Server, true) if err != nil { + if sts.enabled() { + err = &ErrSTSUpgradeFailed{Err: err} + } + + if sts.expired() && !conf.DisableSTSFallback { + sts.lastFailed = time.Now() + sts.reset() + } return nil, err } @@ -245,6 +261,7 @@ func (c *Client) MockConnect(conn net.Conn) error { } func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error { +startConn: // We want to be the only one handling connects/disconnects right now. c.mu.Lock() @@ -253,13 +270,20 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error { } // Reset the state. - c.state.reset() + c.state.reset(false) + + addr := c.server() if mock == nil { // Validate info, and actually make the connection. - c.debug.Printf("connecting to %s...", c.Server()) - conn, err := newConn(c.Config, dialer, c.Server()) + c.debug.Printf("connecting to %s... (sts: %v, config-ssl: %v)", addr, c.state.sts.enabled(), c.Config.SSL) + conn, err := newConn(c.Config, dialer, addr, &c.state.sts) if err != nil { + if _, ok := err.(*ErrSTSUpgradeFailed); ok { + if !c.state.sts.enabled() { + c.RunHandlers(&Event{Command: STS_ERR_FALLBACK}) + } + } c.mu.Unlock() return err } @@ -312,16 +336,18 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error { c.write(&Event{Command: USER, Params: []string{c.Config.User, "*", "*", c.Config.Name}}) // Send a virtual event allowing hooks for successful socket connection. - c.RunHandlers(&Event{Command: INITIALIZED, Params: []string{c.Server()}}) + c.RunHandlers(&Event{Command: INITIALIZED, Params: []string{addr}}) // Wait for the first error. var result error select { case <-ctx.Done(): - c.debug.Print("received request to close, beginning clean up") - c.RunHandlers(&Event{Command: CLOSED, Params: []string{c.Server()}}) + if !c.state.sts.beginUpgrade { + c.debug.Print("received request to close, beginning clean up") + } + c.RunHandlers(&Event{Command: CLOSED, Params: []string{addr}}) case err := <-errs: - c.debug.Print("received error, beginning clean up") + c.debug.Printf("received error, beginning cleanup: %v", err) result = err } @@ -336,7 +362,7 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error { c.conn.mu.Unlock() c.mu.RUnlock() - c.RunHandlers(&Event{Command: DISCONNECTED, Params: []string{c.Server()}}) + c.RunHandlers(&Event{Command: DISCONNECTED, Params: []string{addr}}) // Once we have our error/result, let all other functions know we're done. c.debug.Print("waiting for all routines to finish") @@ -350,6 +376,18 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error { // clients, not multiple instances of Connect(). c.mu.Lock() c.conn = nil + + if result == nil { + if c.state.sts.beginUpgrade { + c.state.sts.beginUpgrade = false + c.mu.Unlock() + goto startConn + } + + if c.state.sts.enabled() { + c.state.sts.persistenceReceived = time.Now() + } + } c.mu.Unlock() return result @@ -392,8 +430,23 @@ func (c *Client) readLoop(ctx context.Context, errs chan error, wg *sync.WaitGro // Send sends an event to the server. Use Client.RunHandlers() if you are // simply looking to trigger handlers with an event. func (c *Client) Send(event *Event) { + var delay time.Duration + if !c.Config.AllowFlood { - <-time.After(c.conn.rate(event.Len())) + c.mu.RLock() + + // Drop the event early as we're disconnected, this way we don't have to wait + // the (potentially long) rate limit delay before dropping. + if c.conn == nil { + c.debugLogEvent(event, true) + c.mu.RUnlock() + return + } + + c.conn.mu.Lock() + delay = c.conn.rate(event.Len()) + c.conn.mu.Unlock() + c.mu.RUnlock() } if c.Config.GlobalFormat && len(event.Params) > 0 && event.Params[len(event.Params)-1] != "" && @@ -401,12 +454,21 @@ func (c *Client) Send(event *Event) { event.Params[len(event.Params)-1] = Fmt(event.Params[len(event.Params)-1]) } + <-time.After(delay) c.write(event) } // write is the lower level function to write an event. It does not have a // write-delay when sending events. func (c *Client) write(event *Event) { + c.mu.RLock() + defer c.mu.RUnlock() + + if c.conn == nil { + // Drop the event if disconnected. + c.debugLogEvent(event, true) + return + } c.tx <- event } @@ -415,14 +477,10 @@ func (c *Client) write(event *Event) { func (c *ircConn) rate(chars int) time.Duration { _time := time.Second + ((time.Duration(chars) * time.Second) / 100) - c.mu.Lock() if c.writeDelay += _time - time.Now().Sub(c.lastWrite); c.writeDelay < 0 { c.writeDelay = 0 } - c.mu.Unlock() - c.mu.RLock() - defer c.mu.RUnlock() if c.writeDelay > (8 * time.Second) { return _time } @@ -445,7 +503,7 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro c.state.RLock() var in bool for i := 0; i < len(c.state.enabledCap); i++ { - if c.state.enabledCap[i] == "message-tags" { + if _, ok := c.state.enabledCap["message-tags"]; ok { in = true break } @@ -457,17 +515,7 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro } } - // Log the event. - if event.Sensitive { - c.debug.Printf("> %s ***redacted***", event.Command) - } else { - c.debug.Print("> ", StripRaw(event.String())) - } - if c.Config.Out != nil { - if pretty, ok := event.Pretty(); ok { - fmt.Fprintln(c.Config.Out, StripRaw(pretty)) - } - } + c.debugLogEvent(event, false) c.conn.mu.Lock() c.conn.lastWrite = time.Now() @@ -488,6 +536,12 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro } } + if event.Command == QUIT { + c.Close() + wg.Done() + return + } + if err != nil { errs <- err wg.Done() |