Skip to content
Snippets Groups Projects
Commit b931c496 authored by Jono Wenger's avatar Jono Wenger
Browse files

Merge branch 'jono/multiStateVector' into 'release'

Add MultiStateVector

See merge request !99
parents 3eef0c38 c18d2fa2
No related branches found
No related tags found
2 merge requests!117Release,!99Add MultiStateVector
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
////////////////////////////////////////////////////////////////////////////////
package utility
import (
"bytes"
"encoding/binary"
"github.com/pkg/errors"
"gitlab.com/elixxir/client/storage/versioned"
"gitlab.com/xx_network/primitives/netTime"
"math"
"sync"
)
// Storage key and version.
const (
multiStateVectorKey = "multiStateVector/"
multiStateVectorVersion = 0
)
// Minimum and maximum allowed number of states.
const (
minNumStates = 2
maxNumStates = 8
)
// Error messages.
const (
keyNumMaxErr = "key number %d greater than max %d"
stateMaxErr = "given state %d greater than max %d"
// NewMultiStateVector
numStateMinErr = "number of received states %d must be greater than %d"
numStateMaxErr = "number of received states %d greater than max %d"
// MultiStateVector.Set
setStateErr = "failed to set state of key %d: %+v"
saveSetStateErr = "failed to save MultiStateVector after setting key %d state to %d: %+v"
// MultiStateVector.SetMany
setManyStateErr = "failed to set state of key %d (%d/%d): %+v"
saveManySetStateErr = "failed to save MultiStateVector after setting keys %d state to %d: %+v"
// MultiStateVector.set
setGetStateErr = "could not get state of key %d: %+v"
oldStateMaxErr = "stored state %d greater than max %d (this should not happen; maybe a storage failure)"
stateChangeErr = "cannot change state of key %d from %d to %d"
// LoadMultiStateVector
loadGetMsvErr = "failed to load MultiStateVector from storage: %+v"
loadUnmarshalMsvErr = "failed to unmarshal MultiStateVector loaded from storage: %+v"
// MultiStateVector.save
saveUnmarshalMsvErr = "failed to marshal MultiStateVector for storage: %+v"
// MultiStateVector.marshal
buffWriteNumKeysErr = "numKeys: %+v"
buffWriteNumStatesErr = "numStates: %+v"
buffWriteVectLenErr = "length of vector: %+v"
buffWriteVectBlockErr = "vector block %d/%d: %+v"
buffWriteStateCountBlockErr = "state use counter %d: %+v"
// MultiStateVector.unmarshal
buffReadNumKeysErr = "numKeys: %+v"
buffReadNumStatesErr = "numStates: %+v"
buffReadVectLenErr = "length of vector: %+v"
buffReadVectBlockErr = "vector block %d/%d: %+v"
buffReadStateCountBlockErr = "state use counter %d: %+v"
// checkStateMap
stateMapLenErr = "state map has %d states when %d are required"
stateMapStateLenErr = "state %d in state map has %d state transitions when %d are required"
)
// MultiStateVector stores a list of a set number of keys and their state. It
// supports any number of states, unlike the StateVector. It is storage backed.
type MultiStateVector struct {
numKeys uint16 // Total number of keys
numStates uint8 // Number of state per key
bitSize uint8 // Number of state bits per key
numKeysPerBlock uint8 // Number of keys that fit in a block
wordSize uint8 // Size of each word
// Bitfield for key states
vect []uint64
// A list of states and which states they can move to
stateMap [][]bool
// Stores the number of keys per state
stateUseCount []uint16
key string // Unique string used to save/load object from storage
kv *versioned.KV
mux sync.RWMutex
}
// NewMultiStateVector generates a new MultiStateVector with the specified
// number of keys and bits pr state per keys. The max number of states in 64.
func NewMultiStateVector(numKeys uint16, numStates uint8, stateMap [][]bool,
key string, kv *versioned.KV) (*MultiStateVector, error) {
// Return an error if the number of states is out of range
if numStates < minNumStates {
return nil, errors.Errorf(numStateMinErr, numStates, minNumStates)
} else if numStates > maxNumStates {
return nil, errors.Errorf(numStateMaxErr, numStates, maxNumStates)
}
// Return an error if the state map is of the wrong length
err := checkStateMap(numStates, stateMap)
if err != nil {
return nil, err
}
// Calculate the number of bits needed to represent all the states
numBits := uint8(math.Ceil(math.Log2(float64(numStates))))
// Calculate number of 64-bit blocks needed to store bitSize for numKeys
numBlocks := ((numKeys * uint16(numBits)) + 63) / 64
msv := &MultiStateVector{
numKeys: numKeys,
numStates: numStates,
bitSize: numBits,
numKeysPerBlock: 64 / numBits,
wordSize: (64 / numBits) * numBits,
vect: make([]uint64, numBlocks),
stateMap: stateMap,
stateUseCount: make([]uint16, numStates),
key: makeMultiStateVectorKey(key),
kv: kv,
}
msv.stateUseCount[0] = numKeys
return msv, msv.save()
}
// Get returns the state of the key.
func (msv *MultiStateVector) Get(keyNum uint16) (uint8, error) {
msv.mux.RLock()
defer msv.mux.RUnlock()
return msv.get(keyNum)
}
// get returns the state of the specified key. This function is not thread-safe.
func (msv *MultiStateVector) get(keyNum uint16) (uint8, error) {
// Return an error if the key number is greater than the number of keys
if keyNum > msv.numKeys-1 {
return 0, errors.Errorf(keyNumMaxErr, keyNum, msv.numKeys-1)
}
// Calculate block and position of the keyNum
block, pos := msv.getMultiBlockAndPos(keyNum)
bitMask := uint64(math.MaxUint64) >> (64 - msv.bitSize)
shift := 64 - pos - uint16(msv.bitSize)
return uint8((msv.vect[block] >> shift) & bitMask), nil
}
// Set marks the key with the given state. Returns an error for an invalid state
// or if the state change is not allowed.
func (msv *MultiStateVector) Set(keyNum uint16, state uint8) error {
msv.mux.Lock()
defer msv.mux.Unlock()
// Set state of the key
err := msv.set(keyNum, state)
if err != nil {
return errors.Errorf(setStateErr, keyNum, err)
}
// Save changes to storage
err = msv.save()
if err != nil {
return errors.Errorf(saveSetStateErr, keyNum, state, err)
}
return nil
}
// SetMany marks each of the keys with the given state. Returns an error for an
// invalid state or if the state change is not allowed.
func (msv *MultiStateVector) SetMany(keyNums []uint16, state uint8) error {
msv.mux.Lock()
defer msv.mux.Unlock()
// Set state of each key
for i, keyNum := range keyNums {
err := msv.set(keyNum, state)
if err != nil {
return errors.Errorf(setManyStateErr, keyNum, i+1, len(keyNums), err)
}
}
// Save changes to storage
err := msv.save()
if err != nil {
return errors.Errorf(saveManySetStateErr, keyNums, state, err)
}
return nil
}
// set sets the specified key to the specified state. An error is returned if
// the given state is invalid or the state change is not allowed by the state
// map.
func (msv *MultiStateVector) set(keyNum uint16, state uint8) error {
// Return an error if the key number is greater than the number of keys
if keyNum > msv.numKeys-1 {
return errors.Errorf(keyNumMaxErr, keyNum, msv.numKeys-1)
}
// Return an error if the state is larger than the max states
if state > msv.numStates-1 {
return errors.Errorf(stateMaxErr, state, msv.numStates-1)
}
// Get the current state
oldState, err := msv.get(keyNum)
if err != nil {
return errors.Errorf(setGetStateErr, keyNum, err)
}
// Check that the current state is within the allowed states
if oldState > msv.numStates-1 {
return errors.Errorf(oldStateMaxErr, oldState, msv.numStates-1)
}
// Check that the state change is allowed (only if state map was supplied)
if msv.stateMap != nil && !msv.stateMap[oldState][state] {
return errors.Errorf(stateChangeErr, keyNum, oldState, state)
}
// Calculate block and position of the key
block, _ := msv.getMultiBlockAndPos(keyNum)
// Clear the key's state
msv.vect[block] &= ^getSelectionMask(keyNum, msv.bitSize)
// Set the state
msv.vect[block] |= shiftBitsOut(uint64(state), keyNum, msv.bitSize)
// Increment/decrement state counters
msv.stateUseCount[oldState]--
msv.stateUseCount[state]++
return nil
}
// GetNumKeys returns the total number of keys.
func (msv *MultiStateVector) GetNumKeys() uint16 {
msv.mux.RLock()
defer msv.mux.RUnlock()
return msv.numKeys
}
// GetCount returns the number of keys with the given state.
func (msv *MultiStateVector) GetCount(state uint8) (uint16, error) {
msv.mux.RLock()
defer msv.mux.RUnlock()
// Return an error if the state is larger than the max states
if int(state) >= len(msv.stateUseCount) {
return 0, errors.Errorf(stateMaxErr, state, msv.numStates-1)
}
return msv.stateUseCount[state], nil
}
// GetKeys returns a list of all keys with the specified status.
func (msv *MultiStateVector) GetKeys(state uint8) ([]uint16, error) {
msv.mux.RLock()
defer msv.mux.RUnlock()
// Return an error if the state is larger than the max states
if state > msv.numStates-1 {
return nil, errors.Errorf(stateMaxErr, state, msv.numStates-1)
}
// Initialise list with capacity set to number of keys in the state
keys := make([]uint16, 0, msv.stateUseCount[state])
// Loop through each key and add any unused to the list
for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ {
keyState, err := msv.get(keyNum)
if err != nil {
return nil, errors.Errorf(keyNumMaxErr, keyNum, err)
}
if keyState == state {
keys = append(keys, keyNum)
}
}
return keys, nil
}
// DeepCopy creates a deep copy of the MultiStateVector without a storage
// backend. The deep copy can only be used for functions that do not access
// storage.
func (msv *MultiStateVector) DeepCopy() *MultiStateVector {
msv.mux.RLock()
defer msv.mux.RUnlock()
// Copy all primitive values into the new MultiStateVector
newMSV := &MultiStateVector{
numKeys: msv.numKeys,
numStates: msv.numStates,
bitSize: msv.bitSize,
numKeysPerBlock: msv.numKeysPerBlock,
wordSize: msv.wordSize,
vect: make([]uint64, len(msv.vect)),
stateMap: make([][]bool, len(msv.stateMap)),
stateUseCount: make([]uint16, len(msv.stateUseCount)),
key: msv.key,
}
// Copy over all values in the vector
copy(newMSV.vect, msv.vect)
// Copy over all values in the state map
if msv.stateMap == nil {
newMSV.stateMap = nil
} else {
for state, stateTransitions := range msv.stateMap {
newMSV.stateMap[state] = make([]bool, len(stateTransitions))
copy(newMSV.stateMap[state], stateTransitions)
}
}
// Copy over all values in the state use counter
copy(newMSV.stateUseCount, msv.stateUseCount)
return newMSV
}
// getMultiBlockAndPos calculates the block index and the position within that
// block of the key.
func (msv *MultiStateVector) getMultiBlockAndPos(keyNum uint16) (
block, pos uint16) {
block = keyNum / uint16(msv.numKeysPerBlock)
pos = (keyNum % uint16(msv.numKeysPerBlock)) * uint16(msv.bitSize)
return block, pos
}
// checkStateMap checks that the state map has the correct number of states and
// correct number of state transitions per state. Returns an error if any of the
// lengths are incorrect. Returns nil if they are all correct or if the stateMap
// is nil.
func checkStateMap(numStates uint8, stateMap [][]bool) error {
if stateMap == nil {
return nil
}
// Checks the length of the first dimension of the state map
if len(stateMap) != int(numStates) {
return errors.Errorf(stateMapLenErr, len(stateMap), numStates)
}
// Checks the length of each transition slice for each state
for i, state := range stateMap {
if len(state) != int(numStates) {
return errors.Errorf(stateMapStateLenErr, i, len(state), numStates)
}
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// Selection Bit Mask //
////////////////////////////////////////////////////////////////////////////////
// getSelectionMask returns the selection bit mask for the given key number and
// state bit size. The masks for each state is found by right shifting
// bitSize * keyNum.
func getSelectionMask(keyNum uint16, bitSize uint8) uint64 {
// Get the mask at the zeroth position at the bit size; these masks look
// like the following for bit sizes 1 through 4
// 0b1000000000000000000000000000000000000000000000000000000000000000
// 0b1100000000000000000000000000000000000000000000000000000000000000
// 0b1110000000000000000000000000000000000000000000000000000000000000
// 0b1111000000000000000000000000000000000000000000000000000000000000
initialMask := uint64(math.MaxUint64) << (64 - bitSize)
// Shift the mask to the keyNum location; for example, a mask of size 3
// in position 3 would be:
// 0b1110000000000000000000000000000000000000000000000000000000000000 ->
// 0b0000000001110000000000000000000000000000000000000000000000000000
shiftedMask := shiftBitsIn(initialMask, keyNum, bitSize)
return shiftedMask
}
// shiftBitsOut shifts the most significant bits of size bitSize to the left to
// the position of keyNum.
func shiftBitsOut(bits uint64, keyNum uint16, bitSize uint8) uint64 {
// Return original bits when the bit size is zero
if bitSize == 0 {
return bits
}
// Calculate the number of keys stored in each block; blocks are designed to
// only the states that cleaning fit fully within the block. All reminder
// trailing bits are unused. The below code calculates the number of actual
// keys in a block.
keysPerBlock := 64 / uint16(bitSize)
// Calculate the index of the key in the local block
keyIndex := keyNum % keysPerBlock
// The starting bit position of the key in the block
pos := keyIndex * uint16(bitSize)
// Shift the bits left to account for the size of the key
bitSizeShift := 64 - uint64(bitSize)
leftShift := bits << bitSizeShift
// Shift the initial mask based upon the bit position of the key
shifted := leftShift >> pos
return shifted
}
// shiftBitsIn shifts the least significant bits of size bitSize to the right to
// the position of keyNum.
func shiftBitsIn(bits uint64, keyNum uint16, bitSize uint8) uint64 {
// Return original bits when the bit size is zero
if bitSize == 0 {
return bits
}
// Calculate the number of keys stored in each block; blocks are designed to
// only the states that cleaning fit fully within the block. All reminder
// trailing bits are unused. The below code calculates the number of actual
// keys in a block.
keysPerBlock := 64 / uint16(bitSize)
// Calculate the index of the key in the local block
keyIndex := keyNum % keysPerBlock
// The starting bit position of the key in the block
pos := keyIndex * uint16(bitSize)
// Shift the initial mask based upon the bit position of the key
shifted := bits >> pos
return shifted
}
////////////////////////////////////////////////////////////////////////////////
// Storage Functions //
////////////////////////////////////////////////////////////////////////////////
// LoadMultiStateVector loads a MultiStateVector with the specified key from the
// given versioned storage.
func LoadMultiStateVector(stateMap [][]bool, key string, kv *versioned.KV) (
*MultiStateVector, error) {
msv := &MultiStateVector{
stateMap: stateMap,
key: makeMultiStateVectorKey(key),
kv: kv,
}
// Load MultiStateVector data from storage
obj, err := kv.Get(msv.key, multiStateVectorVersion)
if err != nil {
return nil, errors.Errorf(loadGetMsvErr, err)
}
// Unmarshal data
err = msv.unmarshal(obj.Data)
if err != nil {
return nil, errors.Errorf(loadUnmarshalMsvErr, err)
}
return msv, nil
}
// save stores the MultiStateVector in storage.
func (msv *MultiStateVector) save() error {
// Marshal the MultiStateVector
data, err := msv.marshal()
if err != nil {
return errors.Errorf(saveUnmarshalMsvErr, err)
}
// Create the versioned object
obj := &versioned.Object{
Version: multiStateVectorVersion,
Timestamp: netTime.Now(),
Data: data,
}
return msv.kv.Set(msv.key, multiStateVectorVersion, obj)
}
// Delete removes the MultiStateVector from storage.
func (msv *MultiStateVector) Delete() error {
msv.mux.Lock()
defer msv.mux.Unlock()
return msv.kv.Delete(msv.key, multiStateVectorVersion)
}
// marshal serialises the MultiStateVector into a byte slice.
func (msv *MultiStateVector) marshal() ([]byte, error) {
var buff bytes.Buffer
buff.Grow((4 * 4) + 8 + (8 * len(msv.vect)))
// Write numKeys to buffer
err := binary.Write(&buff, binary.LittleEndian, msv.numKeys)
if err != nil {
return nil, errors.Errorf(buffWriteNumKeysErr, err)
}
// Write numStates to buffer
err = binary.Write(&buff, binary.LittleEndian, msv.numStates)
if err != nil {
return nil, errors.Errorf(buffWriteNumStatesErr, err)
}
// Write length of vect to buffer
err = binary.Write(&buff, binary.LittleEndian, uint64(len(msv.vect)))
if err != nil {
return nil, errors.Errorf(buffWriteVectLenErr, err)
}
// Write vector to buffer
for i, block := range msv.vect {
err = binary.Write(&buff, binary.LittleEndian, block)
if err != nil {
return nil, errors.Errorf(
buffWriteVectBlockErr, i, len(msv.vect), err)
}
}
// Write state use counter slice to buffer
for state, count := range msv.stateUseCount {
err = binary.Write(&buff, binary.LittleEndian, count)
if err != nil {
return nil, errors.Errorf(buffWriteStateCountBlockErr, state, err)
}
}
return buff.Bytes(), nil
}
// unmarshal deserializes a byte slice into a MultiStateVector.
func (msv *MultiStateVector) unmarshal(b []byte) error {
buff := bytes.NewReader(b)
// Write numKeys to buffer
err := binary.Read(buff, binary.LittleEndian, &msv.numKeys)
if err != nil {
return errors.Errorf(buffReadNumKeysErr, err)
}
// Write numStates to buffer
err = binary.Read(buff, binary.LittleEndian, &msv.numStates)
if err != nil {
return errors.Errorf(buffReadNumStatesErr, err)
}
// Calculate the number of bits needed to represent all the states
msv.bitSize = uint8(math.Ceil(math.Log2(float64(msv.numStates))))
// Calculate numbers of keys per block
msv.numKeysPerBlock = 64 / msv.bitSize
// Calculate the word size
msv.wordSize = msv.numKeysPerBlock * msv.bitSize
// Write vect to buffer
var vectLen uint64
err = binary.Read(buff, binary.LittleEndian, &vectLen)
if err != nil {
return errors.Errorf(buffReadVectLenErr, err)
}
// Create new vector
msv.vect = make([]uint64, vectLen)
// Write vect to buffer
for i := range msv.vect {
err = binary.Read(buff, binary.LittleEndian, &msv.vect[i])
if err != nil {
return errors.Errorf(buffReadVectBlockErr, i, vectLen, err)
}
}
// Create new state use counter
msv.stateUseCount = make([]uint16, msv.numStates)
for i := range msv.stateUseCount {
err = binary.Read(buff, binary.LittleEndian, &msv.stateUseCount[i])
if err != nil {
return errors.Errorf(buffReadStateCountBlockErr, i, err)
}
}
return nil
}
// makeMultiStateVectorKey generates the unique key used to save a
// MultiStateVector to storage.
func makeMultiStateVectorKey(key string) string {
return multiStateVectorKey + key
}
////////////////////////////////////////////////////////////////////////////////
// Copyright © 2020 xx network SEZC //
// //
// Use of this source code is governed by a license that can be found in the //
// LICENSE file //
////////////////////////////////////////////////////////////////////////////////
package utility
import (
"bytes"
"encoding/base64"
"fmt"
"gitlab.com/elixxir/client/storage/versioned"
"gitlab.com/elixxir/ekv"
"gitlab.com/xx_network/primitives/netTime"
"math"
"math/rand"
"reflect"
"strconv"
"strings"
"testing"
)
// Tests that NewMultiStateVector returns a new MultiStateVector with the
// expected values.
func TestNewMultiStateVector(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
key := "testKey"
expected := &MultiStateVector{
numKeys: 189,
numStates: 5,
bitSize: 3,
numKeysPerBlock: 21,
wordSize: 63,
vect: make([]uint64, 9),
stateMap: [][]bool{
{true, true, true, true, true},
{true, false, true, true, true},
{false, false, false, true, true},
{false, false, false, false, true},
{false, false, false, false, false},
},
stateUseCount: []uint16{189, 0, 0, 0, 0},
key: makeMultiStateVectorKey(key),
kv: kv,
}
newMSV, err := NewMultiStateVector(expected.numKeys, expected.numStates, expected.stateMap, key, kv)
if err != nil {
t.Errorf("NewMultiStateVector returned an error: %+v", err)
}
if !reflect.DeepEqual(expected, newMSV) {
t.Errorf("New MultiStateVector does not match expected."+
"\nexpected: %+v\nreceived: %+v", expected, newMSV)
}
}
// Error path: tests that NewMultiStateVector returns the expected error when
// the number of states is too small.
func TestNewMultiStateVector_NumStatesMinError(t *testing.T) {
expectedErr := fmt.Sprintf(numStateMinErr, minNumStates-1, minNumStates)
_, err := NewMultiStateVector(5, minNumStates-1, nil, "", nil)
if err == nil || err.Error() != expectedErr {
t.Errorf("NewMultiStateVector did not return the expected error when "+
"the number of states is too small.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Error path: tests that NewMultiStateVector returns the expected error when
// the number of states is too large.
func TestNewMultiStateVector_NumStatesMaxError(t *testing.T) {
expectedErr := fmt.Sprintf(numStateMaxErr, maxNumStates+1, maxNumStates)
_, err := NewMultiStateVector(5, maxNumStates+1, nil, "", nil)
if err == nil || err.Error() != expectedErr {
t.Errorf("NewMultiStateVector did not return the expected error when "+
"the number of states is too large.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Error path: tests that NewMultiStateVector returns the expected error when
// the state map is of the wrong size.
func TestNewMultiStateVector_StateMapError(t *testing.T) {
expectedErr := fmt.Sprintf(stateMapLenErr, 1, 5)
_, err := NewMultiStateVector(5, 5, [][]bool{{true}}, "", nil)
if err == nil || err.Error() != expectedErr {
t.Errorf("NewMultiStateVector did not return the expected error when "+
"the state map is too small.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Tests that MultiStateVector.Get returns the expected states for all keys in
// two situations. First, it tests that all states are zero on creation. Second,
// random states are generated and manually inserted into the vector and then
// each key is checked for the expected vector
func TestMultiStateVector_Get(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
// Check that all states are zero
for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ {
state, err := msv.Get(keyNum)
if err != nil {
t.Errorf("Get returned an error for key %d: %+v", keyNum, err)
}
if state != 0 {
t.Errorf("Key %d has unexpected state.\nexpected: %d\nreceived: %d",
keyNum, 0, state)
}
}
// Generate slice of expected states and set vector to expected values. Each
// state is generated randomly. A binary string is built for each block
// using a lookup table (note the LUT can only be used for 3-bit states).
// When the string is 64 characters long, it is converted from a binary
// string to an uint64 and inserted into the vector.
expectedStates := make([]uint8, msv.numKeys)
prng := rand.New(rand.NewSource(42))
stateLUT := map[uint8]string{0: "000", 1: "001", 2: "010", 3: "011",
4: "100", 5: "101", 6: "110", 7: "111"}
block, blockString := 0, ""
keysInVect := uint16(int(msv.numKeysPerBlock) * len(msv.vect))
for keyNum := uint16(0); keyNum < keysInVect; keyNum++ {
if keyNum < msv.numKeys {
state := uint8(prng.Intn(int(msv.numStates)))
blockString += stateLUT[state]
expectedStates[keyNum] = state
} else {
blockString += stateLUT[0]
}
if (keyNum+1)%uint16(msv.numKeysPerBlock) == 0 {
val, err := strconv.ParseUint(blockString+"0", 2, 64)
if err != nil {
t.Errorf("Failed to parse vector string %d: %+v", block, err)
}
msv.vect[block] = val
block++
blockString = ""
}
}
// Check that each key has the expected state
for keyNum, expectedState := range expectedStates {
state, err := msv.Get(uint16(keyNum))
if err != nil {
t.Errorf("Get returned an error for key %d: %+v", keyNum, err)
}
if expectedState != state {
t.Errorf("Key %d has unexpected state.\nexpected: %d\nreceived: %d",
keyNum, expectedState, state)
}
}
}
// Error path: tests that MultiStateVector.get returns the expected error when
// the key number is greater than the max key number.
func TestMultiStateVector_get_KeyNumMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
keyNum := msv.numKeys
expectedErr := fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1)
_, err = msv.get(keyNum)
if err == nil || err.Error() != expectedErr {
t.Errorf("get did not return the expected error when the key number "+
"is larger than the max key number.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Tests that MultiStateVector.Set sets the correct key to the correct status
// and that the state use counter is set correctly.
func TestMultiStateVector_Set(t *testing.T) {
stateMap := [][]bool{
{false, true, false, false, false},
{true, false, true, false, false},
{false, true, false, true, false},
{false, false, true, false, true},
{false, false, false, false, false},
}
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, stateMap, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
testValues := []struct {
keyNum uint16
state uint8
newCount, oldCount uint16
}{
{22, 1, 1, msv.numKeys - 1},
{22, 2, 1, 0},
{2, 1, 1, msv.numKeys - 2},
{2, 2, 2, 0},
{22, 3, 1, 1},
{16, 1, 1, msv.numKeys - 3},
{22, 4, 1, 0},
{16, 2, 2, 0},
{2, 3, 1, 1},
{16, 1, 1, 0},
{2, 4, 2, 0},
{16, 2, 1, 0},
{16, 3, 1, 0},
{16, 4, 3, 0},
}
for i, v := range testValues {
oldStatus, _ := msv.Get(v.keyNum)
if err = msv.Set(v.keyNum, v.state); err != nil {
t.Errorf("Failed to move key %d to state %d (%d): %+v",
v.keyNum, v.state, i, err)
} else if received, _ := msv.Get(v.keyNum); received != v.state {
t.Errorf("Key %d has state %d instead of %d (%d).",
v.keyNum, received, v.state, i)
}
if msv.stateUseCount[v.state] != v.newCount {
t.Errorf("Count for new state %d is %d instead of %d (%d).",
v.state, msv.stateUseCount[v.state], v.newCount, i)
}
if msv.stateUseCount[oldStatus] != v.oldCount {
t.Errorf("Count for old state %d is %d instead of %d (%d).",
oldStatus, msv.stateUseCount[oldStatus], v.oldCount, i)
}
}
}
// Error path: tests that MultiStateVector.Set returns the expected error when
// the key number is greater than the last key number.
func TestMultiStateVector_Set_KeyNumMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
keyNum := msv.numKeys
expectedErr := fmt.Sprintf(
setStateErr, keyNum, fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1))
err = msv.Set(keyNum, 1)
if err == nil || !strings.Contains(err.Error(), expectedErr) {
t.Errorf("Set did not return the expected error when the key number "+
"is larger than the max key number.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Tests that MultiStateVector.SetMany
func TestMultiStateVector_SetMany(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
testValues := []struct {
keyNums []uint16
state uint8
newCount uint16
}{
{[]uint16{2, 22, 12, 19}, 1, 4},
{[]uint16{1, 5, 154, 7, 6}, 4, 5},
}
for i, v := range testValues {
if err = msv.SetMany(v.keyNums, v.state); err != nil {
t.Errorf("Failed to move keys %d to state %d (%d): %+v",
v.keyNums, v.state, i, err)
} else {
for _, keyNum := range v.keyNums {
if received, _ := msv.Get(keyNum); received != v.state {
t.Errorf("Key %d has state %d instead of %d (%d).",
keyNum, received, v.state, i)
}
}
}
if msv.stateUseCount[v.state] != v.newCount {
t.Errorf("Count for new state %d is %d instead of %d (%d).",
v.state, msv.stateUseCount[v.state], v.newCount, i)
}
}
}
// Error path: tests that MultiStateVector.SetMany returns the expected error
// when one of the keys is greater than the last key number.
func TestMultiStateVector_SetMany_KeyNumMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
keyNum := msv.numKeys
expectedErr := fmt.Sprintf(setManyStateErr, keyNum, 1, 1,
fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1))
err = msv.SetMany([]uint16{keyNum}, 1)
if err == nil || !strings.Contains(err.Error(), expectedErr) {
t.Errorf("SetMany did not return the expected error when the key "+
"number is larger than the max key number."+
"\nexpected: %s\nreceived: %v", expectedErr, err)
}
}
// Error path: tests that MultiStateVector.set returns the expected error when
// the key number is greater than the last key number.
func TestMultiStateVector_set_KeyNumMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
keyNum := msv.numKeys
expectedErr := fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1)
err = msv.set(keyNum, 1)
if err == nil || err.Error() != expectedErr {
t.Errorf("set did not return the expected error when the key number "+
"is larger than the max key number.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Error path: tests that MultiStateVector.set returns the expected error when
// the given state is greater than the last state.
func TestMultiStateVector_set_NewStateMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
state := msv.numStates
expectedErr := fmt.Sprintf(stateMaxErr, state, msv.numStates-1)
err = msv.set(0, state)
if err == nil || err.Error() != expectedErr {
t.Errorf("set did not return the expected error when the state is "+
"larger than the max number of states.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Error path: tests that MultiStateVector.set returns the expected error when
// the state read from the vector is greater than the last state.
func TestMultiStateVector_set_OldStateMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
msv.vect[0] = 0b1110000000000000000000000000000000000000000000000000000000000000
expectedErr := fmt.Sprintf(oldStateMaxErr, 7, msv.numStates-1)
err = msv.set(0, 1)
if err == nil || err.Error() != expectedErr {
t.Errorf("set did not return the expected error when the state is "+
"larger than the max number of states.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Error path: tests that MultiStateVector.set returns the expected error when
// the state change is not allowed by the state map.
func TestMultiStateVector_set_StateChangeError(t *testing.T) {
stateMap := [][]bool{{true, false}, {true, true}}
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 2, stateMap, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
keyNum, state := uint16(0), uint8(1)
expectedErr := fmt.Sprintf(stateChangeErr, keyNum, 0, state)
err = msv.set(keyNum, state)
if err == nil || err.Error() != expectedErr {
t.Errorf("set did not return the expected error when the state change "+
"should not be allowed.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Tests that MultiStateVector.GetNumKeys returns the expected number of keys.
func TestMultiStateVector_GetNumKeys(t *testing.T) {
numKeys := uint16(155)
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(numKeys, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
if numKeys != msv.GetNumKeys() {
t.Errorf("Got unexpected number of keys.\nexpected: %d\nreceived: %d",
numKeys, msv.GetNumKeys())
}
}
// Tests that MultiStateVector.GetCount returns the correct count for each state
// after each key has been set to a random state.
func TestMultiStateVector_GetCount(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(
156, 5, nil, "TestMultiStateVector_GetCount", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
expectedCounts := make([]uint16, msv.numStates)
prng := rand.New(rand.NewSource(42))
for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ {
state := uint8(prng.Intn(int(msv.numStates)))
err = msv.Set(keyNum, state)
if err != nil {
t.Errorf("Failed to set key %d to state %d.", keyNum, state)
}
expectedCounts[state]++
}
for state, expected := range expectedCounts {
count, err := msv.GetCount(uint8(state))
if err != nil {
t.Errorf("GetCount returned an error for state %d: %+v", state, err)
}
if expected != count {
t.Errorf("Incorrect count for state %d.\nexpected: %d\nreceived: %d",
state, expected, count)
}
}
}
// Error path: tests that MultiStateVector.GetCount returns the expected error
// when the given state is greater than the last state.
func TestMultiStateVector_GetCount_NewStateMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
state := msv.numStates
expectedErr := fmt.Sprintf(stateMaxErr, state, msv.numStates-1)
_, err = msv.GetCount(state)
if err == nil || err.Error() != expectedErr {
t.Errorf("GetCount did not return the expected error when the state is "+
"larger than the max number of states.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Tests that MultiStateVector.GetKeys returns the correct list of keys for each
// state after each key has been set to a random state.
func TestMultiStateVector_GetKeys(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(
156, 5, nil, "TestMultiStateVector_GetKeys", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
expectedKeys := make([][]uint16, msv.numStates)
prng := rand.New(rand.NewSource(42))
for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ {
state := uint8(prng.Intn(int(msv.numStates)))
err = msv.Set(keyNum, state)
if err != nil {
t.Errorf("Failed to set key %d to state %d.", keyNum, state)
}
expectedKeys[state] = append(expectedKeys[state], keyNum)
}
for state, expected := range expectedKeys {
keys, err := msv.GetKeys(uint8(state))
if err != nil {
t.Errorf("GetKeys returned an error: %+v", err)
}
if !reflect.DeepEqual(expected, keys) {
t.Errorf("Incorrect keys for state %d.\nexpected: %d\nreceived: %d",
state, expected, keys)
}
}
}
// Error path: tests that MultiStateVector.GetKeys returns the expected error
// when the given state is greater than the last state.
func TestMultiStateVector_GetKeys_NewStateMaxError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
state := msv.numStates
expectedErr := fmt.Sprintf(stateMaxErr, state, msv.numStates-1)
_, err = msv.GetKeys(state)
if err == nil || err.Error() != expectedErr {
t.Errorf("GetKeys did not return the expected error when the state is "+
"larger than the max number of states.\nexpected: %s\nreceived: %v",
expectedErr, err)
}
}
// Tests that MultiStateVector.DeepCopy makes a copy of the values and not of
// the pointers.
func TestMultiStateVector_DeepCopy(t *testing.T) {
stateMap := [][]bool{
{false, true, false, false, false},
{true, false, true, false, false},
{false, true, false, true, false},
{false, false, true, false, true},
{false, false, false, false, false},
}
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, stateMap, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
newMSV := msv.DeepCopy()
msv.kv = nil
// Check that the values are the same
if !reflect.DeepEqual(msv, newMSV) {
t.Errorf("Original and copy do not match."+
"\nexpected: %+v\nreceived: %+v", msv, newMSV)
}
// Check that the pointers are different
if msv == newMSV {
t.Errorf("Pointers for original and copy match."+
"\nexpected: %p\nreceived: %p", msv, newMSV)
}
if &msv.vect == &newMSV.vect {
t.Errorf("Pointers for original and copy vect match."+
"\nexpected: %p\nreceived: %p", msv.vect, newMSV.vect)
}
if &msv.vect[0] == &newMSV.vect[0] {
t.Errorf("Pointers for original and copy vect[0] match."+
"\nexpected: %p\nreceived: %p", &msv.vect[0], &newMSV.vect[0])
}
if &msv.stateMap == &newMSV.stateMap {
t.Errorf("Pointers for original and copy stateMap match."+
"\nexpected: %p\nreceived: %p", msv.stateMap, newMSV.stateMap)
}
if &msv.stateMap[0] == &newMSV.stateMap[0] {
t.Errorf("Pointers for original and copy stateMap[0] match."+
"\nexpected: %p\nreceived: %p",
&msv.stateMap[0], &newMSV.stateMap[0])
}
if &msv.stateMap[0][0] == &newMSV.stateMap[0][0] {
t.Errorf("Pointers for original and copy stateMap[0][0] match."+
"\nexpected: %p\nreceived: %p",
&msv.stateMap[0][0], &newMSV.stateMap[0][0])
}
if &msv.stateUseCount == &newMSV.stateUseCount {
t.Errorf("Pointers for original and copy stateUseCount match."+
"\nexpected: %p\nreceived: %p",
msv.stateUseCount, newMSV.stateUseCount)
}
if &msv.stateUseCount[0] == &newMSV.stateUseCount[0] {
t.Errorf("Pointers for original and copy stateUseCount[0] match."+
"\nexpected: %p\nreceived: %p",
&msv.stateUseCount[0], &newMSV.stateUseCount[0])
}
}
// Tests that MultiStateVector.DeepCopy is able to make the expected copy when
// the state map is nil.
func TestMultiStateVector_DeepCopy_NilStateMap(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(155, 5, nil, "", kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
newMSV := msv.DeepCopy()
msv.kv = nil
// Check that the values are the same
if !reflect.DeepEqual(msv, newMSV) {
t.Errorf("Original and copy do not match."+
"\nexpected: %+v\nreceived: %+v", msv, newMSV)
}
}
// Tests that getMultiBlockAndPos returns the expected block and position for
// various key numbers, bit sizes, and word sizes.
func Test_getMultiBlockAndPos(t *testing.T) {
testValues := []struct {
keyNum, block, pos uint16
msv *MultiStateVector
}{
{0, 0, 0, getTestMSV(150, 2)},
{1, 0, 1, getTestMSV(150, 2)},
{24, 0, 24, getTestMSV(150, 2)},
{64, 1, 0, getTestMSV(150, 2)},
{0, 0, 0, getTestMSV(150, 5)},
{1, 0, 3, getTestMSV(150, 5)},
{9, 0, 27, getTestMSV(150, 5)},
{21, 1, 0, getTestMSV(150, 5)},
{0, 0, 0, getTestMSV(150, 255)},
{1, 0, 8, getTestMSV(150, 255)},
{2, 0, 16, getTestMSV(150, 255)},
{8, 1, 0, getTestMSV(150, 255)},
{0, 0, 0, getTestMSV(150, 127)},
{8, 0, 56, getTestMSV(150, 127)},
{9, 1, 0, getTestMSV(150, 127)},
}
for i, v := range testValues {
block, pos := v.msv.getMultiBlockAndPos(v.keyNum)
if v.block != block || v.pos != pos {
t.Errorf("Incorrect block/pos for key %2d (%d bits, %2d word size): "+
"expected %d/%-2d, got %d/%-2d (%d).", v.keyNum, v.msv.bitSize,
v.msv.wordSize, v.block, v.pos, block, pos, i)
}
}
}
func getTestMSV(numKeys uint16, numStates uint8) *MultiStateVector {
numBits := uint8(math.Ceil(math.Log2(float64(numStates))))
return &MultiStateVector{
numKeys: numKeys,
numStates: numStates,
bitSize: numBits,
numKeysPerBlock: 64 / numBits,
wordSize: (64 / numBits) * numBits,
}
}
// Tests that checkStateMap does not return an error for various state map
// sizes.
func Test_checkStateMap(tt *testing.T) {
const (
t = true
f = false
)
testValues := []struct {
numStates uint8
stateMap [][]bool
}{
{9, nil},
{0, [][]bool{}},
{1, [][]bool{{t}}},
{2, [][]bool{{t, t}, {f, f}}},
{3, [][]bool{{t, t, f}, {f, f, t}, {t, t, t}}},
{4, [][]bool{{t, t, f, f}, {f, f, t, f}, {t, t, t, f}, {f, f, t, t}}},
}
for i, v := range testValues {
err := checkStateMap(v.numStates, v.stateMap)
if err != nil {
tt.Errorf("Could not verify state map #%d with %d states: %+v",
i, v.numStates, err)
}
}
}
// Error path: tests that checkStateMap returns an error for various state map
// sizes and number of states mismatches.
func Test_checkStateMap_Error(tt *testing.T) {
const (
t = true
f = false
)
testValues := []struct {
numStates uint8
stateMap [][]bool
err string
}{
{1, [][]bool{},
fmt.Sprintf(stateMapLenErr, 0, 1)},
{0, [][]bool{{t}},
fmt.Sprintf(stateMapLenErr, 1, 0)},
{2, [][]bool{{t, t}, {f, f}, {f, f}},
fmt.Sprintf(stateMapLenErr, 3, 2)},
{3, [][]bool{{t, t, f}, {f, f, f, t}, {t, t, t}},
fmt.Sprintf(stateMapStateLenErr, 1, 4, 3)},
{4, [][]bool{{t, t, f, f}, {f, f, t, f}, {t, t, t, f}, {f, t, t}},
fmt.Sprintf(stateMapStateLenErr, 3, 3, 4)},
}
for i, v := range testValues {
err := checkStateMap(v.numStates, v.stateMap)
if err == nil || !strings.Contains(err.Error(), v.err) {
tt.Errorf("Verified invalid state map #%d with %d states."+
"\nexpected: %s\nreceived: %v",
i, v.numStates, v.err, err)
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Selection Bit Mask //
////////////////////////////////////////////////////////////////////////////////
// Tests that getSelectionMask returns the expected bit mask for various key
// numbers and bit sizes.
func Test_getSelectionMask(t *testing.T) {
testValues := []struct {
keyNum uint16
bitSize uint8
mask uint64
}{
{0, 0, 0b0000000000000000000000000000000000000000000000000000000000000000},
{0, 1, 0b1000000000000000000000000000000000000000000000000000000000000000},
{1, 0, 0b0000000000000000000000000000000000000000000000000000000000000000},
{1, 1, 0b0100000000000000000000000000000000000000000000000000000000000000},
{63, 1, 0b0000000000000000000000000000000000000000000000000000000000000001},
{64, 1, 0b1000000000000000000000000000000000000000000000000000000000000000},
{0, 3, 0b1110000000000000000000000000000000000000000000000000000000000000},
{1, 3, 0b0001110000000000000000000000000000000000000000000000000000000000},
{20, 3, 0b0000000000000000000000000000000000000000000000000000000000001110},
{21, 3, 0b1110000000000000000000000000000000000000000000000000000000000000},
{0, 17, 0b1111111111111111100000000000000000000000000000000000000000000000},
{1, 17, 0b0000000000000000011111111111111111000000000000000000000000000000},
{2, 17, 0b0000000000000000000000000000000000111111111111111110000000000000},
{3, 17, 0b1111111111111111100000000000000000000000000000000000000000000000},
{0, 33, 0b1111111111111111111111111111111110000000000000000000000000000000},
{1, 33, 0b1111111111111111111111111111111110000000000000000000000000000000},
{2, 33, 0b1111111111111111111111111111111110000000000000000000000000000000},
{3, 33, 0b1111111111111111111111111111111110000000000000000000000000000000},
{0, 64, 0b1111111111111111111111111111111111111111111111111111111111111111},
{1, 64, 0b1111111111111111111111111111111111111111111111111111111111111111},
{2, 64, 0b1111111111111111111111111111111111111111111111111111111111111111},
{3, 64, 0b1111111111111111111111111111111111111111111111111111111111111111},
}
for i, v := range testValues {
mask := getSelectionMask(v.keyNum, v.bitSize)
if v.mask != mask {
t.Errorf("Unexpected bit mask for key %d of size %d (%d)."+
"\nexpected: %064b\nreceived: %064b",
v.keyNum, v.bitSize, i, v.mask, mask)
}
}
}
// Tests that shiftBitsOut returns the expected shifted bits for various key
// numbers and bit sizes.
func Test_shiftBitsOut(t *testing.T) {
testValues := []struct {
keyNum uint16
bitSize uint8
bits, expected uint64
}{
{0, 0,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b1000000000000000000000000000000000000000000000000000000000000001},
{0, 1,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b1000000000000000000000000000000000000000000000000000000000000000},
{1, 0,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b1000000000000000000000000000000000000000000000000000000000000001},
{1, 2,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b0001000000000000000000000000000000000000000000000000000000000000},
{1, 32,
0b0111111111111111111111111111111111111111111111111111111111111110,
0b0000000000000000000000000000000011111111111111111111111111111110},
{13, 7,
0b0111111111111111111111111111111111111111111111111111111111111110,
0b0000000000000000000000000000111111000000000000000000000000000000},
}
for i, v := range testValues {
bits := shiftBitsOut(v.bits, v.keyNum, v.bitSize)
if v.expected != bits {
t.Errorf("Unexpected shifted mask for key %d of size %d (%d)."+
"\nexpected: %064b\nreceived: %064b",
v.keyNum, v.bitSize, i, v.expected, bits)
}
}
}
// Tests that shiftBitsIn returns the expected shifted bits for various key
// numbers and bit sizes.
func Test_shiftBitsIn(t *testing.T) {
testValues := []struct {
keyNum uint16
bitSize uint8
bits, expected uint64
}{
{0, 0,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b1000000000000000000000000000000000000000000000000000000000000001},
{0, 1,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b1000000000000000000000000000000000000000000000000000000000000001},
{1, 0,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b1000000000000000000000000000000000000000000000000000000000000001},
{1, 2,
0b1000000000000000000000000000000000000000000000000000000000000001,
0b0010000000000000000000000000000000000000000000000000000000000000},
{1, 32,
0b0111111111111111111111111111111111111111111111111111111111111110,
0b0000000000000000000000000000000001111111111111111111111111111111},
{13, 7,
0b0111111111111111111111111111111111111111111111111111111111111110,
0b0000000000000000000000000000011111111111111111111111111111111111},
}
for i, v := range testValues {
bits := shiftBitsIn(v.bits, v.keyNum, v.bitSize)
if v.expected != bits {
t.Errorf("Unexpected shifted mask for key %d of size %d (%d)."+
"\nexpected: %064b\nreceived: %064b",
v.keyNum, v.bitSize, i, v.expected, bits)
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Storage Functions //
////////////////////////////////////////////////////////////////////////////////
// Tests that a MultiStateVector loaded from storage via LoadMultiStateVector
// matches the original.
func TestLoadMultiStateVector(t *testing.T) {
key := "TestLoadMultiStateVector"
msv, kv := newTestFilledMSV(156, 5, nil, "TestLoadMultiStateVector", t)
// Attempt to load MultiStateVector from storage
loadedSV, err := LoadMultiStateVector(nil, key, kv)
if err != nil {
t.Fatalf("LoadMultiStateVector returned an error: %+v", err)
}
if !reflect.DeepEqual(msv, loadedSV) {
t.Errorf("Loaded MultiStateVector does not match original saved."+
"\nexpected: %+v\nreceived: %+v", msv, loadedSV)
}
}
// Error path: tests that LoadMultiStateVector returns the expected error when
// no object is saved in storage.
func TestLoadMultiStateVector_GetFromStorageError(t *testing.T) {
key := "TestLoadMultiStateVector_GetFromStorageError"
kv := versioned.NewKV(make(ekv.Memstore))
expectedErr := strings.Split(loadGetMsvErr, "%")[0]
_, err := LoadMultiStateVector(nil, key, kv)
if err == nil || !strings.Contains(err.Error(), expectedErr) {
t.Fatalf("LoadMultiStateVector did not return the expected error "+
"when no object exists in storage.\nexpected: %s\nreceived: %+v",
expectedErr, err)
}
}
// Error path: tests that LoadMultiStateVector returns the expected error when
// the data in storage cannot be unmarshalled.
func TestLoadMultiStateVector_UnmarshalError(t *testing.T) {
kv := versioned.NewKV(make(ekv.Memstore))
key := "TestLoadMultiStateVector_MarshalError"
expectedErr := strings.Split(loadUnmarshalMsvErr, "%")[0]
// Save invalid data to storage
err := kv.Set(makeMultiStateVectorKey(key), multiStateVectorVersion,
&versioned.Object{
Version: multiStateVectorVersion,
Timestamp: netTime.Now(),
Data: []byte("?"),
})
if err != nil {
t.Errorf("Failed to save data to storage: %+v", err)
}
_, err = LoadMultiStateVector(nil, key, kv)
if err == nil || !strings.Contains(err.Error(), expectedErr) {
t.Fatalf("LoadMultiStateVector did not return the expected error "+
"when the object in storage should be invalid."+
"\nexpected: %s\nreceived: %+v", expectedErr, err)
}
}
// Tests that the data saved via MultiStateVector.save to storage matches the
// expected data.
func TestMultiStateVector_save(t *testing.T) {
msv := &MultiStateVector{
numKeys: 3 * 64,
numStates: 5,
bitSize: 3,
numKeysPerBlock: 21,
wordSize: 63,
vect: []uint64{0, 1, 2},
stateUseCount: []uint16{5, 12, 104, 0, 4000},
key: makeStateVectorKey("TestMultiStateVector_save"),
kv: versioned.NewKV(make(ekv.Memstore)),
}
expectedData, err := msv.marshal()
if err != nil {
t.Errorf("Failed to marshal MultiStateVector: %+v", err)
}
// Save to storage
err = msv.save()
if err != nil {
t.Errorf("save returned an error: %+v", err)
}
// Check that the object can be loaded
loadedData, err := msv.kv.Get(msv.key, multiStateVectorVersion)
if err != nil {
t.Errorf("Failed to load MultiStateVector from storage: %+v", err)
}
if !bytes.Equal(expectedData, loadedData.Data) {
t.Errorf("Loaded data does not match expected."+
"\nexpected: %v\nreceived: %v", expectedData, loadedData)
}
}
// Tests that MultiStateVector.Delete removes the MultiStateVector from storage.
func TestMultiStateVector_Delete(t *testing.T) {
msv, _ := newTestFilledMSV(156, 5, nil, "TestMultiStateVector_Delete", t)
err := msv.Delete()
if err != nil {
t.Errorf("Delete returned an error: %+v", err)
}
// Check that the object can be loaded
loadedData, err := msv.kv.Get(msv.key, multiStateVectorVersion)
if err == nil {
t.Errorf("Loaded MultiStateVector from storage when it should be "+
"deleted: %v", loadedData)
}
}
// Tests that a MultiStateVector marshalled with MultiStateVector.marshal and
// unmarshalled with MultiStateVector.unmarshal matches the original.
func TestMultiStateVector_marshal_unmarshal(t *testing.T) {
prng := rand.New(rand.NewSource(42))
msv := &MultiStateVector{
numKeys: 3 * 64,
numStates: 5,
bitSize: 3,
numKeysPerBlock: 21,
wordSize: 63,
vect: []uint64{prng.Uint64(), prng.Uint64(), prng.Uint64()},
stateUseCount: []uint16{5, 12, 104, 0, 4000},
}
marshalledBytes, err := msv.marshal()
if err != nil {
t.Errorf("marshal returned an error: %+v", err)
}
newMsv := &MultiStateVector{}
err = newMsv.unmarshal(marshalledBytes)
if err != nil {
t.Errorf("unmarshal returned an error: %+v", err)
}
if !reflect.DeepEqual(&msv, &newMsv) {
t.Errorf("Marshalled and unmarsalled MultiStateVector does not match "+
"the original.\nexpected: %+v\nreceived: %+v", msv, newMsv)
}
}
// Consistency test of makeMultiStateVectorKey.
func Test_makeMultiStateVectorKey(t *testing.T) {
prng := rand.New(rand.NewSource(42))
expectedStrings := []string{
multiStateVectorKey + "U4x/lrFkvxuXu59LtHLon1sU",
multiStateVectorKey + "hPJSCcnZND6SugndnVLf15tN",
multiStateVectorKey + "dkKbYXoMn58NO6VbDMDWFEyI",
multiStateVectorKey + "hTWEGsvgcJsHWAg/YdN1vAK0",
multiStateVectorKey + "HfT5GSnhj9qeb4LlTnSOgeee",
multiStateVectorKey + "S71v40zcuoQ+6NY+jE/+HOvq",
multiStateVectorKey + "VG2PrBPdGqwEzi6ih3xVec+i",
multiStateVectorKey + "x44bC6+uiBuCp1EQikLtPJA8",
multiStateVectorKey + "qkNGWnhiBhaXiu0M48bE8657",
multiStateVectorKey + "w+BJW1cS/v2+DBAoh+EA2s0t",
}
for i, expected := range expectedStrings {
b := make([]byte, 18)
prng.Read(b)
key := makeMultiStateVectorKey(base64.StdEncoding.EncodeToString(b))
if expected != key {
t.Errorf("New MultiStateVector key does not match expected (%d)."+
"\nexpected: %q\nreceived: %q", i, expected, key)
}
}
}
// newTestFilledMSV produces a new MultiStateVector and sets each key to a
// random state.
func newTestFilledMSV(numKeys uint16, numStates uint8, stateMap [][]bool,
key string, t *testing.T) (*MultiStateVector, *versioned.KV) {
kv := versioned.NewKV(make(ekv.Memstore))
msv, err := NewMultiStateVector(numKeys, numStates, stateMap, key, kv)
if err != nil {
t.Errorf("Failed to create new MultiStateVector: %+v", err)
}
prng := rand.New(rand.NewSource(42))
for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ {
state := uint8(prng.Intn(int(msv.numStates)))
err = msv.Set(keyNum, state)
if err != nil {
t.Errorf("Failed to set key %d to state %d.", keyNum, state)
}
}
return msv, kv
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment