Skip to content
Snippets Groups Projects
Commit d7b425c5 authored by Jonah Husson's avatar Jonah Husson
Browse files

Fixes from ben's code review

parent 813cd688
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ package cmd
import (
"fmt"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/spf13/cobra"
jww "github.com/spf13/jwalterweatherman"
"github.com/spf13/viper"
......@@ -18,9 +19,11 @@ import (
"gitlab.com/elixxir/notifications-bot/notifications"
"gitlab.com/elixxir/notifications-bot/storage"
"gitlab.com/elixxir/primitives/id"
"gitlab.com/elixxir/primitives/ndf"
"gitlab.com/elixxir/primitives/utils"
"os"
"path"
"strings"
)
var (
......@@ -54,19 +57,20 @@ var rootCmd = &cobra.Command{
certPath := viper.GetString("certPath")
keyPath := viper.GetString("keyPath")
localAddress := fmt.Sprintf("0.0.0.0:%d", viper.GetInt("port"))
fbCreds := viper.GetString("firebaseCredentialsPath")
// Populate params
NotificationParams = notifications.Params{
Address: localAddress,
CertPath: certPath,
KeyPath: keyPath,
FBCreds: fbCreds,
}
jww.INFO.Println("Starting Notifications...")
impl, err := notifications.StartNotifications(NotificationParams, noTLS)
impl, err := notifications.StartNotifications(NotificationParams, noTLS, false)
if err != nil {
err = fmt.Errorf("Failed to start notifications server: %+v", err)
panic(err)
jww.FATAL.Panicf("Failed to start notifications server: %+v", err)
}
impl.Storage = storage.NewDatabase(
......@@ -76,8 +80,23 @@ var rootCmd = &cobra.Command{
viper.GetString("dbAddress"),
)
permissioningAddr := viper.GetString("permissioningAddress")
permissioningCertPath := viper.GetString("permissioningCertPath")
err = setupConnection(impl, viper.GetString("permissioningCertPath"), viper.GetString("permissioningAddress"))
if err != nil {
jww.FATAL.Panicf("Failed to set up connections: %+v", err)
}
// Start notification loop
killChan := make(chan struct{})
errChan := make(chan error)
go impl.RunNotificationLoop(loopDelay, killChan, errChan)
// Wait forever to prevent process from ending
err = <-errChan
panic(err)
},
}
func setupConnection(impl *notifications.Impl, permissioningCertPath, permissioningAddr string) error {
cert, err := utils.ReadFile(permissioningCertPath)
if err != nil {
jww.FATAL.Panicf("Could not read permissioning cert: %+v", err)
......@@ -87,28 +106,20 @@ var rootCmd = &cobra.Command{
if err != nil {
jww.FATAL.Panicf("Failed to Create permissioning host: %+v", err)
}
Setup:
ndf, err := notifications.PollNdf(nil, impl.Comms)
if err != nil {
jww.FATAL.Panicf("Failed to get NDF: %+v", err)
var def *ndf.NetworkDefinition
for def == nil {
def, err = notifications.PollNdf(nil, impl.Comms)
if err != nil && !strings.Contains(err.Error(), ndf.NO_NDF) {
return errors.Wrap(err, "Failed to get NDF")
}
}
err = impl.UpdateNdf(ndf)
err = impl.UpdateNdf(def)
if err != nil {
jww.FATAL.Panicf("Failed to update impl's NDF: %+v", err)
return errors.Wrap(err, "Failed to update impl's NDF")
}
// Start notification loop
killChan := make(chan struct{})
errChan := make(chan error)
go impl.RunNotificationLoop(viper.GetString("firebaseCredentialsPath"),
loopDelay, killChan, errChan)
// Wait forever to prevent process from ending
err = <-errChan
jww.ERROR.Println(err)
goto Setup
},
return nil
}
// Execute adds all child commands to the root command and sets flags
......@@ -142,8 +153,8 @@ func init() {
rootCmd.Flags().BoolVar(&noTLS, "noTLS", false,
"Runs without TLS enabled")
rootCmd.Flags().IntVarP(&loopDelay, "loopDelay", "", 5,
"Set the delay between notification loops (in seconds)")
rootCmd.Flags().IntVarP(&loopDelay, "loopDelay", "", 500,
"Set the delay between notification loops (in milliseconds)")
// Bind config and command line flags of the same name
err := viper.BindPFlag("verbose", rootCmd.Flags().Lookup("verbose"))
......
......@@ -19,12 +19,11 @@ import (
// function types for use in notificationsbot struct
type SetupFunc func(string) (*messaging.Client, context.Context, error)
type SendFunc func(FBSender, context.Context, string) (string, error)
type SendFunc func(FBSender, string) (string, error)
// FirebaseComm is a struct which holds the functions to setup the messaging app and sending notifications
// Using a struct in this manner allows us to properly unit test the NotifyUser function
type FirebaseComm struct {
SetupMessagingApp SetupFunc
SendNotification SendFunc
}
......@@ -36,44 +35,43 @@ type FBSender interface {
// Set up a notificationbot object with the proper setup and send functions
func NewFirebaseComm() *FirebaseComm {
return &FirebaseComm{
SetupMessagingApp: setupMessagingApp,
SendNotification: sendNotification,
}
}
// FOR TESTING USE ONLY: setup a notificationbot object with mocked setup and send funcs
func NewMockFirebaseComm(t *testing.T, setupFunc SetupFunc, sendFunc SendFunc) *FirebaseComm {
func NewMockFirebaseComm(t *testing.T, sendFunc SendFunc) *FirebaseComm {
if t == nil {
panic("This method should only be used in tests")
}
return &FirebaseComm{
SetupMessagingApp: setupFunc,
SendNotification: sendFunc,
}
}
// setupApp is a helper function which sets up a connection with firebase
// It returns a messaging client, a context object and an error
func setupMessagingApp(serviceKeyPath string) (*messaging.Client, context.Context, error) {
func SetupMessagingApp(serviceKeyPath string) (*messaging.Client, error) {
ctx := context.Background()
opt := option.WithCredentialsFile(serviceKeyPath)
app, err := firebase.NewApp(context.Background(), nil, opt)
if err != nil {
return nil, nil, errors.Errorf("Error initializing app: %v", err)
return nil, errors.Errorf("Error initializing app: %v", err)
}
client, err := app.Messaging(ctx)
if err != nil {
return nil, nil, errors.Errorf("Error getting Messaging app: %+v", err)
return nil, errors.Errorf("Error getting Messaging app: %+v", err)
}
return client, ctx, nil
return client, nil
}
// SendNotification accepts a registration token and service account file
// It gets the proper infrastructure, then builds & sends a notification through the firebase admin API
// returns string, error (string is of dubious use, but is returned for the time being)
func sendNotification(app FBSender, ctx context.Context, token string) (string, error) {
func sendNotification(app FBSender, token string) (string, error) {
ctx := context.Background()
message := &messaging.Message{
Notification: &messaging.Notification{
Title: "xx Messenger",
......@@ -84,7 +82,7 @@ func sendNotification(app FBSender, ctx context.Context, token string) (string,
resp, err := app.Send(ctx, message)
if err != nil {
return "", errors.Errorf("Failed to send notification: %+v", err)
return "", errors.Wrap(err, "Failed to send notification")
}
return resp, nil
}
......@@ -22,10 +22,9 @@ func (MockSender) Send(ctx context.Context, app *messaging.Message) (string, err
// This tests the function which sends a notification to firebase.
// Note: this requires you to have a valid token & service credentials
func TestSendNotification(t *testing.T) {
ctx := context.Background()
app := MockSender{}
_, err := sendNotification(app, ctx, token)
_, err := sendNotification(app, token)
if err != nil {
t.Error(err.Error())
}
......@@ -34,7 +33,7 @@ func TestSendNotification(t *testing.T) {
// Unit test the NewFirebaseComm method
func TestNewFirebaseComm(t *testing.T) {
comm := NewFirebaseComm()
if comm.SendNotification == nil || comm.SetupMessagingApp == nil {
if comm.SendNotification == nil {
t.Error("Failed to set functions in comm")
}
}
......
......@@ -10,6 +10,7 @@ package notifications
import (
"crypto/x509"
"firebase.google.com/go/messaging"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/comms/connect"
......@@ -27,7 +28,7 @@ import (
// Function type definitions for the main operations (poll and notify)
type PollFunc func(*Impl) ([]string, error)
type NotifyFunc func(string, string, *firebase.FirebaseComm, storage.Storage) (string, error)
type NotifyFunc func(*messaging.Client, string, *firebase.FirebaseComm, storage.Storage) (string, error)
// Params struct holds info passed in for configuration
type Params struct {
......@@ -35,6 +36,7 @@ type Params struct {
CertPath string
KeyPath string
PublicAddress string
FBCreds string
}
// Local impl for notifications; holds comms, storage object, creds and main functions
......@@ -47,6 +49,7 @@ type Impl struct {
ndf *ndf.NetworkDefinition
pollFunc PollFunc
notifyFunc NotifyFunc
fcm *messaging.Client
}
// We use an interface here inorder to allow us to mock the getHost and RequestNDF in the notifcationsBot.Comms for testing
......@@ -59,36 +62,34 @@ type NotificationComms interface {
// Main function for this repo accepts credentials and an impl
// loops continuously, polling for notifications and notifying the relevant users
func (nb *Impl) RunNotificationLoop(fbCreds string, loopDuration int, killChan chan struct{}, errChan chan error) {
func (nb *Impl) RunNotificationLoop(loopDuration int, killChan chan struct{}, errChan chan error) {
fc := firebase.NewFirebaseComm()
for {
// Stop execution if killed by channel
select {
case <-killChan:
return
default:
case <-time.After(time.Millisecond * time.Duration(loopDuration)):
}
UIDs, err := nb.pollFunc(nb)
if err != nil {
errChan <- errors.Errorf("Failed to poll gateway for users to notify: %+v", err)
errChan <- errors.Wrap(err, "Failed to poll gateway for users to notify")
return
}
for _, id := range UIDs {
_, err := nb.notifyFunc(id, fbCreds, fc, nb.Storage)
_, err := nb.notifyFunc(nb.fcm, id, fc, nb.Storage)
if err != nil {
errChan <- errors.Errorf("Failed to notify user with ID %+v: %+v", id, err)
errChan <- errors.Wrapf(err, "Failed to notify user with ID %+v", id)
return
}
}
time.Sleep(time.Second * time.Duration(loopDuration))
}
}
// StartNotifications creates an Impl from the information passed in
func StartNotifications(params Params, noTLS bool) (*Impl, error) {
func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
impl := &Impl{}
var cert, key []byte
......@@ -97,28 +98,25 @@ func StartNotifications(params Params, noTLS bool) (*Impl, error) {
// Read in private key
key, err = utils.ReadFile(params.KeyPath)
if err != nil {
return nil, errors.Errorf("failed to read key at %+v: %+v", params.KeyPath, err)
return nil, errors.Wrapf(err, "failed to read key at %+v", params.KeyPath)
}
impl.notificationKey, err = rsa.LoadPrivateKeyFromPem(key)
if err != nil {
return nil, errors.Errorf("Failed to parse notifications server key: %+v. "+
"NotificationsKey is %+v",
err, impl.notificationKey)
return nil, errors.Wrapf(err, "Failed to parse notifications server key (%+v)", impl.notificationKey)
}
if !noTLS {
// Read in TLS keys from files
cert, err = utils.ReadFile(params.CertPath)
if err != nil {
return nil, errors.Errorf("failed to read certificate at %+v: %+v", params.CertPath, err)
return nil, errors.Wrapf(err, "failed to read certificate at %+v", params.CertPath)
}
// Set globals for notification server
impl.certFromFile = string(cert)
impl.notificationCert, err = tls.LoadCertificate(string(cert))
if err != nil {
return nil, errors.Errorf("Failed to parse notifications server cert: %+v. "+
"Notifications cert is %+v",
err, impl.notificationCert)
return nil, errors.Wrapf(err, "Failed to parse notifications server cert. "+
"Notifications cert is %+v", impl.notificationCert)
}
}
......@@ -129,6 +127,14 @@ func StartNotifications(params Params, noTLS bool) (*Impl, error) {
impl.Comms = notificationBot.StartNotificationBot(id.NOTIFICATION_BOT, params.PublicAddress, handler, cert, key)
if !noFirebase {
app, err := firebase.SetupMessagingApp(params.FBCreds)
if err != nil {
return nil, errors.Wrap(err, "Failed to setup firebase messaging app")
}
impl.fcm = app
}
return impl, nil
}
......@@ -149,19 +155,14 @@ func NewImplementation(instance *Impl) *notificationBot.Implementation {
// NotifyUser accepts a UID and service key file path.
// It handles the logic involved in retrieving a user's token and sending the notification
func notifyUser(uid string, serviceKeyPath string, fc *firebase.FirebaseComm, db storage.Storage) (string, error) {
func notifyUser(fcm *messaging.Client, uid string, fc *firebase.FirebaseComm, db storage.Storage) (string, error) {
u, err := db.GetUser(uid)
if err != nil {
jww.DEBUG.Printf("No registration found for user with ID %+v", uid)
return "", nil
}
app, ctx, err := fc.SetupMessagingApp(serviceKeyPath)
if err != nil {
return "", errors.Errorf("Failed to setup messaging app: %+v", err)
}
resp, err := fc.SendNotification(app, ctx, u.Token)
resp, err := fc.SendNotification(fcm, u.Token)
if err != nil {
return "", errors.Errorf("Failed to send notification to user with ID %+v: %+v", uid, err)
}
......@@ -178,7 +179,7 @@ func pollForNotifications(nb *Impl) (strings []string, e error) {
users, err := nb.Comms.RequestNotifications(h)
if err != nil {
return nil, errors.Errorf("Failed to retrieve notifications from gateway: %+v", err)
return nil, errors.Wrap(err, "Failed to retrieve notifications from gateway")
}
return users.IDs, nil
......@@ -193,7 +194,7 @@ func (nb *Impl) RegisterForNotifications(clientToken []byte, auth *connect.Auth)
}
err := nb.Storage.UpsertUser(u)
if err != nil {
return errors.Errorf("Failed to register user with notifications: %+v", err)
return errors.Wrap(err, "Failed to register user with notifications")
}
return nil
}
......@@ -202,7 +203,7 @@ func (nb *Impl) RegisterForNotifications(clientToken []byte, auth *connect.Auth)
func (nb *Impl) UnregisterForNotifications(auth *connect.Auth) error {
err := nb.Storage.DeleteUser(auth.Sender.GetId())
if err != nil {
return errors.Errorf("Failed to unregister user with notifications: %+v", err)
return errors.Wrap(err, "Failed to unregister user with notifications")
}
return nil
}
......@@ -211,7 +212,7 @@ func (nb *Impl) UpdateNdf(ndf *ndf.NetworkDefinition) error {
gw := ndf.Gateways[len(ndf.Gateways)-1]
_, err := nb.Comms.AddHost("gw", gw.Address, []byte(gw.TlsCertificate), false, true)
if err != nil {
return errors.Errorf("Failed to add gateway host from NDF: %+v", err)
return errors.Wrap(err, "Failed to add gateway host from NDF")
}
nb.ndf = ndf
......
......@@ -6,7 +6,6 @@
package notifications
import (
"context"
"firebase.google.com/go/messaging"
"github.com/pkg/errors"
"gitlab.com/elixxir/comms/connect"
......@@ -28,7 +27,7 @@ func TestRunNotificationLoop(t *testing.T) {
impl.pollFunc = func(*Impl) (i []string, e error) {
return []string{"test1", "test2"}, nil
}
impl.notifyFunc = func(s3 string, s2 string, comm *firebase.FirebaseComm, storage storage.Storage) (s string, e error) {
impl.notifyFunc = func(fcm *messaging.Client, s3 string, comm *firebase.FirebaseComm, storage storage.Storage) (s string, e error) {
return "good", nil
}
killChan := make(chan struct{})
......@@ -37,42 +36,27 @@ func TestRunNotificationLoop(t *testing.T) {
time.Sleep(10 * time.Second)
killChan <- struct{}{}
}()
impl.RunNotificationLoop("", 3, killChan, errChan)
impl.RunNotificationLoop(3, killChan, errChan)
}
// Test notificationbot's notifyuser function
// this mocks the setup and send functions, and only tests the core logic of this function
func TestNotifyUser(t *testing.T) {
badsetup := func(string) (*messaging.Client, context.Context, error) {
ctx := context.Background()
return &messaging.Client{}, ctx, errors.New("Failed")
}
setup := func(string) (*messaging.Client, context.Context, error) {
ctx := context.Background()
return &messaging.Client{}, ctx, nil
}
badsend := func(firebase.FBSender, context.Context, string) (string, error) {
badsend := func(firebase.FBSender, string) (string, error) {
return "", errors.New("Failed")
}
send := func(firebase.FBSender, context.Context, string) (string, error) {
send := func(firebase.FBSender, string) (string, error) {
return "", nil
}
fc_badsetup := firebase.NewMockFirebaseComm(t, badsetup, send)
fc_badsend := firebase.NewMockFirebaseComm(t, setup, badsend)
fc := firebase.NewMockFirebaseComm(t, setup, send)
_, err := notifyUser("test", "testpath", fc_badsetup, testutil.MockStorage{})
if err == nil {
t.Error("Should have returned an error")
return
}
fc_badsend := firebase.NewMockFirebaseComm(t, badsend)
fc := firebase.NewMockFirebaseComm(t, send)
_, err = notifyUser("test", "testpath", fc_badsend, testutil.MockStorage{})
_, err := notifyUser(nil, "test", fc_badsend, testutil.MockStorage{})
if err == nil {
t.Errorf("Should have returned an error")
}
_, err = notifyUser("test", "testpath", fc, testutil.MockStorage{})
_, err = notifyUser(nil, "test", fc, testutil.MockStorage{})
if err != nil {
t.Errorf("Failed to notify user properly")
}
......@@ -92,31 +76,31 @@ func TestStartNotifications(t *testing.T) {
PublicAddress: "0.0.0.0:42010",
}
n, err := StartNotifications(params, false)
n, err := StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "failed to read key at") {
t.Errorf("Should have thrown an error for no key path")
}
params.KeyPath = wd + "/../testutil/badkey"
n, err = StartNotifications(params, false)
n, err = StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "Failed to parse notifications server key") {
t.Errorf("Should have thrown an error bad key")
}
params.KeyPath = wd + "/../testutil/cmix.rip.key"
n, err = StartNotifications(params, false)
n, err = StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "failed to read certificate at") {
t.Errorf("Should have thrown an error for no cert path")
}
params.CertPath = wd + "/../testutil/badkey"
n, err = StartNotifications(params, false)
n, err = StartNotifications(params, false, true)
if err == nil || !strings.Contains(err.Error(), "Failed to parse notifications server cert") {
t.Errorf("Should have thrown an error for bad certificate")
}
params.CertPath = wd + "/../testutil/cmix.rip.crt"
n, err = StartNotifications(params, false)
n, err = StartNotifications(params, false, true)
if err != nil {
t.Errorf("Failed to start notifications successfully: %+v", err)
}
......@@ -256,7 +240,8 @@ func getNewImpl() *Impl {
KeyPath: wd + "/../testutil/cmix.rip.key",
CertPath: wd + "/../testutil/cmix.rip.crt",
PublicAddress: "0.0.0.0:0",
FBCreds: "",
}
instance, _ := StartNotifications(params, false)
instance, _ := StartNotifications(params, false, true)
return instance
}
......@@ -32,7 +32,7 @@ func PollNdf(currentDef *ndf.NetworkDefinition, comms NotificationComms) (*ndf.N
}
//Put the hash in a message
msg := &pb.NDFHash{Hash: ndfHash}
msg := &pb.NDFHash{Hash: ndfHash} // TODO: this should be a helper somewhere
regHost, ok := comms.GetHost(id.PERMISSIONING)
if !ok {
......@@ -42,7 +42,7 @@ func PollNdf(currentDef *ndf.NetworkDefinition, comms NotificationComms) (*ndf.N
//Send the hash to registration
response, err := comms.RequestNdf(regHost, msg)
if err != nil {
errMsg := errors.Errorf("Failed to get ndf from permissioning: %v", err)
errMsg := errors.Wrap(err, "Failed to get ndf from permissioning")
if strings.Contains(errMsg.Error(), noNDFErr.Error()) {
jww.WARN.Println("Continuing without an updated NDF")
return nil, nil
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment