Skip to content
Snippets Groups Projects
Commit 4f536665 authored by Jonah Husson's avatar Jonah Husson
Browse files

Add basic rate limiting

parent 418c8430
No related branches found
No related tags found
1 merge request!17Add basic rate limiting
...@@ -151,6 +151,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) { ...@@ -151,6 +151,7 @@ func StartNotifications(params Params, noTLS, noFirebase bool) (*Impl, error) {
i.SetGatewayAuthentication() i.SetGatewayAuthentication()
impl.inst = i impl.inst = i
impl.Storage.StartLastSentCleaner()
go impl.Cleaner() go impl.Cleaner()
return impl, nil return impl, nil
...@@ -192,6 +193,13 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba ...@@ -192,6 +193,13 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba
if err != nil { if err != nil {
return errors.WithMessagef(err, "Failed to lookup user with tRSA hash %+v", e.TransmissionRSAHash) return errors.WithMessagef(err, "Failed to lookup user with tRSA hash %+v", e.TransmissionRSAHash)
} }
if t, ok := db.LastSent.Load(u.Token); ok {
lastSent := t.(time.Time)
if time.Since(lastSent) > 3*time.Second {
jww.DEBUG.Printf("Rate limiting token %s [last sent at %+v]", u.Token, lastSent)
continue
}
}
isAPNS := !strings.Contains(u.Token, ":") isAPNS := !strings.Contains(u.Token, ":")
// mutableContent := 1 // mutableContent := 1
...@@ -219,6 +227,7 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba ...@@ -219,6 +227,7 @@ func notifyUser(data *pb.NotificationData, apnsClient *apns.ApnsComm, fc *fireba
// return errors.WithMessagef(err, "Failed to remove user registration tRSA hash: %+v", u.TransmissionRSAHash) // return errors.WithMessagef(err, "Failed to remove user registration tRSA hash: %+v", u.TransmissionRSAHash)
//} //}
} else { } else {
db.LastSent.Store(u.Token, time.Now())
jww.INFO.Printf("Notified ephemeral ID %+v [%+v] and received response %+v", data.EphemeralID, u.Token, resp) jww.INFO.Printf("Notified ephemeral ID %+v [%+v] and received response %+v", data.EphemeralID, u.Token, resp)
} }
} else { } else {
......
...@@ -6,18 +6,20 @@ import ( ...@@ -6,18 +6,20 @@ import (
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/crypto/hash" "gitlab.com/elixxir/crypto/hash"
"gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/id/ephemeral"
"sync"
"time" "time"
) )
type Storage struct { type Storage struct {
database database
LastSent sync.Map
} }
// Create a new Storage object wrapping a database interface // Create a new Storage object wrapping a database interface
// Returns a Storage object and error // Returns a Storage object and error
func NewStorage(username, password, dbName, address, port string) (*Storage, error) { func NewStorage(username, password, dbName, address, port string) (*Storage, error) {
db, err := newDatabase(username, password, dbName, address, port) db, err := newDatabase(username, password, dbName, address, port)
storage := &Storage{db} storage := &Storage{db, sync.Map{}}
return storage, err return storage, err
} }
...@@ -117,3 +119,16 @@ func (s *Storage) AddEphemeralsForOffset(offset int64, epoch int32, size uint, t ...@@ -117,3 +119,16 @@ func (s *Storage) AddEphemeralsForOffset(offset int64, epoch int32, size uint, t
} }
return nil return nil
} }
func (s *Storage) StartLastSentCleaner() {
go func() {
time.Sleep(5 * time.Second)
s.LastSent.Range(func(key, value interface{}) bool {
lastSent := value.(time.Time)
if time.Since(lastSent) > 3*time.Second {
s.LastSent.Delete(key)
}
return true
})
}()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment