diff --git a/cmd/poll_test.go b/cmd/poll_test.go index 67353d019719e50519dc48c62b455fd699c7d7ff..757185d24300c53f8dd8024e2faac414ee9fa9e2 100644 --- a/cmd/poll_test.go +++ b/cmd/poll_test.go @@ -865,7 +865,7 @@ func TestVerifyError(t *testing.T) { _ = nsm.AddNode(errNodeId, "", "", "") n := nsm.GetNode(errNodeId) rsm := round.NewStateMap() - s, _ := rsm.AddRound(id.Round(0), 4, 5*time.Minute, connect.NewCircuit([]*id.ID{errNodeId}), nil) + s, _ := rsm.AddRound(id.Round(0), 4, 5*time.Minute, connect.NewCircuit([]*id.ID{errNodeId})) _ = n.SetRound(s) err = verifyError(msg, n, impl) diff --git a/scheduling/nodeStateChange.go b/scheduling/nodeStateChange.go index 22092bfa1cd406e41851d3768ffcef8fce5fe71d..dfb0f91204e1a125f52aeb21ab2fbdfd790e0a9f 100644 --- a/scheduling/nodeStateChange.go +++ b/scheduling/nodeStateChange.go @@ -125,8 +125,10 @@ func HandleNodeUpdates(update node.UpdateNotification, pool *waitingPool, return false, errors.Errorf("Node %s without round should "+ "not be moving to the %s state", update.Node, states.REALTIME) } - stateComplete := r.NodeIsReadyForTransition() - if stateComplete { + // REALTIME does not use the state complete handler because it + // increments on the first report, not when every node reports in + // order to avoid distributed synchronicity issues + if r.GetRoundState() != states.REALTIME { err := r.Update(states.REALTIME, time.Now()) if err != nil { return false, errors.WithMessagef(err, diff --git a/scheduling/nodeStateChange_test.go b/scheduling/nodeStateChange_test.go index c2fdadc16f10ce2ebc2f1354e8d9a9a14a65ce14..255e5ccfbe57796162b02d32eab827d773c61a8a 100644 --- a/scheduling/nodeStateChange_test.go +++ b/scheduling/nodeStateChange_test.go @@ -101,7 +101,7 @@ func TestHandleNodeStateChance_Standby(t *testing.T) { roundID := testState.GetRoundID() - roundState, err := testState.GetRoundMap().AddRound(roundID, testParams.BatchSize, 5*time.Minute, circuit, nil) + roundState, err := testState.GetRoundMap().AddRound(roundID, testParams.BatchSize, 5*time.Minute, circuit) if err != nil { t.Errorf("Failed to add round: %v", err) } @@ -245,7 +245,7 @@ func TestHandleNodeUpdates_Completed(t *testing.T) { roundID := testState.GetRoundID() - roundState, err := testState.GetRoundMap().AddRound(roundID, testParams.BatchSize, 5*time.Minute, circuit, nil) + roundState, err := testState.GetRoundMap().AddRound(roundID, testParams.BatchSize, 5*time.Minute, circuit) if err != nil { t.Errorf("Failed to add round: %v", err) } diff --git a/scheduling/startRound_test.go b/scheduling/startRound_test.go index fb9ad06e0b682bd831a0a808e2777b608eaf791d..f63323b3ad3093804a7d0bd722b76e302500208e 100644 --- a/scheduling/startRound_test.go +++ b/scheduling/startRound_test.go @@ -66,9 +66,7 @@ func TestStartRound(t *testing.T) { t.Errorf("Happy path of createSimpleRound failed: %v", err) } - errorChan := make(chan error, 1) - - err = startRound(testProtoRound, testState, errorChan) + err = startRound(testProtoRound, testState) if err != nil { t.Errorf("Received error from startRound(): %v", err) } @@ -198,9 +196,7 @@ func TestStartRound_BadNode(t *testing.T) { // Manually set the round of a node testProtoRound.NodeStateList[0].SetRound(badState) - errorChan := make(chan error, 1) - - err = startRound(testProtoRound, testState, errorChan) + err = startRound(testProtoRound, testState) if err == nil { t.Log(err) t.Errorf("Expected error. Artificially created round " + diff --git a/storage/node/state.go b/storage/node/state.go index 8184c7d9e2c8a9522cc2588a7acf7b906d9026f0..206649f095b359af866d5fa7cfc032708f66f741 100644 --- a/storage/node/state.go +++ b/storage/node/state.go @@ -148,13 +148,14 @@ func (n *State) Update(newActivity current.Activity) (bool, UpdateNotification, newActivity) } - if n.currentRound.GetRoundState() != transition.Node.RequiredRoundState(newActivity) { + if !transition.Node.IsValidRoundState(newActivity, n.currentRound.GetRoundState()) { + return false, UpdateNotification{}, errors.Errorf("Node update from %s to %s failed, "+ "requires the Node's be assigned a round to be in the "+ "correct state; Assigned: %s, Expected: %s", oldActivity, newActivity, n.currentRound.GetRoundState(), - transition.Node.RequiredRoundState(oldActivity)) + transition.Node.GetValidRoundStateStrings(newActivity)) } } diff --git a/storage/node/state_test.go b/storage/node/state_test.go index 799d56d6ecf72ae69f08ac8d30fa3ed7c110b5bf..659c110148b9573b93c266f6cfd0bb289cb8cf90 100644 --- a/storage/node/state_test.go +++ b/storage/node/state_test.go @@ -219,7 +219,7 @@ func TestNodeState_Update_Valid_RequiresRound_Round_ValidState(t *testing.T) { updated, old, err := ns.Update(current.PRECOMPUTING) if err != nil { - t.Errorf("Node state update returned no error on valid state change: %s", err) + t.Errorf("Node state update returned error on valid state change: %s", err) } timeDelta := ns.lastPoll.Sub(before) diff --git a/storage/round/map_test.go b/storage/round/map_test.go index 25016681323a1cc220ff54d366a95bf740f0be92..74191738f6d39ea6a6533e1c1a8b9fba5c3277e2 100644 --- a/storage/round/map_test.go +++ b/storage/round/map_test.go @@ -34,7 +34,7 @@ func TestStateMap_AddRound_Happy(t *testing.T) { const numNodes = 5 - rRtn, err := sm.AddRound(rid, 32, 5*time.Minute, buildMockTopology(numNodes, t), nil) + rRtn, err := sm.AddRound(rid, 32, 5*time.Minute, buildMockTopology(numNodes, t)) if err != nil { t.Errorf("Error returned on valid addition of node: %s", err) @@ -69,7 +69,7 @@ func TestStateMap_AddNode_Invalid(t *testing.T) { sm.rounds[rid] = &State{state: states.FAILED} - rRtn, err := sm.AddRound(rid, 32, 5*time.Minute, buildMockTopology(numNodes, t), nil) + rRtn, err := sm.AddRound(rid, 32, 5*time.Minute, buildMockTopology(numNodes, t)) if err == nil { t.Errorf("Error not returned on invalid addition of node: %s", err) diff --git a/transition/ControlState.go b/transition/ControlState.go index 5c70b4982f7490a690859384ce3a16d885abb788..92722bdb36c5e660201ccb3dd6943785a08c6cc5 100644 --- a/transition/ControlState.go +++ b/transition/ControlState.go @@ -6,6 +6,7 @@ package transition import ( + "fmt" "gitlab.com/elixxir/primitives/current" "gitlab.com/elixxir/primitives/states" "math" @@ -29,13 +30,13 @@ type Transitions [current.NUM_STATES]transitionValidation // on state transitions func newTransitions() Transitions { t := Transitions{} - t[current.NOT_STARTED] = NewTransitionValidation(No, nilRoundState) - t[current.WAITING] = NewTransitionValidation(No, nilRoundState, current.NOT_STARTED, current.COMPLETED, current.ERROR) - t[current.PRECOMPUTING] = NewTransitionValidation(Yes, states.PRECOMPUTING, current.WAITING) - t[current.STANDBY] = NewTransitionValidation(Yes, states.PRECOMPUTING, current.PRECOMPUTING) - t[current.REALTIME] = NewTransitionValidation(Yes, states.QUEUED, current.STANDBY) - t[current.COMPLETED] = NewTransitionValidation(Yes, states.REALTIME, current.REALTIME) - t[current.ERROR] = NewTransitionValidation(Maybe, nilRoundState, current.NOT_STARTED, + t[current.NOT_STARTED] = NewTransitionValidation(No, nil) + t[current.WAITING] = NewTransitionValidation(No, nil, current.NOT_STARTED, current.COMPLETED, current.ERROR) + t[current.PRECOMPUTING] = NewTransitionValidation(Yes, []states.Round{states.PRECOMPUTING}, current.WAITING) + t[current.STANDBY] = NewTransitionValidation(Yes, []states.Round{states.PRECOMPUTING}, current.PRECOMPUTING) + t[current.REALTIME] = NewTransitionValidation(Yes, []states.Round{states.QUEUED, states.REALTIME}, current.STANDBY) + t[current.COMPLETED] = NewTransitionValidation(Yes, []states.Round{states.REALTIME}, current.REALTIME) + t[current.ERROR] = NewTransitionValidation(Maybe, nil, current.NOT_STARTED, current.WAITING, current.PRECOMPUTING, current.STANDBY, current.REALTIME, current.COMPLETED) @@ -54,22 +55,43 @@ func (t Transitions) NeedsRound(to current.Activity) int { return t[to].needsRound } -// RequiredRoundState looks up the required round needed prior to transition -func (t Transitions) RequiredRoundState(to current.Activity) states.Round { - return t[to].roundState +// Checks if the passed round state is valid for the transition +func (t Transitions) IsValidRoundState(to current.Activity, st states.Round) bool { + for _, s := range t[to].roundState { + if st == s { + return true + } + } + return false +} + +// returns a string describing valid transitions for error messages +func (t Transitions) GetValidRoundStateStrings(to current.Activity) string { + if len(t[to].roundState) == 0 { + return "NO VALID TRANSITIONS" + } + + var rtnStr string + + for i := 0; i < len(t[to].roundState)-1; i++ { + rtnStr = fmt.Sprintf("%s, ", t[to].roundState[i]) + } + + rtnStr += fmt.Sprintf("%s", t[to].roundState[len(t[to].roundState)-1]) + return rtnStr } // Transitional information used for each state type transitionValidation struct { from [current.NUM_STATES]bool needsRound int - roundState states.Round + roundState []states.Round } // NewTransitionValidation sets the from attribute, // denoting whether going from that to the objects current state // is valid -func NewTransitionValidation(needsRound int, roundState states.Round, from ...current.Activity) transitionValidation { +func NewTransitionValidation(needsRound int, roundState []states.Round, from ...current.Activity) transitionValidation { tv := transitionValidation{} tv.needsRound = needsRound diff --git a/transition/controlState_test.go b/transition/controlState_test.go index 034fba4058e5047ecb807ffd94c503fe5b841094..0f013ed103afe8aa371a4e9005fb3c21204d42ea 100644 --- a/transition/controlState_test.go +++ b/transition/controlState_test.go @@ -65,11 +65,13 @@ func TestTransitions_NeedsRound(t *testing.T) { func TestTransitions_RequiredRoundState(t *testing.T) { testTransition := newTransitions() - expectedRoundState := []states.Round{nilRoundState, nilRoundState, 1, 1, 3, 4, nilRoundState} - receivedRoundState := make([]states.Round, len(expectedRoundState)) + input := []states.Round{states.PENDING, states.PENDING, states.PRECOMPUTING, + states.PRECOMPUTING, states.QUEUED, states.REALTIME, states.PENDING} + receivedRoundState := make([]bool, len(input)) + expectedRoundState := []bool{false, false, true, true, true, true, false} for i := uint32(0); i < uint32(current.CRASH); i++ { - receivedRoundState[i] = testTransition.RequiredRoundState(current.Activity(i)) + receivedRoundState[i] = testTransition.IsValidRoundState(current.Activity(i), input[i]) } if !reflect.DeepEqual(expectedRoundState, receivedRoundState) {