package http

import (
	"net"
	"net/http"
	"reflect"
	"testing"

	"golang.org/x/net/nettest"

	"github.com/go-chi/chi/v5"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestGetOptions(t *testing.T) {
	tests := []struct {
		name string
		want Options
	}{
		{name: "basic", want: defaultServerOptions},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := GetOptions(); !reflect.DeepEqual(got, tt.want) {
				t.Errorf("GetOptions() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestMount(t *testing.T) {
	type args struct {
		pattern string
		h       http.Handler
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{name: "basic", args: args{
			pattern: "/",
			h:       http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}),
		}, wantErr: false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if tt.wantErr {
				require.Error(t, Mount(tt.args.pattern, tt.args.h))
			} else {
				require.NoError(t, Mount(tt.args.pattern, tt.args.h))
			}
			assert.NotNil(t, defaultServer)
			assert.True(t, defaultServer.baseRouter.Match(chi.NewRouteContext(), "GET", tt.args.pattern), "Failed to match route after registering")
		})
		if err := Shutdown(); err != nil {
			t.Fatal(err)
		}
	}
}

func TestNewServer(t *testing.T) {
	type args struct {
		listeners    []net.Listener
		tlsListeners []net.Listener
		opt          Options
	}
	listener, err := nettest.NewLocalListener("tcp")
	if err != nil {
		t.Fatal(err)
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{name: "default http", args: args{
			listeners:    []net.Listener{listener},
			tlsListeners: []net.Listener{},
			opt:          defaultServerOptions,
		}, wantErr: false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := NewServer(tt.args.listeners, tt.args.tlsListeners, tt.args.opt)
			if (err != nil) != tt.wantErr {
				t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			s, ok := got.(*server)
			require.True(t, ok, "NewServer returned unexpected type")
			if len(tt.args.listeners) > 0 {
				assert.Equal(t, listener.Addr(), s.addrs[0])
			} else {
				assert.Empty(t, s.addrs)
			}
			if len(tt.args.tlsListeners) > 0 {
				assert.Equal(t, listener.Addr(), s.tlsAddrs[0])
			} else {
				assert.Empty(t, s.tlsAddrs)
			}
			if tt.args.opt.BaseURL != "" {
				assert.NotSame(t, s.baseRouter, s.httpServer.Handler, "should have wrapped baseRouter")
			} else {
				assert.Same(t, s.baseRouter, s.httpServer.Handler, "should be baseRouter")
			}
			if useSSL(tt.args.opt) {
				assert.NotNil(t, s.httpServer.TLSConfig, "missing SSL config")
			} else {
				assert.Nil(t, s.httpServer.TLSConfig, "unexpectedly has SSL config")
			}
		})
	}
}

func TestRestart(t *testing.T) {
	tests := []struct {
		name    string
		started bool
		wantErr bool
	}{
		{name: "started", started: true, wantErr: false},
		{name: "stopped", started: false, wantErr: false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if tt.started {
				require.NoError(t, Restart()) // Call it twice basically
			} else {
				require.NoError(t, Shutdown())
			}
			current := defaultServer
			if err := Restart(); (err != nil) != tt.wantErr {
				t.Errorf("Restart() error = %v, wantErr %v", err, tt.wantErr)
			}
			assert.NotNil(t, defaultServer, "failed to start default server")
			assert.NotSame(t, current, defaultServer, "same server instance as before restart")
		})
	}
}

func TestRoute(t *testing.T) {
	type args struct {
		pattern string
		fn      func(r chi.Router)
	}
	tests := []struct {
		name string
		args args
		test func(t *testing.T, r chi.Router)
	}{
		{
			name: "basic",
			args: args{
				pattern: "/basic",
				fn:      func(r chi.Router) {},
			},
			test: func(t *testing.T, r chi.Router) {
				require.Len(t, r.Routes(), 1)
				assert.Equal(t, r.Routes()[0].Pattern, "/basic/*")
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			require.NoError(t, Restart())
			_, err := Route(tt.args.pattern, tt.args.fn)
			require.NoError(t, err)
			tt.test(t, defaultServer.baseRouter)
		})

		if err := Shutdown(); err != nil {
			t.Fatal(err)
		}
	}
}

func TestSetOptions(t *testing.T) {
	type args struct {
		opt Options
	}
	tests := []struct {
		name string
		args args
	}{
		{
			name: "basic",
			args: args{opt: Options{
				ListenAddr:         "127.0.0.1:9999",
				BaseURL:            "/basic",
				ServerReadTimeout:  1,
				ServerWriteTimeout: 1,
				MaxHeaderBytes:     1,
			}},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			SetOptions(tt.args.opt)
			require.Equal(t, tt.args.opt, defaultServerOptions)
			require.NoError(t, Restart())
			if useSSL(tt.args.opt) {
				assert.Equal(t, tt.args.opt.ListenAddr, defaultServer.tlsAddrs[0].String())
			} else {
				assert.Equal(t, tt.args.opt.ListenAddr, defaultServer.addrs[0].String())
			}
			assert.Equal(t, tt.args.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout)
			assert.Equal(t, tt.args.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout)
			assert.Equal(t, tt.args.opt.MaxHeaderBytes, defaultServer.httpServer.MaxHeaderBytes)
			if tt.args.opt.BaseURL != "" && tt.args.opt.BaseURL != "/" {
				assert.NotSame(t, defaultServer.httpServer.Handler, defaultServer.baseRouter, "BaseURL ignored")
			}
		})
		SetOptions(DefaultOpt)
	}
}

func TestShutdown(t *testing.T) {
	tests := []struct {
		name    string
		started bool
		wantErr bool
	}{
		{name: "started", started: true, wantErr: false},
		{name: "stopped", started: false, wantErr: false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if tt.started {
				require.NoError(t, Restart())
			} else {
				require.NoError(t, Shutdown()) // Call it twice basically
			}
			if err := Shutdown(); (err != nil) != tt.wantErr {
				t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr)
			}
			assert.Nil(t, defaultServer, "default server not deleted")
		})
	}
}

func TestURL(t *testing.T) {
	tests := []struct {
		name string
		want string
	}{
		{name: "basic", want: "http://127.0.0.1:8080/"},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			require.NoError(t, Restart())
			if got := URL(); got != tt.want {
				t.Errorf("URL() = %v, want %v", got, tt.want)
			}
		})
	}
}

func Test_server_Mount(t *testing.T) {
	type args struct {
		pattern string
		h       http.Handler
	}
	tests := []struct {
		name string
		args args
		opt  Options
	}{
		{name: "basic", args: args{
			pattern: "/",
			h:       http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}),
		}, opt: defaultServerOptions},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			listener, err := nettest.NewLocalListener("tcp")
			require.NoError(t, err)
			s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt)
			require.NoError(t, err2)
			s.Mount(tt.args.pattern, tt.args.h)
			srv, ok := s.(*server)
			require.True(t, ok)
			assert.NotNil(t, srv)
			assert.True(t, srv.baseRouter.Match(chi.NewRouteContext(), "GET", tt.args.pattern), "Failed to Match() route after registering")
		})
	}
}

func Test_server_Route(t *testing.T) {
	type args struct {
		pattern string
		fn      func(r chi.Router)
	}
	tests := []struct {
		name string
		args args
		opt  Options
		test func(t *testing.T, r chi.Router)
	}{
		{
			name: "basic",
			args: args{
				pattern: "/basic",
				fn: func(r chi.Router) {

				},
			},
			test: func(t *testing.T, r chi.Router) {
				require.Len(t, r.Routes(), 1)
				assert.Equal(t, r.Routes()[0].Pattern, "/basic/*")
			},
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			listener, err := nettest.NewLocalListener("tcp")
			require.NoError(t, err)
			s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt)
			require.NoError(t, err2)
			s.Route(tt.args.pattern, tt.args.fn)
			srv, ok := s.(*server)
			require.True(t, ok)
			assert.NotNil(t, srv)
			tt.test(t, srv.baseRouter)
		})
	}
}

func Test_server_Shutdown(t *testing.T) {
	tests := []struct {
		name    string
		opt     Options
		wantErr bool
	}{
		{
			name:    "basic",
			opt:     defaultServerOptions,
			wantErr: false,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			listener, err := nettest.NewLocalListener("tcp")
			require.NoError(t, err)
			s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt)
			require.NoError(t, err2)
			srv, ok := s.(*server)
			require.True(t, ok)
			if err := s.Shutdown(); (err != nil) != tt.wantErr {
				t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr)
			}
			assert.EqualError(t, srv.httpServer.Serve(listener), http.ErrServerClosed.Error())
		})
	}
}

func Test_start(t *testing.T) {
	tests := []struct {
		name    string
		opt     Options
		wantErr bool
	}{
		{
			name:    "basic",
			opt:     defaultServerOptions,
			wantErr: false,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			SetOptions(tt.opt)
			if err := start(); (err != nil) != tt.wantErr {
				t.Errorf("start() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			s := defaultServer
			if useSSL(tt.opt) {
				assert.Equal(t, tt.opt.ListenAddr, s.tlsAddrs[0].String())
			} else {
				assert.Equal(t, tt.opt.ListenAddr, s.addrs[0].String())
			}
			/* accessing s.httpServer.* can't be done synchronously and is a race condition
			assert.Equal(t, tt.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout)
			assert.Equal(t, tt.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout)
			assert.Equal(t, tt.opt.MaxHeaderBytes, defaultServer.httpServer.MaxHeaderBytes)
			if tt.opt.BaseURL != "" && tt.opt.BaseURL != "/" {
				assert.NotSame(t, s.baseRouter, s.httpServer.Handler, "should have wrapped baseRouter")
			} else {
				assert.Same(t, s.baseRouter, s.httpServer.Handler, "should be baseRouter")
			}
			if useSSL(tt.opt) {
				require.NotNil(t, s.httpServer.TLSConfig, "missing SSL config")
				assert.NotEmpty(t, s.httpServer.TLSConfig.Certificates, "missing SSL config")
			} else if s.httpServer.TLSConfig != nil {
				assert.Empty(t, s.httpServer.TLSConfig.Certificates, "unexpectedly has SSL config")
			}
			*/
		})
	}
}

func Test_useSSL(t *testing.T) {
	type args struct {
		opt Options
	}
	tests := []struct {
		name string
		args args
		want bool
	}{
		{
			name: "basic",
			args: args{opt: Options{
				SslCert:  "",
				SslKey:   "",
				ClientCA: "",
			}},
			want: false,
		},
		{
			name: "basic",
			args: args{opt: Options{
				SslCert:  "",
				SslKey:   "test",
				ClientCA: "",
			}},
			want: true,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := useSSL(tt.args.opt); got != tt.want {
				t.Errorf("useSSL() = %v, want %v", got, tt.want)
			}
		})
	}
}