diff --git a/cmd/connect.go b/cmd/connect.go index 53eeca3941b58132acd37c76ba41986c2051e5fb..478af0988a67202ef3af7f5f8df86eb241e142ed 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -17,6 +17,9 @@ import ( "gitlab.com/elixxir/client/e2e" "gitlab.com/elixxir/client/e2e/receive" "gitlab.com/elixxir/client/xxdk" + "os" + "os/signal" + "syscall" "time" ) @@ -36,8 +39,8 @@ var connectionCmd = &cobra.Command{ regCode := viper.GetString("regcode") cmixParams, e2eParams := initParams() forceLegacy := viper.GetBool("force-legacy") - if viper.GetBool("startServer") { - if viper.GetBool("authenticated") { + if viper.GetBool(connectionStartServerFlag) { + if viper.GetBool(connectionAuthenticatedFlag) { secureConnServer(forceLegacy, statePass, statePath, regCode, cmixParams, e2eParams) } else { @@ -45,7 +48,7 @@ var connectionCmd = &cobra.Command{ cmixParams, e2eParams) } } else { - if viper.GetBool("authenticated") { + if viper.GetBool(connectionAuthenticatedFlag) { secureConnClient(forceLegacy, statePass, statePath, regCode, cmixParams, e2eParams) } else { @@ -152,7 +155,7 @@ func secureConnServer(forceLegacy bool, statePass []byte, statePath, regCode str } // Keep server running to receive messages------------------------------------ - serverTimeout := viper.GetDuration("serverTimeout") + serverTimeout := viper.GetDuration(connectionServerTimeoutFlag) if serverTimeout != 0 { timer := time.NewTimer(serverTimeout) select { @@ -163,8 +166,21 @@ func secureConnServer(forceLegacy bool, statePass []byte, statePath, regCode str } } - // If timeout is not specified, leave as long-running thread - select {} + // Keep app running to receive messages------------------------------------ + + // Wait until the user terminates the program + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + + err = connectServer.Messenger.StopNetworkFollower() + if err != nil { + jww.ERROR.Printf("Failed to stop network follower: %+v", err) + } else { + jww.INFO.Printf("Stopped network follower.") + } + + os.Exit(0) } @@ -257,8 +273,8 @@ func insecureConnServer(forceLegacy bool, statePass []byte, statePath, regCode s } // Keep server running to receive messages------------------------------------ - if viper.GetDuration("serverTimeout") != 0 { - timer := time.NewTimer(viper.GetDuration("serverTimeout")) + if viper.GetDuration(connectionServerTimeoutFlag) != 0 { + timer := time.NewTimer(viper.GetDuration(connectionServerTimeoutFlag)) select { case <-timer.C: fmt.Println("Shutting down connection server") @@ -266,9 +282,21 @@ func insecureConnServer(forceLegacy bool, statePass []byte, statePath, regCode s return } } + // Keep app running to receive messages------------------------------------ - // If timeout is not specified, leave as long-running thread - select {} + // Wait until the user terminates the program + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + + err = connectServer.Messenger.StopNetworkFollower() + if err != nil { + jww.ERROR.Printf("Failed to stop network follower: %+v", err) + } else { + jww.INFO.Printf("Stopped network follower.") + } + + os.Exit(0) } @@ -321,7 +349,7 @@ func secureConnClient(forceLegacy bool, statePass []byte, statePath, regCode str waitUntilConnected(connected) // Connect with the server------------------------------------------------- - contactPath := viper.GetString("connect") + contactPath := viper.GetString(connectionFlag) serverContact := getContactFromFile(contactPath) fmt.Println("Sending connection request") @@ -387,7 +415,7 @@ func insecureConnClient(forceLegacy bool, statePass []byte, statePath, regCode s waitUntilConnected(connected) // Connect with the server------------------------------------------------- - contactPath := viper.GetString("connect") + contactPath := viper.GetString(connectionFlag) serverContact := getContactFromFile(contactPath) fmt.Println("Sending connection request") jww.INFO.Printf("[CONN] Sending connection request to %s", @@ -445,7 +473,7 @@ func miscConnectionFunctions(client *xxdk.E2e, conn connect.Connection) { } // Disconnect from connection partner-------------------------------------------- - if viper.GetBool("disconnect") { + if viper.GetBool(connectionDisconnectFlag) { // Close the connection if err := conn.Close(); err != nil { jww.FATAL.Panicf("Failed to disconnect with %s: %v", @@ -486,7 +514,7 @@ func (l listener) Name() string { // init initializes commands and flags for Cobra. func init() { - connectionCmd.Flags().String("connect", "", + connectionCmd.Flags().String(connectionFlag, "", "This flag is a client side operation. "+ "This flag expects a path to a contact file (similar "+ "to destfile). It will parse this into an contact object,"+ @@ -495,51 +523,36 @@ func init() { "If a connection already exists between "+ "the client and the server, this will be used instead of "+ "resending a connection request to the server.") - err := viper.BindPFlag("connect", connectionCmd.Flags().Lookup("connect")) - if err != nil { - jww.ERROR.Printf("viper.BindPFlag failed for %q: %+v", "connect", err) - } + bindFlagHelper(connectionFlag, connectionCmd) - connectionCmd.Flags().Bool("startServer", false, + connectionCmd.Flags().Bool(connectionStartServerFlag, false, "This flag is a server-side operation and takes no arguments. "+ "This initiates a connection server. "+ "Calling this flag will have this process call "+ "connection.StartServer().") - err = viper.BindPFlag("startServer", connectionCmd.Flags().Lookup("startServer")) - if err != nil { - jww.ERROR.Printf("viper.BindPFlag failed for %q: %+v", "startServer", err) - } + bindFlagHelper(connectionStartServerFlag, connectionCmd) - connectionCmd.Flags().Duration("serverTimeout", time.Duration(0), + connectionCmd.Flags().Duration(connectionServerTimeoutFlag, time.Duration(0), "This flag is a connection parameter. "+ "This takes as an argument a time.Duration. "+ "This duration specifies how long a server will run before "+ "closing. Without this flag present, a server will be "+ "long-running.") - err = viper.BindPFlag("serverTimeout", connectionCmd.Flags().Lookup("serverTimeout")) - if err != nil { - jww.ERROR.Printf("viper.BindPFlag failed for %q: %+v", "serverTimeout", err) - } + bindFlagHelper(connectionServerTimeoutFlag, connectionCmd) - connectionCmd.Flags().Bool("disconnect", false, + connectionCmd.Flags().Bool(connectionDisconnectFlag, false, "This flag is available to both server and client. "+ "This uses a contact object from a file specified by --destfile."+ "This will close the connection with the given contact "+ "if it exists.") - err = viper.BindPFlag("disconnect", connectionCmd.Flags().Lookup("disconnect")) - if err != nil { - jww.ERROR.Printf("viper.BindPFlag failed for %q: %+v", "disconnect", err) - } + bindFlagHelper(connectionDisconnectFlag, connectionCmd) - connectionCmd.Flags().Bool("authenticated", false, + connectionCmd.Flags().Bool(connectionAuthenticatedFlag, false, "This flag is available to both server and client. "+ "This flag operates as a switch for the authenticated code-path. "+ "With this flag present, any additional connection related flags"+ " will call the applicable authenticated counterpart") - err = viper.BindPFlag("authenticated", connectionCmd.Flags().Lookup("authenticated")) - if err != nil { - jww.ERROR.Printf("viper.BindPFlag failed for %q: %+v", "authenticated", err) - } + bindFlagHelper(connectionAuthenticatedFlag, connectionCmd) rootCmd.AddCommand(connectionCmd) }