Skip to content
Snippets Groups Projects
Commit fd83dd50 authored by Jonah Husson's avatar Jonah Husson
Browse files

Merge branch 'batchedNotification' into 'release'

Batched notification

See merge request elixxir/notifications-bot!18
parents d520e6e2 ebd2c87c
No related branches found
No related tags found
2 merge requests!21Release,!18Batched notification
...@@ -42,5 +42,9 @@ apnsKeyID: "" ...@@ -42,5 +42,9 @@ apnsKeyID: ""
apnsIssuer: "" apnsIssuer: ""
apnsBundleID: "" apnsBundleID: ""
apnsDev: true apnsDev: true
# Notification params
notificationRate: 30 # Duration in seconds
notificationsPerBatch: 20
# === END YAML # === END YAML
``` ```
...@@ -69,13 +69,16 @@ var rootCmd = &cobra.Command{ ...@@ -69,13 +69,16 @@ var rootCmd = &cobra.Command{
if err != nil { if err != nil {
jww.FATAL.Panicf("Unable to expand apns key path: %+v", err) jww.FATAL.Panicf("Unable to expand apns key path: %+v", err)
} }
viper.SetDefault("notificationRate", 30)
viper.SetDefault("notificationsPerBatch", 20)
// Populate params // Populate params
NotificationParams = notifications.Params{ NotificationParams = notifications.Params{
Address: localAddress, Address: localAddress,
CertPath: certPath, CertPath: certPath,
KeyPath: keyPath, KeyPath: keyPath,
FBCreds: fbCreds, FBCreds: fbCreds,
NotificationRate: viper.GetInt("notificationRate"),
NotificationsPerBatch: viper.GetInt("notificationsPerBatch"),
APNS: notifications.APNSParams{ APNS: notifications.APNSParams{
KeyPath: apnsKeyPath, KeyPath: apnsKeyPath,
KeyID: viper.GetString("apnsKeyID"), KeyID: viper.GetString("apnsKeyID"),
......
...@@ -13,7 +13,7 @@ require ( ...@@ -13,7 +13,7 @@ require (
github.com/spf13/jwalterweatherman v1.1.0 github.com/spf13/jwalterweatherman v1.1.0
github.com/spf13/viper v1.7.0 github.com/spf13/viper v1.7.0
github.com/stretchr/testify v1.7.0 // indirect github.com/stretchr/testify v1.7.0 // indirect
gitlab.com/elixxir/comms v0.0.4-0.20211220224127-670d882e4067 gitlab.com/elixxir/comms v0.0.4-0.20211222154743-2f5bc6365f8d
gitlab.com/elixxir/crypto v0.0.7-0.20211220224022-1f518df56c0f gitlab.com/elixxir/crypto v0.0.7-0.20211220224022-1f518df56c0f
gitlab.com/xx_network/comms v0.0.4-0.20211220223937-660809b69338 gitlab.com/xx_network/comms v0.0.4-0.20211220223937-660809b69338
gitlab.com/xx_network/crypto v0.0.5-0.20211220223913-6088f05a1191 gitlab.com/xx_network/crypto v0.0.5-0.20211220223913-6088f05a1191
......
...@@ -381,8 +381,8 @@ github.com/zeebo/pcg v0.0.0-20181207190024-3cdc6b625a05/go.mod h1:Gr+78ptB0MwXxm ...@@ -381,8 +381,8 @@ github.com/zeebo/pcg v0.0.0-20181207190024-3cdc6b625a05/go.mod h1:Gr+78ptB0MwXxm
github.com/zeebo/pcg v1.0.0 h1:dt+dx+HvX8g7Un32rY9XWoYnd0NmKmrIzpHF7qiTDj0= github.com/zeebo/pcg v1.0.0 h1:dt+dx+HvX8g7Un32rY9XWoYnd0NmKmrIzpHF7qiTDj0=
github.com/zeebo/pcg v1.0.0/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= github.com/zeebo/pcg v1.0.0/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
gitlab.com/elixxir/comms v0.0.4-0.20211220224127-670d882e4067 h1:FlZjW7xdzFkivwTu+D/ZSvH7JLqDYEICuf7OPePNs6U= gitlab.com/elixxir/comms v0.0.4-0.20211222154743-2f5bc6365f8d h1:jcEIpFn2DISLNe1nWiCocJCBgd81KUoxVCe4EVP+QUk=
gitlab.com/elixxir/comms v0.0.4-0.20211220224127-670d882e4067/go.mod h1:kvn6jAHZcRTXYdo9LvZDbO+Nw8V7Gn749fpbyUYlSN8= gitlab.com/elixxir/comms v0.0.4-0.20211222154743-2f5bc6365f8d/go.mod h1:kvn6jAHZcRTXYdo9LvZDbO+Nw8V7Gn749fpbyUYlSN8=
gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo+K0N2GwRRpaNr+tKXMIOVWzmyUD0SbOu2c= gitlab.com/elixxir/crypto v0.0.0-20200804182833-984246dea2c4/go.mod h1:ucm9SFKJo+K0N2GwRRpaNr+tKXMIOVWzmyUD0SbOu2c=
gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA= gitlab.com/elixxir/crypto v0.0.3/go.mod h1:ZNgBOblhYToR4m8tj4cMvJ9UsJAUKq+p0gCp07WQmhA=
gitlab.com/elixxir/crypto v0.0.7-0.20211220224022-1f518df56c0f h1:/VKJYB++EcVy5WvCRm/yIOWHmtqNYDodWZiNKc4bsXg= gitlab.com/elixxir/crypto v0.0.7-0.20211220224022-1f518df56c0f h1:/VKJYB++EcVy5WvCRm/yIOWHmtqNYDodWZiNKc4bsXg=
......
...@@ -65,6 +65,8 @@ func TestImpl_InitCreator(t *testing.T) { ...@@ -65,6 +65,8 @@ func TestImpl_InitCreator(t *testing.T) {
t.FailNow() t.FailNow()
} }
impl, err := StartNotifications(Params{ impl, err := StartNotifications(Params{
NotificationsPerBatch: 20,
NotificationRate: 30,
Address: "", Address: "",
CertPath: "", CertPath: "",
KeyPath: "", KeyPath: "",
......
...@@ -6,10 +6,8 @@ ...@@ -6,10 +6,8 @@
package firebase package firebase
import ( import (
"encoding/base64"
"firebase.google.com/go/messaging" "firebase.google.com/go/messaging"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/elixxir/comms/mixmessages"
"testing" "testing"
"time" "time"
...@@ -21,7 +19,7 @@ import ( ...@@ -21,7 +19,7 @@ import (
) )
// function types for use in notificationsbot struct // function types for use in notificationsbot struct
type SendFunc func(FBSender, string, *mixmessages.NotificationData) (string, error) type SendFunc func(FBSender, string, string) (string, error)
// FirebaseComm is a struct which holds the functions to setup the messaging app and sending notifications // FirebaseComm is a struct which holds the functions to setup the messaging app and sending notifications
// Using a struct in this manner allows us to properly unit test the NotifyUser function // Using a struct in this manner allows us to properly unit test the NotifyUser function
...@@ -74,13 +72,12 @@ func SetupMessagingApp(serviceKeyPath string) (*messaging.Client, error) { ...@@ -74,13 +72,12 @@ func SetupMessagingApp(serviceKeyPath string) (*messaging.Client, error) {
// sendNotification accepts a registration token and service account file // sendNotification accepts a registration token and service account file
// It gets the proper infrastructure, then builds & sends a notification through the firebase admin API // It gets the proper infrastructure, then builds & sends a notification through the firebase admin API
// returns string, error (string is of dubious use, but is returned for the time being) // returns string, error (string is of dubious use, but is returned for the time being)
func sendNotification(app FBSender, token string, data *mixmessages.NotificationData) (string, error) { func sendNotification(app FBSender, token string, data string) (string, error) {
ctx := context.Background() ctx := context.Background()
ttl := 7 * 24 * time.Hour ttl := 7 * 24 * time.Hour
message := &messaging.Message{ message := &messaging.Message{
Data: map[string]string{ Data: map[string]string{
"messagehash": base64.StdEncoding.EncodeToString(data.MessageHash), "notificationsTag": data,
"identityfingerprint": base64.StdEncoding.EncodeToString(data.IdentityFP),
}, },
Android: &messaging.AndroidConfig{ Android: &messaging.AndroidConfig{
Priority: "high", Priority: "high",
......
...@@ -8,7 +8,6 @@ package firebase ...@@ -8,7 +8,6 @@ package firebase
import ( import (
"context" "context"
"firebase.google.com/go/messaging" "firebase.google.com/go/messaging"
"gitlab.com/elixxir/comms/mixmessages"
"testing" "testing"
) )
...@@ -25,11 +24,7 @@ func (MockSender) Send(ctx context.Context, app *messaging.Message) (string, err ...@@ -25,11 +24,7 @@ func (MockSender) Send(ctx context.Context, app *messaging.Message) (string, err
func TestSendNotification(t *testing.T) { func TestSendNotification(t *testing.T) {
app := MockSender{} app := MockSender{}
_, err := sendNotification(app, token, &mixmessages.NotificationData{ _, err := sendNotification(app, token, "data")
EphemeralID: 12345,
IdentityFP: []byte("testfp"),
MessageHash: []byte("testmsghash"),
})
if err != nil { if err != nil {
t.Error(err.Error()) t.Error(err.Error())
} }
......
...@@ -38,8 +38,10 @@ import ( ...@@ -38,8 +38,10 @@ import (
"time" "time"
) )
const notificationsTag = "notificationData"
// Function type definitions for the main operations (poll and notify) // Function type definitions for the main operations (poll and notify)
type NotifyFunc func(*pb.NotificationData, *apns.ApnsComm, *firebase.FirebaseComm, *storage.Storage) error type NotifyFunc func(int64, []*pb.NotificationData, *apns.ApnsComm, *firebase.FirebaseComm, *storage.Storage) error
// Params struct holds info passed in for configuration // Params struct holds info passed in for configuration
type Params struct { type Params struct {
...@@ -47,6 +49,8 @@ type Params struct { ...@@ -47,6 +49,8 @@ type Params struct {
CertPath string CertPath string
KeyPath string KeyPath string
FBCreds string FBCreds string
NotificationsPerBatch int
NotificationRate int
APNS APNSParams APNS APNSParams
} }
type APNSParams struct { type APNSParams struct {
...@@ -67,6 +71,7 @@ type Impl struct { ...@@ -67,6 +71,7 @@ type Impl struct {
apnsClient *apns.ApnsComm apnsClient *apns.ApnsComm
receivedNdf *uint32 receivedNdf *uint32
roundStore sync.Map roundStore sync.Map
maxNotifications int
ndfStopper Stopper ndfStopper Stopper
} }
...@@ -109,6 +114,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) { ...@@ -109,6 +114,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
notifyFunc: notifyUser, notifyFunc: notifyUser,
fcm: fbComm, fcm: fbComm,
receivedNdf: &receivedNdf, receivedNdf: &receivedNdf,
maxNotifications: params.NotificationsPerBatch,
} }
if params.APNS.KeyPath == "" { if params.APNS.KeyPath == "" {
...@@ -152,6 +158,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) { ...@@ -152,6 +158,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
impl.inst = i impl.inst = i
go impl.Cleaner() go impl.Cleaner()
go impl.Sender(params.NotificationRate)
return impl, nil return impl, nil
} }
...@@ -177,15 +184,15 @@ func NewImplementation(instance *Impl) *notificationBot.Implementation { ...@@ -177,15 +184,15 @@ func NewImplementation(instance *Impl) *notificationBot.Implementation {
// NotifyUser accepts a UID and service key file path. // NotifyUser accepts a UID and service key file path.
// It handles the logic involved in retrieving a user's token and sending the notification // It handles the logic involved in retrieving a user's token and sending the notification
func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *firebase.FirebaseComm, db *storage.Storage) error { func notifyUser(ephID int64, data []*pb.NotificationData, apnsClient *apns.ApnsComm, fc *firebase.FirebaseComm, db *storage.Storage) error {
elist, err := db.GetEphemeral(data.EphemeralID) elist, err := db.GetEphemeral(ephID)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
jww.DEBUG.Printf("No registration found for ephemeral ID %+v", data.EphemeralID) jww.DEBUG.Printf("No registration found for ephemeral ID %+v", ephID)
// This path is not an error. if no results are returned, the user hasn't registered for notifications // This path is not an error. if no results are returned, the user hasn't registered for notifications
return nil return nil
} }
return errors.WithMessagef(err, "Could not retrieve registration for ephemeral ID %+v", data.EphemeralID) return errors.WithMessagef(err, "Could not retrieve registration for ephemeral ID %+v", ephID)
} }
for _, e := range elist { for _, e := range elist {
u, err := db.GetUserByHash(e.TransmissionRSAHash) u, err := db.GetUserByHash(e.TransmissionRSAHash)
...@@ -193,14 +200,15 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba ...@@ -193,14 +200,15 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba
return errors.WithMessagef(err, "Failed to lookup user with tRSA hash %+v", e.TransmissionRSAHash) return errors.WithMessagef(err, "Failed to lookup user with tRSA hash %+v", e.TransmissionRSAHash)
} }
notificationsCSV := pb.MakeNotificationsCSV(data)
isAPNS := !strings.Contains(u.Token, ":") isAPNS := !strings.Contains(u.Token, ":")
// mutableContent := 1 // mutableContent := 1
if isAPNS { if isAPNS {
jww.INFO.Printf("Notifying ephemeral ID %+v via APNS to token %+v", data.EphemeralID, u.Token) jww.INFO.Printf("Notifying ephemeral ID %+v via APNS to token %+v", ephID, u.Token)
notifPayload := payload.NewPayload().AlertTitle("Privacy: protected!").AlertBody( notifPayload := payload.NewPayload().AlertTitle("Privacy: protected!").AlertBody(
"Some notifications are not for you to ensure privacy; we hope to remove this notification soon").MutableContent().Custom( "Some notifications are not for you to ensure privacy; we hope to remove this notification soon").MutableContent().Custom(
"messagehash", base64.StdEncoding.EncodeToString(data.MessageHash)).Custom( notificationsTag, notificationsCSV)
"identityfingerprint", base64.StdEncoding.EncodeToString(data.IdentityFP))
notif := &apns2.Notification{ notif := &apns2.Notification{
CollapseID: base64.StdEncoding.EncodeToString(u.TransmissionRSAHash), CollapseID: base64.StdEncoding.EncodeToString(u.TransmissionRSAHash),
DeviceToken: u.Token, DeviceToken: u.Token,
...@@ -219,10 +227,10 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba ...@@ -219,10 +227,10 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba
// return errors.WithMessagef(err, "Failed to remove user registration tRSA hash: %+v", u.TransmissionRSAHash) // return errors.WithMessagef(err, "Failed to remove user registration tRSA hash: %+v", u.TransmissionRSAHash)
//} //}
} else { } else {
jww.INFO.Printf("Notified ephemeral ID %+v [%+v] and received response %+v", data.EphemeralID, u.Token, resp) jww.INFO.Printf("Notified ephemeral ID %+v [%+v] and received response %+v", ephID, u.Token, resp)
} }
} else { } else {
resp, err := fc.SendNotification(fc.Client, u.Token, data) resp, err := fc.SendNotification(fc.Client, u.Token, notificationsCSV)
if err != nil { if err != nil {
// Catch two firebase errors that we don't want to crash on // Catch two firebase errors that we don't want to crash on
// 400 and 404 indicate that the token stored is incorrect // 400 and 404 indicate that the token stored is incorrect
...@@ -240,7 +248,7 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba ...@@ -240,7 +248,7 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba
return errors.WithMessagef(err, "Failed to send notification to user with tRSA hash %+v", u.TransmissionRSAHash) return errors.WithMessagef(err, "Failed to send notification to user with tRSA hash %+v", u.TransmissionRSAHash)
} }
} }
jww.INFO.Printf("Notified ephemeral ID %+v [%+v] and received response %+v", data.EphemeralID, u.Token, resp) jww.INFO.Printf("Notified ephemeral ID %+v [%+v] and received response %+v", ephID, u.Token, resp)
} }
} }
return nil return nil
...@@ -347,11 +355,10 @@ func (nb *Impl) ReceiveNotificationBatch(notifBatch *pb.NotificationBatch, auth ...@@ -347,11 +355,10 @@ func (nb *Impl) ReceiveNotificationBatch(notifBatch *pb.NotificationBatch, auth
jww.INFO.Printf("Received notification batch for round %+v", notifBatch.RoundID) jww.INFO.Printf("Received notification batch for round %+v", notifBatch.RoundID)
buffer := nb.Storage.GetNotificationBuffer()
for _, notifData := range notifBatch.GetNotifications() { for _, notifData := range notifBatch.GetNotifications() {
err := nb.notifyFunc(notifData, nb.apnsClient, nb.fcm, nb.Storage) buffer.Add(notifData)
if err != nil {
return err
}
} }
return nil return nil
...@@ -362,15 +369,41 @@ func (nb *Impl) ReceivedNdf() *uint32 { ...@@ -362,15 +369,41 @@ func (nb *Impl) ReceivedNdf() *uint32 {
} }
func (nb *Impl) Cleaner() { func (nb *Impl) Cleaner() {
for { cleanF := func(key, val interface{}) bool {
f := func(key, val interface{}) bool {
t := val.(time.Time) t := val.(time.Time)
if time.Since(t) > (5 * time.Minute) { if time.Since(t) > (5 * time.Minute) {
nb.roundStore.Delete(key) nb.roundStore.Delete(key)
} }
return true return true
} }
nb.roundStore.Range(f)
time.Sleep(time.Minute * 10) cleanTicker := time.NewTicker(time.Minute * 10)
for {
select {
case <-cleanTicker.C:
nb.roundStore.Range(cleanF)
}
}
}
func (nb *Impl) Sender(sendFreq int) {
sendTicker := time.NewTicker(time.Duration(sendFreq) * time.Second)
for {
select {
case <-sendTicker.C:
notifBuf := nb.Storage.GetNotificationBuffer()
notifMap := notifBuf.Swap(uint(nb.maxNotifications))
for ephID := range notifMap {
localEphID := ephID
notifList := notifMap[localEphID]
go func() {
err := nb.notifyFunc(localEphID, notifList, nb.apnsClient, nb.fcm, nb.Storage)
if err != nil {
jww.ERROR.Printf("Failed to notify %d: %+v", localEphID, err)
}
}()
}
}
} }
} }
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
"gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/id/ephemeral"
"gitlab.com/xx_network/primitives/utils" "gitlab.com/xx_network/primitives/utils"
"os" "os"
"reflect" "strings"
"testing" "testing"
"time" "time"
) )
...@@ -32,10 +32,10 @@ var port = 4200 ...@@ -32,10 +32,10 @@ var port = 4200
// Test notificationbot's notifyuser function // Test notificationbot's notifyuser function
// this mocks the setup and send functions, and only tests the core logic of this function // this mocks the setup and send functions, and only tests the core logic of this function
func TestNotifyUser(t *testing.T) { func TestNotifyUser(t *testing.T) {
badsend := func(firebase.FBSender, string, *pb.NotificationData) (string, error) { badsend := func(firebase.FBSender, string, string) (string, error) {
return "", errors.New("Failed") return "", errors.New("Failed")
} }
send := func(firebase.FBSender, string, *pb.NotificationData) (string, error) { send := func(firebase.FBSender, string, string) (string, error) {
return "", nil return "", nil
} }
fcBadSend := firebase.NewMockFirebaseComm(t, badsend) fcBadSend := firebase.NewMockFirebaseComm(t, badsend)
...@@ -64,20 +64,20 @@ func TestNotifyUser(t *testing.T) { ...@@ -64,20 +64,20 @@ func TestNotifyUser(t *testing.T) {
} }
ac := apns.NewApnsComm(apns2.NewTokenClient(nil), "") ac := apns.NewApnsComm(apns2.NewTokenClient(nil), "")
err = notifyUser(&pb.NotificationData{ err = notifyUser(eph.EphemeralId, []*pb.NotificationData{{
EphemeralID: eph.EphemeralId, EphemeralID: eph.EphemeralId,
IdentityFP: nil, IdentityFP: nil,
MessageHash: nil, MessageHash: nil,
}, ac, fcBadSend, s) }}, ac, fcBadSend, s)
if err == nil { if err == nil {
t.Errorf("Should have returned an error") t.Errorf("Should have returned an error")
} }
err = notifyUser(&pb.NotificationData{ err = notifyUser(eph.EphemeralId, []*pb.NotificationData{{
EphemeralID: eph.EphemeralId, EphemeralID: eph.EphemeralId,
IdentityFP: nil, IdentityFP: nil,
MessageHash: nil, MessageHash: nil,
}, ac, fc, s) }}, ac, fc, s)
if err != nil { if err != nil {
t.Errorf("Failed to notify user properly") t.Errorf("Failed to notify user properly")
} }
...@@ -85,35 +85,37 @@ func TestNotifyUser(t *testing.T) { ...@@ -85,35 +85,37 @@ func TestNotifyUser(t *testing.T) {
// Unit test for startnotifications // Unit test for startnotifications
// tests logic including error cases // tests logic including error cases
//func TestStartNotifications(t *testing.T) { func TestStartNotifications(t *testing.T) {
// wd, err := os.Getwd() wd, err := os.Getwd()
// if err != nil { if err != nil {
// t.Errorf("Failed to get working dir: %+v", err) t.Errorf("Failed to get working dir: %+v", err)
// return return
// } }
//
// params := Params{ params := Params{
// Address: "0.0.0.0:42010", Address: "0.0.0.0:42010",
// APNS: APNSParams{ NotificationsPerBatch: 20,
// KeyPath: wd + "/../testutil/apnsKey.key", NotificationRate: 30,
// KeyID: "WQT68265C5", APNS: APNSParams{
// Issuer: "S6JDM2WW29", KeyPath: "",
// BundleID: "io.xxlabs.messenger", KeyID: "WQT68265C5",
// }, Issuer: "S6JDM2WW29",
// } BundleID: "io.xxlabs.messenger",
// },
// params.KeyPath = wd + "/../testutil/cmix.rip.key" }
// _, err = StartNotifications(params, false, true)
// if err == nil || !strings.Contains(err.Error(), "failed to read certificate at") { params.KeyPath = wd + "/../testutil/cmix.rip.key"
// t.Errorf("Should have thrown an error for no cert path") _, err = StartNotifications(params, false, true)
// } if err == nil || !strings.Contains(err.Error(), "failed to read certificate at") {
// t.Errorf("Should have thrown an error for no cert path")
// params.CertPath = wd + "/../testutil/cmix.rip.crt" }
// _, err = StartNotifications(params, false, true)
// if err != nil { params.CertPath = wd + "/../testutil/cmix.rip.crt"
// t.Errorf("Failed to start notifications successfully: %+v", err) _, err = StartNotifications(params, false, true)
// } if err != nil {
//} t.Errorf("Failed to start notifications successfully: %+v", err)
}
}
// unit test for newimplementation // unit test for newimplementation
// tests logic and error cases // tests logic and error cases
...@@ -288,8 +290,8 @@ func TestImpl_UnregisterForNotifications(t *testing.T) { ...@@ -288,8 +290,8 @@ func TestImpl_UnregisterForNotifications(t *testing.T) {
func TestImpl_ReceiveNotificationBatch(t *testing.T) { func TestImpl_ReceiveNotificationBatch(t *testing.T) {
impl := getNewImpl() impl := getNewImpl()
dataChan := make(chan *pb.NotificationData) dataChan := make(chan *pb.NotificationData)
impl.notifyFunc = func(data *pb.NotificationData, apns *apns.ApnsComm, fc *firebase.FirebaseComm, s *storage.Storage) error { impl.notifyFunc = func(eph int64, data []*pb.NotificationData, apns *apns.ApnsComm, fc *firebase.FirebaseComm, s *storage.Storage) error {
go func() { dataChan <- data }() go func() { dataChan <- data[0] }()
return nil return nil
} }
...@@ -313,14 +315,9 @@ func TestImpl_ReceiveNotificationBatch(t *testing.T) { ...@@ -313,14 +315,9 @@ func TestImpl_ReceiveNotificationBatch(t *testing.T) {
t.Errorf("ReceiveNotificationBatch() returned an error: %+v", err) t.Errorf("ReceiveNotificationBatch() returned an error: %+v", err)
} }
select { nbm := impl.Storage.GetNotificationBuffer().Swap(20)
case result := <-dataChan: if len(nbm[5]) < 1 {
if !reflect.DeepEqual(notifBatch.Notifications[0], result) { t.Errorf("Notification was not added to notification buffer")
t.Errorf("Failed to receive expected NotificationData."+
"\nexpected: %s\nreceived: %s", notifBatch.Notifications[0], result)
}
case <-time.NewTimer(50 * time.Millisecond).C:
t.Error("Timed out while waiting for NotificationData.")
} }
} }
...@@ -328,6 +325,8 @@ func TestImpl_ReceiveNotificationBatch(t *testing.T) { ...@@ -328,6 +325,8 @@ func TestImpl_ReceiveNotificationBatch(t *testing.T) {
func getNewImpl() *Impl { func getNewImpl() *Impl {
wd, _ := os.Getwd() wd, _ := os.Getwd()
params := Params{ params := Params{
NotificationsPerBatch: 20,
NotificationRate: 30,
Address: fmt.Sprintf("0.0.0.0:%d", port), Address: fmt.Sprintf("0.0.0.0:%d", port),
KeyPath: wd + "/../testutil/cmix.rip.key", KeyPath: wd + "/../testutil/cmix.rip.key",
CertPath: wd + "/../testutil/cmix.rip.crt", CertPath: wd + "/../testutil/cmix.rip.crt",
...@@ -335,5 +334,6 @@ func getNewImpl() *Impl { ...@@ -335,5 +334,6 @@ func getNewImpl() *Impl {
} }
port += 1 port += 1
instance, _ := StartNotifications(params, false, true) instance, _ := StartNotifications(params, false, true)
instance.Storage, _ = storage.NewStorage("", "", "", "", "")
return instance return instance
} }
package storage
import (
pb "gitlab.com/elixxir/comms/mixmessages"
"strconv"
"sync"
"sync/atomic"
)
type NotificationBuffer struct {
bufMap atomic.Value
counter *int64
}
func NewNotificationBuffer() *NotificationBuffer {
u := int64(0)
sm := sync.Map{}
nb := &NotificationBuffer{
counter: &u,
bufMap: atomic.Value{},
}
nb.bufMap.Store(&sm)
return nb
}
func (bnm *NotificationBuffer) Swap(maxNotifications uint) map[int64][]*pb.NotificationData {
newSM := &sync.Map{}
m := bnm.bufMap.Swap(newSM).(*sync.Map)
outMap := make(map[int64][]*pb.NotificationData)
f := func(_, value interface{}) bool {
n := value.(*pb.NotificationData)
nSlice, exists := outMap[n.EphemeralID]
if exists {
if uint(len(nSlice)) >= maxNotifications {
bnm.Add(n)
} else {
nSlice = append(nSlice, n)
}
} else {
nSlice = []*pb.NotificationData{n}
}
outMap[n.EphemeralID] = nSlice
return true
}
m.Range(f)
return outMap
}
func (bnm *NotificationBuffer) Add(n *pb.NotificationData) {
c := atomic.AddInt64(bnm.counter, 1)
bnm.bufMap.Load().(*sync.Map).Store(strconv.FormatInt(c, 16), n)
}
...@@ -11,13 +11,15 @@ import ( ...@@ -11,13 +11,15 @@ import (
type Storage struct { type Storage struct {
database database
notificationBuffer *NotificationBuffer
} }
// Create a new Storage object wrapping a database interface // Create a new Storage object wrapping a database interface
// Returns a Storage object and error // Returns a Storage object and error
func NewStorage(username, password, dbName, address, port string) (*Storage, error) { func NewStorage(username, password, dbName, address, port string) (*Storage, error) {
db, err := newDatabase(username, password, dbName, address, port) db, err := newDatabase(username, password, dbName, address, port)
storage := &Storage{db} nb := NewNotificationBuffer()
storage := &Storage{db, nb}
return storage, err return storage, err
} }
...@@ -117,3 +119,7 @@ func (s *Storage) AddEphemeralsForOffset(offset int64, epoch int32, size uint, t ...@@ -117,3 +119,7 @@ func (s *Storage) AddEphemeralsForOffset(offset int64, epoch int32, size uint, t
} }
return nil return nil
} }
func (s *Storage) GetNotificationBuffer()*NotificationBuffer {
return s.notificationBuffer
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment