///////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC                                          //
//                                                                           //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file                                                              //
///////////////////////////////////////////////////////////////////////////////

package partition

import (
	"bytes"
	"encoding/json"
	"gitlab.com/elixxir/client/interfaces/message"
	"gitlab.com/elixxir/client/storage/versioned"
	"gitlab.com/elixxir/crypto/e2e"
	"gitlab.com/elixxir/ekv"
	"gitlab.com/xx_network/primitives/id"
	"gitlab.com/xx_network/primitives/netTime"
	"math/rand"
	"reflect"
	"testing"
	"time"
)

// Tests the creation part of loadOrCreateMultiPartMessage().
func Test_loadOrCreateMultiPartMessage_Create(t *testing.T) {
	// Set up expected test value
	prng := rand.New(rand.NewSource(netTime.Now().UnixNano()))
	expectedMpm := &multiPartMessage{
		Sender:          id.NewIdFromUInt(prng.Uint64(), id.User, t),
		MessageID:       prng.Uint64(),
		NumParts:        0,
		PresentParts:    0,
		SenderTimestamp: time.Time{},
		MessageType:     0,
		kv:              versioned.NewKV(make(ekv.Memstore)),
	}
	expectedData, err := json.Marshal(expectedMpm)
	if err != nil {
		t.Fatalf("Failed to marshal expected multiPartMessage: %v", err)
	}

	// Make new multiPartMessage
	mpm := loadOrCreateMultiPartMessage(expectedMpm.Sender,
		expectedMpm.MessageID, expectedMpm.kv)

	CheckMultiPartMessages(expectedMpm, mpm, t)

	obj, err := mpm.kv.Get(messageKey, 0)
	if err != nil {
		t.Errorf("Get() failed to get multiPartMessage from key value store: %v", err)
	}

	if !bytes.Equal(expectedData, obj.Data) {
		t.Errorf("loadOrCreateMultiPartMessage() did not save the "+
			"multiPartMessage correctly.\n\texpected: %+v\n\treceived: %+v",
			expectedData, obj.Data)
	}
}

// Tests the loading part of loadOrCreateMultiPartMessage().
func Test_loadOrCreateMultiPartMessage_Load(t *testing.T) {
	// Set up expected test value
	prng := rand.New(rand.NewSource(netTime.Now().UnixNano()))
	expectedMpm := &multiPartMessage{
		Sender:          id.NewIdFromUInt(prng.Uint64(), id.User, t),
		MessageID:       prng.Uint64(),
		NumParts:        0,
		PresentParts:    0,
		SenderTimestamp: time.Time{},
		MessageType:     0,
		kv:              versioned.NewKV(make(ekv.Memstore)),
	}
	err := expectedMpm.save()
	if err != nil {
		t.Fatalf("Failed to save multiPartMessage: %v", err)
	}

	// Make new multiPartMessage
	mpm := loadOrCreateMultiPartMessage(expectedMpm.Sender,
		expectedMpm.MessageID, expectedMpm.kv)

	CheckMultiPartMessages(expectedMpm, mpm, t)
}

func CheckMultiPartMessages(expectedMpm *multiPartMessage, mpm *multiPartMessage, t *testing.T) {
	// The kv differs because it has prefix called, so we compare fields individually
	if expectedMpm.SenderTimestamp != mpm.SenderTimestamp {
		t.Errorf("timestamps mismatch: expected %v, got %v", expectedMpm.SenderTimestamp, mpm.SenderTimestamp)
	}
	if expectedMpm.MessageType != mpm.MessageType {
		t.Errorf("messagetype mismatch: expected %v, got %v", expectedMpm.MessageID, mpm.MessageID)
	}
	if expectedMpm.MessageID != mpm.MessageID {
		t.Errorf("messageid mismatch: expected %v, got %v", expectedMpm.MessageID, mpm.MessageID)
	}
	if expectedMpm.NumParts != mpm.NumParts {
		t.Errorf("numparts mismatch: expected %v, got %v", expectedMpm.NumParts, mpm.NumParts)
	}
	if expectedMpm.PresentParts != mpm.PresentParts {
		t.Errorf("presentparts mismatch: expected %v, got %v", expectedMpm.PresentParts, mpm.PresentParts)
	}
	if !expectedMpm.Sender.Cmp(mpm.Sender) {
		t.Errorf("sender mismatch: expected %v, got %v", expectedMpm.Sender, mpm.Sender)
	}
	if len(expectedMpm.parts) != len(mpm.parts) {
		t.Error("parts different length")
	}
	for i := range expectedMpm.parts {
		if !bytes.Equal(expectedMpm.parts[i], mpm.parts[i]) {
			t.Errorf("parts differed at index %v", i)
		}
	}
}

// Tests happy path of multiPartMessage.Add().
func TestMultiPartMessage_Add(t *testing.T) {
	// Generate test values
	prng := rand.New(rand.NewSource(netTime.Now().UnixNano()))
	mpm := loadOrCreateMultiPartMessage(id.NewIdFromUInt(prng.Uint64(), id.User, t),
		prng.Uint64(), versioned.NewKV(make(ekv.Memstore)))
	partNums, parts := generateParts(prng, 0)

	for i := range partNums {
		mpm.Add(partNums[i], parts[i])
	}

	for i, p := range partNums {
		if !bytes.Equal(mpm.parts[p], parts[i]) {
			t.Errorf("Incorrect part at index %d (%d)."+
				"\n\texpected: %v\n\treceived: %v", p, i, parts[i], mpm.parts[p])
		}
	}

	if len(partNums) != int(mpm.PresentParts) {
		t.Errorf("Incorrect PresentParts.\n\texpected: %d\n\treceived: %d",
			len(partNums), int(mpm.PresentParts))
	}

	expectedData, err := json.Marshal(mpm)
	if err != nil {
		t.Fatalf("Failed to marshal expected multiPartMessage: %v", err)
	}

	obj, err := mpm.kv.Get(messageKey, 0)
	if err != nil {
		t.Errorf("Get() failed to get multiPartMessage from key value store: %v", err)
	}

	if !bytes.Equal(expectedData, obj.Data) {
		t.Errorf("loadOrCreateMultiPartMessage() did not save the "+
			"multiPartMessage correctly.\n\texpected: %+v\n\treceived: %+v",
			expectedData, obj.Data)
	}
}

// Tests happy path of multiPartMessage.AddFirst().
func TestMultiPartMessage_AddFirst(t *testing.T) {
	// Generate test values
	prng := rand.New(rand.NewSource(netTime.Now().UnixNano()))
	expectedMpm := &multiPartMessage{
		Sender:          id.NewIdFromUInt(prng.Uint64(), id.User, t),
		MessageID:       prng.Uint64(),
		NumParts:        uint8(prng.Uint32()),
		PresentParts:    1,
		SenderTimestamp: netTime.Now(),
		MessageType:     message.NoType,
		parts:           make([][]byte, 3),
		kv:              versioned.NewKV(make(ekv.Memstore)),
	}
	expectedMpm.parts[2] = []byte{5, 8, 78, 9}
	npm := loadOrCreateMultiPartMessage(expectedMpm.Sender,
		expectedMpm.MessageID, expectedMpm.kv)

	npm.AddFirst(expectedMpm.MessageType, 2, expectedMpm.NumParts,
		expectedMpm.SenderTimestamp, netTime.Now(), expectedMpm.parts[2])

	CheckMultiPartMessages(expectedMpm, npm, t)

	data, err := loadPart(npm.kv, 2)
	if err != nil {
		t.Errorf("loadPart() produced an error: %v", err)
	}

	if !bytes.Equal(data, expectedMpm.parts[2]) {
		t.Errorf("AddFirst() did not save multiPartMessage correctly."+
			"\n\texpected: %#v\n\treceived: %#v", expectedMpm.parts[2], data)
	}
}

// Tests happy path of multiPartMessage.IsComplete().
func TestMultiPartMessage_IsComplete(t *testing.T) {
	// Create multiPartMessage and fill with random parts
	prng := rand.New(rand.NewSource(netTime.Now().UnixNano()))
	mid := prng.Uint64()
	mpm := loadOrCreateMultiPartMessage(id.NewIdFromUInt(prng.Uint64(), id.User, t),
		mid, versioned.NewKV(make(ekv.Memstore)))
	partNums, parts := generateParts(prng, 75)

	// Check that IsComplete() is false where there are no parts
	msg, complete := mpm.IsComplete([]byte{0})
	if complete {
		t.Error("IsComplete() returned true when NumParts == 0.")
	}

	mpm.AddFirst(message.XxMessage, partNums[0], 75, netTime.Now(), netTime.Now(), parts[0])
	for i := range partNums {
		if i > 0 {
			mpm.Add(partNums[i], parts[i])
		}
	}

	msg, complete = mpm.IsComplete([]byte{0})
	if !complete {
		t.Error("IsComplete() returned false when the message should be complete.")
	}

	var payload []byte
	for _, b := range mpm.parts {
		payload = append(payload, b...)
	}

	expectedMsg := message.Receive{
		Payload:     payload,
		MessageType: mpm.MessageType,
		Sender:      mpm.Sender,
		Timestamp:   msg.Timestamp,
		Encryption:  0,
		ID:          e2e.NewMessageID([]byte{0}, mid),
	}

	if !reflect.DeepEqual(expectedMsg, msg) {
		t.Errorf("IsComplete() did not return the expected message."+
			"\n\texpected: %v\n\treceived: %v", expectedMsg, msg)
	}

}

// Tests happy path of multiPartMessage.delete().
func TestMultiPartMessage_delete(t *testing.T) {
	prng := rand.New(rand.NewSource(netTime.Now().UnixNano()))
	kv := versioned.NewKV(make(ekv.Memstore))
	mpm := loadOrCreateMultiPartMessage(id.NewIdFromUInt(prng.Uint64(), id.User, t),
		prng.Uint64(), kv)

	mpm.delete()
	obj, err := kv.Get(messageKey, 0)
	if ekv.Exists(err) {
		t.Errorf("delete() did not properly delete key %s."+
			"\n\tobject received: %+v", messageKey, obj)
	}
}

// generateParts generates a list of test part numbers and a list of test parts.
func generateParts(r *rand.Rand, numParts uint8) ([]uint8, [][]byte) {
	if numParts == 0 {
		numParts = uint8(25 + r.Intn(150))
	}
	partNums := make([]uint8, numParts)
	parts := make([][]byte, len(partNums))
	nums := map[uint8]bool{}

	for i := range partNums {
		n := uint8(r.Intn(int(numParts)))
		for ; nums[n] == true; n = uint8(r.Intn(int(numParts))) {
		}
		nums[n] = true
		partNums[i] = n
		parts[i] = make([]byte, r.Int31n(100))
		_, _ = r.Read(parts[i])
	}

	return partNums, parts
}