diff options
Diffstat (limited to 'vendor/github.com/mattermost/ldap/conn.go')
-rw-r--r-- | vendor/github.com/mattermost/ldap/conn.go | 522 |
1 files changed, 522 insertions, 0 deletions
diff --git a/vendor/github.com/mattermost/ldap/conn.go b/vendor/github.com/mattermost/ldap/conn.go new file mode 100644 index 00000000..5b7ac478 --- /dev/null +++ b/vendor/github.com/mattermost/ldap/conn.go @@ -0,0 +1,522 @@ +package ldap + +import ( + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "net/url" + "sync" + "sync/atomic" + "time" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +const ( + // MessageQuit causes the processMessages loop to exit + MessageQuit = 0 + // MessageRequest sends a request to the server + MessageRequest = 1 + // MessageResponse receives a response from the server + MessageResponse = 2 + // MessageFinish indicates the client considers a particular message ID to be finished + MessageFinish = 3 + // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached + MessageTimeout = 4 +) + +const ( + // DefaultLdapPort default ldap port for pure TCP connection + DefaultLdapPort = "389" + // DefaultLdapsPort default ldap port for SSL connection + DefaultLdapsPort = "636" +) + +// PacketResponse contains the packet or error encountered reading a response +type PacketResponse struct { + // Packet is the packet read from the server + Packet *ber.Packet + // Error is an error encountered while reading + Error error +} + +// ReadPacket returns the packet or an error +func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) { + if (pr == nil) || (pr.Packet == nil && pr.Error == nil) { + return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response")) + } + return pr.Packet, pr.Error +} + +type messageContext struct { + id int64 + // close(done) should only be called from finishMessage() + done chan struct{} + // close(responses) should only be called from processMessages(), and only sent to from sendResponse() + responses chan *PacketResponse +} + +// sendResponse should only be called within the processMessages() loop which +// is also responsible for closing the responses channel. +func (msgCtx *messageContext) sendResponse(packet *PacketResponse) { + select { + case msgCtx.responses <- packet: + // Successfully sent packet to message handler. + case <-msgCtx.done: + // The request handler is done and will not receive more + // packets. + } +} + +type messagePacket struct { + Op int + MessageID int64 + Packet *ber.Packet + Context *messageContext +} + +type sendMessageFlags uint + +const ( + startTLS sendMessageFlags = 1 << iota +) + +// Conn represents an LDAP Connection +type Conn struct { + // requestTimeout is loaded atomically + // so we need to ensure 64-bit alignment on 32-bit platforms. + requestTimeout int64 + conn net.Conn + isTLS bool + closing uint32 + closeErr atomic.Value + isStartingTLS bool + Debug debugging + chanConfirm chan struct{} + messageContexts map[int64]*messageContext + chanMessage chan *messagePacket + chanMessageID chan int64 + wgClose sync.WaitGroup + outstandingRequests uint + messageMutex sync.Mutex +} + +var _ Client = &Conn{} + +// DefaultTimeout is a package-level variable that sets the timeout value +// used for the Dial and DialTLS methods. +// +// WARNING: since this is a package-level variable, setting this value from +// multiple places will probably result in undesired behaviour. +var DefaultTimeout = 60 * time.Second + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, error) { + c, err := net.DialTimeout(network, addr, DefaultTimeout) + if err != nil { + return nil, NewError(ErrorNetwork, err) + } + conn := NewConn(c, false) + conn.Start() + return conn, nil +} + +// DialTLS connects to the given address on the given network using tls.Dial +// and then returns a new Conn for the connection. +func DialTLS(network, addr string, config *tls.Config) (*Conn, error) { + c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config) + if err != nil { + return nil, NewError(ErrorNetwork, err) + } + conn := NewConn(c, true) + conn.Start() + return conn, nil +} + +// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps:// +// or ldap:// specified as protocol. On success a new Conn for the connection +// is returned. +func DialURL(addr string) (*Conn, error) { + lurl, err := url.Parse(addr) + if err != nil { + return nil, NewError(ErrorNetwork, err) + } + + host, port, err := net.SplitHostPort(lurl.Host) + if err != nil { + // we asume that error is due to missing port + host = lurl.Host + port = "" + } + + switch lurl.Scheme { + case "ldapi": + if lurl.Path == "" || lurl.Path == "/" { + lurl.Path = "/var/run/slapd/ldapi" + } + return Dial("unix", lurl.Path) + case "ldap": + if port == "" { + port = DefaultLdapPort + } + return Dial("tcp", net.JoinHostPort(host, port)) + case "ldaps": + if port == "" { + port = DefaultLdapsPort + } + tlsConf := &tls.Config{ + ServerName: host, + } + return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf) + } + + return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme)) +} + +// NewConn returns a new Conn using conn for network I/O. +func NewConn(conn net.Conn, isTLS bool) *Conn { + return &Conn{ + conn: conn, + chanConfirm: make(chan struct{}), + chanMessageID: make(chan int64), + chanMessage: make(chan *messagePacket, 10), + messageContexts: map[int64]*messageContext{}, + requestTimeout: 0, + isTLS: isTLS, + } +} + +// Start initializes goroutines to read responses and process messages +func (l *Conn) Start() { + l.wgClose.Add(1) + go l.reader() + go l.processMessages() +} + +// IsClosing returns whether or not we're currently closing. +func (l *Conn) IsClosing() bool { + return atomic.LoadUint32(&l.closing) == 1 +} + +// setClosing sets the closing value to true +func (l *Conn) setClosing() bool { + return atomic.CompareAndSwapUint32(&l.closing, 0, 1) +} + +// Close closes the connection. +func (l *Conn) Close() { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + + if l.setClosing() { + l.Debug.Printf("Sending quit message and waiting for confirmation") + l.chanMessage <- &messagePacket{Op: MessageQuit} + <-l.chanConfirm + close(l.chanMessage) + + l.Debug.Printf("Closing network connection") + if err := l.conn.Close(); err != nil { + log.Println(err) + } + + l.wgClose.Done() + } + l.wgClose.Wait() +} + +// SetTimeout sets the time after a request is sent that a MessageTimeout triggers +func (l *Conn) SetTimeout(timeout time.Duration) { + if timeout > 0 { + atomic.StoreInt64(&l.requestTimeout, int64(timeout)) + } +} + +// Returns the next available messageID +func (l *Conn) nextMessageID() int64 { + if messageID, ok := <-l.chanMessageID; ok { + return messageID + } + return 0 +} + +// StartTLS sends the command to start a TLS session and then creates a new TLS Client +func (l *Conn) StartTLS(config *tls.Config) error { + if l.isTLS { + return NewError(ErrorNetwork, errors.New("ldap: already encrypted")) + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") + request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) + packet.AppendChild(request) + l.Debug.PrintPacket(packet) + + msgCtx, err := l.sendMessageWithFlags(packet, startTLS) + if err != nil { + return err + } + defer l.finishMessage(msgCtx) + + l.Debug.Printf("%d: waiting for response", msgCtx.id) + + packetResponse, ok := <-msgCtx.responses + if !ok { + return NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + return err + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + l.Close() + return err + } + l.Debug.PrintPacket(packet) + } + + if err := GetLDAPError(packet); err == nil { + conn := tls.Client(l.conn, config) + + if connErr := conn.Handshake(); connErr != nil { + l.Close() + return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr)) + } + + l.isTLS = true + l.conn = conn + } else { + return err + } + go l.reader() + + return nil +} + +// TLSConnectionState returns the client's TLS connection state. +// The return values are their zero values if StartTLS did +// not succeed. +func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) { + tc, ok := l.conn.(*tls.Conn) + if !ok { + return + } + return tc.ConnectionState(), true +} + +func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) { + return l.sendMessageWithFlags(packet, 0) +} + +func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) { + if l.IsClosing() { + return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed")) + } + l.messageMutex.Lock() + l.Debug.Printf("flags&startTLS = %d", flags&startTLS) + if l.isStartingTLS { + l.messageMutex.Unlock() + return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase")) + } + if flags&startTLS != 0 { + if l.outstandingRequests != 0 { + l.messageMutex.Unlock() + return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests")) + } + l.isStartingTLS = true + } + l.outstandingRequests++ + + l.messageMutex.Unlock() + + responses := make(chan *PacketResponse) + messageID := packet.Children[0].Value.(int64) + message := &messagePacket{ + Op: MessageRequest, + MessageID: messageID, + Packet: packet, + Context: &messageContext{ + id: messageID, + done: make(chan struct{}), + responses: responses, + }, + } + l.sendProcessMessage(message) + return message.Context, nil +} + +func (l *Conn) finishMessage(msgCtx *messageContext) { + close(msgCtx.done) + + if l.IsClosing() { + return + } + + l.messageMutex.Lock() + l.outstandingRequests-- + if l.isStartingTLS { + l.isStartingTLS = false + } + l.messageMutex.Unlock() + + message := &messagePacket{ + Op: MessageFinish, + MessageID: msgCtx.id, + } + l.sendProcessMessage(message) +} + +func (l *Conn) sendProcessMessage(message *messagePacket) bool { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() + if l.IsClosing() { + return false + } + l.chanMessage <- message + return true +} + +func (l *Conn) processMessages() { + defer func() { + if err := recover(); err != nil { + log.Printf("ldap: recovered panic in processMessages: %v", err) + } + for messageID, msgCtx := range l.messageContexts { + // If we are closing due to an error, inform anyone who + // is waiting about the error. + if l.IsClosing() && l.closeErr.Load() != nil { + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}) + } + l.Debug.Printf("Closing channel for MessageID %d", messageID) + close(msgCtx.responses) + delete(l.messageContexts, messageID) + } + close(l.chanMessageID) + close(l.chanConfirm) + }() + + var messageID int64 = 1 + for { + select { + case l.chanMessageID <- messageID: + messageID++ + case message := <-l.chanMessage: + switch message.Op { + case MessageQuit: + l.Debug.Printf("Shutting down - quit message received") + return + case MessageRequest: + // Add to message list and write to network + l.Debug.Printf("Sending message %d", message.MessageID) + + buf := message.Packet.Bytes() + _, err := l.conn.Write(buf) + if err != nil { + l.Debug.Printf("Error Sending Message: %s", err.Error()) + message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}) + close(message.Context.responses) + break + } + + // Only add to messageContexts if we were able to + // successfully write the message. + l.messageContexts[message.MessageID] = message.Context + + // Add timeout if defined + requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout)) + if requestTimeout > 0 { + go func() { + defer func() { + if err := recover(); err != nil { + log.Printf("ldap: recovered panic in RequestTimeout: %v", err) + } + }() + time.Sleep(requestTimeout) + timeoutMessage := &messagePacket{ + Op: MessageTimeout, + MessageID: message.MessageID, + } + l.sendProcessMessage(timeoutMessage) + }() + } + case MessageResponse: + l.Debug.Printf("Receiving message %d", message.MessageID) + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + msgCtx.sendResponse(&PacketResponse{message.Packet, nil}) + } else { + log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing()) + l.Debug.PrintPacket(message.Packet) + } + case MessageTimeout: + // Handle the timeout by closing the channel + // All reads will return immediately + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + l.Debug.Printf("Receiving message timeout for %d", message.MessageID) + msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")}) + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) + } + case MessageFinish: + l.Debug.Printf("Finished message %d", message.MessageID) + if msgCtx, ok := l.messageContexts[message.MessageID]; ok { + delete(l.messageContexts, message.MessageID) + close(msgCtx.responses) + } + } + } + } +} + +func (l *Conn) reader() { + cleanstop := false + defer func() { + if err := recover(); err != nil { + log.Printf("ldap: recovered panic in reader: %v", err) + } + if !cleanstop { + l.Close() + } + }() + + for { + if cleanstop { + l.Debug.Printf("reader clean stopping (without closing the connection)") + return + } + packet, err := ber.ReadPacket(l.conn) + if err != nil { + // A read error is expected here if we are closing the connection... + if !l.IsClosing() { + l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err)) + l.Debug.Printf("reader error: %s", err) + } + return + } + if err := addLDAPDescriptions(packet); err != nil { + l.Debug.Printf("descriptions error: %s", err) + } + if len(packet.Children) == 0 { + l.Debug.Printf("Received bad ldap packet") + continue + } + l.messageMutex.Lock() + if l.isStartingTLS { + cleanstop = true + } + l.messageMutex.Unlock() + message := &messagePacket{ + Op: MessageResponse, + MessageID: packet.Children[0].Value.(int64), + Packet: packet, + } + if !l.sendProcessMessage(message) { + return + } + } +} |