From cb0bfcaee7f6730dc0fd77bb5dbc83479c6e02c0 Mon Sep 17 00:00:00 2001
From: jbhusson <jonah@elixxir.io>
Date: Fri, 17 Sep 2021 12:44:26 -0400
Subject: [PATCH] Fixes for apns sending

---
 notifications/notifications.go      | 17 +++++++++++++----
 notifications/notifications_test.go |  6 +++---
 2 files changed, 16 insertions(+), 7 deletions(-)

diff --git a/notifications/notifications.go b/notifications/notifications.go
index 808ea5a..aa8e941 100644
--- a/notifications/notifications.go
+++ b/notifications/notifications.go
@@ -38,7 +38,7 @@ import (
 )
 
 // Function type definitions for the main operations (poll and notify)
-type NotifyFunc func(*pb.NotificationData, ApnsSender, *messaging.Client, *firebase.FirebaseComm, *storage.Storage) error
+type NotifyFunc func(*pb.NotificationData, ApnsSender, *messaging.Client, *firebase.FirebaseComm, *storage.Storage, string) error
 type ApnsSender interface {
 	//Send(token string, p apns.Payload, opts ...apns.SendOption) (*apns.Response, error)
 	Push(n *apns2.Notification) (*apns2.Response, error)
@@ -69,6 +69,7 @@ type Impl struct {
 	fcm         *messaging.Client
 	apnsClient  *apns2.Client
 	receivedNdf *uint32
+	params      *Params
 
 	ndfStopper Stopper
 }
@@ -110,6 +111,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
 		notifyFunc:  notifyUser,
 		fcm:         app,
 		receivedNdf: &receivedNdf,
+		params:      &params,
 	}
 
 	if params.APNS.KeyPath == "" {
@@ -119,7 +121,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
 			return nil, errors.WithMessagef(err, "APNS not properly configured: %+v", params.APNS)
 		}
 
-		authKey, err := apnstoken.AuthKeyFromFile("params.APNS.KeyPath")
+		authKey, err := apnstoken.AuthKeyFromFile(params.APNS.KeyPath)
 		if err != nil {
 			log.Fatal("token error:", err)
 		}
@@ -131,6 +133,12 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
 			TeamID: params.APNS.Issuer,
 		}
 		apnsClient := apns2.NewTokenClient(token)
+		if params.APNS.Dev {
+			jww.INFO.Printf("Running with dev apns gateway")
+			apnsClient.Development()
+		} else {
+			apnsClient.Production()
+		}
 
 		impl.apnsClient = apnsClient
 	}
@@ -170,7 +178,7 @@ func NewImplementation(instance *Impl) *notificationBot.Implementation {
 
 // NotifyUser accepts a UID and service key file path.
 // It handles the logic involved in retrieving a user's token and sending the notification
-func notifyUser(data *pb.NotificationData, apnsClient ApnsSender, fcm *messaging.Client, fc *firebase.FirebaseComm, db *storage.Storage) error {
+func notifyUser(data *pb.NotificationData, apnsClient ApnsSender, fcm *messaging.Client, fc *firebase.FirebaseComm, db *storage.Storage, topic string) error {
 	elist, err := db.GetEphemeral(data.EphemeralID)
 	if err != nil {
 		if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -201,6 +209,7 @@ func notifyUser(data *pb.NotificationData, apnsClient ApnsSender, fcm *messaging
 				Priority:    apns2.PriorityHigh,
 				Payload:     notifPayload,
 				PushType:    apns2.PushTypeAlert,
+				Topic:       topic,
 			}
 			resp, err := apnsClient.Push(notif)
 			//resp, err := apnsClient.Send(u.Token, apns.Payload{
@@ -349,7 +358,7 @@ func (nb *Impl) ReceiveNotificationBatch(notifBatch *pb.NotificationBatch, auth
 
 	fbComm := firebase.NewFirebaseComm()
 	for _, notifData := range notifBatch.GetNotifications() {
-		err := nb.notifyFunc(notifData, nb.apnsClient, nb.fcm, fbComm, nb.Storage)
+		err := nb.notifyFunc(notifData, nb.apnsClient, nb.fcm, fbComm, nb.Storage, nb.params.APNS.BundleID)
 		if err != nil {
 			return err
 		}
diff --git a/notifications/notifications_test.go b/notifications/notifications_test.go
index 8f92766..d54017a 100644
--- a/notifications/notifications_test.go
+++ b/notifications/notifications_test.go
@@ -77,7 +77,7 @@ func TestNotifyUser(t *testing.T) {
 		EphemeralID: eph.EphemeralId,
 		IdentityFP:  nil,
 		MessageHash: nil,
-	}, &MockApns{}, nil, fcBadSend, s)
+	}, &MockApns{}, nil, fcBadSend, s, "")
 	if err == nil {
 		t.Errorf("Should have returned an error")
 	}
@@ -86,7 +86,7 @@ func TestNotifyUser(t *testing.T) {
 		EphemeralID: eph.EphemeralId,
 		IdentityFP:  nil,
 		MessageHash: nil,
-	}, &MockApns{}, nil, fc, s)
+	}, &MockApns{}, nil, fc, s, "")
 	if err != nil {
 		t.Errorf("Failed to notify user properly")
 	}
@@ -297,7 +297,7 @@ func TestImpl_UnregisterForNotifications(t *testing.T) {
 func TestImpl_ReceiveNotificationBatch(t *testing.T) {
 	impl := getNewImpl()
 	dataChan := make(chan *pb.NotificationData)
-	impl.notifyFunc = func(data *pb.NotificationData, apns ApnsSender, f *messaging.Client, fc *firebase.FirebaseComm, s *storage.Storage) error {
+	impl.notifyFunc = func(data *pb.NotificationData, apns ApnsSender, f *messaging.Client, fc *firebase.FirebaseComm, s *storage.Storage, topic string) error {
 		go func() { dataChan <- data }()
 		return nil
 	}
-- 
GitLab