Skip to content
Snippets Groups Projects
Commit c743ca8e authored by ksparakis's avatar ksparakis
Browse files

Merge branch 'release' into hotfix/poll-error

# Conflicts:
#	go.mod
#	go.sum
#	notifications/updateNDF.go
#	notifications/updateNDF_test.go
parents 5d02765f 6dff1068
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ package cmd ...@@ -11,6 +11,7 @@ package cmd
import ( import (
"fmt" "fmt"
"github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"github.com/spf13/viper" "github.com/spf13/viper"
...@@ -18,8 +19,11 @@ import ( ...@@ -18,8 +19,11 @@ import (
"gitlab.com/elixxir/notifications-bot/notifications" "gitlab.com/elixxir/notifications-bot/notifications"
"gitlab.com/elixxir/notifications-bot/storage" "gitlab.com/elixxir/notifications-bot/storage"
"gitlab.com/elixxir/primitives/id" "gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf"
"gitlab.com/elixxir/primitives/utils"
"os" "os"
"path" "path"
"strings"
) )
var ( var (
...@@ -53,19 +57,20 @@ var rootCmd = &cobra.Command{ ...@@ -53,19 +57,20 @@ var rootCmd = &cobra.Command{
certPath := viper.GetString("certPath") certPath := viper.GetString("certPath")
keyPath := viper.GetString("keyPath") keyPath := viper.GetString("keyPath")
localAddress := fmt.Sprintf("0.0.0.0:%d", viper.GetInt("port")) localAddress := fmt.Sprintf("0.0.0.0:%d", viper.GetInt("port"))
fbCreds := viper.GetString("firebaseCredentialsPath")
// Populate params // Populate params
NotificationParams = notifications.Params{ NotificationParams = notifications.Params{
Address: localAddress, Address: localAddress,
CertPath: certPath, CertPath: certPath,
KeyPath: keyPath, KeyPath: keyPath,
FBCreds: fbCreds,
} }
jww.INFO.Println("Starting Notifications...") jww.INFO.Println("Starting Notifications...")
impl, err := notifications.StartNotifications(NotificationParams, noTLS) impl, err := notifications.StartNotifications(NotificationParams, noTLS, false)
if err != nil { if err != nil {
err = fmt.Errorf("Failed to start notifications server: %+v", err) jww.FATAL.Panicf("Failed to start notifications server: %+v", err)
panic(err)
} }
impl.Storage = storage.NewDatabase( impl.Storage = storage.NewDatabase(
...@@ -75,25 +80,54 @@ var rootCmd = &cobra.Command{ ...@@ -75,25 +80,54 @@ var rootCmd = &cobra.Command{
viper.GetString("dbAddress"), viper.GetString("dbAddress"),
) )
permissioningAddr := viper.GetString("permissioningAddress") err = setupConnection(impl, viper.GetString("permissioningCertPath"), viper.GetString("permissioningAddress"))
permissioningCertPath := viper.GetString("permissioningCertPath")
_, err = impl.Comms.AddHost(id.PERMISSIONING, permissioningAddr, []byte(permissioningCertPath), true, true)
if err != nil { if err != nil {
jww.FATAL.Panicf("Failed to Create permissioning host: %+v", err) jww.FATAL.Panicf("Failed to set up connections: %+v", err)
} }
// Start notification loop // Start notification loop
killChan := make(chan struct{}) killChan := make(chan struct{})
go impl.RunNotificationLoop(viper.GetString("firebaseCredentialsPath"), errChan := make(chan error)
loopDelay, go impl.RunNotificationLoop(loopDelay, killChan, errChan)
killChan)
// Wait forever to prevent process from ending // Wait forever to prevent process from ending
select {} err = <-errChan
panic(err)
}, },
} }
// setupConnection handles connecting to permissioning and polling for the NDF once connected
func setupConnection(impl *notifications.Impl, permissioningCertPath, permissioningAddr string) error {
// Read in permissioning certificate
cert, err := utils.ReadFile(permissioningCertPath)
if err != nil {
return errors.Wrap(err, "Could not read permissioning cert")
}
// Add host for permissioning server
_, err = impl.Comms.AddHost(id.PERMISSIONING, permissioningAddr, cert, true, false)
if err != nil {
return errors.Wrap(err, "Failed to Create permissioning host")
}
// Loop until an NDF is received
var def *ndf.NetworkDefinition
for def == nil {
def, err = notifications.PollNdf(nil, impl.Comms)
// Don't stop if error is expected
if err != nil && !strings.Contains(err.Error(), ndf.NO_NDF) {
return errors.Wrap(err, "Failed to get NDF")
}
}
// Update NDF & gateway host
err = impl.UpdateNdf(def)
if err != nil {
return errors.Wrap(err, "Failed to update impl's NDF")
}
return nil
}
// Execute adds all child commands to the root command and sets flags // Execute adds all child commands to the root command and sets flags
// appropriately. This is called by main.main(). It only needs to // appropriately. This is called by main.main(). It only needs to
// happen once to the rootCmd. // happen once to the rootCmd.
...@@ -125,8 +159,8 @@ func init() { ...@@ -125,8 +159,8 @@ func init() {
rootCmd.Flags().BoolVar(&noTLS, "noTLS", false, rootCmd.Flags().BoolVar(&noTLS, "noTLS", false,
"Runs without TLS enabled") "Runs without TLS enabled")
rootCmd.Flags().IntVarP(&loopDelay, "loopDelay", "", 5, rootCmd.Flags().IntVarP(&loopDelay, "loopDelay", "", 500,
"Set the delay between notification loops (in seconds)") "Set the delay between notification loops (in milliseconds)")
// Bind config and command line flags of the same name // Bind config and command line flags of the same name
err := viper.BindPFlag("verbose", rootCmd.Flags().Lookup("verbose")) err := viper.BindPFlag("verbose", rootCmd.Flags().Lookup("verbose"))
......
...@@ -19,12 +19,11 @@ import ( ...@@ -19,12 +19,11 @@ import (
// function types for use in notificationsbot struct // function types for use in notificationsbot struct
type SetupFunc func(string) (*messaging.Client, context.Context, error) type SetupFunc func(string) (*messaging.Client, context.Context, error)
type SendFunc func(FBSender, context.Context, string) (string, error) type SendFunc func(FBSender, string) (string, error)
// FirebaseComm is a struct which holds the functions to setup the messaging app and sending notifications // FirebaseComm is a struct which holds the functions to setup the messaging app and sending notifications
// Using a struct in this manner allows us to properly unit test the NotifyUser function // Using a struct in this manner allows us to properly unit test the NotifyUser function
type FirebaseComm struct { type FirebaseComm struct {
SetupMessagingApp SetupFunc
SendNotification SendFunc SendNotification SendFunc
} }
...@@ -36,44 +35,43 @@ type FBSender interface { ...@@ -36,44 +35,43 @@ type FBSender interface {
// Set up a notificationbot object with the proper setup and send functions // Set up a notificationbot object with the proper setup and send functions
func NewFirebaseComm() *FirebaseComm { func NewFirebaseComm() *FirebaseComm {
return &FirebaseComm{ return &FirebaseComm{
SetupMessagingApp: setupMessagingApp,
SendNotification: sendNotification, SendNotification: sendNotification,
} }
} }
// FOR TESTING USE ONLY: setup a notificationbot object with mocked setup and send funcs // FOR TESTING USE ONLY: setup a notificationbot object with mocked setup and send funcs
func NewMockFirebaseComm(t *testing.T, setupFunc SetupFunc, sendFunc SendFunc) *FirebaseComm { func NewMockFirebaseComm(t *testing.T, sendFunc SendFunc) *FirebaseComm {
if t == nil { if t == nil {
panic("This method should only be used in tests") panic("This method should only be used in tests")
} }
return &FirebaseComm{ return &FirebaseComm{
SetupMessagingApp: setupFunc,
SendNotification: sendFunc, SendNotification: sendFunc,
} }
} }
// setupApp is a helper function which sets up a connection with firebase // setupApp is a helper function which sets up a connection with firebase
// It returns a messaging client, a context object and an error // It returns a messaging client, a context object and an error
func setupMessagingApp(serviceKeyPath string) (*messaging.Client, context.Context, error) { func SetupMessagingApp(serviceKeyPath string) (*messaging.Client, error) {
ctx := context.Background() ctx := context.Background()
opt := option.WithCredentialsFile(serviceKeyPath) opt := option.WithCredentialsFile(serviceKeyPath)
app, err := firebase.NewApp(context.Background(), nil, opt) app, err := firebase.NewApp(context.Background(), nil, opt)
if err != nil { if err != nil {
return nil, nil, errors.Errorf("Error initializing app: %v", err) return nil, errors.Errorf("Error initializing app: %v", err)
} }
client, err := app.Messaging(ctx) client, err := app.Messaging(ctx)
if err != nil { if err != nil {
return nil, nil, errors.Errorf("Error getting Messaging app: %+v", err) return nil, errors.Errorf("Error getting Messaging app: %+v", err)
} }
return client, ctx, nil return client, nil
} }
// SendNotification accepts a registration token and service account file // SendNotification accepts a registration token and service account file
// It gets the proper infrastructure, then builds & sends a notification through the firebase admin API // It gets the proper infrastructure, then builds & sends a notification through the firebase admin API
// returns string, error (string is of dubious use, but is returned for the time being) // returns string, error (string is of dubious use, but is returned for the time being)
func sendNotification(app FBSender, ctx context.Context, token string) (string, error) { func sendNotification(app FBSender, token string) (string, error) {
ctx := context.Background()
message := &messaging.Message{ message := &messaging.Message{
Notification: &messaging.Notification{ Notification: &messaging.Notification{
Title: "xx Messenger", Title: "xx Messenger",
...@@ -84,7 +82,7 @@ func sendNotification(app FBSender, ctx context.Context, token string) (string, ...@@ -84,7 +82,7 @@ func sendNotification(app FBSender, ctx context.Context, token string) (string,
resp, err := app.Send(ctx, message) resp, err := app.Send(ctx, message)
if err != nil { if err != nil {
return "", errors.Errorf("Failed to send notification: %+v", err) return "", errors.Wrap(err, "Failed to send notification")
} }
return resp, nil return resp, nil
} }
...@@ -22,10 +22,9 @@ func (MockSender) Send(ctx context.Context, app *messaging.Message) (string, err ...@@ -22,10 +22,9 @@ func (MockSender) Send(ctx context.Context, app *messaging.Message) (string, err
// This tests the function which sends a notification to firebase. // This tests the function which sends a notification to firebase.
// Note: this requires you to have a valid token & service credentials // Note: this requires you to have a valid token & service credentials
func TestSendNotification(t *testing.T) { func TestSendNotification(t *testing.T) {
ctx := context.Background()
app := MockSender{} app := MockSender{}
_, err := sendNotification(app, ctx, token) _, err := sendNotification(app, token)
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
...@@ -34,7 +33,7 @@ func TestSendNotification(t *testing.T) { ...@@ -34,7 +33,7 @@ func TestSendNotification(t *testing.T) {
// Unit test the NewFirebaseComm method // Unit test the NewFirebaseComm method
func TestNewFirebaseComm(t *testing.T) { func TestNewFirebaseComm(t *testing.T) {
comm := NewFirebaseComm() comm := NewFirebaseComm()
if comm.SendNotification == nil || comm.SetupMessagingApp == nil { if comm.SendNotification == nil {
t.Error("Failed to set functions in comm") t.Error("Failed to set functions in comm")
} }
} }
......
...@@ -10,6 +10,7 @@ package notifications ...@@ -10,6 +10,7 @@ package notifications
import ( import (
"crypto/x509" "crypto/x509"
"firebase.google.com/go/messaging"
"github.com/pkg/errors" "github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/comms/connect" "gitlab.com/elixxir/comms/connect"
...@@ -20,14 +21,14 @@ import ( ...@@ -20,14 +21,14 @@ import (
"gitlab.com/elixxir/notifications-bot/firebase" "gitlab.com/elixxir/notifications-bot/firebase"
"gitlab.com/elixxir/notifications-bot/storage" "gitlab.com/elixxir/notifications-bot/storage"
"gitlab.com/elixxir/primitives/id" "gitlab.com/elixxir/primitives/id"
ndf "gitlab.com/elixxir/primitives/ndf" "gitlab.com/elixxir/primitives/ndf"
"gitlab.com/elixxir/primitives/utils" "gitlab.com/elixxir/primitives/utils"
"time" "time"
) )
// Function type definitions for the main operations (poll and notify) // Function type definitions for the main operations (poll and notify)
type PollFunc func(*connect.Host, RequestInterface) ([]string, error) type PollFunc func(*Impl) ([]string, error)
type NotifyFunc func(string, string, *firebase.FirebaseComm, storage.Storage) (string, error) type NotifyFunc func(*messaging.Client, string, *firebase.FirebaseComm, storage.Storage) (string, error)
// Params struct holds info passed in for configuration // Params struct holds info passed in for configuration
type Params struct { type Params struct {
...@@ -35,55 +36,63 @@ type Params struct { ...@@ -35,55 +36,63 @@ type Params struct {
CertPath string CertPath string
KeyPath string KeyPath string
PublicAddress string PublicAddress string
FBCreds string
} }
// Local impl for notifications; holds comms, storage object, creds and main functions // Local impl for notifications; holds comms, storage object, creds and main functions
type Impl struct { type Impl struct {
Comms *notificationBot.Comms Comms NotificationComms
Storage storage.Storage Storage storage.Storage
notificationCert *x509.Certificate notificationCert *x509.Certificate
notificationKey *rsa.PrivateKey notificationKey *rsa.PrivateKey
certFromFile string certFromFile string
gatewayHost *connect.Host // TODO: populate this field from ndf ndf *ndf.NetworkDefinition
pollFunc PollFunc pollFunc PollFunc
notifyFunc NotifyFunc notifyFunc NotifyFunc
ndf *ndf.NetworkDefinition fcm *messaging.Client
gwId *id.Gateway
} }
// Request interface holds the request function from comms, allowing us to unit test polling // We use an interface here inorder to allow us to mock the getHost and RequestNDF in the notifcationsBot.Comms for testing
type RequestInterface interface { type NotificationComms interface {
GetHost(hostId string) (*connect.Host, bool)
AddHost(id, address string, cert []byte, disableTimeout, enableAuth bool) (host *connect.Host, err error)
RequestNotifications(host *connect.Host) (*pb.IDList, error) RequestNotifications(host *connect.Host) (*pb.IDList, error)
RequestNdf(host *connect.Host, message *pb.NDFHash) (*pb.NDF, error)
} }
// Main function for this repo accepts credentials and an impl // Main function for this repo accepts credentials and an impl
// loops continuously, polling for notifications and notifying the relevant users // loops continuously, polling for notifications and notifying the relevant users
func (nb *Impl) RunNotificationLoop(fbCreds string, loopDuration int, killChan chan struct{}) { func (nb *Impl) RunNotificationLoop(loopDuration int, killChan chan struct{}, errChan chan error) {
fc := firebase.NewFirebaseComm() fc := firebase.NewFirebaseComm()
for { for {
// Stop execution if killed by channel
select { select {
case <-killChan: case <-killChan:
return return
default: case <-time.After(time.Millisecond * time.Duration(loopDuration)):
} }
// TODO: fill in body of main loop, should poll gateway and send relevant notifications to firebase
UIDs, err := nb.pollFunc(nb.gatewayHost, nb.Comms) // Poll for UIDs to notify
UIDs, err := nb.pollFunc(nb)
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to poll gateway for users to notify: %+v", err) errChan <- errors.Wrap(err, "Failed to poll gateway for users to notify")
return
} }
for _, id := range UIDs { for _, id := range UIDs {
_, err := nb.notifyFunc(id, fbCreds, fc, nb.Storage) // Attempt to notify a given user (will not error if UID not registered)
_, err := nb.notifyFunc(nb.fcm, id, fc, nb.Storage)
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to notify user with ID %+v: %+v", id, err) errChan <- errors.Wrapf(err, "Failed to notify user with ID %+v", id)
return
} }
} }
time.Sleep(time.Second * time.Duration(loopDuration))
} }
} }
// StartNotifications creates an Impl from the information passed in // StartNotifications creates an Impl from the information passed in
func StartNotifications(params Params, noTLS bool) (*Impl, error) { func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
impl := &Impl{} impl := &Impl{}
var cert, key []byte var cert, key []byte
...@@ -92,38 +101,45 @@ func StartNotifications(params Params, noTLS bool) (*Impl, error) { ...@@ -92,38 +101,45 @@ func StartNotifications(params Params, noTLS bool) (*Impl, error) {
// Read in private key // Read in private key
key, err = utils.ReadFile(params.KeyPath) key, err = utils.ReadFile(params.KeyPath)
if err != nil { if err != nil {
return nil, errors.Errorf("failed to read key at %+v: %+v", params.KeyPath, err) return nil, errors.Wrapf(err, "failed to read key at %+v", params.KeyPath)
} }
impl.notificationKey, err = rsa.LoadPrivateKeyFromPem(key) impl.notificationKey, err = rsa.LoadPrivateKeyFromPem(key)
if err != nil { if err != nil {
return nil, errors.Errorf("Failed to parse notifications server key: %+v. "+ return nil, errors.Wrapf(err, "Failed to parse notifications server key (%+v)", impl.notificationKey)
"NotificationsKey is %+v",
err, impl.notificationKey)
} }
if !noTLS { if !noTLS {
// Read in TLS keys from files // Read in TLS keys from files
cert, err = utils.ReadFile(params.CertPath) cert, err = utils.ReadFile(params.CertPath)
if err != nil { if err != nil {
return nil, errors.Errorf("failed to read certificate at %+v: %+v", params.CertPath, err) return nil, errors.Wrapf(err, "failed to read certificate at %+v", params.CertPath)
} }
// Set globals for notification server // Set globals for notification server
impl.certFromFile = string(cert) impl.certFromFile = string(cert)
impl.notificationCert, err = tls.LoadCertificate(string(cert)) impl.notificationCert, err = tls.LoadCertificate(string(cert))
if err != nil { if err != nil {
return nil, errors.Errorf("Failed to parse notifications server cert: %+v. "+ return nil, errors.Wrapf(err, "Failed to parse notifications server cert. "+
"Notifications cert is %+v", "Notifications cert is %+v", impl.notificationCert)
err, impl.notificationCert)
} }
} }
// set up stored functions
impl.pollFunc = pollForNotifications impl.pollFunc = pollForNotifications
impl.notifyFunc = notifyUser impl.notifyFunc = notifyUser
// Start notification comms server
handler := NewImplementation(impl) handler := NewImplementation(impl)
impl.Comms = notificationBot.StartNotificationBot(id.NOTIFICATION_BOT, params.PublicAddress, handler, cert, key) impl.Comms = notificationBot.StartNotificationBot(id.NOTIFICATION_BOT, params.PublicAddress, handler, cert, key)
// Set up firebase messaging client
if !noFirebase {
app, err := firebase.SetupMessagingApp(params.FBCreds)
if err != nil {
return nil, errors.Wrap(err, "Failed to setup firebase messaging app")
}
impl.fcm = app
}
return impl, nil return impl, nil
} }
...@@ -144,18 +160,15 @@ func NewImplementation(instance *Impl) *notificationBot.Implementation { ...@@ -144,18 +160,15 @@ func NewImplementation(instance *Impl) *notificationBot.Implementation {
// NotifyUser accepts a UID and service key file path. // NotifyUser accepts a UID and service key file path.
// It handles the logic involved in retrieving a user's token and sending the notification // It handles the logic involved in retrieving a user's token and sending the notification
func notifyUser(uid string, serviceKeyPath string, fc *firebase.FirebaseComm, db storage.Storage) (string, error) { func notifyUser(fcm *messaging.Client, uid string, fc *firebase.FirebaseComm, db storage.Storage) (string, error) {
u, err := db.GetUser(uid) u, err := db.GetUser(uid)
if err != nil { if err != nil {
return "", errors.Errorf("Failed to get token for UID %+v: %+v", uid, err) jww.DEBUG.Printf("No registration found for user with ID %+v", uid)
// This path is not an error. if no results are returned, the user hasn't registered for notifications
return "", nil
} }
app, ctx, err := fc.SetupMessagingApp(serviceKeyPath) resp, err := fc.SendNotification(fcm, u.Token)
if err != nil {
return "", errors.Errorf("Failed to setup messaging app: %+v", err)
}
resp, err := fc.SendNotification(app, ctx, u.Token)
if err != nil { if err != nil {
return "", errors.Errorf("Failed to send notification to user with ID %+v: %+v", uid, err) return "", errors.Errorf("Failed to send notification to user with ID %+v: %+v", uid, err)
} }
...@@ -164,10 +177,15 @@ func notifyUser(uid string, serviceKeyPath string, fc *firebase.FirebaseComm, db ...@@ -164,10 +177,15 @@ func notifyUser(uid string, serviceKeyPath string, fc *firebase.FirebaseComm, db
// pollForNotifications accepts a gateway host and a RequestInterface (a comms object) // pollForNotifications accepts a gateway host and a RequestInterface (a comms object)
// It retrieves a list of user ids to be notified from the gateway // It retrieves a list of user ids to be notified from the gateway
func pollForNotifications(h *connect.Host, comms RequestInterface) (strings []string, e error) { func pollForNotifications(nb *Impl) (strings []string, e error) {
users, err := comms.RequestNotifications(h) h, ok := nb.Comms.GetHost(nb.gwId.String())
if !ok {
return nil, errors.New("Could not find gateway host")
}
users, err := nb.Comms.RequestNotifications(h)
if err != nil { if err != nil {
return nil, errors.Errorf("Failed to retrieve notifications from gateway: %+v", err) return nil, errors.Wrap(err, "Failed to retrieve notifications from gateway")
} }
return users.IDs, nil return users.IDs, nil
...@@ -182,7 +200,7 @@ func (nb *Impl) RegisterForNotifications(clientToken []byte, auth *connect.Auth) ...@@ -182,7 +200,7 @@ func (nb *Impl) RegisterForNotifications(clientToken []byte, auth *connect.Auth)
} }
err := nb.Storage.UpsertUser(u) err := nb.Storage.UpsertUser(u)
if err != nil { if err != nil {
return errors.Errorf("Failed to register user with notifications: %+v", err) return errors.Wrap(err, "Failed to register user with notifications")
} }
return nil return nil
} }
...@@ -191,12 +209,20 @@ func (nb *Impl) RegisterForNotifications(clientToken []byte, auth *connect.Auth) ...@@ -191,12 +209,20 @@ func (nb *Impl) RegisterForNotifications(clientToken []byte, auth *connect.Auth)
func (nb *Impl) UnregisterForNotifications(auth *connect.Auth) error { func (nb *Impl) UnregisterForNotifications(auth *connect.Auth) error {
err := nb.Storage.DeleteUser(auth.Sender.GetId()) err := nb.Storage.DeleteUser(auth.Sender.GetId())
if err != nil { if err != nil {
return errors.Errorf("Failed to unregister user with notifications: %+v", err) return errors.Wrap(err, "Failed to unregister user with notifications")
} }
return nil return nil
} }
// Setter function to, set NDF into our Impl structure // Update stored NDF and add host for gateway to poll
func (nb *Impl) updateNdf(ndf *ndf.NetworkDefinition) { func (nb *Impl) UpdateNdf(ndf *ndf.NetworkDefinition) error {
gw := ndf.Gateways[len(ndf.Gateways)-1]
nb.gwId = id.NewNodeFromBytes(ndf.Nodes[len(ndf.Nodes)-1].ID).NewGateway()
_, err := nb.Comms.AddHost(nb.gwId.String(), gw.Address, []byte(gw.TlsCertificate), true, true)
if err != nil {
return errors.Wrap(err, "Failed to add gateway host from NDF")
}
nb.ndf = ndf nb.ndf = ndf
return nil
} }
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
package notifications package notifications
import ( import (
"context"
"firebase.google.com/go/messaging" "firebase.google.com/go/messaging"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/elixxir/comms/connect" "gitlab.com/elixxir/comms/connect"
...@@ -14,6 +13,7 @@ import ( ...@@ -14,6 +13,7 @@ import (
"gitlab.com/elixxir/notifications-bot/firebase" "gitlab.com/elixxir/notifications-bot/firebase"
"gitlab.com/elixxir/notifications-bot/storage" "gitlab.com/elixxir/notifications-bot/storage"
"gitlab.com/elixxir/notifications-bot/testutil" "gitlab.com/elixxir/notifications-bot/testutil"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf" "gitlab.com/elixxir/primitives/ndf"
"gitlab.com/elixxir/primitives/utils" "gitlab.com/elixxir/primitives/utils"
"os" "os"
...@@ -28,56 +28,39 @@ var ExampleNdfJSON = `{"Timestamp":"2019-06-04T20:48:48-07:00","gateways":[{"Add ...@@ -28,56 +28,39 @@ var ExampleNdfJSON = `{"Timestamp":"2019-06-04T20:48:48-07:00","gateways":[{"Add
// Basic test to cover RunNotificationLoop, including error sending // Basic test to cover RunNotificationLoop, including error sending
func TestRunNotificationLoop(t *testing.T) { func TestRunNotificationLoop(t *testing.T) {
impl := getNewImpl() impl := getNewImpl()
impl.pollFunc = func(host *connect.Host, requestInterface RequestInterface) (i []string, e error) { impl.pollFunc = func(*Impl) (i []string, e error) {
return []string{"test1", "test2"}, nil return []string{"test1", "test2"}, nil
} }
impl.notifyFunc = func(s3 string, s2 string, comm *firebase.FirebaseComm, storage storage.Storage) (s string, e error) { impl.notifyFunc = func(fcm *messaging.Client, s3 string, comm *firebase.FirebaseComm, storage storage.Storage) (s string, e error) {
if s3 == "test1" {
return "", errors.New("Failed to notify")
}
return "good", nil return "good", nil
} }
killChan := make(chan struct{}) killChan := make(chan struct{})
errChan := make(chan error)
go func() { go func() {
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
killChan <- struct{}{} killChan <- struct{}{}
}() }()
impl.RunNotificationLoop("", 3, killChan) impl.RunNotificationLoop(3, killChan, errChan)
} }
// Test notificationbot's notifyuser function // Test notificationbot's notifyuser function
// this mocks the setup and send functions, and only tests the core logic of this function // this mocks the setup and send functions, and only tests the core logic of this function
func TestNotifyUser(t *testing.T) { func TestNotifyUser(t *testing.T) {
badsetup := func(string) (*messaging.Client, context.Context, error) { badsend := func(firebase.FBSender, string) (string, error) {
ctx := context.Background()
return &messaging.Client{}, ctx, errors.New("Failed")
}
setup := func(string) (*messaging.Client, context.Context, error) {
ctx := context.Background()
return &messaging.Client{}, ctx, nil
}
badsend := func(firebase.FBSender, context.Context, string) (string, error) {
return "", errors.New("Failed") return "", errors.New("Failed")
} }
send := func(firebase.FBSender, context.Context, string) (string, error) { send := func(firebase.FBSender, string) (string, error) {
return "", nil return "", nil
} }
fc_badsetup := firebase.NewMockFirebaseComm(t, badsetup, send) fc_badsend := firebase.NewMockFirebaseComm(t, badsend)
fc_badsend := firebase.NewMockFirebaseComm(t, setup, badsend) fc := firebase.NewMockFirebaseComm(t, send)
fc := firebase.NewMockFirebaseComm(t, setup, send)
_, err := notifyUser("test", "testpath", fc_badsetup, testutil.MockStorage{}) _, err := notifyUser(nil, "test", fc_badsend, testutil.MockStorage{})
if err == nil {
t.Error("Should have returned an error")
return
}
_, err = notifyUser("test", "testpath", fc_badsend, testutil.MockStorage{})
if err == nil { if err == nil {
t.Errorf("Should have returned an error") t.Errorf("Should have returned an error")
} }
_, err = notifyUser("test", "testpath", fc, testutil.MockStorage{}) _, err = notifyUser(nil, "test", fc, testutil.MockStorage{})
if err != nil { if err != nil {
t.Errorf("Failed to notify user properly") t.Errorf("Failed to notify user properly")
} }
...@@ -97,31 +80,31 @@ func TestStartNotifications(t *testing.T) { ...@@ -97,31 +80,31 @@ func TestStartNotifications(t *testing.T) {
PublicAddress: "0.0.0.0:42010", PublicAddress: "0.0.0.0:42010",
} }
n, err := StartNotifications(params, false) n, err := StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "failed to read key at") { if err == nil || !strings.Contains(err.Error(), "failed to read key at") {
t.Errorf("Should have thrown an error for no key path") t.Errorf("Should have thrown an error for no key path")
} }
params.KeyPath = wd + "/../testutil/badkey" params.KeyPath = wd + "/../testutil/badkey"
n, err = StartNotifications(params, false) n, err = StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "Failed to parse notifications server key") { if err == nil || !strings.Contains(err.Error(), "Failed to parse notifications server key") {
t.Errorf("Should have thrown an error bad key") t.Errorf("Should have thrown an error bad key")
} }
params.KeyPath = wd + "/../testutil/cmix.rip.key" params.KeyPath = wd + "/../testutil/cmix.rip.key"
n, err = StartNotifications(params, false) n, err = StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "failed to read certificate at") { if err == nil || !strings.Contains(err.Error(), "failed to read certificate at") {
t.Errorf("Should have thrown an error for no cert path") t.Errorf("Should have thrown an error for no cert path")
} }
params.CertPath = wd + "/../testutil/badkey" params.CertPath = wd + "/../testutil/badkey"
n, err = StartNotifications(params, false) n, err = StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "Failed to parse notifications server cert") { if err == nil || !strings.Contains(err.Error(), "Failed to parse notifications server cert") {
t.Errorf("Should have thrown an error for bad certificate") t.Errorf("Should have thrown an error for bad certificate")
} }
params.CertPath = wd + "/../testutil/cmix.rip.crt" params.CertPath = wd + "/../testutil/cmix.rip.crt"
n, err = StartNotifications(params, false) n, err = StartNotifications(params, false, true)
if err != nil { if err != nil {
t.Errorf("Failed to start notifications successfully: %+v", err) t.Errorf("Failed to start notifications successfully: %+v", err)
} }
...@@ -152,21 +135,47 @@ func (m mockPollComm) RequestNotifications(host *connect.Host) (*pb.IDList, erro ...@@ -152,21 +135,47 @@ func (m mockPollComm) RequestNotifications(host *connect.Host) (*pb.IDList, erro
IDs: []string{"test"}, IDs: []string{"test"},
}, nil }, nil
} }
func (m mockPollComm) GetHost(hostId string) (*connect.Host, bool) {
return &connect.Host{}, true
}
func (m mockPollComm) AddHost(id, address string, cert []byte, disableTimeout, enableAuth bool) (host *connect.Host, err error) {
return nil, nil
}
func (m mockPollComm) RequestNdf(host *connect.Host, message *pb.NDFHash) (*pb.NDF, error) {
return nil, nil
}
type mockPollErrComm struct{} type mockPollErrComm struct{}
func (m mockPollErrComm) RequestNotifications(host *connect.Host) (*pb.IDList, error) { func (m mockPollErrComm) RequestNotifications(host *connect.Host) (*pb.IDList, error) {
return nil, errors.New("failed to poll") return nil, errors.New("failed to poll")
} }
func (m mockPollErrComm) GetHost(hostId string) (*connect.Host, bool) {
return nil, false
}
func (m mockPollErrComm) AddHost(id, address string, cert []byte, disableTimeout, enableAuth bool) (host *connect.Host, err error) {
return nil, nil
}
func (m mockPollErrComm) RequestNdf(host *connect.Host, message *pb.NDFHash) (*pb.NDF, error) {
return nil, nil
}
// Unit test for PollForNotifications // Unit test for PollForNotifications
func TestPollForNotifications(t *testing.T) { func TestPollForNotifications(t *testing.T) {
_, err := pollForNotifications(nil, mockPollErrComm{}) impl := &Impl{
Comms: mockPollComm{},
gwId: id.NewNodeFromBytes([]byte("test")).NewGateway(),
}
errImpl := &Impl{
Comms: mockPollErrComm{},
gwId: id.NewNodeFromBytes([]byte("test")).NewGateway(),
}
_, err := pollForNotifications(errImpl)
if err == nil { if err == nil {
t.Errorf("Failed to poll for notifications: %+v", err) t.Errorf("Failed to poll for notifications: %+v", err)
} }
_, err = pollForNotifications(nil, mockPollComm{}) _, err = pollForNotifications(impl)
if err != nil { if err != nil {
t.Errorf("Failed to poll for notifications: %+v", err) t.Errorf("Failed to poll for notifications: %+v", err)
} }
...@@ -199,7 +208,10 @@ func TestImpl_UpdateNdf(t *testing.T) { ...@@ -199,7 +208,10 @@ func TestImpl_UpdateNdf(t *testing.T) {
t.Logf("%+v", err) t.Logf("%+v", err)
} }
impl.updateNdf(testNdf) err = impl.UpdateNdf(testNdf)
if err != nil {
t.Errorf("Failed to update ndf")
}
if impl.ndf != testNdf { if impl.ndf != testNdf {
t.Logf("Failed to change ndf") t.Logf("Failed to change ndf")
...@@ -234,7 +246,8 @@ func getNewImpl() *Impl { ...@@ -234,7 +246,8 @@ func getNewImpl() *Impl {
KeyPath: wd + "/../testutil/cmix.rip.key", KeyPath: wd + "/../testutil/cmix.rip.key",
CertPath: wd + "/../testutil/cmix.rip.crt", CertPath: wd + "/../testutil/cmix.rip.crt",
PublicAddress: "0.0.0.0:0", PublicAddress: "0.0.0.0:0",
FBCreds: "",
} }
instance, _ := StartNotifications(params, false) instance, _ := StartNotifications(params, false, true)
return instance return instance
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment