// Copyright 2017 Frank Schroeder. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package assert provides helper functions for testing.
package assert

import (
	"fmt"
	"path/filepath"
	"reflect"
	"regexp"
	"runtime"
	"strings"
	"testing"
)

// skip defines the default call depth
const skip = 2

// Equal asserts that got and want are equal as defined by
// reflect.DeepEqual. The test fails with msg if they are not equal.
func Equal(t *testing.T, got, want interface{}, msg ...string) {
	if x := equal(2, got, want, msg...); x != "" {
		fmt.Println(x)
		t.Fail()
	}
}

func equal(skip int, got, want interface{}, msg ...string) string {
	if !reflect.DeepEqual(got, want) {
		return fail(skip, "got %v want %v %s", got, want, strings.Join(msg, " "))
	}
	return ""
}

// Panic asserts that function fn() panics.
// It assumes that recover() either returns a string or
// an error and fails if the message does not match
// the regular expression in 'matches'.
func Panic(t *testing.T, fn func(), matches string) {
	if x := doesPanic(2, fn, matches); x != "" {
		fmt.Println(x)
		t.Fail()
	}
}

func doesPanic(skip int, fn func(), expr string) (err string) {
	defer func() {
		r := recover()
		if r == nil {
			err = fail(skip, "did not panic")
			return
		}
		var v string
		switch r.(type) {
		case error:
			v = r.(error).Error()
		case string:
			v = r.(string)
		}
		err = matches(skip, v, expr)
	}()
	fn()
	return ""
}

// Matches asserts that a value matches a given regular expression.
func Matches(t *testing.T, value, expr string) {
	if x := matches(2, value, expr); x != "" {
		fmt.Println(x)
		t.Fail()
	}
}

func matches(skip int, value, expr string) string {
	ok, err := regexp.MatchString(expr, value)
	if err != nil {
		return fail(skip, "invalid pattern %q. %s", expr, err)
	}
	if !ok {
		return fail(skip, "got %s which does not match %s", value, expr)
	}
	return ""
}

func fail(skip int, format string, args ...interface{}) string {
	_, file, line, _ := runtime.Caller(skip)
	return fmt.Sprintf("\t%s:%d: %s\n", filepath.Base(file), line, fmt.Sprintf(format, args...))
}