///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC                                          //
//                                                                           //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file                                                              //
///////////////////////////////////////////////////////////////////////////////

package single

import (
	"gitlab.com/xx_network/primitives/id"
	"reflect"
	"strings"
	"sync/atomic"
	"testing"
	"time"
)

// Happy path: trigger the exit channel.
func Test_pending_addState_ExitChan(t *testing.T) {
	p := newPending()
	rid := id.NewIdFromString("test RID", id.User, t)
	dhKey := getGroup().NewInt(5)
	maxMsgs := uint8(6)
	timeout := 500 * time.Millisecond
	callback, callbackChan := createReplyComm()

	quitChan, quit, err := p.addState(rid, dhKey, maxMsgs, callback, timeout)
	if err != nil {
		t.Errorf("addState() returned an error: %+v", err)
	}

	hasQuit := atomic.CompareAndSwapInt32(quit, 0, 1)
	if !hasQuit {
		t.Error("Quit atomic called.")
	}

	expectedState := newState(dhKey, maxMsgs, callback)

	quitChan <- struct{}{}

	timer := time.NewTimer(timeout + 1*time.Millisecond)

	select {
	case results := <-callbackChan:
		t.Errorf("Callback called when the quit channel was used."+
			"\npayload: %+v\nerror:   %+v", results.payload, results.err)
	case <-timer.C:
	}

	state, exists := p.singleUse[*rid]
	if !exists {
		t.Error("State not found in map.")
	}
	if !equalState(*expectedState, *state, t) {
		t.Errorf("State in map is incorrect.\nexpected: %+v\nreceived: %+v",
			*expectedState, *state)
	}

	hasQuit = atomic.CompareAndSwapInt32(quit, 0, 1)
	if hasQuit {
		t.Error("Quit atomic not called.")
	}
}

// Happy path: state is removed before deletion can occur.
func Test_pending_addState_StateRemoved(t *testing.T) {
	p := newPending()
	rid := id.NewIdFromString("test RID", id.User, t)
	dhKey := getGroup().NewInt(5)
	maxMsgs := uint8(6)
	timeout := 5 * time.Millisecond
	callback, callbackChan := createReplyComm()

	_, _, err := p.addState(rid, dhKey, maxMsgs, callback, timeout)
	if err != nil {
		t.Errorf("addState() returned an error: %+v", err)
	}

	p.Lock()
	delete(p.singleUse, *rid)
	p.Unlock()

	timer := time.NewTimer(timeout + 1*time.Millisecond)

	select {
	case results := <-callbackChan:
		t.Errorf("Callback should not have been called.\npayload: %+v\nerror:   %+v",
			results.payload, results.err)
	case <-timer.C:
	}
}

// Error path: timeout occurs and deletes the entry from the map.
func Test_pending_addState_TimeoutError(t *testing.T) {
	p := newPending()
	rid := id.NewIdFromString("test RID", id.User, t)
	dhKey := getGroup().NewInt(5)
	maxMsgs := uint8(6)
	timeout := 5 * time.Millisecond
	callback, callbackChan := createReplyComm()

	_, _, err := p.addState(rid, dhKey, maxMsgs, callback, timeout)
	if err != nil {
		t.Errorf("addState() returned an error: %+v", err)
	}

	expectedState := newState(dhKey, maxMsgs, callback)
	p.Lock()
	state, exists := p.singleUse[*rid]
	p.Unlock()
	if !exists {
		t.Error("State not found in map.")
	}
	if !equalState(*expectedState, *state, t) {
		t.Errorf("State in map is incorrect.\nexpected: %+v\nreceived: %+v",
			*expectedState, *state)
	}

	timerTimeout := timeout * 4

	select {
	case results := <-callbackChan:
		state, exists = p.singleUse[*rid]
		if exists {
			t.Errorf("State found in map when it should have been deleted."+
				"\nstate: %+v", state)
		}
		if results.payload != nil {
			t.Errorf("Payload not nil on timeout.\npayload: %+v", results.payload)
		}
		if results.err == nil || !strings.Contains(results.err.Error(), "timed out") {
			t.Errorf("Callback did not return a time out error on return: %+v", results.err)
		}
	case <-time.NewTimer(timerTimeout).C:
		t.Errorf("Failed to time out after %s.", timerTimeout)
	}
}

// Error path: state already exists.
func Test_pending_addState_StateExistsError(t *testing.T) {
	p := newPending()
	rid := id.NewIdFromString("test RID", id.User, t)
	dhKey := getGroup().NewInt(5)
	maxMsgs := uint8(6)
	timeout := 5 * time.Millisecond
	callback, _ := createReplyComm()

	quitChan, _, err := p.addState(rid, dhKey, maxMsgs, callback, timeout)
	if err != nil {
		t.Errorf("addState() returned an error: %+v", err)
	}
	quitChan <- struct{}{}

	quitChan, _, err = p.addState(rid, dhKey, maxMsgs, callback, timeout)
	if !check(err, "a state already exists in the map") {
		t.Errorf("addState() did not return an error when the state already "+
			"exists: %+v", err)
	}
}

type replyCommData struct {
	payload []byte
	err     error
}

func createReplyComm() (func(payload []byte, err error), chan replyCommData) {
	callbackChan := make(chan replyCommData)
	callback := func(payload []byte, err error) {
		callbackChan <- replyCommData{
			payload: payload,
			err:     err,
		}
	}
	return callback, callbackChan
}

// equalState determines if the two states have equal values.
func equalState(a, b state, t *testing.T) bool {
	if a.dhKey.Cmp(b.dhKey) != 0 {
		t.Errorf("DH Keys differ.\nexpected: %s\nreceived: %s",
			a.dhKey.Text(10), b.dhKey.Text(10))
		return false
	}
	if !reflect.DeepEqual(a.fpMap.fps, b.fpMap.fps) {
		t.Errorf("Fingerprint maps differ.\nexpected: %+v\nreceived: %+v",
			a.fpMap.fps, b.fpMap.fps)
		return false
	}
	if !reflect.DeepEqual(b.c, b.c) {
		t.Errorf("collators differ.\nexpected: %+v\nreceived: %+v",
			a.c, b.c)
		return false
	}
	if reflect.ValueOf(a.callback).Pointer() != reflect.ValueOf(b.callback).Pointer() {
		t.Errorf("callbackFuncs differ.\nexpected: %p\nreceived: %p",
			a.callback, b.callback)
		return false
	}
	return true
}

// check returns true if the error is not nil and contains the substring.
func check(err error, subStr string) bool {
	return err != nil && strings.Contains(err.Error(), subStr)
}