diff --git a/api/client.go b/api/client.go index 118615bdcb7d4dc23b6a7fde682a3fdc224b4a6a..73da1de10e6b04a0f74d5e63cc885c391255612d 100644 --- a/api/client.go +++ b/api/client.go @@ -8,10 +8,6 @@ package api import ( - "gitlab.com/xx_network/comms/connect" - "gitlab.com/xx_network/primitives/id" - "time" - "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/auth" @@ -28,12 +24,17 @@ import ( "gitlab.com/elixxir/crypto/cyclic" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/elixxir/primitives/version" + "gitlab.com/xx_network/comms/connect" "gitlab.com/xx_network/crypto/csprng" "gitlab.com/xx_network/crypto/large" "gitlab.com/xx_network/crypto/signature/rsa" + "gitlab.com/xx_network/primitives/id" "gitlab.com/xx_network/primitives/ndf" + "time" ) +const followerStoppableName = "client" + type Client struct { //generic RNG for client rng *fastRNG.StreamGenerator @@ -181,7 +182,7 @@ func OpenClient(storageDir string, password []byte, parameters params.Network) ( rng: rngStreamGen, comms: nil, network: nil, - runner: stoppable.NewMulti("client"), + runner: stoppable.NewMulti(followerStoppableName), status: newStatusTracker(), parameters: parameters, } @@ -348,6 +349,7 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { } // ----- Client Functions ----- + // StartNetworkFollower kicks off the tracking of the network. It starts // long running network client threads and returns an object for checking // state and stopping those threads. @@ -378,7 +380,7 @@ func (c *Client) initPermissioning(def *ndf.NetworkDefinition) error { // Responds to confirmations of successful rekey operations // - Auth Callback (/auth/callback.go) // Handles both auth confirm and requests -func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { +func (c *Client) StartNetworkFollower(timeout time.Duration) (<-chan interfaces.ClientError, error) { u := c.GetUser() jww.INFO.Printf("StartNetworkFollower() \n\tTransmisstionID: %s "+ "\n\tReceptionID: %s", u.TransmissionID, u.ReceptionID) @@ -397,12 +399,21 @@ func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { } } - err := c.status.toStarting() + // Wait for any threads from the previous follower to close and then create + // a new stoppable + err := stoppable.WaitForStopped(c.runner, timeout) + if err != nil { + return nil, err + } else { + c.runner = stoppable.NewMulti(followerStoppableName) + } + + err = c.status.toStarting() if err != nil { return nil, errors.WithMessage(err, "Failed to Start the Network Follower") } - stopAuth := c.auth.StartProcessies() + stopAuth := c.auth.StartProcesses() c.runner.Add(stopAuth) stopFollow, err := c.network.Follow(cer) @@ -429,19 +440,18 @@ func (c *Client) StartNetworkFollower() (<-chan interfaces.ClientError, error) { // fails to stop it. // if the network follower is running and this fails, the client object will // most likely be in an unrecoverable state and need to be trashed. -func (c *Client) StopNetworkFollower(timeout time.Duration) error { +func (c *Client) StopNetworkFollower() error { err := c.status.toStopping() if err != nil { return errors.WithMessage(err, "Failed to Stop the Network Follower") } - err = c.runner.Close(timeout) - c.runner = stoppable.NewMulti("client") + err = c.runner.Close() err2 := c.status.toStopped() if err2 != nil { - if err ==nil{ + if err == nil { err = err2 - }else{ - err = errors.WithMessage(err,err2.Error()) + } else { + err = errors.WithMessage(err, err2.Error()) } } return err diff --git a/api/results.go b/api/results.go index 590634b9646f2b92f3b56e29d422145806807d66..d87c538c38ecb9f8baebe489815cca148870d616 100644 --- a/api/results.go +++ b/api/results.go @@ -12,7 +12,6 @@ import ( jww "github.com/spf13/jwalterweatherman" pb "gitlab.com/elixxir/comms/mixmessages" - "gitlab.com/elixxir/comms/network" ds "gitlab.com/elixxir/comms/network/dataStructures" "gitlab.com/elixxir/primitives/states" "gitlab.com/xx_network/comms/connect" @@ -131,7 +130,7 @@ func (c *Client) getRoundResults(roundList []id.Round, timeout time.Duration, // Find out what happened to old (historical) rounds if any are needed if len(historicalRequest.Rounds) > 0 { - go c.getHistoricalRounds(historicalRequest, networkInstance, sendResults, commsInterface) + go c.getHistoricalRounds(historicalRequest, sendResults, commsInterface) } // Determine the results of all rounds requested @@ -180,7 +179,7 @@ func (c *Client) getRoundResults(roundList []id.Round, timeout time.Duration, // Helper function which asynchronously pings a random gateway until // it gets information on it's requested historical rounds func (c *Client) getHistoricalRounds(msg *pb.HistoricalRounds, - instance *network.Instance, sendResults chan ds.EventReturn, comms historicalRoundsComm) { + sendResults chan ds.EventReturn, comms historicalRoundsComm) { var resp *pb.HistoricalRoundsResponse @@ -189,7 +188,7 @@ func (c *Client) getHistoricalRounds(msg *pb.HistoricalRounds, // Find a gateway to request about the roundRequests result, err := c.GetNetworkInterface().GetSender().SendToAny(func(host *connect.Host) (interface{}, error) { return comms.RequestHistoricalRounds(host, msg) - }) + }, nil) // If an error, retry with (potentially) a different gw host. // If no error from received gateway request, exit loop diff --git a/api/send.go b/api/send.go index 19d6a54b853f911e47af7d2d9bcafa46e2b7bd4d..820261023bf0f8f3a0caff47d26857df9fd1f807 100644 --- a/api/send.go +++ b/api/send.go @@ -27,7 +27,7 @@ func (c *Client) SendE2E(m message.Send, param params.E2E) ([]id.Round, e2e.MessageID, error) { jww.INFO.Printf("SendE2E(%s, %d. %v)", m.Recipient, m.MessageType, m.Payload) - return c.network.SendE2E(m, param) + return c.network.SendE2E(m, param, nil) } // SendUnsafe sends an unencrypted payload to the provided recipient diff --git a/api/utilsInterfaces_test.go b/api/utilsInterfaces_test.go index 259349681d4c3466fcf5af3bddc84be63215d076..b603066be89d019f2feb0643552c17b4a4f8b6a0 100644 --- a/api/utilsInterfaces_test.go +++ b/api/utilsInterfaces_test.go @@ -93,7 +93,7 @@ func (t *testNetworkManagerGeneric) Follow(report interfaces.ClientErrorReport) func (t *testNetworkManagerGeneric) CheckGarbledMessages() { return } -func (t *testNetworkManagerGeneric) SendE2E(m message.Send, p params.E2E) ( +func (t *testNetworkManagerGeneric) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} return rounds, cE2e.MessageID{}, nil diff --git a/auth/callback.go b/auth/callback.go index 7ece838249e97d3fadbb7fb6f6bcae46e059b458..3d175c0623d6c8d982c944c8ee2ad990542bdb43 100644 --- a/auth/callback.go +++ b/auth/callback.go @@ -23,20 +23,21 @@ import ( "strings" ) -func (m *Manager) StartProcessies() stoppable.Stoppable { - +func (m *Manager) StartProcesses() stoppable.Stoppable { stop := stoppable.NewSingle("Auth") go func() { for { select { case <-stop.Quit(): + stop.ToStopped() return case msg := <-m.rawMessages: m.processAuthMessage(msg) } } }() + return stop } diff --git a/bindings/client.go b/bindings/client.go index 33ef281983799bab25a50663a71b151749ba3535..aaaf1ac8faccce2855b491a3d5f04764e7d0cc83 100644 --- a/bindings/client.go +++ b/bindings/client.go @@ -198,8 +198,9 @@ func UnmarshalSendReport(b []byte) (*SendReport, error) { // Responds to sent rekeys and executes them // - KeyExchange Confirm (/keyExchange/confirm.go) // Responds to confirmations of successful rekey operations -func (c *Client) StartNetworkFollower(clientError ClientError) error { - errChan, err := c.api.StartNetworkFollower() +func (c *Client) StartNetworkFollower(clientError ClientError, timeoutMS int) error { + timeout := time.Duration(timeoutMS) * time.Millisecond + errChan, err := c.api.StartNetworkFollower(timeout) if err != nil { return errors.New(fmt.Sprintf("Failed to start the "+ "network follower: %+v", err)) @@ -218,9 +219,8 @@ func (c *Client) StartNetworkFollower(clientError ClientError) error { // fails to stop it. // if the network follower is running and this fails, the client object will // most likely be in an unrecoverable state and need to be trashed. -func (c *Client) StopNetworkFollower(timeoutMS int) error { - timeout := time.Duration(timeoutMS) * time.Millisecond - if err := c.api.StopNetworkFollower(timeout); err != nil { +func (c *Client) StopNetworkFollower() error { + if err := c.api.StopNetworkFollower(); err != nil { return errors.New(fmt.Sprintf("Failed to stop the "+ "network follower: %+v", err)) } diff --git a/cmd/root.go b/cmd/root.go index 608dc55760698889f6c8c23310143d7cf5d6fe43..a7161259a2898c1ab9383c7febbd24e221820719 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -115,7 +115,7 @@ var rootCmd = &cobra.Command{ }) } - _, err := client.StartNetworkFollower() + _, err := client.StartNetworkFollower(5 * time.Second) if err != nil { jww.FATAL.Panicf("%+v", err) } @@ -265,7 +265,7 @@ var rootCmd = &cobra.Command{ } fmt.Printf("Received %d\n", receiveCnt) - err = client.StopNetworkFollower(5 * time.Second) + err = client.StopNetworkFollower() if err != nil { jww.WARN.Printf( "Failed to cleanly close threads: %+v\n", diff --git a/cmd/single.go b/cmd/single.go index 15f803b11bec0f775845fdfdfb9f8292c75ea2a7..b827627fa2af0be3f5ec0fa338170f7af6a36493 100644 --- a/cmd/single.go +++ b/cmd/single.go @@ -62,7 +62,7 @@ var singleCmd = &cobra.Command{ }) } - _, err := client.StartNetworkFollower() + _, err := client.StartNetworkFollower(5 * time.Second) if err != nil { jww.FATAL.Panicf("%+v", err) } diff --git a/cmd/ud.go b/cmd/ud.go index 38cbecb2cd1c6572cd140b158f435fe4291500cd..2a3b3d31927ee62522444c179c7dab50c859f4cd 100644 --- a/cmd/ud.go +++ b/cmd/ud.go @@ -62,7 +62,7 @@ var udCmd = &cobra.Command{ }) } - _, err := client.StartNetworkFollower() + _, err := client.StartNetworkFollower(50 * time.Millisecond) if err != nil { jww.FATAL.Panicf("%+v", err) } @@ -176,7 +176,7 @@ var udCmd = &cobra.Command{ } if len(facts) == 0 { - err = client.StopNetworkFollower(10 * time.Second) + err = client.StopNetworkFollower() if err != nil { jww.WARN.Print(err) } @@ -196,7 +196,7 @@ var udCmd = &cobra.Command{ jww.FATAL.Panicf("%+v", err) } time.Sleep(91 * time.Second) - err = client.StopNetworkFollower(90 * time.Second) + err = client.StopNetworkFollower() if err != nil { jww.WARN.Print(err) } diff --git a/interfaces/IsRunning.go b/interfaces/IsRunning.go deleted file mode 100644 index 950b842864712858b3d1dcf769f4765e5af86159..0000000000000000000000000000000000000000 --- a/interfaces/IsRunning.go +++ /dev/null @@ -1,8 +0,0 @@ -package interfaces - -// this interface is used to allow the follower to to be stopped later if it -// fails - -type Running interface{ - IsRunning()bool -} diff --git a/interfaces/networkManager.go b/interfaces/networkManager.go index ad9e03caf6f53e13e9d8e9190f2417749c99f490..69bd8aee8e65aaaa6f382788e611d3ff9d9e2c8c 100644 --- a/interfaces/networkManager.go +++ b/interfaces/networkManager.go @@ -20,7 +20,8 @@ import ( ) type NetworkManager interface { - SendE2E(m message.Send, p params.E2E) ([]id.Round, e2e.MessageID, error) + // The stoppable can be nil. + SendE2E(m message.Send, p params.E2E, stop *stoppable.Single) ([]id.Round, e2e.MessageID, error) SendUnsafe(m message.Send, p params.Unsafe) ([]id.Round, error) SendCMIX(message format.Message, recipient *id.ID, p params.CMIX) (id.Round, ephemeral.Id, error) GetInstance() *network.Instance @@ -45,4 +46,4 @@ type NetworkManager interface { } //for use in key exchange which needs to be callable inside of network -type SendE2E func(m message.Send, p params.E2E) ([]id.Round, e2e.MessageID, error) +type SendE2E func(m message.Send, p params.E2E, stop *stoppable.Single) ([]id.Round, e2e.MessageID, error) diff --git a/keyExchange/confirm.go b/keyExchange/confirm.go index 01c43c5d42dd602fb314815d993b20ab7d16e401..ce1bec2baf28d92502c494ec94e659ba1e7c48c1 100644 --- a/keyExchange/confirm.go +++ b/keyExchange/confirm.go @@ -23,6 +23,7 @@ func startConfirm(sess *storage.Session, c chan message.Receive, select { case <-stop.Quit(): cleanup() + stop.ToStopped() return case confirmation := <-c: handleConfirm(sess, confirmation) diff --git a/keyExchange/rekey.go b/keyExchange/rekey.go index f2aab4c42ebedebd258e83168461b719821e1ef2..f4caacd2586bd7a8c0c8c296ba7f8ae19377db4f 100644 --- a/keyExchange/rekey.go +++ b/keyExchange/rekey.go @@ -15,6 +15,7 @@ import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/interfaces/utility" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/storage/e2e" "gitlab.com/elixxir/comms/network" @@ -25,10 +26,11 @@ import ( ) func CheckKeyExchanges(instance *network.Instance, sendE2E interfaces.SendE2E, - sess *storage.Session, manager *e2e.Manager, sendTimeout time.Duration) { + sess *storage.Session, manager *e2e.Manager, sendTimeout time.Duration, + stop *stoppable.Single) { sessions := manager.TriggerNegotiations() for _, session := range sessions { - go trigger(instance, sendE2E, sess, manager, session, sendTimeout) + go trigger(instance, sendE2E, sess, manager, session, sendTimeout, stop) } } @@ -38,7 +40,7 @@ func CheckKeyExchanges(instance *network.Instance, sendE2E interfaces.SendE2E, // session while the latter on an extant session func trigger(instance *network.Instance, sendE2E interfaces.SendE2E, sess *storage.Session, manager *e2e.Manager, session *e2e.Session, - sendTimeout time.Duration) { + sendTimeout time.Duration, stop *stoppable.Single) { var negotiatingSession *e2e.Session jww.INFO.Printf("Negotation triggered for session %s with "+ "status: %s", session, session.NegotiationStatus()) @@ -61,7 +63,7 @@ func trigger(instance *network.Instance, sendE2E interfaces.SendE2E, } // send the rekey notification to the partner - err := negotiate(instance, sendE2E, sess, negotiatingSession, sendTimeout) + err := negotiate(instance, sendE2E, sess, negotiatingSession, sendTimeout, stop) // if sending the negotiation fails, revert the state of the session to // unconfirmed so it will be triggered in the future if err != nil { @@ -71,8 +73,8 @@ func trigger(instance *network.Instance, sendE2E interfaces.SendE2E, } func negotiate(instance *network.Instance, sendE2E interfaces.SendE2E, - sess *storage.Session, session *e2e.Session, - sendTimeout time.Duration) error { + sess *storage.Session, session *e2e.Session, sendTimeout time.Duration, + stop *stoppable.Single) error { e2eStore := sess.E2e() //generate public key @@ -102,7 +104,7 @@ func negotiate(instance *network.Instance, sendE2E interfaces.SendE2E, e2eParams := params.GetDefaultE2E() e2eParams.Type = params.KeyExchange - rounds, _, err := sendE2E(m, e2eParams) + rounds, _, err := sendE2E(m, e2eParams, stop) // If the send fails, returns the error so it can be handled. The caller // should ensure the calling session is in a state where the Rekey will // be triggered next time a key is used diff --git a/keyExchange/trigger.go b/keyExchange/trigger.go index e22dd76f0b30974254de9fb6bd356e8e3f69d88e..c88d068fe8d1463ea4c2b5408dd9b18c4030e488 100644 --- a/keyExchange/trigger.go +++ b/keyExchange/trigger.go @@ -32,14 +32,15 @@ const ( func startTrigger(sess *storage.Session, net interfaces.NetworkManager, c chan message.Receive, stop *stoppable.Single, params params.Rekey, cleanup func()) { - for true { + for { select { case <-stop.Quit(): cleanup() + stop.ToStopped() return case request := <-c: go func() { - err := handleTrigger(sess, net, request, params) + err := handleTrigger(sess, net, request, params, stop) if err != nil { jww.ERROR.Printf(errFailed, err) } @@ -49,7 +50,7 @@ func startTrigger(sess *storage.Session, net interfaces.NetworkManager, } func handleTrigger(sess *storage.Session, net interfaces.NetworkManager, - request message.Receive, param params.Rekey) error { + request message.Receive, param params.Rekey, stop *stoppable.Single) error { //ensure the message was encrypted properly if request.Encryption != message.E2E { errMsg := fmt.Sprintf(errBadTrigger, request.Sender) @@ -126,7 +127,10 @@ func handleTrigger(sess *storage.Session, net interfaces.NetworkManager, // send fails sess.GetCriticalMessages().AddProcessing(m, e2eParams) - rounds, _, err := net.SendE2E(m, e2eParams) + rounds, _, err := net.SendE2E(m, e2eParams, stop) + if err != nil { + return err + } //Register the event for all rounds sendResults := make(chan ds.EventReturn, len(rounds)) diff --git a/keyExchange/trigger_test.go b/keyExchange/trigger_test.go index 5dce2d6a7c50c8e774c3ecd88dd00f5e3db1cacc..d01f9101396d573c8c97892aa3256ede2ba91d49 100644 --- a/keyExchange/trigger_test.go +++ b/keyExchange/trigger_test.go @@ -10,6 +10,7 @@ package keyExchange import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/e2e" dh "gitlab.com/elixxir/crypto/diffieHellman" "gitlab.com/xx_network/crypto/csprng" @@ -66,8 +67,9 @@ func TestHandleTrigger(t *testing.T) { // Handle the trigger and check for an error rekeyParams := params.GetDefaultRekey() + stop := stoppable.NewSingle("stoppable") rekeyParams.RoundTimeout = 0 * time.Second - err = handleTrigger(aliceSession, aliceManager, receiveMsg, rekeyParams) + err = handleTrigger(aliceSession, aliceManager, receiveMsg, rekeyParams, stop) if err != nil { t.Errorf("Handle trigger error: %v", err) } diff --git a/keyExchange/utils_test.go b/keyExchange/utils_test.go index d9dd4ef14b2dea6fd6c3f06db794e14a87b44cc8..84db8a67ca517eb36fd5f9a0982a996b5100c87f 100644 --- a/keyExchange/utils_test.go +++ b/keyExchange/utils_test.go @@ -66,7 +66,7 @@ func (t *testNetworkManagerGeneric) CheckGarbledMessages() { return } -func (t *testNetworkManagerGeneric) SendE2E(m message.Send, p params.E2E) ( +func (t *testNetworkManagerGeneric) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} return rounds, cE2e.MessageID{}, nil @@ -163,7 +163,7 @@ func (t *testNetworkManagerFullExchange) CheckGarbledMessages() { // Intended for alice to send to bob. Trigger's Bob's confirmation, chaining the operation // together -func (t *testNetworkManagerFullExchange) SendE2E(m message.Send, p params.E2E) ( +func (t *testNetworkManagerFullExchange) SendE2E(message.Send, params.E2E, *stoppable.Single) ( []id.Round, cE2e.MessageID, error) { rounds := []id.Round{id.Round(0), id.Round(1), id.Round(2)} diff --git a/network/ephemeral/addressSpace_test.go b/network/ephemeral/addressSpace_test.go index e08f08d844a6aff3bf387db5456a2147fe89f16b..3761c3f55f080f61b90d9c1353143c57664b4011 100644 --- a/network/ephemeral/addressSpace_test.go +++ b/network/ephemeral/addressSpace_test.go @@ -72,7 +72,7 @@ func Test_addressSpace_Get_WaitBroadcast(t *testing.T) { t.Errorf("Get returned the wrong size.\nexpected: %d\nreceived: %d", initSize, size) } - case <-time.NewTimer(15 * time.Millisecond).C: + case <-time.NewTimer(25 * time.Millisecond).C: t.Error("Get blocking when the Cond has broadcast.") } }() diff --git a/network/ephemeral/testutil.go b/network/ephemeral/testutil.go index c213078eeecaca6178d6f2cd7617caf468da792e..fe34b4934ff2955658a3239f0dfcbe048a3a10d5 100644 --- a/network/ephemeral/testutil.go +++ b/network/ephemeral/testutil.go @@ -33,7 +33,7 @@ type testNetworkManager struct { msg message.Send } -func (t *testNetworkManager) SendE2E(m message.Send, _ params.E2E) ([]id.Round, +func (t *testNetworkManager) SendE2E(m message.Send, _ params.E2E, _ *stoppable.Single) ([]id.Round, e2e.MessageID, error) { rounds := []id.Round{ id.Round(0), diff --git a/network/ephemeral/tracker.go b/network/ephemeral/tracker.go index da81a326914389233631d15fd19bf9ee65744ed0..83f855375c23ad6fa0d95d175802de5d12f5659c 100644 --- a/network/ephemeral/tracker.go +++ b/network/ephemeral/tracker.go @@ -122,6 +122,7 @@ func track(session *storage.Session, addrSpace *AddressSpace, ourId *id.ID, stop case addressSize = <-addressSizeUpdate: receptionStore.SetToExpire(addressSize) case <-stop.Quit(): + stop.ToStopped() return } } diff --git a/network/ephemeral/tracker_test.go b/network/ephemeral/tracker_test.go index 3b9e03b29db0b470eebb6a56e490d5f98ac0385b..3307446bc17f25cbf8d8bc119e01a2889b1c2ae2 100644 --- a/network/ephemeral/tracker_test.go +++ b/network/ephemeral/tracker_test.go @@ -47,7 +47,7 @@ func TestCheck(t *testing.T) { ourId := id.NewIdFromBytes([]byte("Sauron"), t) stop := Track(session, NewTestAddressSpace(15, t), ourId) - err = stop.Close(3 * time.Second) + err = stop.Close() if err != nil { t.Errorf("Could not close thread: %+v", err) } @@ -83,7 +83,7 @@ func TestCheck_Thread(t *testing.T) { }() time.Sleep(3 * time.Second) - err = stop.Close(3 * time.Second) + err = stop.Close() if err != nil { t.Errorf("Could not close thread: %v", err) } diff --git a/network/follow.go b/network/follow.go index 8669ea162a38dc51bdbb7601f8023a19406b7351..aa505deecbb5cda93f1ebed0124ef070cf234705 100644 --- a/network/follow.go +++ b/network/follow.go @@ -28,6 +28,7 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces" "gitlab.com/elixxir/client/network/rounds" + "gitlab.com/elixxir/client/stoppable" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/primitives/knownRounds" "gitlab.com/elixxir/primitives/states" @@ -49,19 +50,20 @@ type followNetworkComms interface { // followNetwork polls the network to get updated on the state of nodes, the // round status, and informs the client when messages can be retrieved. -func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-chan struct{}, isRunning interfaces.Running) { +func (m *manager) followNetwork(report interfaces.ClientErrorReport, + stop *stoppable.Single) { ticker := time.NewTicker(m.param.TrackNetworkPeriod) TrackTicker := time.NewTicker(debugTrackPeriod) rng := m.Rng.GetStream() - done := false - for !done { + for { select { - case <-quitCh: + case <-stop.Quit(): rng.Close() - done = true + stop.ToStopped() + return case <-ticker.C: - m.follow(report, rng, m.Comms, isRunning) + m.follow(report, rng, m.Comms, stop) case <-TrackTicker.C: numPolls := atomic.SwapUint64(m.tracker, 0) if m.numLatencies != 0 { @@ -75,19 +77,13 @@ func (m *manager) followNetwork(report interfaces.ClientErrorReport, quitCh <-ch jww.INFO.Printf("Polled the network %d times in the "+ "last %s", numPolls, debugTrackPeriod) } - - } - if !isRunning.IsRunning() { - jww.ERROR.Printf("Killing network follower " + - "due to failed exit") - return } } } // executes each iteration of the follower func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, - comms followNetworkComms, isRunning interfaces.Running) { + comms followNetworkComms, stop *stoppable.Single) { //Get the identity we will poll for identity, err := m.Session.Reception().GetIdentity(rng, m.addrSpace.GetWithoutWait()) @@ -119,10 +115,11 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, identity.EphId.Int64(), identity.Source, identity.StartRequest, identity.EndRequest, identity.EndRequest.Sub(identity.StartRequest), host.GetId()) return comms.SendPoll(host, &pollReq) - }) - if !isRunning.IsRunning() { - jww.ERROR.Printf("Killing network follower " + - "due to failed exit") + }, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + jww.INFO.Print(err) return } @@ -167,9 +164,9 @@ func (m *manager) follow(report interfaces.ClientErrorReport, rng csprng.Source, // Update the address space size // todo: this is a fix for incompatibility with the live network // remove once the live network has been pushed to - if len(m.Instance.GetPartialNdf().Get().AddressSpace)!=0{ + if len(m.Instance.GetPartialNdf().Get().AddressSpace) != 0 { m.addrSpace.Update(m.Instance.GetPartialNdf().Get().AddressSpace[0].Size) - }else{ + } else { m.addrSpace.Update(18) } diff --git a/network/gateway/sender.go b/network/gateway/sender.go index 7e07b247353156960992382eb93bbe718ab9851a..e8943e21549e34bdbe692a689f6f789e7b01af73 100644 --- a/network/gateway/sender.go +++ b/network/gateway/sender.go @@ -11,6 +11,7 @@ package gateway import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/comms/network" "gitlab.com/elixxir/crypto/fastRNG" @@ -36,12 +37,14 @@ func NewSender(poolParams PoolParams, rng *fastRNG.StreamGenerator, ndf *ndf.Net } // SendToAny Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations -func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error)) (interface{}, error) { +func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error), stop *stoppable.Single) (interface{}, error) { proxies := s.getAny(s.poolParams.ProxyAttempts, nil) for i := range proxies { result, err := sendFunc(proxies[i]) - if err == nil { + if stop != nil && !stop.IsRunning() { + return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToAny") + } else if err == nil { return result, nil } else { jww.WARN.Printf("Unable to SendToAny %s: %s", proxies[i].GetId().String(), err) @@ -57,7 +60,8 @@ func (s *Sender) SendToAny(sendFunc func(host *connect.Host) (interface{}, error // SendToPreferred Call given sendFunc to any Host in the HostPool, attempting with up to numProxies destinations func (s *Sender) SendToPreferred(targets []*id.ID, - sendFunc func(host *connect.Host, target *id.ID) (interface{}, bool, error)) (interface{}, error) { + sendFunc func(host *connect.Host, target *id.ID) (interface{}, bool, error), + stop *stoppable.Single) (interface{}, error) { // Get the hosts and shuffle randomly targetHosts := s.getPreferred(targets) @@ -65,7 +69,9 @@ func (s *Sender) SendToPreferred(targets []*id.ID, // Attempt to send directly to targets if they are in the HostPool for i := range targetHosts { result, didAbort, err := sendFunc(targetHosts[i], targets[i]) - if err == nil { + if stop != nil && !stop.IsRunning() { + return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToPreferred") + } else if err == nil { return result, nil } else { if didAbort { @@ -103,7 +109,9 @@ func (s *Sender) SendToPreferred(targets []*id.ID, } result, didAbort, err := sendFunc(targetProxies[proxyIdx], target) - if err == nil { + if stop != nil && !stop.IsRunning() { + return nil, errors.Errorf(stoppable.ErrMsg, stop.Name(), "SendToPreferred") + } else if err == nil { return result, nil } else { if didAbort { diff --git a/network/gateway/sender_test.go b/network/gateway/sender_test.go index 35868f1d6c32f83d28cfb3c534d1877cc94b53a4..d8dbf16227e8724ef509589bb7133efea888cc26 100644 --- a/network/gateway/sender_test.go +++ b/network/gateway/sender_test.go @@ -78,7 +78,7 @@ func TestSender_SendToAny(t *testing.T) { } // Test sendToAny with test interfaces - result, err := sender.SendToAny(SendToAny_HappyPath) + result, err := sender.SendToAny(SendToAny_HappyPath, nil) if err != nil { t.Errorf("Should not error in SendToAny happy path: %v", err) } @@ -89,12 +89,12 @@ func TestSender_SendToAny(t *testing.T) { "\n\tReceived: %v", happyPathReturn, result) } - _, err = sender.SendToAny(SendToAny_KnownError) + _, err = sender.SendToAny(SendToAny_KnownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } - _, err = sender.SendToAny(SendToAny_UnknownError) + _, err = sender.SendToAny(SendToAny_UnknownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } @@ -139,7 +139,7 @@ func TestSender_SendToPreferred(t *testing.T) { preferredHost := sender.hostList[preferredIndex] // Happy path - result, err := sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_HappyPath) + result, err := sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_HappyPath, nil) if err != nil { t.Errorf("Should not error in SendToPreferred happy path: %v", err) } @@ -151,7 +151,7 @@ func TestSender_SendToPreferred(t *testing.T) { } // Call a send which returns an error which triggers replacement - _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_KnownError) + _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_KnownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } @@ -171,7 +171,7 @@ func TestSender_SendToPreferred(t *testing.T) { preferredHost = sender.hostList[preferredIndex] // Unknown error return will not trigger replacement - _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_UnknownError) + _, err = sender.SendToPreferred([]*id.ID{preferredHost.GetId()}, SendToPreferred_UnknownError, nil) if err == nil { t.Fatalf("Expected error path did not receive error") } diff --git a/network/health/tracker.go b/network/health/tracker.go index ff53904f2015fe81af4938efd9ca8222fd68eba1..1fc5e8cfb7285916192e34918e5fa9068b06a928 100644 --- a/network/health/tracker.go +++ b/network/health/tracker.go @@ -114,25 +114,26 @@ func (t *Tracker) Start() (stoppable.Stoppable, error) { stop := stoppable.NewSingle("Health Tracker") - go t.start(stop.Quit()) + go t.start(stop) return stop, nil } // Long-running thread used to monitor and report on network health -func (t *Tracker) start(quitCh <-chan struct{}) { +func (t *Tracker) start(stop *stoppable.Single) { timer := time.NewTimer(t.timeout) for { var heartbeat network.Heartbeat select { - case <-quitCh: + case <-stop.Quit(): t.mux.Lock() t.isHealthy = false t.running = false t.mux.Unlock() t.transmit(false) - break + stop.ToStopped() + return case heartbeat = <-t.heartbeat: if healthy(heartbeat) { // Stop and reset timer @@ -146,10 +147,9 @@ func (t *Tracker) start(quitCh <-chan struct{}) { timer.Reset(t.timeout) t.setHealth(true) } - break case <-timer.C: t.setHealth(false) - break + return } } } diff --git a/network/manager.go b/network/manager.go index c072da4eb05bd74f60fd95dfdcfcdd7bd47e7428..fcb20f2d77f94a57c4c2f01e581b03fca06b6178 100644 --- a/network/manager.go +++ b/network/manager.go @@ -136,7 +136,7 @@ func (m *manager) Follow(report interfaces.ClientErrorReport) (stoppable.Stoppab // Start the Network Tracker trackNetworkStopper := stoppable.NewSingle("TrackNetwork") - go m.followNetwork(report, trackNetworkStopper.Quit(), trackNetworkStopper) + go m.followNetwork(report, trackNetworkStopper) multi.Add(trackNetworkStopper) // Message reception diff --git a/network/message/critical.go b/network/message/critical.go index 384ac9981f1dec0471b72aa3083880c26fc56382..84b88da602ac8fb87a1562d25ab8a52c6f1c23ed 100644 --- a/network/message/critical.go +++ b/network/message/critical.go @@ -12,6 +12,7 @@ import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/interfaces/utility" + "gitlab.com/elixxir/client/stoppable" ds "gitlab.com/elixxir/comms/network/dataStructures" "gitlab.com/elixxir/primitives/format" "gitlab.com/elixxir/primitives/states" @@ -27,22 +28,22 @@ import ( // Tracker (/network/Health/Tracker.g0) //Thread loop for processing critical messages -func (m *Manager) processCriticalMessages(quitCh <-chan struct{}) { - done := false - for !done { +func (m *Manager) processCriticalMessages(stop *stoppable.Single) { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case isHealthy := <-m.networkIsHealthy: if isHealthy { - m.criticalMessages() + m.criticalMessages(stop) } } } } // processes all critical messages -func (m *Manager) criticalMessages() { +func (m *Manager) criticalMessages(stop *stoppable.Single) { critMsgs := m.Session.GetCriticalMessages() // try to send every message in the critical messages and the raw critical // messages buffer in parallel @@ -53,7 +54,7 @@ func (m *Manager) criticalMessages() { jww.INFO.Printf("Resending critical message to %s ", msg.Recipient) //send the message - rounds, _, err := m.SendE2E(msg, param) + rounds, _, err := m.SendE2E(msg, param, stop) //if the message fail to send, notify the buffer so it can be handled //in the future and exit if err != nil { @@ -95,7 +96,7 @@ func (m *Manager) criticalMessages() { jww.INFO.Printf("Resending critical raw message to %s "+ "(msgDigest: %s)", rid, msg.Digest()) //send the message - round, _, err := m.SendCMIX(m.sender, msg, rid, param) + round, _, err := m.SendCMIX(m.sender, msg, rid, param, stop) //if the message fail to send, notify the buffer so it can be handled //in the future and exit if err != nil { diff --git a/network/message/garbled.go b/network/message/garbled.go index d9e1ac6be427f2ca0e9c8e89572fb6eed2285483..3a828b5eed2584ff323df7cdbb6016726c9dfa18 100644 --- a/network/message/garbled.go +++ b/network/message/garbled.go @@ -10,6 +10,7 @@ package message import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/primitives/format" "time" ) @@ -33,12 +34,12 @@ func (m *Manager) CheckGarbledMessages() { } //long running thread which processes garbled messages -func (m *Manager) processGarbledMessages(quitCh <-chan struct{}) { - done := false - for !done { +func (m *Manager) processGarbledMessages(stop *stoppable.Single) { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case <-m.triggerGarbled: m.handleGarbledMessages() } diff --git a/network/message/garbled_test.go b/network/message/garbled_test.go index ab4919f16d674ea10a609ca6d0e46a442259789f..d254c56c48329187955596c8367b106fab77cc08 100644 --- a/network/message/garbled_test.go +++ b/network/message/garbled_test.go @@ -7,6 +7,7 @@ import ( "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message/parse" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" "gitlab.com/elixxir/client/switchboard" "gitlab.com/elixxir/comms/client" @@ -120,8 +121,8 @@ func TestManager_CheckGarbledMessages(t *testing.T) { encryptedMsg := key.Encrypt(msg) i.Session.GetGarbledMessages().Add(encryptedMsg) - quitch := make(chan struct{}) - go m.processGarbledMessages(quitch) + stop := stoppable.NewSingle("stop") + go m.processGarbledMessages(stop) m.CheckGarbledMessages() diff --git a/network/message/handler.go b/network/message/handler.go index b8892cff56f51009558998fcd74b1fa027c4a712..5c54e8308b73ed35f668620a731577f4cfe67271 100644 --- a/network/message/handler.go +++ b/network/message/handler.go @@ -10,6 +10,7 @@ package message import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" "gitlab.com/elixxir/crypto/e2e" fingerprint2 "gitlab.com/elixxir/crypto/fingerprint" @@ -18,12 +19,12 @@ import ( "time" ) -func (m *Manager) handleMessages(quitCh <-chan struct{}) { - done := false - for !done { +func (m *Manager) handleMessages(stop *stoppable.Single) { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case bundle := <-m.messageReception: for _, msg := range bundle.Messages { m.handleMessage(msg, bundle.Identity) diff --git a/network/message/manager.go b/network/message/manager.go index 7728910aa7e62186c682dcb97b297cb470dcac58..d5bf21a10db947a10f6566c93d005481770789ef 100644 --- a/network/message/manager.go +++ b/network/message/manager.go @@ -58,19 +58,19 @@ func (m *Manager) StartProcessies() stoppable.Stoppable { //create the message handler workers for i := uint(0); i < m.param.MessageReceptionWorkerPoolSize; i++ { stop := stoppable.NewSingle(fmt.Sprintf("MessageReception Worker %v", i)) - go m.handleMessages(stop.Quit()) + go m.handleMessages(stop) multi.Add(stop) } //create the critical messages thread critStop := stoppable.NewSingle("CriticalMessages") - go m.processCriticalMessages(critStop.Quit()) + go m.processCriticalMessages(critStop) m.Health.AddChannel(m.networkIsHealthy) multi.Add(critStop) //create the garbled messages thread garbledStop := stoppable.NewSingle("GarbledMessages") - go m.processGarbledMessages(garbledStop.Quit()) + go m.processGarbledMessages(garbledStop) multi.Add(garbledStop) return multi diff --git a/network/message/sendCmix.go b/network/message/sendCmix.go index 2263b903f2d05491ede47501802d7e3e3fab5135..12f9604a26cbd473753b81fd3b6707477d4c287b 100644 --- a/network/message/sendCmix.go +++ b/network/message/sendCmix.go @@ -13,6 +13,7 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/network/gateway" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/comms/network" @@ -38,9 +39,11 @@ const sendTimeBuffer = 500 * time.Millisecond // WARNING: Potentially Unsafe // Public manager function to send a message over CMIX -func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, recipient *id.ID, param params.CMIX) (id.Round, ephemeral.Id, error) { +func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, + recipient *id.ID, param params.CMIX, stop *stoppable.Single) (id.Round, ephemeral.Id, error) { msgCopy := msg.Copy() - return sendCmixHelper(sender, msgCopy, recipient, m.param, param, m.Instance, m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms) + return sendCmixHelper(sender, msgCopy, recipient, param, m.Instance, + m.Session, m.nodeRegistration, m.Rng, m.TransmissionID, m.Comms, stop) } // Payloads send are not End to End encrypted, MetaData is NOT protected with @@ -51,9 +54,11 @@ func (m *Manager) SendCMIX(sender *gateway.Sender, msg format.Message, recipient // If the message is successfully sent, the id of the round sent it is returned, // which can be registered with the network instance to get a callback on // its status -func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID, messageParams params.Messages, cmixParams params.CMIX, instance *network.Instance, - session *storage.Session, nodeRegistration chan network.NodeGateway, rng *fastRNG.StreamGenerator, senderId *id.ID, - comms sendCmixCommsInterface) (id.Round, ephemeral.Id, error) { +func sendCmixHelper(sender *gateway.Sender, msg format.Message, + recipient *id.ID, cmixParams params.CMIX, instance *network.Instance, + session *storage.Session, nodeRegistration chan network.NodeGateway, + rng *fastRNG.StreamGenerator, senderId *id.ID, comms sendCmixCommsInterface, + stop *stoppable.Single) (id.Round, ephemeral.Id, error) { timeStart := netTime.Now() attempted := set.New() @@ -204,7 +209,12 @@ func sendCmixHelper(sender *gateway.Sender, msg format.Message, recipient *id.ID } return result, false, err } - result, err := sender.SendToPreferred([]*id.ID{firstGateway}, sendFunc) + result, err := sender.SendToPreferred([]*id.ID{firstGateway}, sendFunc, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + return 0, ephemeral.Id{}, err + } //if the comm errors or the message fails to send, continue retrying. //return if it sends properly diff --git a/network/message/sendCmix_test.go b/network/message/sendCmix_test.go index f3182a8a2407c4f29c1c615160a28fd793ab4ee6..87412834c6a33d4165a77f674c0fe58bdd648797 100644 --- a/network/message/sendCmix_test.go +++ b/network/message/sendCmix_test.go @@ -143,9 +143,9 @@ func Test_attemptSendCmix(t *testing.T) { msgCmix := format.NewMessage(m.Session.Cmix().GetGroup().GetP().ByteLen()) msgCmix.SetContents([]byte("test")) e2e.SetUnencrypted(msgCmix, m.Session.User().GetCryptographicIdentity().GetTransmissionID()) - _, _, err = sendCmixHelper(sender, msgCmix, sess2.GetUser().ReceptionID, params.GetDefaultMessage(), params.GetDefaultCMIX(), - m.Instance, m.Session, m.nodeRegistration, m.Rng, - m.TransmissionID, &MockSendCMIXComms{t: t}) + _, _, err = sendCmixHelper(sender, msgCmix, sess2.GetUser().ReceptionID, + params.GetDefaultCMIX(), m.Instance, m.Session, m.nodeRegistration, + m.Rng, m.TransmissionID, &MockSendCMIXComms{t: t}, nil) if err != nil { t.Errorf("Failed to sendcmix: %+v", err) panic("t") diff --git a/network/message/sendE2E.go b/network/message/sendE2E.go index b09e96c4a19598e7724ccf0bd5781ff60745b3d1..7b468ad1a30c818396d631b4b9f85b5945dd88ab 100644 --- a/network/message/sendE2E.go +++ b/network/message/sendE2E.go @@ -13,6 +13,7 @@ import ( "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/keyExchange" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" @@ -21,7 +22,8 @@ import ( "time" ) -func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.MessageID, error) { +func (m *Manager) SendE2E(msg message.Send, param params.E2E, + stop *stoppable.Single) ([]id.Round, e2e.MessageID, error) { if msg.MessageType == message.Raw { return nil, e2e.MessageID{}, errors.Errorf("Raw (%d) is a reserved "+ "message type", msg.MessageType) @@ -58,7 +60,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M if msg.MessageType != message.KeyExchangeTrigger { // check if any rekeys need to happen and trigger them keyExchange.CheckKeyExchanges(m.Instance, m.SendE2E, - m.Session, partner, 1*time.Minute) + m.Session, partner, 1*time.Minute, stop) } //create the cmix message @@ -96,7 +98,7 @@ func (m *Manager) SendE2E(msg message.Send, param params.E2E) ([]id.Round, e2e.M go func(i int) { var err error roundIds[i], _, err = m.SendCMIX(m.sender, msgEnc, msg.Recipient, - param.CMIX) + param.CMIX, stop) if err != nil { errCh <- err } diff --git a/network/message/sendUnsafe.go b/network/message/sendUnsafe.go index 19f7d5d5ce0bfe6dd14df77b76f0a21c824b2404..938bcc07c8bab9bd24e1bdd53cb1c38c3b4111c4 100644 --- a/network/message/sendUnsafe.go +++ b/network/message/sendUnsafe.go @@ -64,7 +64,7 @@ func (m *Manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, wg.Add(1) go func(i int) { var err error - roundIds[i], _, err = m.SendCMIX(m.sender, msgCmix, msg.Recipient, param.CMIX) + roundIds[i], _, err = m.SendCMIX(m.sender, msgCmix, msg.Recipient, param.CMIX, nil) if err != nil { errCh <- err } diff --git a/network/node/register.go b/network/node/register.go index 82ab2cc3d937245adbc63812d6e875bc63657845..a16e1cd08486b3bd45abd59a3f3061b7cd43a3bd 100644 --- a/network/node/register.go +++ b/network/node/register.go @@ -55,7 +55,8 @@ func StartRegistration(sender *gateway.Sender, session *storage.Session, rngGen return multi } -func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, +func registerNodes(sender *gateway.Sender, session *storage.Session, + rngGen *fastRNG.StreamGenerator, comms RegisterNodeCommsInterface, stop *stoppable.Single, c chan network.NodeGateway) { u := session.User() regSignature := u.GetTransmissionRegistrationValidationSignature() @@ -67,13 +68,15 @@ func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fas rng := rngGen.GetStream() interval := time.Duration(500) * time.Millisecond t := time.NewTicker(interval) - for true { + for { select { case <-stop.Quit(): t.Stop() + stop.ToStopped() return case gw := <-c: - err := registerWithNode(sender, comms, gw, regSignature, regTimestamp, uci, cmix, rng) + err := registerWithNode(sender, comms, gw, regSignature, + regTimestamp, uci, cmix, rng, stop) if err != nil { jww.ERROR.Printf("Failed to register node: %+v", err) } @@ -84,9 +87,10 @@ func registerNodes(sender *gateway.Sender, session *storage.Session, rngGen *fas //registerWithNode serves as a helper for RegisterWithNodes // It registers a user with a specific in the client's ndf. -func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, ngw network.NodeGateway, - regSig []byte, registrationTimestampNano int64, uci *user.CryptographicIdentity, - store *cmix.Store, rng csprng.Source) error { +func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, + ngw network.NodeGateway, regSig []byte, registrationTimestampNano int64, + uci *user.CryptographicIdentity, store *cmix.Store, rng csprng.Source, + stop *stoppable.Single) error { nodeID, err := ngw.Node.GetNodeId() if err != nil { @@ -122,7 +126,9 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, // keys transmissionHash, _ := hash.NewCMixHash() - nonce, dhPub, err := requestNonce(sender, comms, gatewayID, regSig, registrationTimestampNano, uci, store, rng) + nonce, dhPub, err := requestNonce(sender, comms, gatewayID, regSig, + registrationTimestampNano, uci, store, rng, stop) + if err != nil { return errors.Errorf("Failed to request nonce: %+v", err) } @@ -133,7 +139,8 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, // Confirm received nonce jww.INFO.Printf("Register: Confirming received nonce from node %s", nodeID.String()) err = confirmNonce(sender, comms, uci.GetTransmissionID().Bytes(), - nonce, uci.GetTransmissionRSA(), gatewayID) + nonce, uci.GetTransmissionRSA(), gatewayID, stop) + if err != nil { errMsg := fmt.Sprintf("Register: Unable to confirm nonce: %v", err) return errors.New(errMsg) @@ -151,7 +158,7 @@ func registerWithNode(sender *gateway.Sender, comms RegisterNodeCommsInterface, func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId *id.ID, regSig []byte, registrationTimestampNano int64, uci *user.CryptographicIdentity, - store *cmix.Store, rng csprng.Source) ([]byte, []byte, error) { + store *cmix.Store, rng csprng.Source, stop *stoppable.Single) ([]byte, []byte, error) { dhPub := store.GetDHPublicKey().Bytes() opts := rsa.NewDefaultOptions() @@ -195,7 +202,8 @@ func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId return nil, err } return nonceResponse, nil - }) + }, stop) + if err != nil { return nil, nil, err } @@ -208,8 +216,9 @@ func requestNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, gwId // confirmNonce is a helper for the Register function // It signs a nonce and sends it for confirmation // Returns nil if successful, error otherwise -func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, nonce []byte, - privateKeyRSA *rsa.PrivateKey, gwID *id.ID) error { +func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, + nonce []byte, privateKeyRSA *rsa.PrivateKey, gwID *id.ID, + stop *stoppable.Single) error { opts := rsa.NewDefaultOptions() opts.Hash = hash.CMixHash h, _ := hash.NewCMixHash() @@ -248,6 +257,7 @@ func confirmNonce(sender *gateway.Sender, comms RegisterNodeCommsInterface, UID, return nil, err } return confirmResponse, nil - }) + }, stop) + return err } diff --git a/network/rounds/historical.go b/network/rounds/historical.go index 45aed7fdfaa7b296a46de79645e19fd010f88634..b4920335458283243fe7f4b77d4777bde2efb4e7 100644 --- a/network/rounds/historical.go +++ b/network/rounds/historical.go @@ -9,6 +9,7 @@ package rounds import ( jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/xx_network/comms/connect" @@ -41,19 +42,18 @@ type historicalRoundRequest struct { // Long running thread which process historical rounds // Can be killed by sending a signal to the quit channel // takes a comms interface to aid in testing -func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-chan struct{}) { +func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, stop *stoppable.Single) { timerCh := make(<-chan time.Time) rng := m.Rng.GetStream() var roundRequests []historicalRoundRequest - done := false - for !done { + for { shouldProcess := false // wait for a quit or new round to check select { - case <-quitCh: + case <-stop.Quit(): rng.Close() // return all roundRequests in the queue to the input channel so they can // be checked in the future. If the queue is full, disable them as @@ -64,7 +64,8 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c default: } } - done = true + stop.ToStopped() + return // if the timer elapses process roundRequests to ensure the delay isn't too long case <-timerCh: if len(roundRequests) > 0 { @@ -100,7 +101,7 @@ func (m *Manager) processHistoricalRounds(comm historicalRoundsComms, quitCh <-c jww.DEBUG.Printf("Requesting Historical rounds %v from "+ "gateway %s", rounds, host.GetId()) return comm.RequestHistoricalRounds(host, hr) - }) + }, stop) if err != nil { jww.ERROR.Printf("Failed to request historical roundRequests "+ diff --git a/network/rounds/manager.go b/network/rounds/manager.go index 4c1ec87a993afcdcfcf724d7af2ca69bf2ad6b98..c695cd7a8c78be622769a4388a669eda5757b304 100644 --- a/network/rounds/manager.go +++ b/network/rounds/manager.go @@ -8,12 +8,12 @@ package rounds import ( - "fmt" "gitlab.com/elixxir/client/interfaces/params" "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/internal" "gitlab.com/elixxir/client/network/message" "gitlab.com/elixxir/client/stoppable" + "strconv" ) type Manager struct { @@ -47,19 +47,19 @@ func (m *Manager) StartProcessors() stoppable.Stoppable { //start the historical rounds thread historicalRoundsStopper := stoppable.NewSingle("ProcessHistoricalRounds") - go m.processHistoricalRounds(m.Comms, historicalRoundsStopper.Quit()) + go m.processHistoricalRounds(m.Comms, historicalRoundsStopper) multi.Add(historicalRoundsStopper) //start the message retrieval worker pool for i := uint(0); i < m.params.NumMessageRetrievalWorkers; i++ { - stopper := stoppable.NewSingle(fmt.Sprintf("Messager Retriever %v", i)) - go m.processMessageRetrieval(m.Comms, stopper.Quit()) + stopper := stoppable.NewSingle("Message Retriever " + strconv.Itoa(int(i))) + go m.processMessageRetrieval(m.Comms, stopper) multi.Add(stopper) } // Start the periodic unchecked round worker stopper := stoppable.NewSingle("UncheckRound") - go m.processUncheckedRounds(m.params.UncheckRoundPeriod, backOffTable, stopper.Quit()) + go m.processUncheckedRounds(m.params.UncheckRoundPeriod, backOffTable, stopper) multi.Add(stopper) return multi diff --git a/network/rounds/retrieve.go b/network/rounds/retrieve.go index 6f37cb53d61f25d11001a758f8ea38554d1287cc..d69cc06b9a68a1391fa59440deaa5c6e3b6fad69 100644 --- a/network/rounds/retrieve.go +++ b/network/rounds/retrieve.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/network/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/crypto/shuffle" @@ -37,12 +38,13 @@ const noRoundError = "does not have round %d" // processMessageRetrieval received a roundLookup request and pings the gateways // of that round for messages for the requested identity in the roundLookup func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, - quitCh <-chan struct{}) { - done := false - for !done { + stop *stoppable.Single) { + + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case rl := <-m.lookupRoundMessages: ri := rl.roundInfo jww.DEBUG.Printf("Checking for messages in round %d", ri.ID) @@ -84,7 +86,13 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, // round has not been ignored before var bundle message.Bundle if m.params.ForceMessagePickupRetry { - bundle, err = m.forceMessagePickupRetry(ri, rl, comms, gwIds) + bundle, err = m.forceMessagePickupRetry(ri, rl, comms, gwIds, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + jww.ERROR.Print(err) + continue + } if err != nil { jww.ERROR.Printf("Failed to get pickup round %d "+ "from all gateways (%v): %s", @@ -92,7 +100,14 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, } } else { // Attempt to request for this gateway - bundle, err = m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds) + bundle, err = m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds, stop) + + // Exit if the thread has been stopped + if stoppable.CheckErr(err) { + jww.ERROR.Print(err) + continue + } + // After trying all gateways, if none returned we mark the round as a // failure and print out the last error if err != nil { @@ -122,8 +137,9 @@ func (m *Manager) processMessageRetrieval(comms messageRetrievalComms, // getMessagesFromGateway attempts to get messages from their assigned // gateway host in the round specified. If successful -func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.IdentityUse, - comms messageRetrievalComms, gwIds []*id.ID) (message.Bundle, error) { +func (m *Manager) getMessagesFromGateway(roundID id.Round, + identity reception.IdentityUse, comms messageRetrievalComms, gwIds []*id.ID, + stop *stoppable.Single) (message.Bundle, error) { start := time.Now() // Send to the gateways using backup proxies result, err := m.sender.SendToPreferred(gwIds, func(host *connect.Host, target *id.ID) (interface{}, bool, error) { @@ -145,7 +161,7 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.Id } return msgResp, false, err - }) + }, stop) jww.INFO.Printf("Received message for round %d, processing...", roundID) // Fail the round if an error occurs so it can be tried again later if err != nil { @@ -191,7 +207,8 @@ func (m *Manager) getMessagesFromGateway(roundID id.Round, identity reception.Id // Helper function which forces processUncheckedRounds by randomly // not looking up messages func (m *Manager) forceMessagePickupRetry(ri *pb.RoundInfo, rl roundLookup, - comms messageRetrievalComms, gwIds []*id.ID) (bundle message.Bundle, err error) { + comms messageRetrievalComms, gwIds []*id.ID, + stop *stoppable.Single) (bundle message.Bundle, err error) { rnd, _ := m.Session.UncheckedRounds().GetRound(id.Round(ri.ID)) if rnd.NumChecks == 0 { // Flip a coin to determine whether to pick up message @@ -213,5 +230,5 @@ func (m *Manager) forceMessagePickupRetry(ri *pb.RoundInfo, rl roundLookup, } // Attempt to request for this gateway - return m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds) + return m.getMessagesFromGateway(id.Round(ri.ID), rl.identity, comms, gwIds, stop) } diff --git a/network/rounds/retrieve_test.go b/network/rounds/retrieve_test.go index ae54730119dd1c441116ba2187985b078caf6c84..6984e285c8ff9ac9d635a2ef0728c43c6c510c50 100644 --- a/network/rounds/retrieve_test.go +++ b/network/rounds/retrieve_test.go @@ -10,6 +10,7 @@ import ( "bytes" "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/crypto/fastRNG" @@ -28,7 +29,7 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") testNdf := getNDF() nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -51,7 +52,7 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -92,8 +93,10 @@ func TestManager_ProcessMessageRetrieval(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + err := stop.Close() + if err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() // Ensure bundle received and has expected values @@ -136,7 +139,7 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { testManager.sender, _ = gateway.NewSender(p, testManager.Rng, testNdf, mockComms, testManager.Session, nil) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) @@ -144,7 +147,7 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -183,8 +186,9 @@ func TestManager_ProcessMessageRetrieval_NoRound(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() time.Sleep(2 * time.Second) @@ -202,7 +206,7 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") testNdf := getNDF() nodeId := id.NewIdFromString(FalsePositive, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -222,7 +226,7 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -263,8 +267,9 @@ func TestManager_ProcessMessageRetrieval_FalsePositive(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() // Ensure no bundle was received due to false positive test @@ -282,7 +287,7 @@ func TestManager_ProcessMessageRetrieval_Quit(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") // Create a local channel so reception is possible (testManager.messageBundles is // send only via newManager call above) @@ -290,10 +295,12 @@ func TestManager_ProcessMessageRetrieval_Quit(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Close the process early, before any logic below can be completed - quitChan <- struct{}{} + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -348,7 +355,7 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") testNdf := getNDF() nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -368,7 +375,7 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { testManager.messageBundles = messageBundleChan // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) + go testManager.processMessageRetrieval(mockComms, stop) // Construct expected values for checking expectedEphID := ephemeral.Id{1, 2, 3, 4, 5, 6, 7, 8} @@ -410,8 +417,9 @@ func TestManager_ProcessMessageRetrieval_MultipleGateways(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} - + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() // Ensure that expected bundle is still received from happy comm diff --git a/network/rounds/unchecked.go b/network/rounds/unchecked.go index e02f1181282d50bbd4c2d26e1bb8a4213115315d..e62bff0c6885d71080b4544cf6602ec684fa2a9a 100644 --- a/network/rounds/unchecked.go +++ b/network/rounds/unchecked.go @@ -9,6 +9,7 @@ package rounds import ( jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/client/storage/reception" "gitlab.com/xx_network/primitives/netTime" "time" @@ -36,14 +37,14 @@ var backOffTable = [cappedTries]time.Duration{tryZero, tryOne, tryTwo, tryThree, // If a round is found to be due on a periodical check, the round is sent // back to processMessageRetrieval. func (m *Manager) processUncheckedRounds(checkInterval time.Duration, backoffTable [cappedTries]time.Duration, - quitCh <-chan struct{}) { + stop *stoppable.Single) { ticker := time.NewTicker(checkInterval) uncheckedRoundStore := m.Session.UncheckedRounds() - done := false - for !done { + for { select { - case <-quitCh: - done = true + case <-stop.Quit(): + stop.ToStopped() + return case <-ticker.C: // Pull and iterate through uncheckedRound list @@ -67,7 +68,7 @@ func (m *Manager) processUncheckedRounds(checkInterval time.Duration, backoffTab // Send to processMessageRetrieval select { case m.lookupRoundMessages <- rl: - case <-time.After(500 * time.Second): + case <-time.After(1 * time.Second): jww.WARN.Printf("Timing out, not retrying round %d", rl.roundInfo.ID) } diff --git a/network/rounds/unchecked_test.go b/network/rounds/unchecked_test.go index cc66c02712acbffae509eeb7270cc88463e80c58..980ca1e6157a583cc239cab4332e358c5b04e84a 100644 --- a/network/rounds/unchecked_test.go +++ b/network/rounds/unchecked_test.go @@ -10,6 +10,7 @@ package rounds import ( "gitlab.com/elixxir/client/network/gateway" "gitlab.com/elixxir/client/network/message" + "gitlab.com/elixxir/client/stoppable" pb "gitlab.com/elixxir/comms/mixmessages" "gitlab.com/elixxir/crypto/fastRNG" "gitlab.com/xx_network/crypto/csprng" @@ -27,7 +28,8 @@ func TestUncheckedRoundScheduler(t *testing.T) { testManager := newManager(t) roundId := id.Round(5) mockComms := &mockMessageRetrievalComms{testingSignature: t} - quitChan := make(chan struct{}) + stop1 := stoppable.NewSingle("singleStoppable1") + stop2 := stoppable.NewSingle("singleStoppable2") testNdf := getNDF() nodeId := id.NewIdFromString(ReturningGateway, id.Node, &testing.T{}) gwId := nodeId.DeepCopy() @@ -47,8 +49,8 @@ func TestUncheckedRoundScheduler(t *testing.T) { testBackoffTable := newTestBackoffTable(t) checkInterval := 250 * time.Millisecond // Initialize the message retrieval - go testManager.processMessageRetrieval(mockComms, quitChan) - go testManager.processUncheckedRounds(checkInterval, testBackoffTable, quitChan) + go testManager.processMessageRetrieval(mockComms, stop1) + go testManager.processUncheckedRounds(checkInterval, testBackoffTable, stop2) requestGateway := id.NewIdFromString(ReturningGateway, id.Gateway, t) @@ -73,7 +75,12 @@ func TestUncheckedRoundScheduler(t *testing.T) { testBundle = <-messageBundleChan // Close the process - quitChan <- struct{}{} + if err := stop1.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } + if err := stop2.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } }() diff --git a/network/send.go b/network/send.go index d70a0c6661c2ee6f4c40f313219baa7cbb86fd52..d09390c6f3bbb97c729e2bed480bc3a68899d53c 100644 --- a/network/send.go +++ b/network/send.go @@ -12,6 +12,7 @@ import ( jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" "gitlab.com/elixxir/client/interfaces/params" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e" "gitlab.com/elixxir/primitives/format" "gitlab.com/xx_network/primitives/id" @@ -28,7 +29,7 @@ func (m *manager) SendCMIX(msg format.Message, recipient *id.ID, param params.CM "network is not healthy") } - return m.message.SendCMIX(m.GetSender(), msg, recipient, param) + return m.message.SendCMIX(m.GetSender(), msg, recipient, param, nil) } // SendUnsafe sends an unencrypted payload to the provided recipient @@ -52,7 +53,7 @@ func (m *manager) SendUnsafe(msg message.Send, param params.Unsafe) ([]id.Round, // SendE2E sends an end-to-end payload to the provided recipient with // the provided msgType. Returns the list of rounds in which parts of // the message were sent or an error if it fails. -func (m *manager) SendE2E(msg message.Send, e2eP params.E2E) ( +func (m *manager) SendE2E(msg message.Send, e2eP params.E2E, stop *stoppable.Single) ( []id.Round, e2e.MessageID, error) { if !m.Health.IsHealthy() { @@ -60,5 +61,5 @@ func (m *manager) SendE2E(msg message.Send, e2eP params.E2E) ( "message when the network is not healthy") } - return m.message.SendE2E(msg, e2eP) + return m.message.SendE2E(msg, e2eP, stop) } diff --git a/single/manager.go b/single/manager.go index 78481829180017f242ca45e43b2bacc4175cf4dd..063f7e8541e0f8f3f9063228fbc33d618247e2b7 100644 --- a/single/manager.go +++ b/single/manager.go @@ -74,13 +74,13 @@ func (m *Manager) StartProcesses() stoppable.Stoppable { transmissionStop := stoppable.NewSingle(singleUseTransmission) transmissionChan := make(chan message.Receive, rawMessageBuffSize) m.swb.RegisterChannel(singleUseReceiveTransmission, &id.ID{}, message.Raw, transmissionChan) - go m.receiveTransmissionHandler(transmissionChan, transmissionStop.Quit()) + go m.receiveTransmissionHandler(transmissionChan, transmissionStop) // Start waiting for single-use response responseStop := stoppable.NewSingle(singleUseResponse) responseChan := make(chan message.Receive, rawMessageBuffSize) m.swb.RegisterChannel(singleUseReceiveResponse, &id.ID{}, message.Raw, responseChan) - go m.receiveResponseHandler(responseChan, responseStop.Quit()) + go m.receiveResponseHandler(responseChan, responseStop) // Create a multi stoppable singleUseMulti := stoppable.NewMulti(singleUseStop) diff --git a/single/manager_test.go b/single/manager_test.go index a4384b166322df3206937ff58d56187aed0a6f52..8efc25fe0dcb94126cb374a2634e70fdc71be45c 100644 --- a/single/manager_test.go +++ b/single/manager_test.go @@ -181,7 +181,7 @@ func TestManager_StartProcesses_Stop(t *testing.T) { t.Error("Stoppable is not running.") } - err = stop.Close(1 * time.Millisecond) + err = stop.Close() if err != nil { t.Errorf("Failed to close: %+v", err) } @@ -283,8 +283,8 @@ func (tnm *testNetworkManager) GetMsg(i int) format.Message { return tnm.msgs[i] } -func (tnm *testNetworkManager) SendE2E(_ message.Send, _ params.E2E) ([]id.Round, e2e.MessageID, error) { - return nil, [32]byte{}, nil +func (tnm *testNetworkManager) SendE2E(message.Send, params.E2E, *stoppable.Single) ([]id.Round, e2e.MessageID, error) { + return nil, e2e.MessageID{}, nil } func (tnm *testNetworkManager) SendUnsafe(_ message.Send, _ params.Unsafe) ([]id.Round, error) { diff --git a/single/receiveResponse.go b/single/receiveResponse.go index d25880dae86dee70d882463aeba08d0310daa08b..6a6fa94fddbc10f71a6e311fea944a71ea8f35eb 100644 --- a/single/receiveResponse.go +++ b/single/receiveResponse.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -20,13 +21,14 @@ import ( // receiveResponseHandler handles the reception of single-use response messages. func (m *Manager) receiveResponseHandler(rawMessages chan message.Receive, - quitChan <-chan struct{}) { + stop *stoppable.Single) { jww.DEBUG.Print("Waiting to receive single-use response messages.") for { select { - case <-quitChan: + case <-stop.Quit(): jww.DEBUG.Printf("Stopping waiting to receive single-use " + "response message.") + stop.ToStopped() return case msg := <-rawMessages: jww.DEBUG.Printf("Received CMIX message; checking if it is a " + diff --git a/single/receiveResponse_test.go b/single/receiveResponse_test.go index 4e7e8c352e8196f0ff24d5a0d60633f78439e8dd..ea75c5e3ec10cb0af1971af162b4e0fc69b78871 100644 --- a/single/receiveResponse_test.go +++ b/single/receiveResponse_test.go @@ -10,6 +10,7 @@ package single import ( "bytes" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -26,7 +27,7 @@ import ( func TestManager_ReceiveResponseHandler(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := NewContact(id.NewIdFromString("recipientID", id.User, t), m.store.E2e().GetGroup().NewInt(43), m.store.E2e().GetGroup().NewInt(42), singleUse.TagFP{}, 8) @@ -52,7 +53,7 @@ func TestManager_ReceiveResponseHandler(t *testing.T) { } }() - go m.receiveResponseHandler(rawMessages, quitChan) + go m.receiveResponseHandler(rawMessages, stop) for _, msg := range msgs { rawMessages <- message.Receive{ @@ -78,14 +79,16 @@ func TestManager_ReceiveResponseHandler(t *testing.T) { t.Errorf("Callback failed to be called.") } - quitChan <- struct{}{} + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } } // Error path: invalid CMIX message. func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := NewContact(id.NewIdFromString("recipientID", id.User, t), m.store.E2e().GetGroup().NewInt(43), m.store.E2e().GetGroup().NewInt(42), singleUse.TagFP{}, 8) @@ -106,7 +109,7 @@ func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { } }() - go m.receiveResponseHandler(rawMessages, quitChan) + go m.receiveResponseHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: make([]byte, format.MinimumPrimeSize*2), @@ -124,7 +127,9 @@ func TestManager_ReceiveResponseHandler_CmixMessageError(t *testing.T) { case <-timer.C: } - quitChan <- struct{}{} + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } } // Happy path. diff --git a/single/reception.go b/single/reception.go index 53b6c12eda52386a3544b547bee0b54b91bc1d2a..23ca6f8bef891b8ae4426fe24172d4a88c2510c5 100644 --- a/single/reception.go +++ b/single/reception.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" jww "github.com/spf13/jwalterweatherman" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" cAuth "gitlab.com/elixxir/crypto/e2e/auth" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -19,14 +20,15 @@ import ( // receiveTransmissionHandler waits to receive single-use transmissions. When // a message is received, its is returned via its registered callback. func (m *Manager) receiveTransmissionHandler(rawMessages chan message.Receive, - quitChan <-chan struct{}) { + stop *stoppable.Single) { fp := singleUse.NewTransmitFingerprint(m.store.E2e().GetDHPublicKey()) jww.DEBUG.Print("Waiting to receive single-use transmission messages.") for { select { - case <-quitChan: + case <-stop.Quit(): jww.DEBUG.Printf("Stopping waiting to receive single-use " + "transmission message.") + stop.ToStopped() return case msg := <-rawMessages: jww.DEBUG.Printf("Received CMIX message; checking if it is a " + diff --git a/single/reception_test.go b/single/reception_test.go index 3266d9904b2fb51dd592c766b61a9c40409e0925..27a87b9a181e94437b8a3b471492123c2451547b 100644 --- a/single/reception_test.go +++ b/single/reception_test.go @@ -3,6 +3,7 @@ package single import ( "bytes" "gitlab.com/elixxir/client/interfaces/message" + "gitlab.com/elixxir/client/stoppable" contact2 "gitlab.com/elixxir/crypto/contact" "gitlab.com/elixxir/crypto/e2e/singleUse" "gitlab.com/elixxir/primitives/format" @@ -17,7 +18,6 @@ import ( func TestManager_receiveTransmissionHandler(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetDHPublicKey(), @@ -35,7 +35,7 @@ func TestManager_receiveTransmissionHandler(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stoppable.NewSingle("singleStoppable")) rawMessages <- message.Receive{ Payload: msg.Marshal(), } @@ -57,7 +57,7 @@ func TestManager_receiveTransmissionHandler(t *testing.T) { func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") tag := "Test tag" payload := make([]byte, 132) rand.New(rand.NewSource(42)).Read(payload) @@ -65,8 +65,11 @@ func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) - quitChan <- struct{}{} + go m.receiveTransmissionHandler(rawMessages, stop) + + if err := stop.Close(); err != nil { + t.Errorf("Failed to signal close to process: %+v", err) + } timer := time.NewTimer(50 * time.Millisecond) @@ -82,7 +85,7 @@ func TestManager_receiveTransmissionHandler_QuitChan(t *testing.T) { func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetGroup().NewInt(42), @@ -100,7 +103,7 @@ func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: msg.Marshal(), } @@ -119,7 +122,7 @@ func TestManager_receiveTransmissionHandler_FingerPrintError(t *testing.T) { func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetDHPublicKey(), @@ -139,7 +142,7 @@ func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { m.callbackMap.registerCallback(tag, callback) - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: msg.Marshal(), } @@ -158,7 +161,7 @@ func TestManager_receiveTransmissionHandler_ProcessMessageError(t *testing.T) { func TestManager_receiveTransmissionHandler_TagFpError(t *testing.T) { m := newTestManager(0, false, t) rawMessages := make(chan message.Receive, rawMessageBuffSize) - quitChan := make(chan struct{}) + stop := stoppable.NewSingle("singleStoppable") partner := contact2.Contact{ ID: id.NewIdFromString("recipientID", id.User, t), DhPubKey: m.store.E2e().GetDHPublicKey(), @@ -173,7 +176,7 @@ func TestManager_receiveTransmissionHandler_TagFpError(t *testing.T) { t.Fatalf("Failed to create tranmission CMIX message: %+v", err) } - go m.receiveTransmissionHandler(rawMessages, quitChan) + go m.receiveTransmissionHandler(rawMessages, stop) rawMessages <- message.Receive{ Payload: msg.Marshal(), } diff --git a/single/singleUseMap_test.go b/single/singleUseMap_test.go index 3b4a18171640935735b5d3b28042141c77ded768..eb1f5ef5a4d3b2f942396be1ae72e6971081f5b4 100644 --- a/single/singleUseMap_test.go +++ b/single/singleUseMap_test.go @@ -117,7 +117,7 @@ func Test_pending_addState_TimeoutError(t *testing.T) { *expectedState, *state) } - timer := time.NewTimer(timeout * 2) + timer := time.NewTimer(timeout * 4) select { case results := <-callbackChan: diff --git a/stoppable/bindings.go b/stoppable/bindings.go deleted file mode 100644 index 55784d6522bcb0567d8eb86b282bb9895302ab57..0000000000000000000000000000000000000000 --- a/stoppable/bindings.go +++ /dev/null @@ -1,37 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package stoppable - -import "time" - -type Bindings interface { - Close(timeoutMS int) error - IsRunning() bool - Name() string -} - -func WrapForBindings(s Stoppable) Bindings { - return &bindingsStoppable{s: s} -} - -type bindingsStoppable struct { - s Stoppable -} - -func (bs *bindingsStoppable) Close(timeoutMS int) error { - timeout := time.Duration(timeoutMS) * time.Millisecond - return bs.s.Close(timeout) -} - -func (bs *bindingsStoppable) IsRunning() bool { - return bs.s.IsRunning() -} - -func (bs *bindingsStoppable) Name() string { - return bs.s.Name() -} diff --git a/stoppable/cleanup.go b/stoppable/cleanup.go deleted file mode 100644 index 0a84235087dfc0d420c8695679f036294ac9460c..0000000000000000000000000000000000000000 --- a/stoppable/cleanup.go +++ /dev/null @@ -1,99 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package stoppable - -import ( - "github.com/pkg/errors" - jww "github.com/spf13/jwalterweatherman" - "gitlab.com/xx_network/primitives/netTime" - "sync" - "sync/atomic" - "time" -) - -const nameTag = " with cleanup" - -type CleanFunc func(duration time.Duration) error - -// Cleanup wraps any stoppable and runs a callback after to stop for cleanup -// behavior. The cleanup is run under the remainder of the timeout but will not -// be canceled if the timeout runs out. The cleanup function does not run if the -// thread does not stop. -type Cleanup struct { - stop Stoppable - // clean receives how long it has to run before the timeout, this is not - // expected to be used in most cases - clean CleanFunc - running uint32 - once sync.Once -} - -// NewCleanup creates a new Cleanup from the passed in stoppable and clean -// function. -func NewCleanup(stop Stoppable, clean CleanFunc) *Cleanup { - return &Cleanup{ - stop: stop, - clean: clean, - running: stopped, - } -} - -// IsRunning returns true if the thread is still running and its cleanup has -// completed. -func (c *Cleanup) IsRunning() bool { - return atomic.LoadUint32(&c.running) == running -} - -// Name returns the name of the stoppable denoting it has cleanup. -func (c *Cleanup) Name() string { - return c.stop.Name() + nameTag -} - -// Close stops the wrapped stoppable and after, runs the cleanup function. The -// cleanup function does not run if the thread fails to stop. -func (c *Cleanup) Close(timeout time.Duration) error { - var err error - - c.once.Do( - func() { - defer atomic.StoreUint32(&c.running, stopped) - start := netTime.Now() - - // Close each stoppable - if err := c.stop.Close(timeout); err != nil { - err = errors.WithMessagef(err, "Cleanup not executed for %s", - c.stop.Name()) - return - } - - // Run the cleanup function with the remaining time as a timeout - elapsed := netTime.Now().Sub(start) - - complete := make(chan error, 1) - go func() { - complete <- c.clean(elapsed) - }() - - select { - case err := <-complete: - if err != nil { - err = errors.WithMessagef(err, "Cleanup for %s failed", - c.stop.Name()) - } - case <-time.NewTimer(elapsed).C: - err = errors.Errorf("Clean up for %s timed out after %s", - c.stop.Name(), elapsed) - } - }) - - if err != nil { - jww.ERROR.Printf(err.Error()) - } - - return err -} diff --git a/stoppable/cleanup_test.go b/stoppable/cleanup_test.go deleted file mode 100644 index bc60be5def7b0a6296711262c456da591adfd08e..0000000000000000000000000000000000000000 --- a/stoppable/cleanup_test.go +++ /dev/null @@ -1,78 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// Copyright © 2020 xx network SEZC // -// // -// Use of this source code is governed by a license that can be found in the // -// LICENSE file // -/////////////////////////////////////////////////////////////////////////////// - -package stoppable - -import ( - "testing" -) - -// Tests that NewCleanup returns a Cleanup that is stopped with the given -// Stoppable. -func TestNewCleanup(t *testing.T) { - single := NewSingle("testSingle") - cleanup := NewCleanup(single, single.Close) - - if cleanup.stop != single { - t.Errorf("NewCleanup returned cleanup with incorrect Stoppable."+ - "\nexpected: %+v\nreceived: %+v", single, cleanup.stop) - } - - if cleanup.running != stopped { - t.Errorf("NewMulti returned Multi with incorrect running."+ - "\nexpected: %d\nreceived: %d", stopped, cleanup.running) - } -} - -// Tests that Cleanup.IsRunning returns the expected value when the Cleanup is -// marked as both running and not running. -func TestCleanup_IsRunning(t *testing.T) { - single := NewSingle("threadName") - cleanup := NewCleanup(single, single.Close) - - if cleanup.IsRunning() { - t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", true, cleanup.IsRunning()) - } - - cleanup.running = running - if !single.IsRunning() { - t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", false, single.IsRunning()) - } -} - -// Unit test of Cleanup.Name. -func TestCleanup_Name(t *testing.T) { - name := "threadName" - single := NewSingle(name) - cleanup := NewCleanup(single, single.Close) - - if name+nameTag != cleanup.Name() { - t.Errorf("Name did not return the expected name."+ - "\nexpected: %s\nreceived: %s", name+nameTag, cleanup.Name()) - } -} - -// Tests happy path of Cleanup.Close(). -func TestCleanup_Close(t *testing.T) { - single := NewSingle("threadName") - cleanup := NewCleanup(single, single.Close) - - // go func() { - // select { - // case <-time.NewTimer(10 * time.Millisecond).C: - // t.Error("Timed out waiting for quit channel.") - // case <-single.Quit(): - // } - // }() - - err := cleanup.Close(0) - if err != nil { - t.Errorf("Close() returned an error: %+v", err) - } -} diff --git a/stoppable/multi.go b/stoppable/multi.go index 5e60ab2d7759d6874f8018a18643a96dc1e7af26..60d1c8530300f2d52babf469d923627f236a6f1d 100644 --- a/stoppable/multi.go +++ b/stoppable/multi.go @@ -13,16 +13,14 @@ import ( "strings" "sync" "sync/atomic" - "time" ) // Error message. -const closeMultiErr = "MultiStopper %s failed to close %d/%d stoppers" +const closeMultiErr = "multi stoppable %q failed to close %d/%d stoppables" type Multi struct { stoppables []Stoppable name string - running uint32 mux sync.RWMutex once sync.Once } @@ -30,17 +28,11 @@ type Multi struct { // NewMulti returns a new multi Stoppable. func NewMulti(name string) *Multi { return &Multi{ - name: name, - running: running, + name: name, } } -// IsRunning returns true if stoppable is marked as running. -func (m *Multi) IsRunning() bool { - return atomic.LoadUint32(&m.running) == running -} - -// Add adds the given stoppable to the list of stoppables. +// Add adds the given Stoppable to the list of stoppables. func (m *Multi) Add(stoppable Stoppable) { m.mux.Lock() m.stoppables = append(m.stoppables, stoppable) @@ -51,47 +43,85 @@ func (m *Multi) Add(stoppable Stoppable) { // it contains. func (m *Multi) Name() string { m.mux.RLock() - defer m.mux.RUnlock() names := make([]string, len(m.stoppables)) for i, s := range m.stoppables { names[i] = s.Name() } - return m.name + ": {" + strings.Join(names, ", ") + "}" + m.mux.RUnlock() + + return m.name + "{" + strings.Join(names, ", ") + "}" +} + +// GetStatus returns the lowest status of all of the Stoppable children. The +// status is not the status of all Stoppables, but the status of the Stoppable +// with the lowest status. +func (m *Multi) GetStatus() Status { + lowestStatus := Stopped + m.mux.RLock() + + for _, s := range m.stoppables { + status := s.GetStatus() + if status < lowestStatus { + lowestStatus = status + } + } + + m.mux.RUnlock() + + return lowestStatus +} + +// IsRunning returns true if Stoppable is marked as running. +func (m *Multi) IsRunning() bool { + return m.GetStatus() == Running +} + +// IsStopping returns true if Stoppable is marked as stopping. +func (m *Multi) IsStopping() bool { + return m.GetStatus() == Stopping +} + +// IsStopped returns true if Stoppable is marked as stopped. +func (m *Multi) IsStopped() bool { + return m.GetStatus() == Stopped } -// Close closes all child stoppers. It does not return their errors and assumes -// they print them to the log. -func (m *Multi) Close(timeout time.Duration) error { - var err error - m.once.Do( - func() { - atomic.StoreUint32(&m.running, stopped) - - var numErrors uint32 - var wg sync.WaitGroup - - m.mux.Lock() - for _, stoppable := range m.stoppables { - wg.Add(1) - go func(stoppable Stoppable) { - if stoppable.Close(timeout) != nil { - atomic.AddUint32(&numErrors, 1) - } - wg.Done() - }(stoppable) - } - m.mux.Unlock() - - wg.Wait() - - if numErrors > 0 { - err = errors.Errorf( - closeMultiErr, m.name, numErrors, len(m.stoppables)) - jww.ERROR.Print(err.Error()) - } - }) - - return err +// Close issues a close signal to all child stoppables and marks the status of +// the Multi Stoppable as stopping. Returns an error if one or more child +// stoppables failed to close but it does not return their specific errors and +// assumes they print them to the log. +func (m *Multi) Close() error { + var numErrors uint32 + + m.once.Do(func() { + var wg sync.WaitGroup + + jww.TRACE.Printf("Sending on quit channel to multi stoppable %q.", + m.Name()) + + m.mux.Lock() + // Attempt to stop each stoppable in its own goroutine + for _, stoppable := range m.stoppables { + wg.Add(1) + go func(stoppable Stoppable) { + if stoppable.Close() != nil { + atomic.AddUint32(&numErrors, 1) + } + wg.Done() + }(stoppable) + } + m.mux.Unlock() + + wg.Wait() + }) + + if numErrors > 0 { + err := errors.Errorf(closeMultiErr, m.name, numErrors, len(m.stoppables)) + jww.ERROR.Print(err.Error()) + return err + } + + return nil } diff --git a/stoppable/multi_test.go b/stoppable/multi_test.go index b633e1d77d4e80317cff8ca938c0ff440e855baa..4a7eaf0873019d7ab9bd89e4f6aac2ba8f1661c9 100644 --- a/stoppable/multi_test.go +++ b/stoppable/multi_test.go @@ -12,6 +12,8 @@ import ( "reflect" "strconv" "strings" + "sync" + "sync/atomic" "testing" "time" ) @@ -25,28 +27,6 @@ func TestNewMulti(t *testing.T) { t.Errorf("NewMulti returned Multi with incorrect name."+ "\nexpected: %s\nreceived: %s", name, multi.name) } - - if multi.running != running { - t.Errorf("NewMulti returned Multi with incorrect running."+ - "\nexpected: %d\nreceived: %d", running, multi.running) - } -} - -// Tests that Multi.IsRunning returns the expected value when the Multi is -// marked as both running and not running. -func TestMulti_IsRunning(t *testing.T) { - multi := NewMulti("testMulti") - - if !multi.IsRunning() { - t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", true, multi.IsRunning()) - } - - multi.running = stopped - if multi.IsRunning() { - t.Errorf("IsRunning returned the wrong value when not running."+ - "\nexpected: %t\nreceived: %t", false, multi.IsRunning()) - } } // Tests that Multi.Add adds all the stoppables to the list. @@ -93,7 +73,7 @@ func TestMulti_Name(t *testing.T) { nameList = append(nameList, newName) } - expected := name + ": {" + strings.Join(nameList, ", ") + "}" + expected := name + "{" + strings.Join(nameList, ", ") + "}" if multi.Name() != expected { t.Errorf("Name failed to return the expected string."+ @@ -106,7 +86,7 @@ func TestMulti_Name_NoStoppables(t *testing.T) { name := "testMulti" multi := NewMulti(name) - expected := name + ": {" + "}" + expected := name + "{}" if multi.Name() != expected { t.Errorf("Name failed to return the expected string."+ @@ -114,8 +94,136 @@ func TestMulti_Name_NoStoppables(t *testing.T) { } } -// Tests that Multi.Close sends on all Single quit channels. -func TestMulti_Close(t *testing.T) { +// Tests that Multi.GetStatus returns the expected Status. +func TestMulti_GetStatus(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + multi.Add(single1) + multi.Add(single2) + + status := multi.GetStatus() + if status != Running { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Running, status) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + status = multi.GetStatus() + if status != Stopping { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopping, status) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopped)) + status = multi.GetStatus() + if status != Stopped { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopped, status) + } +} + +// Tests that Multi.GetStatus returns the expected Status when it has no +// children. +func TestMulti_GetStatus_NoChildren(t *testing.T) { + multi := NewMulti("testMulti") + + status := multi.GetStatus() + if status != Stopped { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopped, status) + } +} + +// Tests that Multi.IsRunning returns the expected value when the Multi is +// marked as running, stopping, and stopped. +func TestMulti_IsRunning(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopping)) + multi.Add(single1) + multi.Add(single2) + + if result := multi.IsRunning(); !result { + t.Errorf("IsRunning returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + if result := multi.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + if result := multi.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Multi.IsStopping returns the expected value when the Multi is +// marked as running, stopping, and stopped. +func TestMulti_IsStopping(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + multi.Add(single1) + multi.Add(single2) + + if result := multi.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + if result := multi.IsStopping(); !result { + t.Errorf("IsStopping returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopped)) + if result := multi.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Multi.IsStopped returns the expected value when the Multi is +// marked as running, stopping, and stopped. +func TestMulti_IsStopped(t *testing.T) { + multi := NewMulti("testMulti") + single1 := NewSingle("testSingle1") + single2 := NewSingle("testSingle2") + atomic.StoreUint32((*uint32)(&single2.status), uint32(Stopped)) + multi.Add(single1) + multi.Add(single2) + + if result := multi.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopping)) + if result := multi.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + atomic.StoreUint32((*uint32)(&single1.status), uint32(Stopped)) + if result := multi.IsStopped(); !result { + t.Errorf("IsStopped returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Multi.IsStopped returns true when all of the child stoppables are +// stopped. +func TestMulti_IsStopped_StoppedStatus(t *testing.T) { multi := NewMulti("testMulti") singles := []*Single{ NewSingle("testSingle0"), @@ -125,37 +233,48 @@ func TestMulti_Close(t *testing.T) { NewSingle("testSingle4"), } for _, single := range singles[:3] { + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) multi.Add(single) } subMulti := NewMulti("subMulti") for _, single := range singles[3:] { + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) subMulti.Add(single) } multi.Add(subMulti) + if !multi.IsStopped() { + t.Error("IsStopped did not find all stoppables as stopped.") + } +} + +// Error path: tests that Multi.IsStopped returns false when not all of the +// child stoppables are stopped. +func TestMulti_IsStopped_NotStoppedError(t *testing.T) { + multi := NewMulti("testMulti") + singles := []*Single{ + NewSingle("testSingle0"), + NewSingle("testSingle1"), + NewSingle("testSingle2"), + NewSingle("testSingle3"), + NewSingle("testSingle4"), + } for _, single := range singles { - go func(single *Single) { - select { - case <-time.NewTimer(5 * time.Millisecond).C: - t.Errorf("Single %s failed to quit.", single.Name()) - case <-single.Quit(): - } - }(single) + multi.Add(single) } - err := multi.Close(5 * time.Millisecond) - if err != nil { - t.Errorf("Close() returned an error: %v", err) + for _, single := range singles[:4] { + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) } - err = multi.Close(0) - if err != nil { - t.Errorf("Close() returned an error: %v", err) + if multi.IsStopped() { + t.Error("IsStopped found all the stoppables as stopped when some are " + + "still running") } } // Tests that Multi.Close sends on all Single quit channels. -func TestMulti_Close_Error(t *testing.T) { +func TestMulti_Close(t *testing.T) { multi := NewMulti("testMulti") singles := []*Single{ NewSingle("testSingle0"), @@ -173,7 +292,7 @@ func TestMulti_Close_Error(t *testing.T) { } multi.Add(subMulti) - for _, single := range singles[:2] { + for _, single := range singles { go func(single *Single) { select { case <-time.NewTimer(5 * time.Millisecond).C: @@ -182,16 +301,56 @@ func TestMulti_Close_Error(t *testing.T) { } }(single) } + + err := multi.Close() + if err != nil { + t.Errorf("Close() returned an error: %v", err) + } + + err = multi.Close() + if err != nil { + t.Errorf("Close() returned an error: %v", err) + } +} + +// Error path: tests that Multi.Close returns the expected error when the Single +// stoppables are not running. +func TestMulti_Close_StoppableCloseError(t *testing.T) { + multi := NewMulti("testMulti") + var singles []*Single + for i := 0; i < 5; i++ { + single := NewSingle("testSingle" + strconv.Itoa(i)) + singles = append(singles, single) + multi.Add(single) + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) + } + + var wg sync.WaitGroup + for _, single := range singles { + wg.Add(1) + go func(single *Single) { + select { + case <-time.NewTimer(15 * time.Millisecond).C: + case <-single.Quit(): + t.Errorf("Single %s to quit when it should have failed.", + single.Name()) + } + wg.Done() + }(single) + } + expectedErr := fmt.Sprintf(closeMultiErr, multi.name, 0, 0) expectedErr = strings.SplitN(expectedErr, " 0/0", 2)[0] - err := multi.Close(5 * time.Millisecond) + err := multi.Close() if err == nil || !strings.Contains(err.Error(), expectedErr) { t.Errorf("Close() did not return the expected error."+ "\nexpected: %s\nreceived: %v", expectedErr, err) } - err = multi.Close(0) + wg.Wait() + + err = multi.Close() if err != nil { t.Errorf("Close() returned an error: %v", err) } diff --git a/stoppable/single.go b/stoppable/single.go index df2a5468c4aa1a348ea3b26c85949cca2b54ec0a..dfde7242ed83f0af975efa0656933b56d0b8145a 100644 --- a/stoppable/single.go +++ b/stoppable/single.go @@ -12,33 +12,79 @@ import ( jww "github.com/spf13/jwalterweatherman" "sync" "sync/atomic" - "time" ) // Error message. -const closeTimeoutErr = "stopper for %s failed to stop after timeout of %s" +const toStoppingErr = "failed to set the status of single stoppable %q to " + + "stopped when status is %s instead of %s" // Single allows stopping a single goroutine using a channel. It adheres to the // Stoppable interface. type Single struct { - name string - quit chan struct{} - running uint32 - once sync.Once + name string + quit chan struct{} + status Status + once sync.Once } // NewSingle returns a new single Stoppable. func NewSingle(name string) *Single { return &Single{ - name: name, - quit: make(chan struct{}), - running: running, + name: name, + quit: make(chan struct{}, 1), + status: Running, } } -// IsRunning returns true if stoppable is marked as running. +// Name returns the name of the Single Stoppable. +func (s *Single) Name() string { + return s.name +} + +// GetStatus returns the status of the Stoppable. +func (s *Single) GetStatus() Status { + return Status(atomic.LoadUint32((*uint32)(&s.status))) +} + +// IsRunning returns true if Stoppable is marked as running. func (s *Single) IsRunning() bool { - return atomic.LoadUint32(&s.running) == running + return s.GetStatus() == Running +} + +// IsStopping returns true if Stoppable is marked as stopping. +func (s *Single) IsStopping() bool { + return s.GetStatus() == Stopping +} + +// IsStopped returns true if Stoppable is marked as stopped. +func (s *Single) IsStopped() bool { + return s.GetStatus() == Stopped +} + +// toStopping changes the status from running to stopping. An error is returned +// if the status is not already set to running. +func (s *Single) toStopping() error { + if !atomic.CompareAndSwapUint32((*uint32)(&s.status), uint32(Running), uint32(Stopping)) { + return errors.Errorf(toStoppingErr, s.Name(), s.GetStatus(), Running) + } + + jww.TRACE.Printf("Switched status of single stoppable %q from %s to %s.", + s.Name(), Running, Stopping) + + return nil +} + +// ToStopped changes the status from stopping to stopped. Panics if the status +// is not already set to stopping. +func (s *Single) ToStopped() { + if !atomic.CompareAndSwapUint32((*uint32)(&s.status), uint32(Stopping), uint32(Stopped)) { + jww.FATAL.Panicf("Failed to set the status of single stoppable %q to "+ + "stopped when status is %s instead of %s.", + s.Name(), s.GetStatus(), Stopping) + } + + jww.TRACE.Printf("Switched status of single stoppable %q from %s to %s.", + s.Name(), Stopping, Stopped) } // Quit returns a receive-only channel that will be triggered when the Stoppable @@ -47,24 +93,28 @@ func (s *Single) Quit() <-chan struct{} { return s.quit } -// Name returns the name of the Single Stoppable. -func (s *Single) Name() string { - return s.name -} - // Close signals the Single to close via the quit channel. Returns an error if -// sending on the quit channel times out. -func (s *Single) Close(timeout time.Duration) error { +// the status of the Single is not Running. +func (s *Single) Close() error { var err error + s.once.Do(func() { - select { - case <-time.NewTimer(timeout).C: - err = errors.Errorf(closeTimeoutErr, s.name, timeout) - jww.ERROR.Print(err.Error()) - case s.quit <- struct{}{}: + // Attempt to set status to stopping or return an error if unable + err = s.toStopping() + if err != nil { + return } - atomic.StoreUint32(&s.running, stopped) + + jww.TRACE.Printf("Sending on quit channel to single stoppable %q.", + s.Name()) + + // Send on quit channel + s.quit <- struct{}{} }) + if err != nil { + jww.ERROR.Print(err.Error()) + } + return err } diff --git a/stoppable/single_test.go b/stoppable/single_test.go index 4898a4a59008c6ea747f7a54c59ee4b84388b23c..c93a1ebfcc1815b2d5c82e7f2e2dc43ff96a076c 100644 --- a/stoppable/single_test.go +++ b/stoppable/single_test.go @@ -9,6 +9,7 @@ package stoppable import ( "fmt" + "sync/atomic" "testing" "time" ) @@ -23,29 +24,186 @@ func TestNewSingle(t *testing.T) { "\nexpected: %s\nreceived: %s", name, single.name) } - if single.running != running { - t.Errorf("NewSingle returned Single with incorrect running."+ - "\nexpected: %d\nreceived: %d", running, single.running) + if single.status != Running { + t.Errorf("NewSingle returned Single with incorrect status."+ + "\nexpected: %s\nreceived: %s", Running, single.status) + } +} + +// Unit test of Single.Name. +func TestSingle_Name(t *testing.T) { + name := "threadName" + single := NewSingle(name) + + if name != single.Name() { + t.Errorf("Name did not return the expected name."+ + "\nexpected: %s\nreceived: %s", name, single.Name()) + } +} + +// Tests that Single.GetStatus returns the expected Status. +func TestSingle_GetStatus(t *testing.T) { + single := NewSingle("threadName") + + status := single.GetStatus() + if status != Running { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Running, status) + } + + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopping)) + status = single.GetStatus() + if status != Stopping { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopping, status) + } + + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) + status = single.GetStatus() + if status != Stopped { + t.Errorf("GetStatus returned the wrong status."+ + "\nexpected: %s\nreceived: %s", Stopped, status) } } // Tests that Single.IsRunning returns the expected value when the Single is -// marked as both running and not running. +// marked as running, stopping, and stopped. func TestSingle_IsRunning(t *testing.T) { single := NewSingle("threadName") - if !single.IsRunning() { + if result := single.IsRunning(); !result { t.Errorf("IsRunning returned the wrong value when running."+ - "\nexpected: %t\nreceived: %t", true, single.IsRunning()) + "\nexpected: %t\nreceived: %t", true, result) + } + + single.status = Stopping + if result := single.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + single.status = Stopped + if result := single.IsRunning(); result { + t.Errorf("IsRunning returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Single.IsStopping returns the expected value when the Single is +// marked as running, stopping, and stopped. +func TestSingle_IsStopping(t *testing.T) { + single := NewSingle("threadName") + + if result := single.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + single.status = Stopping + if result := single.IsStopping(); !result { + t.Errorf("IsStopping returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + single.status = Stopped + if result := single.IsStopping(); result { + t.Errorf("IsStopping returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Single.IsStopped returns the expected value when the Single is +// marked as running, stopping, and stopped. +func TestSingle_IsStopped(t *testing.T) { + single := NewSingle("threadName") + + if result := single.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when running."+ + "\nexpected: %t\nreceived: %t", true, result) + } + + single.status = Stopping + if result := single.IsStopped(); result { + t.Errorf("IsStopped returned the wrong value when stopping."+ + "\nexpected: %t\nreceived: %t", false, result) + } + + single.status = Stopped + if result := single.IsStopped(); !result { + t.Errorf("IsStopped returned the wrong value when stopped."+ + "\nexpected: %t\nreceived: %t", false, result) + } +} + +// Tests that Single.toStopping changes the status to stopping. +func TestSingle_toStopping(t *testing.T) { + single := NewSingle("threadName") + + err := single.toStopping() + if err != nil { + t.Errorf("toStopping returned an error: %+v", err) } - single.running = stopped - if single.IsRunning() { - t.Errorf("IsRunning returned the wrong value when not running."+ - "\nexpected: %t\nreceived: %t", false, single.IsRunning()) + if single.status != Stopping { + t.Errorf("toStopping failed to set the status correctly."+ + "\nexpected: %s\nreceived: %s", Stopping, single.status) } } +// Error path: tests that Single.toStopping returns an error when failing to +// change the status to stopping when the current status is not running. +func TestSingle_toStopping_StatusError(t *testing.T) { + single := NewSingle("threadName") + single.status = Stopped + expectedErr := fmt.Sprintf( + toStoppingErr, single.Name(), single.GetStatus(), Running) + + err := single.toStopping() + if err == nil || err.Error() != expectedErr { + t.Errorf("toStopping failed to return the expected error."+ + "\nexpected: %s\nreceived: %+v", expectedErr, err) + } + + if single.status != Stopped { + t.Errorf("toStopping changed the status when the compare failed."+ + "\nexpected: %s\nreceived: %s", Stopped, single.status) + } +} + +// Tests that Single.ToStopped changes the status to stopped. +func TestSingle_ToStopped(t *testing.T) { + single := NewSingle("threadName") + + single.status = Stopping + single.ToStopped() + + if single.status != Stopped { + t.Errorf("ToStopped failed to set the status correctly."+ + "\nexpected: %s\nreceived: %s", Stopped, single.status) + } +} + +// Panic path: tests that Single.ToStopped panics when failing to change the +// status to stopped when the current status is not stopping. +func TestSingle_ToStopped_StatusPanic(t *testing.T) { + single := NewSingle("threadName") + + defer func() { + if r := recover(); r == nil { + t.Errorf("ToStopped failed to panic when the status should not " + + "have changed.") + } else { + if single.status != Running { + t.Errorf("ToStopped changed the status when the compare failed."+ + "\nexpected: %s\nreceived: %s", Running, single.status) + } + } + }() + + single.status = Running + single.ToStopped() +} + // Tests that Single.Quit returns a channel that is triggered when the Single // quit channel is triggered. func TestSingle_Quit(t *testing.T) { @@ -62,48 +220,40 @@ func TestSingle_Quit(t *testing.T) { single.quit <- struct{}{} } -// Unit test of Single.Name. -func TestSingle_Name(t *testing.T) { - name := "threadName" - single := NewSingle(name) - - if name != single.Name() { - t.Errorf("Name did not return the expected name."+ - "\nexpected: %s\nreceived: %s", name, single.Name()) - } -} - // Test happy path of Single.Close(). func TestSingle_Close(t *testing.T) { single := NewSingle("threadName") + timeout := 10 * time.Millisecond go func() { select { - case <-time.NewTimer(10 * time.Millisecond).C: - t.Error("Timed out waiting for quit channel.") + case <-time.NewTimer(timeout).C: + t.Errorf("Timed out waiting to receive on quit channel after %s.", + timeout) case <-single.Quit(): + if !single.IsStopping() { + t.Errorf("Status of stoppable incorrect."+ + "\nexpected: %s\nreceived: %s", Stopping, single.status) + } + atomic.StoreUint32((*uint32)(&single.status), uint32(Stopped)) } }() - err := single.Close(5 * time.Millisecond) + err := single.Close() if err != nil { t.Errorf("Close returned an error: %v", err) } } -// Error path: tests that Single.Close returns an error when the timeout is -// reached. +// Error path: tests that Single.Close returns an error when the status fails +// to change to stopping. func TestSingle_Close_Error(t *testing.T) { single := NewSingle("threadName") - timeout := time.Millisecond - expectedErr := fmt.Sprintf(closeTimeoutErr, single.Name(), timeout) - - go func() { - time.Sleep(5 * time.Millisecond) - <-single.Quit() - }() + single.status = Stopped + expectedErr := fmt.Sprintf( + toStoppingErr, single.Name(), single.GetStatus(), Running) - err := single.Close(timeout) + err := single.Close() if err == nil || err.Error() != expectedErr { t.Errorf("Close did not return the expected error."+ "\nexpected: %s\nreceived: %v", expectedErr, err) diff --git a/stoppable/status.go b/stoppable/status.go new file mode 100644 index 0000000000000000000000000000000000000000..1b306bd69a13394d8321c52b742ef5ecf82a83a9 --- /dev/null +++ b/stoppable/status.go @@ -0,0 +1,36 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package stoppable + +import ( + "strconv" +) + +const ( + Running Status = iota + Stopping + Stopped +) + +// Status holds the current status of a Stoppable. +type Status uint32 + +// String prints a string representation of the current Status. This functions +// satisfies the fmt.Stringer interface. +func (s Status) String() string { + switch s { + case Running: + return "running" + case Stopping: + return "stopping" + case Stopped: + return "stopped" + default: + return "INVALID STATUS: " + strconv.FormatUint(uint64(s), 10) + } +} diff --git a/stoppable/status_test.go b/stoppable/status_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c2f4bcd05f106e1f0bc6e6083a88096a3cdde1f7 --- /dev/null +++ b/stoppable/status_test.go @@ -0,0 +1,32 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +//////////////////////////////////////////////////////////////////////////////// + +package stoppable + +import ( + "testing" +) + +// Unit test of Status.String. +func TestStatus_String(t *testing.T) { + testValues := []struct { + status Status + expected string + }{ + {Running, "running"}, + {Stopping, "stopping"}, + {Stopped, "stopped"}, + {100, "INVALID STATUS: 100"}, + } + + for i, val := range testValues { + if val.status.String() != val.expected { + t.Errorf("String did not return the expected value (%d)."+ + "\nexpected: %s\nreceived: %s", i, val.status.String(), val.expected) + } + } +} diff --git a/stoppable/stoppable.go b/stoppable/stoppable.go index b3b219edf9883e12cc301c8570199c35e40bbb27..b5b072d1424feddf97cd029fa08d519ef2998c01 100644 --- a/stoppable/stoppable.go +++ b/stoppable/stoppable.go @@ -7,16 +7,80 @@ package stoppable -import "time" +import ( + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "strings" + "time" +) +// Error message returned after a comms operations ends and finds that its +// parent thread is stopping or stopped. const ( - stopped = 0 - running = 1 + errKey = "[StoppableNotRunning]" + ErrMsg = "stoppable %q is not running, exiting %s early " + errKey + timeoutErr = "timed out after %s waiting for the stoppable to stop for %q" ) -// Stoppable interface for stopping a goroutine. +// pollPeriod is the duration to wait between polls to see of stoppables are +// stopped. +const pollPeriod = 100 * time.Millisecond + +// Stoppable interface for stopping a goroutine. All functions are thread safe. type Stoppable interface { - Close(timeout time.Duration) error - IsRunning() bool + // Name returns the name of the Stoppable. Name() string + + // GetStatus returns the status of the Stoppable. + GetStatus() Status + + // IsRunning returns true if the Stoppable is running. + IsRunning() bool + + // IsStopping returns true if Stoppable is marked as stopping. + IsStopping() bool + + // IsStopped returns true if Stoppable is marked as stopped. + IsStopped() bool + + // Close marks the Stoppable as stopping and issues a close signal to the + // Stoppable or any children it may have. + Close() error +} + +// WaitForStopped polls the stoppable and all its children to see if they are +// stopped. Returns an error if its times out waiting for all children to stop. +func WaitForStopped(s Stoppable, timeout time.Duration) error { + done := make(chan struct{}) + + // Launch the processes to check if all stoppables are stopped in separate + // goroutine so that when the timeout is reached, no time is wasted exiting + go func() { + for !s.IsStopped() { + time.Sleep(pollPeriod) + } + + select { + case done <- struct{}{}: + case <-time.NewTimer(50 * time.Millisecond).C: + } + }() + + select { + case <-done: + jww.INFO.Printf("All stoppables have stopped for %q.", s.Name()) + return nil + case <-time.NewTimer(timeout).C: + return errors.Errorf(timeoutErr, timeout, s.Name()) + } +} + +// CheckErr returns true if the error contains a stoppable error message. This +// function is used by callers to determine if a sub function quit due to a +// stoppable closing and tells the caller to exit. +func CheckErr(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), errKey) } diff --git a/stoppable/stoppable_test.go b/stoppable/stoppable_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0f070317aedef9fca29a0b66d023b4e7d151517f --- /dev/null +++ b/stoppable/stoppable_test.go @@ -0,0 +1,123 @@ +/////////////////////////////////////////////////////////////////////////////// +// Copyright © 2020 xx network SEZC // +// // +// Use of this source code is governed by a license that can be found in the // +// LICENSE file // +/////////////////////////////////////////////////////////////////////////////// + +package stoppable + +import ( + "fmt" + "github.com/pkg/errors" + jww "github.com/spf13/jwalterweatherman" + "os" + "strconv" + "testing" + "time" +) + +func TestMain(m *testing.M) { + jww.SetStdoutThreshold(jww.LevelTrace) + + os.Exit(m.Run()) +} + +// Tests that WaitForStopped does not return an error when all children are +// stopped. +func TestWaitForStopped(t *testing.T) { + m := newTestMulti() + + err := m.Close() + if err != nil { + t.Errorf("Failed to close multi stoppable: %+v", err) + } + + err = WaitForStopped(m, 2*time.Second) + if err != nil { + t.Errorf("WaitForStopped returned an error: %+v", err) + } +} + +// Error path: tests that WaitForStopped returns an error if the timeout is +// reached before all stoppables are checked. +func TestWaitForStopped_TimeoutError(t *testing.T) { + m := newTestMulti() + + err := m.Close() + if err != nil { + t.Errorf("Failed to close multi stoppable: %+v", err) + } + + expectedErr := fmt.Sprintf(timeoutErr, time.Duration(0), m.Name()) + + err = WaitForStopped(m, 0) + if err == nil || err.Error() != expectedErr { + t.Errorf("WaitForStopped did not return the expected error."+ + "\nexpected: %s\nrecieved: %+v", expectedErr, err) + } +} + +// Tests that TestCheckErr returns true for stoppable errors and false for all +// other errors +func TestCheckErr(t *testing.T) { + testValues := []struct { + err error + expected bool + }{ + {errors.Errorf(ErrMsg, "testThre", "testFunc"), true}, + {errors.Errorf(ErrMsg, "", ""), true}, + {errors.Errorf(errKey), true}, + {errors.Errorf("Random error"), false}, + {errors.Errorf(""), false}, + {nil, false}, + } + + for i, val := range testValues { + result := CheckErr(val.err) + if result != val.expected { + t.Errorf("CheckErr failed to return the expected value (%d)."+ + "\nexpected: %t\nreceived: %t", i, val.expected, result) + } + } +} + +// newTestMulti creates a new Multi Stoppable that has many Single and Multi +// stoppable children. +func newTestMulti() *Multi { + singles := make([]*Single, 15) + for i := range singles { + singles[i] = NewSingle("testSingle_" + strconv.Itoa(i)) + go func(single *Single) { + <-single.Quit() + time.Sleep(600 * time.Millisecond) + single.ToStopped() + }(singles[i]) + } + + m := NewMulti("testMulti") + for _, s := range singles[:5] { + m.Add(s) + } + m0 := NewMulti("testMulti_0") + for _, s := range singles[5:8] { + m0.Add(s) + } + m.Add(m0) + m1 := NewMulti("testMulti_1") + for _, s := range singles[8:10] { + m1.Add(s) + } + m2 := NewMulti("testMulti_2") + for _, s := range singles[10:13] { + m2.Add(s) + } + m1.Add(m2) + m.Add(m1) + for _, s := range singles[13:] { + m.Add(s) + } + m.Add(NewMulti("testMulti_3")) + + return m +}