summaryrefslogtreecommitdiffstats
path: root/vendor/github.com/mattermost/ldap/conn.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/mattermost/ldap/conn.go')
-rw-r--r--vendor/github.com/mattermost/ldap/conn.go522
1 files changed, 522 insertions, 0 deletions
diff --git a/vendor/github.com/mattermost/ldap/conn.go b/vendor/github.com/mattermost/ldap/conn.go
new file mode 100644
index 00000000..5b7ac478
--- /dev/null
+++ b/vendor/github.com/mattermost/ldap/conn.go
@@ -0,0 +1,522 @@
+package ldap
+
+import (
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "log"
+ "net"
+ "net/url"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ ber "github.com/go-asn1-ber/asn1-ber"
+)
+
+const (
+ // MessageQuit causes the processMessages loop to exit
+ MessageQuit = 0
+ // MessageRequest sends a request to the server
+ MessageRequest = 1
+ // MessageResponse receives a response from the server
+ MessageResponse = 2
+ // MessageFinish indicates the client considers a particular message ID to be finished
+ MessageFinish = 3
+ // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
+ MessageTimeout = 4
+)
+
+const (
+ // DefaultLdapPort default ldap port for pure TCP connection
+ DefaultLdapPort = "389"
+ // DefaultLdapsPort default ldap port for SSL connection
+ DefaultLdapsPort = "636"
+)
+
+// PacketResponse contains the packet or error encountered reading a response
+type PacketResponse struct {
+ // Packet is the packet read from the server
+ Packet *ber.Packet
+ // Error is an error encountered while reading
+ Error error
+}
+
+// ReadPacket returns the packet or an error
+func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
+ if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
+ return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
+ }
+ return pr.Packet, pr.Error
+}
+
+type messageContext struct {
+ id int64
+ // close(done) should only be called from finishMessage()
+ done chan struct{}
+ // close(responses) should only be called from processMessages(), and only sent to from sendResponse()
+ responses chan *PacketResponse
+}
+
+// sendResponse should only be called within the processMessages() loop which
+// is also responsible for closing the responses channel.
+func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
+ select {
+ case msgCtx.responses <- packet:
+ // Successfully sent packet to message handler.
+ case <-msgCtx.done:
+ // The request handler is done and will not receive more
+ // packets.
+ }
+}
+
+type messagePacket struct {
+ Op int
+ MessageID int64
+ Packet *ber.Packet
+ Context *messageContext
+}
+
+type sendMessageFlags uint
+
+const (
+ startTLS sendMessageFlags = 1 << iota
+)
+
+// Conn represents an LDAP Connection
+type Conn struct {
+ // requestTimeout is loaded atomically
+ // so we need to ensure 64-bit alignment on 32-bit platforms.
+ requestTimeout int64
+ conn net.Conn
+ isTLS bool
+ closing uint32
+ closeErr atomic.Value
+ isStartingTLS bool
+ Debug debugging
+ chanConfirm chan struct{}
+ messageContexts map[int64]*messageContext
+ chanMessage chan *messagePacket
+ chanMessageID chan int64
+ wgClose sync.WaitGroup
+ outstandingRequests uint
+ messageMutex sync.Mutex
+}
+
+var _ Client = &Conn{}
+
+// DefaultTimeout is a package-level variable that sets the timeout value
+// used for the Dial and DialTLS methods.
+//
+// WARNING: since this is a package-level variable, setting this value from
+// multiple places will probably result in undesired behaviour.
+var DefaultTimeout = 60 * time.Second
+
+// Dial connects to the given address on the given network using net.Dial
+// and then returns a new Conn for the connection.
+func Dial(network, addr string) (*Conn, error) {
+ c, err := net.DialTimeout(network, addr, DefaultTimeout)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+ conn := NewConn(c, false)
+ conn.Start()
+ return conn, nil
+}
+
+// DialTLS connects to the given address on the given network using tls.Dial
+// and then returns a new Conn for the connection.
+func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
+ c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+ conn := NewConn(c, true)
+ conn.Start()
+ return conn, nil
+}
+
+// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
+// or ldap:// specified as protocol. On success a new Conn for the connection
+// is returned.
+func DialURL(addr string) (*Conn, error) {
+ lurl, err := url.Parse(addr)
+ if err != nil {
+ return nil, NewError(ErrorNetwork, err)
+ }
+
+ host, port, err := net.SplitHostPort(lurl.Host)
+ if err != nil {
+ // we asume that error is due to missing port
+ host = lurl.Host
+ port = ""
+ }
+
+ switch lurl.Scheme {
+ case "ldapi":
+ if lurl.Path == "" || lurl.Path == "/" {
+ lurl.Path = "/var/run/slapd/ldapi"
+ }
+ return Dial("unix", lurl.Path)
+ case "ldap":
+ if port == "" {
+ port = DefaultLdapPort
+ }
+ return Dial("tcp", net.JoinHostPort(host, port))
+ case "ldaps":
+ if port == "" {
+ port = DefaultLdapsPort
+ }
+ tlsConf := &tls.Config{
+ ServerName: host,
+ }
+ return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
+ }
+
+ return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
+}
+
+// NewConn returns a new Conn using conn for network I/O.
+func NewConn(conn net.Conn, isTLS bool) *Conn {
+ return &Conn{
+ conn: conn,
+ chanConfirm: make(chan struct{}),
+ chanMessageID: make(chan int64),
+ chanMessage: make(chan *messagePacket, 10),
+ messageContexts: map[int64]*messageContext{},
+ requestTimeout: 0,
+ isTLS: isTLS,
+ }
+}
+
+// Start initializes goroutines to read responses and process messages
+func (l *Conn) Start() {
+ l.wgClose.Add(1)
+ go l.reader()
+ go l.processMessages()
+}
+
+// IsClosing returns whether or not we're currently closing.
+func (l *Conn) IsClosing() bool {
+ return atomic.LoadUint32(&l.closing) == 1
+}
+
+// setClosing sets the closing value to true
+func (l *Conn) setClosing() bool {
+ return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
+}
+
+// Close closes the connection.
+func (l *Conn) Close() {
+ l.messageMutex.Lock()
+ defer l.messageMutex.Unlock()
+
+ if l.setClosing() {
+ l.Debug.Printf("Sending quit message and waiting for confirmation")
+ l.chanMessage <- &messagePacket{Op: MessageQuit}
+ <-l.chanConfirm
+ close(l.chanMessage)
+
+ l.Debug.Printf("Closing network connection")
+ if err := l.conn.Close(); err != nil {
+ log.Println(err)
+ }
+
+ l.wgClose.Done()
+ }
+ l.wgClose.Wait()
+}
+
+// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
+func (l *Conn) SetTimeout(timeout time.Duration) {
+ if timeout > 0 {
+ atomic.StoreInt64(&l.requestTimeout, int64(timeout))
+ }
+}
+
+// Returns the next available messageID
+func (l *Conn) nextMessageID() int64 {
+ if messageID, ok := <-l.chanMessageID; ok {
+ return messageID
+ }
+ return 0
+}
+
+// StartTLS sends the command to start a TLS session and then creates a new TLS Client
+func (l *Conn) StartTLS(config *tls.Config) error {
+ if l.isTLS {
+ return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
+ }
+
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
+ packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
+ request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
+ request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
+ packet.AppendChild(request)
+ l.Debug.PrintPacket(packet)
+
+ msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
+ if err != nil {
+ return err
+ }
+ defer l.finishMessage(msgCtx)
+
+ l.Debug.Printf("%d: waiting for response", msgCtx.id)
+
+ packetResponse, ok := <-msgCtx.responses
+ if !ok {
+ return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
+ }
+ packet, err = packetResponse.ReadPacket()
+ l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
+ if err != nil {
+ return err
+ }
+
+ if l.Debug {
+ if err := addLDAPDescriptions(packet); err != nil {
+ l.Close()
+ return err
+ }
+ l.Debug.PrintPacket(packet)
+ }
+
+ if err := GetLDAPError(packet); err == nil {
+ conn := tls.Client(l.conn, config)
+
+ if connErr := conn.Handshake(); connErr != nil {
+ l.Close()
+ return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
+ }
+
+ l.isTLS = true
+ l.conn = conn
+ } else {
+ return err
+ }
+ go l.reader()
+
+ return nil
+}
+
+// TLSConnectionState returns the client's TLS connection state.
+// The return values are their zero values if StartTLS did
+// not succeed.
+func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
+ tc, ok := l.conn.(*tls.Conn)
+ if !ok {
+ return
+ }
+ return tc.ConnectionState(), true
+}
+
+func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
+ return l.sendMessageWithFlags(packet, 0)
+}
+
+func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
+ if l.IsClosing() {
+ return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
+ }
+ l.messageMutex.Lock()
+ l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
+ if l.isStartingTLS {
+ l.messageMutex.Unlock()
+ return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
+ }
+ if flags&startTLS != 0 {
+ if l.outstandingRequests != 0 {
+ l.messageMutex.Unlock()
+ return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
+ }
+ l.isStartingTLS = true
+ }
+ l.outstandingRequests++
+
+ l.messageMutex.Unlock()
+
+ responses := make(chan *PacketResponse)
+ messageID := packet.Children[0].Value.(int64)
+ message := &messagePacket{
+ Op: MessageRequest,
+ MessageID: messageID,
+ Packet: packet,
+ Context: &messageContext{
+ id: messageID,
+ done: make(chan struct{}),
+ responses: responses,
+ },
+ }
+ l.sendProcessMessage(message)
+ return message.Context, nil
+}
+
+func (l *Conn) finishMessage(msgCtx *messageContext) {
+ close(msgCtx.done)
+
+ if l.IsClosing() {
+ return
+ }
+
+ l.messageMutex.Lock()
+ l.outstandingRequests--
+ if l.isStartingTLS {
+ l.isStartingTLS = false
+ }
+ l.messageMutex.Unlock()
+
+ message := &messagePacket{
+ Op: MessageFinish,
+ MessageID: msgCtx.id,
+ }
+ l.sendProcessMessage(message)
+}
+
+func (l *Conn) sendProcessMessage(message *messagePacket) bool {
+ l.messageMutex.Lock()
+ defer l.messageMutex.Unlock()
+ if l.IsClosing() {
+ return false
+ }
+ l.chanMessage <- message
+ return true
+}
+
+func (l *Conn) processMessages() {
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("ldap: recovered panic in processMessages: %v", err)
+ }
+ for messageID, msgCtx := range l.messageContexts {
+ // If we are closing due to an error, inform anyone who
+ // is waiting about the error.
+ if l.IsClosing() && l.closeErr.Load() != nil {
+ msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
+ }
+ l.Debug.Printf("Closing channel for MessageID %d", messageID)
+ close(msgCtx.responses)
+ delete(l.messageContexts, messageID)
+ }
+ close(l.chanMessageID)
+ close(l.chanConfirm)
+ }()
+
+ var messageID int64 = 1
+ for {
+ select {
+ case l.chanMessageID <- messageID:
+ messageID++
+ case message := <-l.chanMessage:
+ switch message.Op {
+ case MessageQuit:
+ l.Debug.Printf("Shutting down - quit message received")
+ return
+ case MessageRequest:
+ // Add to message list and write to network
+ l.Debug.Printf("Sending message %d", message.MessageID)
+
+ buf := message.Packet.Bytes()
+ _, err := l.conn.Write(buf)
+ if err != nil {
+ l.Debug.Printf("Error Sending Message: %s", err.Error())
+ message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
+ close(message.Context.responses)
+ break
+ }
+
+ // Only add to messageContexts if we were able to
+ // successfully write the message.
+ l.messageContexts[message.MessageID] = message.Context
+
+ // Add timeout if defined
+ requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
+ if requestTimeout > 0 {
+ go func() {
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
+ }
+ }()
+ time.Sleep(requestTimeout)
+ timeoutMessage := &messagePacket{
+ Op: MessageTimeout,
+ MessageID: message.MessageID,
+ }
+ l.sendProcessMessage(timeoutMessage)
+ }()
+ }
+ case MessageResponse:
+ l.Debug.Printf("Receiving message %d", message.MessageID)
+ if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
+ msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
+ } else {
+ log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
+ l.Debug.PrintPacket(message.Packet)
+ }
+ case MessageTimeout:
+ // Handle the timeout by closing the channel
+ // All reads will return immediately
+ if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
+ l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
+ msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
+ delete(l.messageContexts, message.MessageID)
+ close(msgCtx.responses)
+ }
+ case MessageFinish:
+ l.Debug.Printf("Finished message %d", message.MessageID)
+ if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
+ delete(l.messageContexts, message.MessageID)
+ close(msgCtx.responses)
+ }
+ }
+ }
+ }
+}
+
+func (l *Conn) reader() {
+ cleanstop := false
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("ldap: recovered panic in reader: %v", err)
+ }
+ if !cleanstop {
+ l.Close()
+ }
+ }()
+
+ for {
+ if cleanstop {
+ l.Debug.Printf("reader clean stopping (without closing the connection)")
+ return
+ }
+ packet, err := ber.ReadPacket(l.conn)
+ if err != nil {
+ // A read error is expected here if we are closing the connection...
+ if !l.IsClosing() {
+ l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
+ l.Debug.Printf("reader error: %s", err)
+ }
+ return
+ }
+ if err := addLDAPDescriptions(packet); err != nil {
+ l.Debug.Printf("descriptions error: %s", err)
+ }
+ if len(packet.Children) == 0 {
+ l.Debug.Printf("Received bad ldap packet")
+ continue
+ }
+ l.messageMutex.Lock()
+ if l.isStartingTLS {
+ cleanstop = true
+ }
+ l.messageMutex.Unlock()
+ message := &messagePacket{
+ Op: MessageResponse,
+ MessageID: packet.Children[0].Value.(int64),
+ Packet: packet,
+ }
+ if !l.sendProcessMessage(message) {
+ return
+ }
+ }
+}