package middleware

import (
	"errors"
	"fmt"
	"io"
	"math/rand"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"sync/atomic"
	"time"

	"github.com/labstack/echo"
)

// TODO: Handle TLS proxy

type (
	// ProxyConfig defines the config for Proxy middleware.
	ProxyConfig struct {
		// Skipper defines a function to skip middleware.
		Skipper Skipper

		// Balancer defines a load balancing technique.
		// Required.
		// Possible values:
		// - RandomBalancer
		// - RoundRobinBalancer
		Balancer ProxyBalancer
	}

	// ProxyTarget defines the upstream target.
	ProxyTarget struct {
		URL *url.URL
	}

	// RandomBalancer implements a random load balancing technique.
	RandomBalancer struct {
		Targets []*ProxyTarget
		random  *rand.Rand
	}

	// RoundRobinBalancer implements a round-robin load balancing technique.
	RoundRobinBalancer struct {
		Targets []*ProxyTarget
		i       uint32
	}

	// ProxyBalancer defines an interface to implement a load balancing technique.
	ProxyBalancer interface {
		Next() *ProxyTarget
	}
)

func proxyHTTP(t *ProxyTarget) http.Handler {
	return httputil.NewSingleHostReverseProxy(t.URL)
}

func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		h, ok := w.(http.Hijacker)
		if !ok {
			c.Error(errors.New("proxy raw, not a hijacker"))
			return
		}
		in, _, err := h.Hijack()
		if err != nil {
			c.Error(fmt.Errorf("proxy raw, hijack error=%v, url=%s", r.URL, err))
			return
		}
		defer in.Close()

		out, err := net.Dial("tcp", t.URL.Host)
		if err != nil {
			he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", r.URL, err))
			c.Error(he)
			return
		}
		defer out.Close()

		err = r.Write(out)
		if err != nil {
			he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, request copy error=%v, url=%s", r.URL, err))
			c.Error(he)
			return
		}

		errc := make(chan error, 2)
		cp := func(dst io.Writer, src io.Reader) {
			_, err := io.Copy(dst, src)
			errc <- err
		}

		go cp(out, in)
		go cp(in, out)
		err = <-errc
		if err != nil && err != io.EOF {
			c.Logger().Errorf("proxy raw, error=%v, url=%s", r.URL, err)
		}
	})
}

// Next randomly returns an upstream target.
func (r *RandomBalancer) Next() *ProxyTarget {
	if r.random == nil {
		r.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
	}
	return r.Targets[r.random.Intn(len(r.Targets))]
}

// Next returns an upstream target using round-robin technique.
func (r *RoundRobinBalancer) Next() *ProxyTarget {
	r.i = r.i % uint32(len(r.Targets))
	t := r.Targets[r.i]
	atomic.AddUint32(&r.i, 1)
	return t
}

// Proxy returns an HTTP/WebSocket reverse proxy middleware.
func Proxy(config ProxyConfig) echo.MiddlewareFunc {
	// Defaults
	if config.Skipper == nil {
		config.Skipper = DefaultLoggerConfig.Skipper
	}
	if config.Balancer == nil {
		panic("echo: proxy middleware requires balancer")
	}

	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) (err error) {
			req := c.Request()
			res := c.Response()
			tgt := config.Balancer.Next()

			// Fix header
			if req.Header.Get(echo.HeaderXRealIP) == "" {
				req.Header.Set(echo.HeaderXRealIP, c.RealIP())
			}
			if req.Header.Get(echo.HeaderXForwardedProto) == "" {
				req.Header.Set(echo.HeaderXForwardedProto, c.Scheme())
			}
			if c.IsWebSocket() && req.Header.Get(echo.HeaderXForwardedFor) == "" { // For HTTP, it is automatically set by Go HTTP reverse proxy.
				req.Header.Set(echo.HeaderXForwardedFor, c.RealIP())
			}

			// Proxy
			switch {
			case c.IsWebSocket():
				proxyRaw(tgt, c).ServeHTTP(res, req)
			case req.Header.Get(echo.HeaderAccept) == "text/event-stream":
			default:
				proxyHTTP(tgt).ServeHTTP(res, req)
			}

			return
		}
	}
}