diff --git a/storage/utility/multiStateVector.go b/storage/utility/multiStateVector.go new file mode 100644 index 0000000000000000000000000000000000000000..0a882ada564d7044ae1a5b42e914b1b311692146 --- /dev/null +++ b/storage/utility/multiStateVector.go @@ -0,0 +1,617 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package utility + +import ( + "bytes" + "encoding/binary" + "github.com/pkg/errors" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/netTime" + "math" + "sync" +) + +// Storage key and version. +const ( + multiStateVectorKey = "multiStateVector/" + multiStateVectorVersion = 0 +) + +// Minimum and maximum allowed number of states. +const ( + minNumStates = 2 + maxNumStates = 8 +) + +// Error messages. +const ( + keyNumMaxErr = "key number %d greater than max %d" + stateMaxErr = "given state %d greater than max %d" + + // NewMultiStateVector + numStateMinErr = "number of received states %d must be greater than %d" + numStateMaxErr = "number of received states %d greater than max %d" + + // MultiStateVector.Set + setStateErr = "failed to set state of key %d: %+v" + saveSetStateErr = "failed to save MultiStateVector after setting key %d state to %d: %+v" + + // MultiStateVector.SetMany + setManyStateErr = "failed to set state of key %d (%d/%d): %+v" + saveManySetStateErr = "failed to save MultiStateVector after setting keys %d state to %d: %+v" + + // MultiStateVector.set + setGetStateErr = "could not get state of key %d: %+v" + oldStateMaxErr = "stored state %d greater than max %d (this should not happen; maybe a storage failure)" + stateChangeErr = "cannot change state of key %d from %d to %d" + + // LoadMultiStateVector + loadGetMsvErr = "failed to load MultiStateVector from storage: %+v" + loadUnmarshalMsvErr = "failed to unmarshal MultiStateVector loaded from storage: %+v" + + // MultiStateVector.save + saveUnmarshalMsvErr = "failed to marshal MultiStateVector for storage: %+v" + + // MultiStateVector.marshal + buffWriteNumKeysErr = "numKeys: %+v" + buffWriteNumStatesErr = "numStates: %+v" + buffWriteVectLenErr = "length of vector: %+v" + buffWriteVectBlockErr = "vector block %d/%d: %+v" + buffWriteStateCountBlockErr = "state use counter %d: %+v" + + // MultiStateVector.unmarshal + buffReadNumKeysErr = "numKeys: %+v" + buffReadNumStatesErr = "numStates: %+v" + buffReadVectLenErr = "length of vector: %+v" + buffReadVectBlockErr = "vector block %d/%d: %+v" + buffReadStateCountBlockErr = "state use counter %d: %+v" + + // checkStateMap + stateMapLenErr = "state map has %d states when %d are required" + stateMapStateLenErr = "state %d in state map has %d state transitions when %d are required" +) + +// MultiStateVector stores a list of a set number of keys and their state. It +// supports any number of states, unlike the StateVector. It is storage backed. +type MultiStateVector struct { + numKeys uint16 // Total number of keys + numStates uint8 // Number of state per key + bitSize uint8 // Number of state bits per key + numKeysPerBlock uint8 // Number of keys that fit in a block + wordSize uint8 // Size of each word + + // Bitfield for key states + vect []uint64 + + // A list of states and which states they can move to + stateMap [][]bool + + // Stores the number of keys per state + stateUseCount []uint16 + + key string // Unique string used to save/load object from storage + kv *versioned.KV + mux sync.RWMutex +} + +// NewMultiStateVector generates a new MultiStateVector with the specified +// number of keys and bits pr state per keys. The max number of states in 64. +func NewMultiStateVector(numKeys uint16, numStates uint8, stateMap [][]bool, + key string, kv *versioned.KV) (*MultiStateVector, error) { + + // Return an error if the number of states is out of range + if numStates < minNumStates { + return nil, errors.Errorf(numStateMinErr, numStates, minNumStates) + } else if numStates > maxNumStates { + return nil, errors.Errorf(numStateMaxErr, numStates, maxNumStates) + } + + // Return an error if the state map is of the wrong length + err := checkStateMap(numStates, stateMap) + if err != nil { + return nil, err + } + + // Calculate the number of bits needed to represent all the states + numBits := uint8(math.Ceil(math.Log2(float64(numStates)))) + + // Calculate number of 64-bit blocks needed to store bitSize for numKeys + numBlocks := ((numKeys * uint16(numBits)) + 63) / 64 + + msv := &MultiStateVector{ + numKeys: numKeys, + numStates: numStates, + bitSize: numBits, + numKeysPerBlock: 64 / numBits, + wordSize: (64 / numBits) * numBits, + vect: make([]uint64, numBlocks), + stateMap: stateMap, + stateUseCount: make([]uint16, numStates), + key: makeMultiStateVectorKey(key), + kv: kv, + } + + msv.stateUseCount[0] = numKeys + + return msv, msv.save() +} + +// Get returns the state of the key. +func (msv *MultiStateVector) Get(keyNum uint16) (uint8, error) { + msv.mux.RLock() + defer msv.mux.RUnlock() + + return msv.get(keyNum) +} + +// get returns the state of the specified key. This function is not thread-safe. +func (msv *MultiStateVector) get(keyNum uint16) (uint8, error) { + // Return an error if the key number is greater than the number of keys + if keyNum > msv.numKeys-1 { + return 0, errors.Errorf(keyNumMaxErr, keyNum, msv.numKeys-1) + } + + // Calculate block and position of the keyNum + block, pos := msv.getMultiBlockAndPos(keyNum) + + bitMask := uint64(math.MaxUint64) >> (64 - msv.bitSize) + + shift := 64 - pos - uint16(msv.bitSize) + + return uint8((msv.vect[block] >> shift) & bitMask), nil +} + +// Set marks the key with the given state. Returns an error for an invalid state +// or if the state change is not allowed. +func (msv *MultiStateVector) Set(keyNum uint16, state uint8) error { + msv.mux.Lock() + defer msv.mux.Unlock() + + // Set state of the key + err := msv.set(keyNum, state) + if err != nil { + return errors.Errorf(setStateErr, keyNum, err) + } + + // Save changes to storage + err = msv.save() + if err != nil { + return errors.Errorf(saveSetStateErr, keyNum, state, err) + } + + return nil +} + +// SetMany marks each of the keys with the given state. Returns an error for an +// invalid state or if the state change is not allowed. +func (msv *MultiStateVector) SetMany(keyNums []uint16, state uint8) error { + msv.mux.Lock() + defer msv.mux.Unlock() + + // Set state of each key + for i, keyNum := range keyNums { + err := msv.set(keyNum, state) + if err != nil { + return errors.Errorf(setManyStateErr, keyNum, i+1, len(keyNums), err) + } + } + + // Save changes to storage + err := msv.save() + if err != nil { + return errors.Errorf(saveManySetStateErr, keyNums, state, err) + } + + return nil +} + +// set sets the specified key to the specified state. An error is returned if +// the given state is invalid or the state change is not allowed by the state +// map. +func (msv *MultiStateVector) set(keyNum uint16, state uint8) error { + // Return an error if the key number is greater than the number of keys + if keyNum > msv.numKeys-1 { + return errors.Errorf(keyNumMaxErr, keyNum, msv.numKeys-1) + } + + // Return an error if the state is larger than the max states + if state > msv.numStates-1 { + return errors.Errorf(stateMaxErr, state, msv.numStates-1) + } + + // Get the current state + oldState, err := msv.get(keyNum) + if err != nil { + return errors.Errorf(setGetStateErr, keyNum, err) + } + + // Check that the current state is within the allowed states + if oldState > msv.numStates-1 { + return errors.Errorf(oldStateMaxErr, oldState, msv.numStates-1) + } + + // Check that the state change is allowed (only if state map was supplied) + if msv.stateMap != nil && !msv.stateMap[oldState][state] { + return errors.Errorf(stateChangeErr, keyNum, oldState, state) + } + + // Calculate block and position of the key + block, _ := msv.getMultiBlockAndPos(keyNum) + + // Clear the key's state + msv.vect[block] &= ^getSelectionMask(keyNum, msv.bitSize) + + // Set the state + msv.vect[block] |= shiftBitsOut(uint64(state), keyNum, msv.bitSize) + + // Increment/decrement state counters + msv.stateUseCount[oldState]-- + msv.stateUseCount[state]++ + + return nil +} + +// GetNumKeys returns the total number of keys. +func (msv *MultiStateVector) GetNumKeys() uint16 { + msv.mux.RLock() + defer msv.mux.RUnlock() + return msv.numKeys +} + +// GetCount returns the number of keys with the given state. +func (msv *MultiStateVector) GetCount(state uint8) (uint16, error) { + msv.mux.RLock() + defer msv.mux.RUnlock() + + // Return an error if the state is larger than the max states + if int(state) >= len(msv.stateUseCount) { + return 0, errors.Errorf(stateMaxErr, state, msv.numStates-1) + } + + return msv.stateUseCount[state], nil +} + +// GetKeys returns a list of all keys with the specified status. +func (msv *MultiStateVector) GetKeys(state uint8) ([]uint16, error) { + msv.mux.RLock() + defer msv.mux.RUnlock() + + // Return an error if the state is larger than the max states + if state > msv.numStates-1 { + return nil, errors.Errorf(stateMaxErr, state, msv.numStates-1) + } + + // Initialise list with capacity set to number of keys in the state + keys := make([]uint16, 0, msv.stateUseCount[state]) + + // Loop through each key and add any unused to the list + for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ { + keyState, err := msv.get(keyNum) + if err != nil { + return nil, errors.Errorf(keyNumMaxErr, keyNum, err) + } + if keyState == state { + keys = append(keys, keyNum) + } + } + + return keys, nil +} + +// DeepCopy creates a deep copy of the MultiStateVector without a storage +// backend. The deep copy can only be used for functions that do not access +// storage. +func (msv *MultiStateVector) DeepCopy() *MultiStateVector { + msv.mux.RLock() + defer msv.mux.RUnlock() + + // Copy all primitive values into the new MultiStateVector + newMSV := &MultiStateVector{ + numKeys: msv.numKeys, + numStates: msv.numStates, + bitSize: msv.bitSize, + numKeysPerBlock: msv.numKeysPerBlock, + wordSize: msv.wordSize, + vect: make([]uint64, len(msv.vect)), + stateMap: make([][]bool, len(msv.stateMap)), + stateUseCount: make([]uint16, len(msv.stateUseCount)), + key: msv.key, + } + + // Copy over all values in the vector + copy(newMSV.vect, msv.vect) + + // Copy over all values in the state map + if msv.stateMap == nil { + newMSV.stateMap = nil + } else { + for state, stateTransitions := range msv.stateMap { + newMSV.stateMap[state] = make([]bool, len(stateTransitions)) + copy(newMSV.stateMap[state], stateTransitions) + } + } + + // Copy over all values in the state use counter + copy(newMSV.stateUseCount, msv.stateUseCount) + + return newMSV +} + +// getMultiBlockAndPos calculates the block index and the position within that +// block of the key. +func (msv *MultiStateVector) getMultiBlockAndPos(keyNum uint16) ( + block, pos uint16) { + block = keyNum / uint16(msv.numKeysPerBlock) + pos = (keyNum % uint16(msv.numKeysPerBlock)) * uint16(msv.bitSize) + + return block, pos +} + +// checkStateMap checks that the state map has the correct number of states and +// correct number of state transitions per state. Returns an error if any of the +// lengths are incorrect. Returns nil if they are all correct or if the stateMap +// is nil. +func checkStateMap(numStates uint8, stateMap [][]bool) error { + if stateMap == nil { + return nil + } + + // Checks the length of the first dimension of the state map + if len(stateMap) != int(numStates) { + return errors.Errorf(stateMapLenErr, len(stateMap), numStates) + } + + // Checks the length of each transition slice for each state + for i, state := range stateMap { + if len(state) != int(numStates) { + return errors.Errorf(stateMapStateLenErr, i, len(state), numStates) + } + } + + return nil +} + +//////////////////////////////////////////////////////////////////////////////// +// Selection Bit Mask // +//////////////////////////////////////////////////////////////////////////////// + +// getSelectionMask returns the selection bit mask for the given key number and +// state bit size. The masks for each state is found by right shifting +// bitSize * keyNum. +func getSelectionMask(keyNum uint16, bitSize uint8) uint64 { + // Get the mask at the zeroth position at the bit size; these masks look + // like the following for bit sizes 1 through 4 + // 0b1000000000000000000000000000000000000000000000000000000000000000 + // 0b1100000000000000000000000000000000000000000000000000000000000000 + // 0b1110000000000000000000000000000000000000000000000000000000000000 + // 0b1111000000000000000000000000000000000000000000000000000000000000 + initialMask := uint64(math.MaxUint64) << (64 - bitSize) + + // Shift the mask to the keyNum location; for example, a mask of size 3 + // in position 3 would be: + // 0b1110000000000000000000000000000000000000000000000000000000000000 -> + // 0b0000000001110000000000000000000000000000000000000000000000000000 + shiftedMask := shiftBitsIn(initialMask, keyNum, bitSize) + + return shiftedMask +} + +// shiftBitsOut shifts the most significant bits of size bitSize to the left to +// the position of keyNum. +func shiftBitsOut(bits uint64, keyNum uint16, bitSize uint8) uint64 { + // Return original bits when the bit size is zero + if bitSize == 0 { + return bits + } + + // Calculate the number of keys stored in each block; blocks are designed to + // only the states that cleaning fit fully within the block. All reminder + // trailing bits are unused. The below code calculates the number of actual + // keys in a block. + keysPerBlock := 64 / uint16(bitSize) + + // Calculate the index of the key in the local block + keyIndex := keyNum % keysPerBlock + + // The starting bit position of the key in the block + pos := keyIndex * uint16(bitSize) + + // Shift the bits left to account for the size of the key + bitSizeShift := 64 - uint64(bitSize) + leftShift := bits << bitSizeShift + + // Shift the initial mask based upon the bit position of the key + shifted := leftShift >> pos + + return shifted +} + +// shiftBitsIn shifts the least significant bits of size bitSize to the right to +// the position of keyNum. +func shiftBitsIn(bits uint64, keyNum uint16, bitSize uint8) uint64 { + // Return original bits when the bit size is zero + if bitSize == 0 { + return bits + } + + // Calculate the number of keys stored in each block; blocks are designed to + // only the states that cleaning fit fully within the block. All reminder + // trailing bits are unused. The below code calculates the number of actual + // keys in a block. + keysPerBlock := 64 / uint16(bitSize) + + // Calculate the index of the key in the local block + keyIndex := keyNum % keysPerBlock + + // The starting bit position of the key in the block + pos := keyIndex * uint16(bitSize) + + // Shift the initial mask based upon the bit position of the key + shifted := bits >> pos + + return shifted +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// LoadMultiStateVector loads a MultiStateVector with the specified key from the +// given versioned storage. +func LoadMultiStateVector(stateMap [][]bool, key string, kv *versioned.KV) ( + *MultiStateVector, error) { + msv := &MultiStateVector{ + stateMap: stateMap, + key: makeMultiStateVectorKey(key), + kv: kv, + } + + // Load MultiStateVector data from storage + obj, err := kv.Get(msv.key, multiStateVectorVersion) + if err != nil { + return nil, errors.Errorf(loadGetMsvErr, err) + } + + // Unmarshal data + err = msv.unmarshal(obj.Data) + if err != nil { + return nil, errors.Errorf(loadUnmarshalMsvErr, err) + } + + return msv, nil +} + +// save stores the MultiStateVector in storage. +func (msv *MultiStateVector) save() error { + // Marshal the MultiStateVector + data, err := msv.marshal() + if err != nil { + return errors.Errorf(saveUnmarshalMsvErr, err) + } + + // Create the versioned object + obj := &versioned.Object{ + Version: multiStateVectorVersion, + Timestamp: netTime.Now(), + Data: data, + } + + return msv.kv.Set(msv.key, multiStateVectorVersion, obj) +} + +// Delete removes the MultiStateVector from storage. +func (msv *MultiStateVector) Delete() error { + msv.mux.Lock() + defer msv.mux.Unlock() + return msv.kv.Delete(msv.key, multiStateVectorVersion) +} + +// marshal serialises the MultiStateVector into a byte slice. +func (msv *MultiStateVector) marshal() ([]byte, error) { + var buff bytes.Buffer + buff.Grow((4 * 4) + 8 + (8 * len(msv.vect))) + + // Write numKeys to buffer + err := binary.Write(&buff, binary.LittleEndian, msv.numKeys) + if err != nil { + return nil, errors.Errorf(buffWriteNumKeysErr, err) + } + + // Write numStates to buffer + err = binary.Write(&buff, binary.LittleEndian, msv.numStates) + if err != nil { + return nil, errors.Errorf(buffWriteNumStatesErr, err) + } + + // Write length of vect to buffer + err = binary.Write(&buff, binary.LittleEndian, uint64(len(msv.vect))) + if err != nil { + return nil, errors.Errorf(buffWriteVectLenErr, err) + } + + // Write vector to buffer + for i, block := range msv.vect { + err = binary.Write(&buff, binary.LittleEndian, block) + if err != nil { + return nil, errors.Errorf( + buffWriteVectBlockErr, i, len(msv.vect), err) + } + } + + // Write state use counter slice to buffer + for state, count := range msv.stateUseCount { + err = binary.Write(&buff, binary.LittleEndian, count) + if err != nil { + return nil, errors.Errorf(buffWriteStateCountBlockErr, state, err) + } + } + + return buff.Bytes(), nil +} + +// unmarshal deserializes a byte slice into a MultiStateVector. +func (msv *MultiStateVector) unmarshal(b []byte) error { + buff := bytes.NewReader(b) + + // Write numKeys to buffer + err := binary.Read(buff, binary.LittleEndian, &msv.numKeys) + if err != nil { + return errors.Errorf(buffReadNumKeysErr, err) + } + + // Write numStates to buffer + err = binary.Read(buff, binary.LittleEndian, &msv.numStates) + if err != nil { + return errors.Errorf(buffReadNumStatesErr, err) + } + + // Calculate the number of bits needed to represent all the states + msv.bitSize = uint8(math.Ceil(math.Log2(float64(msv.numStates)))) + + // Calculate numbers of keys per block + msv.numKeysPerBlock = 64 / msv.bitSize + + // Calculate the word size + msv.wordSize = msv.numKeysPerBlock * msv.bitSize + + // Write vect to buffer + var vectLen uint64 + err = binary.Read(buff, binary.LittleEndian, &vectLen) + if err != nil { + return errors.Errorf(buffReadVectLenErr, err) + } + + // Create new vector + msv.vect = make([]uint64, vectLen) + + // Write vect to buffer + for i := range msv.vect { + err = binary.Read(buff, binary.LittleEndian, &msv.vect[i]) + if err != nil { + return errors.Errorf(buffReadVectBlockErr, i, vectLen, err) + } + } + + // Create new state use counter + msv.stateUseCount = make([]uint16, msv.numStates) + for i := range msv.stateUseCount { + err = binary.Read(buff, binary.LittleEndian, &msv.stateUseCount[i]) + if err != nil { + return errors.Errorf(buffReadStateCountBlockErr, i, err) + } + } + + return nil +} + +// makeMultiStateVectorKey generates the unique key used to save a +// MultiStateVector to storage. +func makeMultiStateVectorKey(key string) string { + return multiStateVectorKey + key +} diff --git a/storage/utility/multiStateVector_test.go b/storage/utility/multiStateVector_test.go new file mode 100644 index 0000000000000000000000000000000000000000..080e6e3b594cd34b0418e1065cf2db4e0cf14f89 --- /dev/null +++ b/storage/utility/multiStateVector_test.go @@ -0,0 +1,1047 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package utility + +import ( + "bytes" + "encoding/base64" + "fmt" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/primitives/netTime" + "math" + "math/rand" + "reflect" + "strconv" + "strings" + "testing" +) + +// Tests that NewMultiStateVector returns a new MultiStateVector with the +// expected values. +func TestNewMultiStateVector(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := "testKey" + expected := &MultiStateVector{ + numKeys: 189, + numStates: 5, + bitSize: 3, + numKeysPerBlock: 21, + wordSize: 63, + vect: make([]uint64, 9), + stateMap: [][]bool{ + {true, true, true, true, true}, + {true, false, true, true, true}, + {false, false, false, true, true}, + {false, false, false, false, true}, + {false, false, false, false, false}, + }, + stateUseCount: []uint16{189, 0, 0, 0, 0}, + key: makeMultiStateVectorKey(key), + kv: kv, + } + + newMSV, err := NewMultiStateVector(expected.numKeys, expected.numStates, expected.stateMap, key, kv) + if err != nil { + t.Errorf("NewMultiStateVector returned an error: %+v", err) + } + + if !reflect.DeepEqual(expected, newMSV) { + t.Errorf("New MultiStateVector does not match expected."+ + "\nexpected: %+v\nreceived: %+v", expected, newMSV) + } +} + +// Error path: tests that NewMultiStateVector returns the expected error when +// the number of states is too small. +func TestNewMultiStateVector_NumStatesMinError(t *testing.T) { + expectedErr := fmt.Sprintf(numStateMinErr, minNumStates-1, minNumStates) + _, err := NewMultiStateVector(5, minNumStates-1, nil, "", nil) + if err == nil || err.Error() != expectedErr { + t.Errorf("NewMultiStateVector did not return the expected error when "+ + "the number of states is too small.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Error path: tests that NewMultiStateVector returns the expected error when +// the number of states is too large. +func TestNewMultiStateVector_NumStatesMaxError(t *testing.T) { + expectedErr := fmt.Sprintf(numStateMaxErr, maxNumStates+1, maxNumStates) + _, err := NewMultiStateVector(5, maxNumStates+1, nil, "", nil) + if err == nil || err.Error() != expectedErr { + t.Errorf("NewMultiStateVector did not return the expected error when "+ + "the number of states is too large.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Error path: tests that NewMultiStateVector returns the expected error when +// the state map is of the wrong size. +func TestNewMultiStateVector_StateMapError(t *testing.T) { + expectedErr := fmt.Sprintf(stateMapLenErr, 1, 5) + _, err := NewMultiStateVector(5, 5, [][]bool{{true}}, "", nil) + if err == nil || err.Error() != expectedErr { + t.Errorf("NewMultiStateVector did not return the expected error when "+ + "the state map is too small.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that MultiStateVector.Get returns the expected states for all keys in +// two situations. First, it tests that all states are zero on creation. Second, +// random states are generated and manually inserted into the vector and then +// each key is checked for the expected vector +func TestMultiStateVector_Get(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + // Check that all states are zero + for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ { + state, err := msv.Get(keyNum) + if err != nil { + t.Errorf("Get returned an error for key %d: %+v", keyNum, err) + } + if state != 0 { + t.Errorf("Key %d has unexpected state.\nexpected: %d\nreceived: %d", + keyNum, 0, state) + } + } + + // Generate slice of expected states and set vector to expected values. Each + // state is generated randomly. A binary string is built for each block + // using a lookup table (note the LUT can only be used for 3-bit states). + // When the string is 64 characters long, it is converted from a binary + // string to an uint64 and inserted into the vector. + expectedStates := make([]uint8, msv.numKeys) + prng := rand.New(rand.NewSource(42)) + stateLUT := map[uint8]string{0: "000", 1: "001", 2: "010", 3: "011", + 4: "100", 5: "101", 6: "110", 7: "111"} + block, blockString := 0, "" + keysInVect := uint16(int(msv.numKeysPerBlock) * len(msv.vect)) + for keyNum := uint16(0); keyNum < keysInVect; keyNum++ { + if keyNum < msv.numKeys { + state := uint8(prng.Intn(int(msv.numStates))) + blockString += stateLUT[state] + expectedStates[keyNum] = state + } else { + blockString += stateLUT[0] + } + + if (keyNum+1)%uint16(msv.numKeysPerBlock) == 0 { + val, err := strconv.ParseUint(blockString+"0", 2, 64) + if err != nil { + t.Errorf("Failed to parse vector string %d: %+v", block, err) + } + msv.vect[block] = val + block++ + blockString = "" + } + } + + // Check that each key has the expected state + for keyNum, expectedState := range expectedStates { + state, err := msv.Get(uint16(keyNum)) + if err != nil { + t.Errorf("Get returned an error for key %d: %+v", keyNum, err) + } + + if expectedState != state { + t.Errorf("Key %d has unexpected state.\nexpected: %d\nreceived: %d", + keyNum, expectedState, state) + } + } +} + +// Error path: tests that MultiStateVector.get returns the expected error when +// the key number is greater than the max key number. +func TestMultiStateVector_get_KeyNumMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + keyNum := msv.numKeys + expectedErr := fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1) + _, err = msv.get(keyNum) + if err == nil || err.Error() != expectedErr { + t.Errorf("get did not return the expected error when the key number "+ + "is larger than the max key number.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that MultiStateVector.Set sets the correct key to the correct status +// and that the state use counter is set correctly. +func TestMultiStateVector_Set(t *testing.T) { + stateMap := [][]bool{ + {false, true, false, false, false}, + {true, false, true, false, false}, + {false, true, false, true, false}, + {false, false, true, false, true}, + {false, false, false, false, false}, + } + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, stateMap, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + testValues := []struct { + keyNum uint16 + state uint8 + newCount, oldCount uint16 + }{ + {22, 1, 1, msv.numKeys - 1}, + {22, 2, 1, 0}, + {2, 1, 1, msv.numKeys - 2}, + {2, 2, 2, 0}, + {22, 3, 1, 1}, + {16, 1, 1, msv.numKeys - 3}, + {22, 4, 1, 0}, + {16, 2, 2, 0}, + {2, 3, 1, 1}, + {16, 1, 1, 0}, + {2, 4, 2, 0}, + {16, 2, 1, 0}, + {16, 3, 1, 0}, + {16, 4, 3, 0}, + } + + for i, v := range testValues { + oldStatus, _ := msv.Get(v.keyNum) + + if err = msv.Set(v.keyNum, v.state); err != nil { + t.Errorf("Failed to move key %d to state %d (%d): %+v", + v.keyNum, v.state, i, err) + } else if received, _ := msv.Get(v.keyNum); received != v.state { + t.Errorf("Key %d has state %d instead of %d (%d).", + v.keyNum, received, v.state, i) + } + + if msv.stateUseCount[v.state] != v.newCount { + t.Errorf("Count for new state %d is %d instead of %d (%d).", + v.state, msv.stateUseCount[v.state], v.newCount, i) + } + + if msv.stateUseCount[oldStatus] != v.oldCount { + t.Errorf("Count for old state %d is %d instead of %d (%d).", + oldStatus, msv.stateUseCount[oldStatus], v.oldCount, i) + } + } +} + +// Error path: tests that MultiStateVector.Set returns the expected error when +// the key number is greater than the last key number. +func TestMultiStateVector_Set_KeyNumMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + keyNum := msv.numKeys + expectedErr := fmt.Sprintf( + setStateErr, keyNum, fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1)) + err = msv.Set(keyNum, 1) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("Set did not return the expected error when the key number "+ + "is larger than the max key number.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that MultiStateVector.SetMany +func TestMultiStateVector_SetMany(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + testValues := []struct { + keyNums []uint16 + state uint8 + newCount uint16 + }{ + {[]uint16{2, 22, 12, 19}, 1, 4}, + {[]uint16{1, 5, 154, 7, 6}, 4, 5}, + } + + for i, v := range testValues { + + if err = msv.SetMany(v.keyNums, v.state); err != nil { + t.Errorf("Failed to move keys %d to state %d (%d): %+v", + v.keyNums, v.state, i, err) + } else { + for _, keyNum := range v.keyNums { + if received, _ := msv.Get(keyNum); received != v.state { + t.Errorf("Key %d has state %d instead of %d (%d).", + keyNum, received, v.state, i) + } + } + } + + if msv.stateUseCount[v.state] != v.newCount { + t.Errorf("Count for new state %d is %d instead of %d (%d).", + v.state, msv.stateUseCount[v.state], v.newCount, i) + } + } +} + +// Error path: tests that MultiStateVector.SetMany returns the expected error +// when one of the keys is greater than the last key number. +func TestMultiStateVector_SetMany_KeyNumMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + keyNum := msv.numKeys + expectedErr := fmt.Sprintf(setManyStateErr, keyNum, 1, 1, + fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1)) + err = msv.SetMany([]uint16{keyNum}, 1) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Errorf("SetMany did not return the expected error when the key "+ + "number is larger than the max key number."+ + "\nexpected: %s\nreceived: %v", expectedErr, err) + } +} + +// Error path: tests that MultiStateVector.set returns the expected error when +// the key number is greater than the last key number. +func TestMultiStateVector_set_KeyNumMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + keyNum := msv.numKeys + expectedErr := fmt.Sprintf(keyNumMaxErr, keyNum, msv.numKeys-1) + err = msv.set(keyNum, 1) + if err == nil || err.Error() != expectedErr { + t.Errorf("set did not return the expected error when the key number "+ + "is larger than the max key number.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Error path: tests that MultiStateVector.set returns the expected error when +// the given state is greater than the last state. +func TestMultiStateVector_set_NewStateMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + state := msv.numStates + expectedErr := fmt.Sprintf(stateMaxErr, state, msv.numStates-1) + err = msv.set(0, state) + if err == nil || err.Error() != expectedErr { + t.Errorf("set did not return the expected error when the state is "+ + "larger than the max number of states.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Error path: tests that MultiStateVector.set returns the expected error when +// the state read from the vector is greater than the last state. +func TestMultiStateVector_set_OldStateMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + msv.vect[0] = 0b1110000000000000000000000000000000000000000000000000000000000000 + + expectedErr := fmt.Sprintf(oldStateMaxErr, 7, msv.numStates-1) + + err = msv.set(0, 1) + if err == nil || err.Error() != expectedErr { + t.Errorf("set did not return the expected error when the state is "+ + "larger than the max number of states.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Error path: tests that MultiStateVector.set returns the expected error when +// the state change is not allowed by the state map. +func TestMultiStateVector_set_StateChangeError(t *testing.T) { + stateMap := [][]bool{{true, false}, {true, true}} + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 2, stateMap, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + keyNum, state := uint16(0), uint8(1) + expectedErr := fmt.Sprintf(stateChangeErr, keyNum, 0, state) + + err = msv.set(keyNum, state) + if err == nil || err.Error() != expectedErr { + t.Errorf("set did not return the expected error when the state change "+ + "should not be allowed.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that MultiStateVector.GetNumKeys returns the expected number of keys. +func TestMultiStateVector_GetNumKeys(t *testing.T) { + numKeys := uint16(155) + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(numKeys, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + if numKeys != msv.GetNumKeys() { + t.Errorf("Got unexpected number of keys.\nexpected: %d\nreceived: %d", + numKeys, msv.GetNumKeys()) + } +} + +// Tests that MultiStateVector.GetCount returns the correct count for each state +// after each key has been set to a random state. +func TestMultiStateVector_GetCount(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector( + 156, 5, nil, "TestMultiStateVector_GetCount", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + expectedCounts := make([]uint16, msv.numStates) + + prng := rand.New(rand.NewSource(42)) + for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ { + state := uint8(prng.Intn(int(msv.numStates))) + + err = msv.Set(keyNum, state) + if err != nil { + t.Errorf("Failed to set key %d to state %d.", keyNum, state) + } + + expectedCounts[state]++ + } + + for state, expected := range expectedCounts { + count, err := msv.GetCount(uint8(state)) + if err != nil { + t.Errorf("GetCount returned an error for state %d: %+v", state, err) + } + if expected != count { + t.Errorf("Incorrect count for state %d.\nexpected: %d\nreceived: %d", + state, expected, count) + } + } +} + +// Error path: tests that MultiStateVector.GetCount returns the expected error +// when the given state is greater than the last state. +func TestMultiStateVector_GetCount_NewStateMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + state := msv.numStates + expectedErr := fmt.Sprintf(stateMaxErr, state, msv.numStates-1) + _, err = msv.GetCount(state) + if err == nil || err.Error() != expectedErr { + t.Errorf("GetCount did not return the expected error when the state is "+ + "larger than the max number of states.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that MultiStateVector.GetKeys returns the correct list of keys for each +// state after each key has been set to a random state. +func TestMultiStateVector_GetKeys(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector( + 156, 5, nil, "TestMultiStateVector_GetKeys", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + expectedKeys := make([][]uint16, msv.numStates) + + prng := rand.New(rand.NewSource(42)) + for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ { + state := uint8(prng.Intn(int(msv.numStates))) + + err = msv.Set(keyNum, state) + if err != nil { + t.Errorf("Failed to set key %d to state %d.", keyNum, state) + } + + expectedKeys[state] = append(expectedKeys[state], keyNum) + } + + for state, expected := range expectedKeys { + keys, err := msv.GetKeys(uint8(state)) + if err != nil { + t.Errorf("GetKeys returned an error: %+v", err) + } + if !reflect.DeepEqual(expected, keys) { + t.Errorf("Incorrect keys for state %d.\nexpected: %d\nreceived: %d", + state, expected, keys) + } + } +} + +// Error path: tests that MultiStateVector.GetKeys returns the expected error +// when the given state is greater than the last state. +func TestMultiStateVector_GetKeys_NewStateMaxError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + state := msv.numStates + expectedErr := fmt.Sprintf(stateMaxErr, state, msv.numStates-1) + _, err = msv.GetKeys(state) + if err == nil || err.Error() != expectedErr { + t.Errorf("GetKeys did not return the expected error when the state is "+ + "larger than the max number of states.\nexpected: %s\nreceived: %v", + expectedErr, err) + } +} + +// Tests that MultiStateVector.DeepCopy makes a copy of the values and not of +// the pointers. +func TestMultiStateVector_DeepCopy(t *testing.T) { + stateMap := [][]bool{ + {false, true, false, false, false}, + {true, false, true, false, false}, + {false, true, false, true, false}, + {false, false, true, false, true}, + {false, false, false, false, false}, + } + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, stateMap, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + newMSV := msv.DeepCopy() + msv.kv = nil + + // Check that the values are the same + if !reflect.DeepEqual(msv, newMSV) { + t.Errorf("Original and copy do not match."+ + "\nexpected: %+v\nreceived: %+v", msv, newMSV) + } + + // Check that the pointers are different + if msv == newMSV { + t.Errorf("Pointers for original and copy match."+ + "\nexpected: %p\nreceived: %p", msv, newMSV) + } + + if &msv.vect == &newMSV.vect { + t.Errorf("Pointers for original and copy vect match."+ + "\nexpected: %p\nreceived: %p", msv.vect, newMSV.vect) + } + + if &msv.vect[0] == &newMSV.vect[0] { + t.Errorf("Pointers for original and copy vect[0] match."+ + "\nexpected: %p\nreceived: %p", &msv.vect[0], &newMSV.vect[0]) + } + + if &msv.stateMap == &newMSV.stateMap { + t.Errorf("Pointers for original and copy stateMap match."+ + "\nexpected: %p\nreceived: %p", msv.stateMap, newMSV.stateMap) + } + + if &msv.stateMap[0] == &newMSV.stateMap[0] { + t.Errorf("Pointers for original and copy stateMap[0] match."+ + "\nexpected: %p\nreceived: %p", + &msv.stateMap[0], &newMSV.stateMap[0]) + } + + if &msv.stateMap[0][0] == &newMSV.stateMap[0][0] { + t.Errorf("Pointers for original and copy stateMap[0][0] match."+ + "\nexpected: %p\nreceived: %p", + &msv.stateMap[0][0], &newMSV.stateMap[0][0]) + } + + if &msv.stateUseCount == &newMSV.stateUseCount { + t.Errorf("Pointers for original and copy stateUseCount match."+ + "\nexpected: %p\nreceived: %p", + msv.stateUseCount, newMSV.stateUseCount) + } + + if &msv.stateUseCount[0] == &newMSV.stateUseCount[0] { + t.Errorf("Pointers for original and copy stateUseCount[0] match."+ + "\nexpected: %p\nreceived: %p", + &msv.stateUseCount[0], &newMSV.stateUseCount[0]) + } +} + +// Tests that MultiStateVector.DeepCopy is able to make the expected copy when +// the state map is nil. +func TestMultiStateVector_DeepCopy_NilStateMap(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + msv, err := NewMultiStateVector(155, 5, nil, "", kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + newMSV := msv.DeepCopy() + msv.kv = nil + + // Check that the values are the same + if !reflect.DeepEqual(msv, newMSV) { + t.Errorf("Original and copy do not match."+ + "\nexpected: %+v\nreceived: %+v", msv, newMSV) + } +} + +// Tests that getMultiBlockAndPos returns the expected block and position for +// various key numbers, bit sizes, and word sizes. +func Test_getMultiBlockAndPos(t *testing.T) { + testValues := []struct { + keyNum, block, pos uint16 + msv *MultiStateVector + }{ + {0, 0, 0, getTestMSV(150, 2)}, + {1, 0, 1, getTestMSV(150, 2)}, + {24, 0, 24, getTestMSV(150, 2)}, + {64, 1, 0, getTestMSV(150, 2)}, + {0, 0, 0, getTestMSV(150, 5)}, + {1, 0, 3, getTestMSV(150, 5)}, + {9, 0, 27, getTestMSV(150, 5)}, + {21, 1, 0, getTestMSV(150, 5)}, + {0, 0, 0, getTestMSV(150, 255)}, + {1, 0, 8, getTestMSV(150, 255)}, + {2, 0, 16, getTestMSV(150, 255)}, + {8, 1, 0, getTestMSV(150, 255)}, + {0, 0, 0, getTestMSV(150, 127)}, + {8, 0, 56, getTestMSV(150, 127)}, + {9, 1, 0, getTestMSV(150, 127)}, + } + + for i, v := range testValues { + block, pos := v.msv.getMultiBlockAndPos(v.keyNum) + if v.block != block || v.pos != pos { + t.Errorf("Incorrect block/pos for key %2d (%d bits, %2d word size): "+ + "expected %d/%-2d, got %d/%-2d (%d).", v.keyNum, v.msv.bitSize, + v.msv.wordSize, v.block, v.pos, block, pos, i) + } + } +} + +func getTestMSV(numKeys uint16, numStates uint8) *MultiStateVector { + numBits := uint8(math.Ceil(math.Log2(float64(numStates)))) + return &MultiStateVector{ + numKeys: numKeys, + numStates: numStates, + bitSize: numBits, + numKeysPerBlock: 64 / numBits, + wordSize: (64 / numBits) * numBits, + } +} + +// Tests that checkStateMap does not return an error for various state map +// sizes. +func Test_checkStateMap(tt *testing.T) { + const ( + t = true + f = false + ) + testValues := []struct { + numStates uint8 + stateMap [][]bool + }{ + {9, nil}, + {0, [][]bool{}}, + {1, [][]bool{{t}}}, + {2, [][]bool{{t, t}, {f, f}}}, + {3, [][]bool{{t, t, f}, {f, f, t}, {t, t, t}}}, + {4, [][]bool{{t, t, f, f}, {f, f, t, f}, {t, t, t, f}, {f, f, t, t}}}, + } + + for i, v := range testValues { + err := checkStateMap(v.numStates, v.stateMap) + if err != nil { + tt.Errorf("Could not verify state map #%d with %d states: %+v", + i, v.numStates, err) + } + } +} + +// Error path: tests that checkStateMap returns an error for various state map +// sizes and number of states mismatches. +func Test_checkStateMap_Error(tt *testing.T) { + const ( + t = true + f = false + ) + testValues := []struct { + numStates uint8 + stateMap [][]bool + err string + }{ + {1, [][]bool{}, + fmt.Sprintf(stateMapLenErr, 0, 1)}, + {0, [][]bool{{t}}, + fmt.Sprintf(stateMapLenErr, 1, 0)}, + {2, [][]bool{{t, t}, {f, f}, {f, f}}, + fmt.Sprintf(stateMapLenErr, 3, 2)}, + {3, [][]bool{{t, t, f}, {f, f, f, t}, {t, t, t}}, + fmt.Sprintf(stateMapStateLenErr, 1, 4, 3)}, + {4, [][]bool{{t, t, f, f}, {f, f, t, f}, {t, t, t, f}, {f, t, t}}, + fmt.Sprintf(stateMapStateLenErr, 3, 3, 4)}, + } + + for i, v := range testValues { + err := checkStateMap(v.numStates, v.stateMap) + if err == nil || !strings.Contains(err.Error(), v.err) { + tt.Errorf("Verified invalid state map #%d with %d states."+ + "\nexpected: %s\nreceived: %v", + i, v.numStates, v.err, err) + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Selection Bit Mask // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that getSelectionMask returns the expected bit mask for various key +// numbers and bit sizes. +func Test_getSelectionMask(t *testing.T) { + testValues := []struct { + keyNum uint16 + bitSize uint8 + mask uint64 + }{ + {0, 0, 0b0000000000000000000000000000000000000000000000000000000000000000}, + {0, 1, 0b1000000000000000000000000000000000000000000000000000000000000000}, + {1, 0, 0b0000000000000000000000000000000000000000000000000000000000000000}, + {1, 1, 0b0100000000000000000000000000000000000000000000000000000000000000}, + {63, 1, 0b0000000000000000000000000000000000000000000000000000000000000001}, + {64, 1, 0b1000000000000000000000000000000000000000000000000000000000000000}, + {0, 3, 0b1110000000000000000000000000000000000000000000000000000000000000}, + {1, 3, 0b0001110000000000000000000000000000000000000000000000000000000000}, + {20, 3, 0b0000000000000000000000000000000000000000000000000000000000001110}, + {21, 3, 0b1110000000000000000000000000000000000000000000000000000000000000}, + {0, 17, 0b1111111111111111100000000000000000000000000000000000000000000000}, + {1, 17, 0b0000000000000000011111111111111111000000000000000000000000000000}, + {2, 17, 0b0000000000000000000000000000000000111111111111111110000000000000}, + {3, 17, 0b1111111111111111100000000000000000000000000000000000000000000000}, + {0, 33, 0b1111111111111111111111111111111110000000000000000000000000000000}, + {1, 33, 0b1111111111111111111111111111111110000000000000000000000000000000}, + {2, 33, 0b1111111111111111111111111111111110000000000000000000000000000000}, + {3, 33, 0b1111111111111111111111111111111110000000000000000000000000000000}, + {0, 64, 0b1111111111111111111111111111111111111111111111111111111111111111}, + {1, 64, 0b1111111111111111111111111111111111111111111111111111111111111111}, + {2, 64, 0b1111111111111111111111111111111111111111111111111111111111111111}, + {3, 64, 0b1111111111111111111111111111111111111111111111111111111111111111}, + } + + for i, v := range testValues { + mask := getSelectionMask(v.keyNum, v.bitSize) + + if v.mask != mask { + t.Errorf("Unexpected bit mask for key %d of size %d (%d)."+ + "\nexpected: %064b\nreceived: %064b", + v.keyNum, v.bitSize, i, v.mask, mask) + } + } +} + +// Tests that shiftBitsOut returns the expected shifted bits for various key +// numbers and bit sizes. +func Test_shiftBitsOut(t *testing.T) { + testValues := []struct { + keyNum uint16 + bitSize uint8 + bits, expected uint64 + }{ + {0, 0, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b1000000000000000000000000000000000000000000000000000000000000001}, + {0, 1, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b1000000000000000000000000000000000000000000000000000000000000000}, + {1, 0, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b1000000000000000000000000000000000000000000000000000000000000001}, + {1, 2, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b0001000000000000000000000000000000000000000000000000000000000000}, + {1, 32, + 0b0111111111111111111111111111111111111111111111111111111111111110, + 0b0000000000000000000000000000000011111111111111111111111111111110}, + {13, 7, + 0b0111111111111111111111111111111111111111111111111111111111111110, + 0b0000000000000000000000000000111111000000000000000000000000000000}, + } + + for i, v := range testValues { + bits := shiftBitsOut(v.bits, v.keyNum, v.bitSize) + + if v.expected != bits { + t.Errorf("Unexpected shifted mask for key %d of size %d (%d)."+ + "\nexpected: %064b\nreceived: %064b", + v.keyNum, v.bitSize, i, v.expected, bits) + } + } +} + +// Tests that shiftBitsIn returns the expected shifted bits for various key +// numbers and bit sizes. +func Test_shiftBitsIn(t *testing.T) { + testValues := []struct { + keyNum uint16 + bitSize uint8 + bits, expected uint64 + }{ + {0, 0, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b1000000000000000000000000000000000000000000000000000000000000001}, + {0, 1, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b1000000000000000000000000000000000000000000000000000000000000001}, + {1, 0, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b1000000000000000000000000000000000000000000000000000000000000001}, + {1, 2, + 0b1000000000000000000000000000000000000000000000000000000000000001, + 0b0010000000000000000000000000000000000000000000000000000000000000}, + {1, 32, + 0b0111111111111111111111111111111111111111111111111111111111111110, + 0b0000000000000000000000000000000001111111111111111111111111111111}, + {13, 7, + 0b0111111111111111111111111111111111111111111111111111111111111110, + 0b0000000000000000000000000000011111111111111111111111111111111111}, + } + + for i, v := range testValues { + bits := shiftBitsIn(v.bits, v.keyNum, v.bitSize) + + if v.expected != bits { + t.Errorf("Unexpected shifted mask for key %d of size %d (%d)."+ + "\nexpected: %064b\nreceived: %064b", + v.keyNum, v.bitSize, i, v.expected, bits) + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Storage Functions // +//////////////////////////////////////////////////////////////////////////////// + +// Tests that a MultiStateVector loaded from storage via LoadMultiStateVector +// matches the original. +func TestLoadMultiStateVector(t *testing.T) { + key := "TestLoadMultiStateVector" + msv, kv := newTestFilledMSV(156, 5, nil, "TestLoadMultiStateVector", t) + + // Attempt to load MultiStateVector from storage + loadedSV, err := LoadMultiStateVector(nil, key, kv) + if err != nil { + t.Fatalf("LoadMultiStateVector returned an error: %+v", err) + } + + if !reflect.DeepEqual(msv, loadedSV) { + t.Errorf("Loaded MultiStateVector does not match original saved."+ + "\nexpected: %+v\nreceived: %+v", msv, loadedSV) + } +} + +// Error path: tests that LoadMultiStateVector returns the expected error when +// no object is saved in storage. +func TestLoadMultiStateVector_GetFromStorageError(t *testing.T) { + key := "TestLoadMultiStateVector_GetFromStorageError" + kv := versioned.NewKV(make(ekv.Memstore)) + expectedErr := strings.Split(loadGetMsvErr, "%")[0] + + _, err := LoadMultiStateVector(nil, key, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("LoadMultiStateVector did not return the expected error "+ + "when no object exists in storage.\nexpected: %s\nreceived: %+v", + expectedErr, err) + } +} + +// Error path: tests that LoadMultiStateVector returns the expected error when +// the data in storage cannot be unmarshalled. +func TestLoadMultiStateVector_UnmarshalError(t *testing.T) { + kv := versioned.NewKV(make(ekv.Memstore)) + key := "TestLoadMultiStateVector_MarshalError" + expectedErr := strings.Split(loadUnmarshalMsvErr, "%")[0] + + // Save invalid data to storage + err := kv.Set(makeMultiStateVectorKey(key), multiStateVectorVersion, + &versioned.Object{ + Version: multiStateVectorVersion, + Timestamp: netTime.Now(), + Data: []byte("?"), + }) + if err != nil { + t.Errorf("Failed to save data to storage: %+v", err) + } + + _, err = LoadMultiStateVector(nil, key, kv) + if err == nil || !strings.Contains(err.Error(), expectedErr) { + t.Fatalf("LoadMultiStateVector did not return the expected error "+ + "when the object in storage should be invalid."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } +} + +// Tests that the data saved via MultiStateVector.save to storage matches the +// expected data. +func TestMultiStateVector_save(t *testing.T) { + msv := &MultiStateVector{ + numKeys: 3 * 64, + numStates: 5, + bitSize: 3, + numKeysPerBlock: 21, + wordSize: 63, + vect: []uint64{0, 1, 2}, + stateUseCount: []uint16{5, 12, 104, 0, 4000}, + key: makeStateVectorKey("TestMultiStateVector_save"), + kv: versioned.NewKV(make(ekv.Memstore)), + } + + expectedData, err := msv.marshal() + if err != nil { + t.Errorf("Failed to marshal MultiStateVector: %+v", err) + } + + // Save to storage + err = msv.save() + if err != nil { + t.Errorf("save returned an error: %+v", err) + } + + // Check that the object can be loaded + loadedData, err := msv.kv.Get(msv.key, multiStateVectorVersion) + if err != nil { + t.Errorf("Failed to load MultiStateVector from storage: %+v", err) + } + + if !bytes.Equal(expectedData, loadedData.Data) { + t.Errorf("Loaded data does not match expected."+ + "\nexpected: %v\nreceived: %v", expectedData, loadedData) + } +} + +// Tests that MultiStateVector.Delete removes the MultiStateVector from storage. +func TestMultiStateVector_Delete(t *testing.T) { + msv, _ := newTestFilledMSV(156, 5, nil, "TestMultiStateVector_Delete", t) + + err := msv.Delete() + if err != nil { + t.Errorf("Delete returned an error: %+v", err) + } + + // Check that the object can be loaded + loadedData, err := msv.kv.Get(msv.key, multiStateVectorVersion) + if err == nil { + t.Errorf("Loaded MultiStateVector from storage when it should be "+ + "deleted: %v", loadedData) + } +} + +// Tests that a MultiStateVector marshalled with MultiStateVector.marshal and +// unmarshalled with MultiStateVector.unmarshal matches the original. +func TestMultiStateVector_marshal_unmarshal(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + msv := &MultiStateVector{ + numKeys: 3 * 64, + numStates: 5, + bitSize: 3, + numKeysPerBlock: 21, + wordSize: 63, + vect: []uint64{prng.Uint64(), prng.Uint64(), prng.Uint64()}, + stateUseCount: []uint16{5, 12, 104, 0, 4000}, + } + + marshalledBytes, err := msv.marshal() + if err != nil { + t.Errorf("marshal returned an error: %+v", err) + } + + newMsv := &MultiStateVector{} + err = newMsv.unmarshal(marshalledBytes) + if err != nil { + t.Errorf("unmarshal returned an error: %+v", err) + } + + if !reflect.DeepEqual(&msv, &newMsv) { + t.Errorf("Marshalled and unmarsalled MultiStateVector does not match "+ + "the original.\nexpected: %+v\nreceived: %+v", msv, newMsv) + } +} + +// Consistency test of makeMultiStateVectorKey. +func Test_makeMultiStateVectorKey(t *testing.T) { + prng := rand.New(rand.NewSource(42)) + expectedStrings := []string{ + multiStateVectorKey + "U4x/lrFkvxuXu59LtHLon1sU", + multiStateVectorKey + "hPJSCcnZND6SugndnVLf15tN", + multiStateVectorKey + "dkKbYXoMn58NO6VbDMDWFEyI", + multiStateVectorKey + "hTWEGsvgcJsHWAg/YdN1vAK0", + multiStateVectorKey + "HfT5GSnhj9qeb4LlTnSOgeee", + multiStateVectorKey + "S71v40zcuoQ+6NY+jE/+HOvq", + multiStateVectorKey + "VG2PrBPdGqwEzi6ih3xVec+i", + multiStateVectorKey + "x44bC6+uiBuCp1EQikLtPJA8", + multiStateVectorKey + "qkNGWnhiBhaXiu0M48bE8657", + multiStateVectorKey + "w+BJW1cS/v2+DBAoh+EA2s0t", + } + + for i, expected := range expectedStrings { + b := make([]byte, 18) + prng.Read(b) + key := makeMultiStateVectorKey(base64.StdEncoding.EncodeToString(b)) + + if expected != key { + t.Errorf("New MultiStateVector key does not match expected (%d)."+ + "\nexpected: %q\nreceived: %q", i, expected, key) + } + } +} + +// newTestFilledMSV produces a new MultiStateVector and sets each key to a +// random state. +func newTestFilledMSV(numKeys uint16, numStates uint8, stateMap [][]bool, + key string, t *testing.T) (*MultiStateVector, *versioned.KV) { + kv := versioned.NewKV(make(ekv.Memstore)) + + msv, err := NewMultiStateVector(numKeys, numStates, stateMap, key, kv) + if err != nil { + t.Errorf("Failed to create new MultiStateVector: %+v", err) + } + + prng := rand.New(rand.NewSource(42)) + for keyNum := uint16(0); keyNum < msv.numKeys; keyNum++ { + state := uint8(prng.Intn(int(msv.numStates))) + + err = msv.Set(keyNum, state) + if err != nil { + t.Errorf("Failed to set key %d to state %d.", keyNum, state) + } + } + + return msv, kv +}