package ldap

import (
	"crypto/tls"
	"errors"
	"fmt"
	"log"
	"net"
	"net/url"
	"sync"
	"sync/atomic"
	"time"

	ber "github.com/go-asn1-ber/asn1-ber"
)

const (
	// MessageQuit causes the processMessages loop to exit
	MessageQuit = 0
	// MessageRequest sends a request to the server
	MessageRequest = 1
	// MessageResponse receives a response from the server
	MessageResponse = 2
	// MessageFinish indicates the client considers a particular message ID to be finished
	MessageFinish = 3
	// MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
	MessageTimeout = 4
)

const (
	// DefaultLdapPort default ldap port for pure TCP connection
	DefaultLdapPort = "389"
	// DefaultLdapsPort default ldap port for SSL connection
	DefaultLdapsPort = "636"
)

// PacketResponse contains the packet or error encountered reading a response
type PacketResponse struct {
	// Packet is the packet read from the server
	Packet *ber.Packet
	// Error is an error encountered while reading
	Error error
}

// ReadPacket returns the packet or an error
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
	if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
		return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
	}
	return pr.Packet, pr.Error
}

type messageContext struct {
	id int64
	// close(done) should only be called from finishMessage()
	done chan struct{}
	// close(responses) should only be called from processMessages(), and only sent to from sendResponse()
	responses chan *PacketResponse
}

// sendResponse should only be called within the processMessages() loop which
// is also responsible for closing the responses channel.
func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
	select {
	case msgCtx.responses <- packet:
		// Successfully sent packet to message handler.
	case <-msgCtx.done:
		// The request handler is done and will not receive more
		// packets.
	}
}

type messagePacket struct {
	Op        int
	MessageID int64
	Packet    *ber.Packet
	Context   *messageContext
}

type sendMessageFlags uint

const (
	startTLS sendMessageFlags = 1 << iota
)

// Conn represents an LDAP Connection
type Conn struct {
	// requestTimeout is loaded atomically
	// so we need to ensure 64-bit alignment on 32-bit platforms.
	requestTimeout      int64
	conn                net.Conn
	isTLS               bool
	closing             uint32
	closeErr            atomic.Value
	isStartingTLS       bool
	Debug               debugging
	chanConfirm         chan struct{}
	messageContexts     map[int64]*messageContext
	chanMessage         chan *messagePacket
	chanMessageID       chan int64
	wgClose             sync.WaitGroup
	outstandingRequests uint
	messageMutex        sync.Mutex
}

var _ Client = &Conn{}

// DefaultTimeout is a package-level variable that sets the timeout value
// used for the Dial and DialTLS methods.
//
// WARNING: since this is a package-level variable, setting this value from
// multiple places will probably result in undesired behaviour.
var DefaultTimeout = 60 * time.Second

// Dial connects to the given address on the given network using net.Dial
// and then returns a new Conn for the connection.
func Dial(network, addr string) (*Conn, error) {
	c, err := net.DialTimeout(network, addr, DefaultTimeout)
	if err != nil {
		return nil, NewError(ErrorNetwork, err)
	}
	conn := NewConn(c, false)
	return conn, nil
}

// DialTLS connects to the given address on the given network using tls.Dial
// and then returns a new Conn for the connection.
func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
	c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
	if err != nil {
		return nil, NewError(ErrorNetwork, err)
	}
	conn := NewConn(c, true)
	return conn, nil
}

// DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
// or ldap:// specified as protocol. On success a new Conn for the connection
// is returned.
func DialURL(addr string) (*Conn, error) {
	lurl, err := url.Parse(addr)
	if err != nil {
		return nil, NewError(ErrorNetwork, err)
	}

	host, port, err := net.SplitHostPort(lurl.Host)
	if err != nil {
		// we asume that error is due to missing port
		host = lurl.Host
		port = ""
	}

	switch lurl.Scheme {
	case "ldapi":
		if lurl.Path == "" || lurl.Path == "/" {
			lurl.Path = "/var/run/slapd/ldapi"
		}
		return Dial("unix", lurl.Path)
	case "ldap":
		if port == "" {
			port = DefaultLdapPort
		}
		return Dial("tcp", net.JoinHostPort(host, port))
	case "ldaps":
		if port == "" {
			port = DefaultLdapsPort
		}
		tlsConf := &tls.Config{
			ServerName: host,
		}
		return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
	}

	return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
}

// NewConn returns a new Conn using conn for network I/O.
func NewConn(conn net.Conn, isTLS bool) *Conn {
	return &Conn{
		conn:            conn,
		chanConfirm:     make(chan struct{}),
		chanMessageID:   make(chan int64),
		chanMessage:     make(chan *messagePacket, 10),
		messageContexts: map[int64]*messageContext{},
		requestTimeout:  0,
		isTLS:           isTLS,
	}
}

// Start initializes goroutines to read responses and process messages
func (l *Conn) Start() {
	l.wgClose.Add(1)
	go l.reader()
	go l.processMessages()
}

// IsClosing returns whether or not we're currently closing.
func (l *Conn) IsClosing() bool {
	return atomic.LoadUint32(&l.closing) == 1
}

// setClosing sets the closing value to true
func (l *Conn) setClosing() bool {
	return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
}

// Close closes the connection.
func (l *Conn) Close() {
	l.messageMutex.Lock()
	defer l.messageMutex.Unlock()

	if l.setClosing() {
		l.Debug.Printf("Sending quit message and waiting for confirmation")
		l.chanMessage <- &messagePacket{Op: MessageQuit}
		<-l.chanConfirm
		close(l.chanMessage)

		l.Debug.Printf("Closing network connection")
		if err := l.conn.Close(); err != nil {
			log.Println(err)
		}

		l.wgClose.Done()
	}
	l.wgClose.Wait()
}

// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func (l *Conn) SetTimeout(timeout time.Duration) {
	if timeout > 0 {
		atomic.StoreInt64(&l.requestTimeout, int64(timeout))
	}
}

// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
	if messageID, ok := <-l.chanMessageID; ok {
		return messageID
	}
	return 0
}

// StartTLS sends the command to start a TLS session and then creates a new TLS Client
func (l *Conn) StartTLS(config *tls.Config) error {
	if l.isTLS {
		return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
	}

	packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
	packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
	request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
	request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
	packet.AppendChild(request)
	l.Debug.PrintPacket(packet)

	msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
	if err != nil {
		return err
	}
	defer l.finishMessage(msgCtx)

	l.Debug.Printf("%d: waiting for response", msgCtx.id)

	packetResponse, ok := <-msgCtx.responses
	if !ok {
		return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
	}
	packet, err = packetResponse.ReadPacket()
	l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
	if err != nil {
		return err
	}

	if l.Debug {
		if err := addLDAPDescriptions(packet); err != nil {
			l.Close()
			return err
		}
		l.Debug.PrintPacket(packet)
	}

	if err := GetLDAPError(packet); err == nil {
		conn := tls.Client(l.conn, config)

		if connErr := conn.Handshake(); connErr != nil {
			l.Close()
			return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
		}

		l.isTLS = true
		l.conn = conn
	} else {
		return err
	}
	go l.reader()

	return nil
}

// TLSConnectionState returns the client's TLS connection state.
// The return values are their zero values if StartTLS did
// not succeed.
func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
	tc, ok := l.conn.(*tls.Conn)
	if !ok {
		return
	}
	return tc.ConnectionState(), true
}

func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
	return l.sendMessageWithFlags(packet, 0)
}

func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
	if l.IsClosing() {
		return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
	}
	l.messageMutex.Lock()
	l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
	if l.isStartingTLS {
		l.messageMutex.Unlock()
		return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
	}
	if flags&startTLS != 0 {
		if l.outstandingRequests != 0 {
			l.messageMutex.Unlock()
			return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
		}
		l.isStartingTLS = true
	}
	l.outstandingRequests++

	l.messageMutex.Unlock()

	responses := make(chan *PacketResponse)
	messageID := packet.Children[0].Value.(int64)
	message := &messagePacket{
		Op:        MessageRequest,
		MessageID: messageID,
		Packet:    packet,
		Context: &messageContext{
			id:        messageID,
			done:      make(chan struct{}),
			responses: responses,
		},
	}
	l.sendProcessMessage(message)
	return message.Context, nil
}

func (l *Conn) finishMessage(msgCtx *messageContext) {
	close(msgCtx.done)

	if l.IsClosing() {
		return
	}

	l.messageMutex.Lock()
	l.outstandingRequests--
	if l.isStartingTLS {
		l.isStartingTLS = false
	}
	l.messageMutex.Unlock()

	message := &messagePacket{
		Op:        MessageFinish,
		MessageID: msgCtx.id,
	}
	l.sendProcessMessage(message)
}

func (l *Conn) sendProcessMessage(message *messagePacket) bool {
	l.messageMutex.Lock()
	defer l.messageMutex.Unlock()
	if l.IsClosing() {
		return false
	}
	l.chanMessage <- message
	return true
}

func (l *Conn) processMessages() {
	defer func() {
		if err := recover(); err != nil {
			log.Printf("ldap: recovered panic in processMessages: %v", err)
		}
		for messageID, msgCtx := range l.messageContexts {
			// If we are closing due to an error, inform anyone who
			// is waiting about the error.
			if l.IsClosing() && l.closeErr.Load() != nil {
				msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
			}
			l.Debug.Printf("Closing channel for MessageID %d", messageID)
			close(msgCtx.responses)
			delete(l.messageContexts, messageID)
		}
		close(l.chanMessageID)
		close(l.chanConfirm)
	}()

	var messageID int64 = 1
	for {
		select {
		case l.chanMessageID <- messageID:
			messageID++
		case message := <-l.chanMessage:
			switch message.Op {
			case MessageQuit:
				l.Debug.Printf("Shutting down - quit message received")
				return
			case MessageRequest:
				// Add to message list and write to network
				l.Debug.Printf("Sending message %d", message.MessageID)

				buf := message.Packet.Bytes()
				_, err := l.conn.Write(buf)
				if err != nil {
					l.Debug.Printf("Error Sending Message: %s", err.Error())
					message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
					close(message.Context.responses)
					break
				}

				// Only add to messageContexts if we were able to
				// successfully write the message.
				l.messageContexts[message.MessageID] = message.Context

				// Add timeout if defined
				requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
				if requestTimeout > 0 {
					go func() {
						defer func() {
							if err := recover(); err != nil {
								log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
							}
						}()
						time.Sleep(requestTimeout)
						timeoutMessage := &messagePacket{
							Op:        MessageTimeout,
							MessageID: message.MessageID,
						}
						l.sendProcessMessage(timeoutMessage)
					}()
				}
			case MessageResponse:
				l.Debug.Printf("Receiving message %d", message.MessageID)
				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
					msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
				} else {
					log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
					l.Debug.PrintPacket(message.Packet)
				}
			case MessageTimeout:
				// Handle the timeout by closing the channel
				// All reads will return immediately
				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
					l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
					msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
					delete(l.messageContexts, message.MessageID)
					close(msgCtx.responses)
				}
			case MessageFinish:
				l.Debug.Printf("Finished message %d", message.MessageID)
				if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
					delete(l.messageContexts, message.MessageID)
					close(msgCtx.responses)
				}
			}
		}
	}
}

func (l *Conn) reader() {
	cleanstop := false
	defer func() {
		if err := recover(); err != nil {
			log.Printf("ldap: recovered panic in reader: %v", err)
		}
		if !cleanstop {
			l.Close()
		}
	}()

	for {
		if cleanstop {
			l.Debug.Printf("reader clean stopping (without closing the connection)")
			return
		}
		packet, err := ber.ReadPacket(l.conn)
		if err != nil {
			// A read error is expected here if we are closing the connection...
			if !l.IsClosing() {
				l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
				l.Debug.Printf("reader error: %s", err)
			}
			return
		}
		if err := addLDAPDescriptions(packet); err != nil {
			l.Debug.Printf("descriptions error: %s", err)
		}
		if len(packet.Children) == 0 {
			l.Debug.Printf("Received bad ldap packet")
			continue
		}
		l.messageMutex.Lock()
		if l.isStartingTLS {
			cleanstop = true
		}
		l.messageMutex.Unlock()
		message := &messagePacket{
			Op:        MessageResponse,
			MessageID: packet.Children[0].Value.(int64),
			Packet:    packet,
		}
		if !l.sendProcessMessage(message) {
			return
		}
	}
}