summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/lrstanley/girc/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/lrstanley/girc/conn.go')
-rw-r--r--vendor/github.com/lrstanley/girc/conn.go108
1 files changed, 81 insertions, 27 deletions
diff --git a/vendor/github.com/lrstanley/girc/conn.go b/vendor/github.com/lrstanley/girc/conn.go
index d9ec6319..441c3e71 100644
--- a/vendor/github.com/lrstanley/girc/conn.go
+++ b/vendor/github.com/lrstanley/girc/conn.go
@@ -58,7 +58,7 @@ type Dialer interface {
}
// newConn sets up and returns a new connection to the server.
-func newConn(conf Config, dialer Dialer, addr string) (*ircConn, error) {
+func newConn(conf Config, dialer Dialer, addr string, sts *strictTransport) (*ircConn, error) {
if err := conf.isValid(); err != nil {
return nil, err
}
@@ -83,13 +83,29 @@ func newConn(conf Config, dialer Dialer, addr string) (*ircConn, error) {
}
if conn, err = dialer.Dial("tcp", addr); err != nil {
+ if sts.enabled() {
+ err = &ErrSTSUpgradeFailed{Err: err}
+ }
+
+ if sts.expired() && !conf.DisableSTSFallback {
+ sts.lastFailed = time.Now()
+ sts.reset()
+ }
return nil, err
}
- if conf.SSL {
+ if conf.SSL || sts.enabled() {
var tlsConn net.Conn
tlsConn, err = tlsHandshake(conn, conf.TLSConfig, conf.Server, true)
if err != nil {
+ if sts.enabled() {
+ err = &ErrSTSUpgradeFailed{Err: err}
+ }
+
+ if sts.expired() && !conf.DisableSTSFallback {
+ sts.lastFailed = time.Now()
+ sts.reset()
+ }
return nil, err
}
@@ -245,6 +261,7 @@ func (c *Client) MockConnect(conn net.Conn) error {
}
func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error {
+startConn:
// We want to be the only one handling connects/disconnects right now.
c.mu.Lock()
@@ -253,13 +270,20 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error {
}
// Reset the state.
- c.state.reset()
+ c.state.reset(false)
+
+ addr := c.server()
if mock == nil {
// Validate info, and actually make the connection.
- c.debug.Printf("connecting to %s...", c.Server())
- conn, err := newConn(c.Config, dialer, c.Server())
+ c.debug.Printf("connecting to %s... (sts: %v, config-ssl: %v)", addr, c.state.sts.enabled(), c.Config.SSL)
+ conn, err := newConn(c.Config, dialer, addr, &c.state.sts)
if err != nil {
+ if _, ok := err.(*ErrSTSUpgradeFailed); ok {
+ if !c.state.sts.enabled() {
+ c.RunHandlers(&Event{Command: STS_ERR_FALLBACK})
+ }
+ }
c.mu.Unlock()
return err
}
@@ -312,16 +336,18 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error {
c.write(&Event{Command: USER, Params: []string{c.Config.User, "*", "*", c.Config.Name}})
// Send a virtual event allowing hooks for successful socket connection.
- c.RunHandlers(&Event{Command: INITIALIZED, Params: []string{c.Server()}})
+ c.RunHandlers(&Event{Command: INITIALIZED, Params: []string{addr}})
// Wait for the first error.
var result error
select {
case <-ctx.Done():
- c.debug.Print("received request to close, beginning clean up")
- c.RunHandlers(&Event{Command: CLOSED, Params: []string{c.Server()}})
+ if !c.state.sts.beginUpgrade {
+ c.debug.Print("received request to close, beginning clean up")
+ }
+ c.RunHandlers(&Event{Command: CLOSED, Params: []string{addr}})
case err := <-errs:
- c.debug.Print("received error, beginning clean up")
+ c.debug.Printf("received error, beginning cleanup: %v", err)
result = err
}
@@ -336,7 +362,7 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error {
c.conn.mu.Unlock()
c.mu.RUnlock()
- c.RunHandlers(&Event{Command: DISCONNECTED, Params: []string{c.Server()}})
+ c.RunHandlers(&Event{Command: DISCONNECTED, Params: []string{addr}})
// Once we have our error/result, let all other functions know we're done.
c.debug.Print("waiting for all routines to finish")
@@ -350,6 +376,18 @@ func (c *Client) internalConnect(mock net.Conn, dialer Dialer) error {
// clients, not multiple instances of Connect().
c.mu.Lock()
c.conn = nil
+
+ if result == nil {
+ if c.state.sts.beginUpgrade {
+ c.state.sts.beginUpgrade = false
+ c.mu.Unlock()
+ goto startConn
+ }
+
+ if c.state.sts.enabled() {
+ c.state.sts.persistenceReceived = time.Now()
+ }
+ }
c.mu.Unlock()
return result
@@ -392,8 +430,23 @@ func (c *Client) readLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
// Send sends an event to the server. Use Client.RunHandlers() if you are
// simply looking to trigger handlers with an event.
func (c *Client) Send(event *Event) {
+ var delay time.Duration
+
if !c.Config.AllowFlood {
- <-time.After(c.conn.rate(event.Len()))
+ c.mu.RLock()
+
+ // Drop the event early as we're disconnected, this way we don't have to wait
+ // the (potentially long) rate limit delay before dropping.
+ if c.conn == nil {
+ c.debugLogEvent(event, true)
+ c.mu.RUnlock()
+ return
+ }
+
+ c.conn.mu.Lock()
+ delay = c.conn.rate(event.Len())
+ c.conn.mu.Unlock()
+ c.mu.RUnlock()
}
if c.Config.GlobalFormat && len(event.Params) > 0 && event.Params[len(event.Params)-1] != "" &&
@@ -401,12 +454,21 @@ func (c *Client) Send(event *Event) {
event.Params[len(event.Params)-1] = Fmt(event.Params[len(event.Params)-1])
}
+ <-time.After(delay)
c.write(event)
}
// write is the lower level function to write an event. It does not have a
// write-delay when sending events.
func (c *Client) write(event *Event) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ if c.conn == nil {
+ // Drop the event if disconnected.
+ c.debugLogEvent(event, true)
+ return
+ }
c.tx <- event
}
@@ -415,14 +477,10 @@ func (c *Client) write(event *Event) {
func (c *ircConn) rate(chars int) time.Duration {
_time := time.Second + ((time.Duration(chars) * time.Second) / 100)
- c.mu.Lock()
if c.writeDelay += _time - time.Now().Sub(c.lastWrite); c.writeDelay < 0 {
c.writeDelay = 0
}
- c.mu.Unlock()
- c.mu.RLock()
- defer c.mu.RUnlock()
if c.writeDelay > (8 * time.Second) {
return _time
}
@@ -445,7 +503,7 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
c.state.RLock()
var in bool
for i := 0; i < len(c.state.enabledCap); i++ {
- if c.state.enabledCap[i] == "message-tags" {
+ if _, ok := c.state.enabledCap["message-tags"]; ok {
in = true
break
}
@@ -457,17 +515,7 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
}
}
- // Log the event.
- if event.Sensitive {
- c.debug.Printf("> %s ***redacted***", event.Command)
- } else {
- c.debug.Print("> ", StripRaw(event.String()))
- }
- if c.Config.Out != nil {
- if pretty, ok := event.Pretty(); ok {
- fmt.Fprintln(c.Config.Out, StripRaw(pretty))
- }
- }
+ c.debugLogEvent(event, false)
c.conn.mu.Lock()
c.conn.lastWrite = time.Now()
@@ -488,6 +536,12 @@ func (c *Client) sendLoop(ctx context.Context, errs chan error, wg *sync.WaitGro
}
}
+ if event.Command == QUIT {
+ c.Close()
+ wg.Done()
+ return
+ }
+
if err != nil {
errs <- err
wg.Done()