package sshd import ( "io" "net" "time" "github.com/shazow/rateio" ) type limitedConn struct { net.Conn io.Reader // Our rate-limited io.Reader for net.Conn } func (r *limitedConn) Read(p []byte) (n int, err error) { return r.Reader.Read(p) } // ReadLimitConn returns a net.Conn whose io.Reader interface is rate-limited by limiter. func ReadLimitConn(conn net.Conn, limiter rateio.Limiter) net.Conn { return &limitedConn{ Conn: conn, Reader: rateio.NewReader(conn, limiter), } } // Count each read as 1 unless it exceeds some number of bytes. type inputLimiter struct { // TODO: Could do all kinds of fancy things here, like be more forgiving of // connections that have been around for a while. Amount int Frequency time.Duration remaining int readCap int numRead int timeRead time.Time } // NewInputLimiter returns a rateio.Limiter with sensible defaults for // differentiating between humans typing and bots spamming. func NewInputLimiter() rateio.Limiter { grace := time.Second * 3 return &inputLimiter{ Amount: 2 << 14, // ~16kb, should be plenty for a high typing rate/copypasta/large key handshakes. Frequency: time.Minute * 1, readCap: 128, // Allow up to 128 bytes per read (anecdotally, 1 character = 52 bytes over ssh) numRead: -1024 * 1024, // Start with a 1mb grace timeRead: time.Now().Add(grace), } } // Count applies 1 if n<readCap, else n func (limit *inputLimiter) Count(n int) error { now := time.Now() if now.After(limit.timeRead) { limit.numRead = 0 limit.timeRead = now.Add(limit.Frequency) } if n <= limit.readCap { limit.numRead += 1 } else { limit.numRead += n } if limit.numRead > limit.Amount { return rateio.ErrRateExceeded } return nil }