Skip to content
Snippets Groups Projects
Commit 66b6194e authored by Jonah Husson's avatar Jonah Husson
Browse files

Add state table, use to store last epoch

parent 4f1fbf6d
No related branches found
No related tags found
3 merge requests!32Release,!26M203/auto recover,!25Restructure notification code for batching
...@@ -4,14 +4,17 @@ import ( ...@@ -4,14 +4,17 @@ import (
"errors" "errors"
"fmt" "fmt"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gitlab.com/elixxir/notifications-bot/storage"
"gitlab.com/xx_network/primitives/id/ephemeral" "gitlab.com/xx_network/primitives/id/ephemeral"
"gorm.io/gorm" "gorm.io/gorm"
"strconv"
"time" "time"
) )
const offsetPhase = ephemeral.Period / ephemeral.NumOffsets const offsetPhase = ephemeral.Period / ephemeral.NumOffsets
const creationLead = 5 * time.Minute const creationLead = 5 * time.Minute
const deletionDelay = -(time.Duration(ephemeral.Period) + creationLead) const deletionDelay = -(time.Duration(ephemeral.Period) + creationLead)
const ephemeralStateKey = "lastEphemeralOffset"
// EphIdCreator runs as a thread to track ephemeral IDs for users who registered to receive push notifications // EphIdCreator runs as a thread to track ephemeral IDs for users who registered to receive push notifications
func (nb *Impl) EphIdCreator() { func (nb *Impl) EphIdCreator() {
...@@ -28,14 +31,18 @@ func (nb *Impl) EphIdCreator() { ...@@ -28,14 +31,18 @@ func (nb *Impl) EphIdCreator() {
func (nb *Impl) initCreator() { func (nb *Impl) initCreator() {
// Retrieve most recent ephemeral from storage // Retrieve most recent ephemeral from storage
var lastEpochTime time.Time var lastEpochTime time.Time
lastEph, err := nb.Storage.GetLatestEphemeral() lastEphEpoch, err := nb.Storage.GetStateValue(ephemeralStateKey)
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
jww.WARN.Printf("Failed to get latest ephemeral (no records found): %+v", err) jww.WARN.Printf("Failed to get latest ephemeral (no records found): %+v", err)
lastEpochTime = time.Now().Add(-time.Duration(ephemeral.Period)) lastEpochTime = time.Now().Add(-time.Duration(ephemeral.Period))
} else if err != nil { } else if err != nil {
jww.FATAL.Panicf("Database lookup for latest ephemeral failed: %+v", err) jww.FATAL.Panicf("Database lookup for latest ephemeral failed: %+v", err)
} else { } else {
lastEpochTime = time.Unix(0, int64(lastEph.Epoch)*offsetPhase) // Epoch time of last ephemeral ID lastEpochInt, err := strconv.Atoi(lastEphEpoch)
if err != nil {
jww.FATAL.Printf("Failed to convert last epoch to int: %+v", err)
}
lastEpochTime = time.Unix(0, int64(lastEpochInt)*offsetPhase) // Epoch time of last ephemeral ID
// If the last epoch is further back than the ephemeral ID period, only go back one period for generation // If the last epoch is further back than the ephemeral ID period, only go back one period for generation
if lastEpochTime.Before(time.Now().Add(-time.Duration(ephemeral.Period))) { if lastEpochTime.Before(time.Now().Add(-time.Duration(ephemeral.Period))) {
lastEpochTime = time.Now().Add(-time.Duration(ephemeral.Period)) lastEpochTime = time.Now().Add(-time.Duration(ephemeral.Period))
...@@ -61,6 +68,10 @@ func (nb *Impl) addEphemerals(start time.Time) { ...@@ -61,6 +68,10 @@ func (nb *Impl) addEphemerals(start time.Time) {
if err != nil { if err != nil {
jww.WARN.Printf("failed to update ephemerals: %+v", err) jww.WARN.Printf("failed to update ephemerals: %+v", err)
} }
err = nb.Storage.UpsertState(&storage.State{
Key: ephemeralStateKey,
Value: strconv.Itoa(int(epoch)),
})
} }
func (nb *Impl) EphIdDeleter() { func (nb *Impl) EphIdDeleter() {
......
...@@ -12,6 +12,7 @@ package notifications ...@@ -12,6 +12,7 @@ package notifications
import ( import (
//"github.com/pkg/errors" //"github.com/pkg/errors"
"bytes" "bytes"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
pb "gitlab.com/elixxir/comms/mixmessages" pb "gitlab.com/elixxir/comms/mixmessages"
"gitlab.com/elixxir/notifications-bot/io" "gitlab.com/elixxir/notifications-bot/io"
...@@ -24,24 +25,24 @@ type Stopper func(timeout time.Duration) bool ...@@ -24,24 +25,24 @@ type Stopper func(timeout time.Duration) bool
// GatewaysChanged function processes the gateways changed event when detected // GatewaysChanged function processes the gateways changed event when detected
// in the NDF // in the NDF
type GatewaysChanged func(ndf pb.NDF) []byte type GatewaysChanged func(ndf pb.NDF) ([]byte, error)
// TrackNdf kicks off the ndf tracking thread // TrackNdf kicks off the ndf tracking thread
func (nb *Impl) TrackNdf() { func (nb *Impl) TrackNdf() {
// Handler function for the gateways changed event // Handler function for the gateways changed event
gatewayEventHandler := func(ndf pb.NDF) []byte { gatewayEventHandler := func(ndf pb.NDF) ([]byte, error) {
jww.DEBUG.Printf("Updating Gateways with new NDF") jww.DEBUG.Printf("Updating Gateways with new NDF")
// TODO: If this returns an error, print that error if it occurs // TODO: If this returns an error, print that error if it occurs
err := nb.inst.UpdatePartialNdf(&ndf) err := nb.inst.UpdatePartialNdf(&ndf)
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to update partial NDF: %+v", err) return nil, errors.WithMessage(err, "Failed to update partial NDF")
} }
err = nb.inst.UpdateGatewayConnections() err = nb.inst.UpdateGatewayConnections()
if err != nil { if err != nil {
jww.ERROR.Printf("Failed to update gateway connections: %+v", err) return nil, errors.WithMessage(err, "Failed to update gateway connections")
} }
atomic.SwapUint32(nb.receivedNdf, 1) atomic.SwapUint32(nb.receivedNdf, 1)
return nb.inst.GetPartialNdf().GetHash() return nb.inst.GetPartialNdf().GetHash(), nil
} }
// Stopping function for the thread // Stopping function for the thread
...@@ -96,7 +97,14 @@ func trackNdf(poller io.PollingConn, quitCh chan bool, gwEvt GatewaysChanged) { ...@@ -96,7 +97,14 @@ func trackNdf(poller io.PollingConn, quitCh chan bool, gwEvt GatewaysChanged) {
if ndf != nil && len(ndf.Ndf) > 0 && !bytes.Equal(ndf.Ndf, lastNdf.Ndf) { if ndf != nil && len(ndf.Ndf) > 0 && !bytes.Equal(ndf.Ndf, lastNdf.Ndf) {
// FIXME: we should be able to get hash from the ndf // FIXME: we should be able to get hash from the ndf
// object, but we can't. // object, but we can't.
go func() { hashCh <- gwEvt(*ndf) }() go func() {
h, err := gwEvt(*ndf)
if err != nil {
jww.ERROR.Println(err)
return
}
hashCh <- h
}()
lastNdf = *ndf lastNdf = *ndf
} }
......
...@@ -7,11 +7,14 @@ import ( ...@@ -7,11 +7,14 @@ import (
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"sync"
"time" "time"
) )
// interface declaration for storage methods // interface declaration for storage methods
type database interface { type database interface {
UpsertState(state *State) error
GetStateValue(key string) (string, error)
upsertUser(user *User) error upsertUser(user *User) error
GetUser(iid []byte) (*User, error) GetUser(iid []byte) (*User, error)
GetUserByHash(transmissionRsaHash []byte) (*User, error) GetUserByHash(transmissionRsaHash []byte) (*User, error)
...@@ -32,6 +35,8 @@ type DatabaseImpl struct { ...@@ -32,6 +35,8 @@ type DatabaseImpl struct {
// Struct implementing the Database Interface with an underlying Map // Struct implementing the Database Interface with an underlying Map
type MapImpl struct { type MapImpl struct {
mut sync.Mutex
states map[string]string
usersById map[string]*User usersById map[string]*User
usersByRsaHash map[string]*User usersByRsaHash map[string]*User
usersByOffset map[int64][]*User usersByOffset map[int64][]*User
...@@ -41,6 +46,11 @@ type MapImpl struct { ...@@ -41,6 +46,11 @@ type MapImpl struct {
ephIDSeq int ephIDSeq int
} }
type State struct {
Key string `gorm:"primary_key"`
Value string `gorm:"NOT NULL"`
}
// Structure representing a User in the Storage backend // Structure representing a User in the Storage backend
type User struct { type User struct {
TransmissionRSAHash []byte `gorm:"primaryKey"` TransmissionRSAHash []byte `gorm:"primaryKey"`
......
...@@ -10,6 +10,7 @@ package storage ...@@ -10,6 +10,7 @@ package storage
import ( import (
"github.com/pkg/errors" "github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -105,3 +106,39 @@ func (impl *DatabaseImpl) GetLatestEphemeral() (*Ephemeral, error) { ...@@ -105,3 +106,39 @@ func (impl *DatabaseImpl) GetLatestEphemeral() (*Ephemeral, error) {
} }
return result[0], nil return result[0], nil
} }
// Inserts the given State into Storage if it does not exist
// Or updates the Database State if its value does not match the given State
func (d *DatabaseImpl) UpsertState(state *State) error {
jww.TRACE.Printf("Attempting to insert State into DB: %+v", state)
// Build a transaction to prevent race conditions
return d.db.Transaction(func(tx *gorm.DB) error {
// Make a copy of the provided state
newState := *state
// Attempt to insert state into the Database,
// or if it already exists, replace state with the Database value
err := tx.FirstOrCreate(state, &State{Key: state.Key}).Error
if err != nil {
return err
}
// If state is already present in the Database, overwrite it with newState
if newState.Value != state.Value {
return tx.Save(newState).Error
}
// Commit
return nil
})
}
// Returns a State's value from Storage with the given key
// Or an error if a matching State does not exist
func (d *DatabaseImpl) GetStateValue(key string) (string, error) {
result := &State{Key: key}
err := d.db.Take(result).Error
jww.TRACE.Printf("Obtained State from DB: %+v", result)
return result.Value, err
}
...@@ -10,8 +10,9 @@ package storage ...@@ -10,8 +10,9 @@ package storage
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman"
"gorm.io/gorm" "gorm.io/gorm"
) )
...@@ -138,3 +139,30 @@ func (m *MapImpl) GetLatestEphemeral() (*Ephemeral, error) { ...@@ -138,3 +139,30 @@ func (m *MapImpl) GetLatestEphemeral() (*Ephemeral, error) {
} }
return e, nil return e, nil
} }
// Inserts the given State into Storage if it does not exist
// Or updates the Database State if its value does not match the given State
func (m *MapImpl) UpsertState(state *State) error {
jww.TRACE.Printf("Attempting to insert State into Map: %+v", state)
m.mut.Lock()
defer m.mut.Unlock()
m.states[state.Key] = state.Value
return nil
}
// Returns a State's value from Storage with the given key
// Or an error if a matching State does not exist
func (m *MapImpl) GetStateValue(key string) (string, error) {
m.mut.Lock()
defer m.mut.Unlock()
if val, ok := m.states[key]; ok {
jww.TRACE.Printf("Obtained State from Map: %+v", val)
return val, nil
}
// NOTE: Other code depends on this error string
return "", errors.Errorf("Unable to locate state for key %s", key)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment