Skip to content
Snippets Groups Projects
Commit 336d8dca authored by Jonah Husson's avatar Jonah Husson
Browse files

Merge db open conditions, remove database interface, insertUser should be an insert not upsert

parent 3b67ac3e
No related branches found
No related tags found
2 merge requests!38Project/haven beta,!36Update notifications bot to allow multiple identities & devices
...@@ -18,37 +18,11 @@ import ( ...@@ -18,37 +18,11 @@ import (
"time" "time"
) )
// interface declaration for storage methods const postgresConnectString = "host=%s port=%s user=%s dbname=%s sslmode=disable"
type database interface { const sqliteDatabasePath = "file:%s?mode=memory&cache=shared"
UpsertState(state *State) error
GetStateValue(key string) (string, error)
insertUser(user *User) error // databaseImpl is a struct which implements database on an underlying gorm.DB
GetUser(transmissionRsaHash []byte) (*User, error) type databaseImpl struct {
deleteUser(transmissionRsaHash []byte) error
GetAllUsers() ([]*User, error)
getIdentity(iid []byte) (*Identity, error)
insertIdentity(identity *Identity) error
getIdentitiesByOffset(offset int64) ([]*Identity, error)
GetOrphanedIdentities() ([]*Identity, error)
insertEphemeral(ephemeral *Ephemeral) error
GetEphemeral(ephemeralId int64) ([]*Ephemeral, error)
GetLatestEphemeral() (*Ephemeral, error)
DeleteOldEphemerals(currentEpoch int32) error
GetToNotify(ephemeralIds []int64) ([]GTNResult, error)
DeleteToken(token string) error
unregisterIdentities(u *User, iids []Identity) error
unregisterTokens(u *User, tokens []Token) error
registerForNotifications(u *User, identity Identity, token string) error
LegacyUnregister(iid []byte) error
}
// DatabaseImpl is a struct which implements database on an underlying gorm.DB
type DatabaseImpl struct {
db *gorm.DB // Stored database connection db *gorm.DB // Stored database connection
} }
...@@ -111,43 +85,38 @@ type Ephemeral struct { ...@@ -111,43 +85,38 @@ type Ephemeral struct {
// Initialize the database interface with database backend // Initialize the database interface with database backend
// Returns a database interface, close function, and error // Returns a database interface, close function, and error
func newDatabase(username, password, dbName, address, func newDatabase(username, password, dbName, address,
port string) (database, error) { port string) (*databaseImpl, error) {
var err error var err error
var db *gorm.DB var db *gorm.DB
var useSqlite bool
var dialector gorm.Dialector
// Connect to the database if the correct information is provided // Connect to the database if the correct information is provided
if address != "" && port != "" { if address != "" && port != "" {
// Create the database connection // Create the database connection
connectString := fmt.Sprintf( connectString := fmt.Sprintf(
"host=%s port=%s user=%s dbname=%s sslmode=disable", postgresConnectString,
address, port, username, dbName) address, port, username, dbName)
// Handle empty database password // Handle empty database password
if len(password) > 0 { if len(password) > 0 {
connectString += fmt.Sprintf(" password=%s", password) connectString += fmt.Sprintf(" password=%s", password)
} }
db, err = gorm.Open(postgres.Open(connectString), &gorm.Config{ dialector = postgres.Open(connectString)
Logger: logger.New(jww.TRACE, logger.Config{LogLevel: logger.Info}),
})
}
// Return the map-backend interface
// in the event there is a database error or information is not provided
if (address == "" || port == "") || err != nil {
if err != nil {
jww.WARN.Printf("Unable to initialize database backend: %+v", err)
} else { } else {
useSqlite = true
jww.WARN.Printf("Database backend connection information not provided") jww.WARN.Printf("Database backend connection information not provided")
temporaryDbPath := fmt.Sprintf(sqliteDatabasePath, dbName)
dialector = sqlite.Open(temporaryDbPath)
} }
temporaryDbPath := fmt.Sprintf("file:%s?mode=memory&cache=shared", dbName)
// Create the database connection // Create the database connection
db, err = gorm.Open(sqlite.Open(temporaryDbPath), &gorm.Config{ db, err = gorm.Open(dialector, &gorm.Config{
Logger: logger.New(jww.TRACE, logger.Config{LogLevel: logger.Info}), Logger: logger.New(jww.TRACE, logger.Config{LogLevel: logger.Info}),
}) })
if err != nil { if err != nil {
return nil, errors.Errorf("Unable to initialize in-memory sqlite database backend: %+v", err) return nil, errors.Errorf("Unable to initialize in-memory sqlite database backend: %+v", err)
} }
if useSqlite {
// Enable foreign keys because they are disabled in SQLite by default // Enable foreign keys because they are disabled in SQLite by default
if err = db.Exec("PRAGMA foreign_keys = ON", nil).Error; err != nil { if err = db.Exec("PRAGMA foreign_keys = ON", nil).Error; err != nil {
return nil, err return nil, err
...@@ -157,14 +126,12 @@ func newDatabase(username, password, dbName, address, ...@@ -157,14 +126,12 @@ func newDatabase(username, password, dbName, address,
if err = db.Exec("PRAGMA journal_mode = WAL;", nil).Error; err != nil { if err = db.Exec("PRAGMA journal_mode = WAL;", nil).Error; err != nil {
return nil, err return nil, err
} }
defer jww.INFO.Println("SQLite in-memory backend initialized successfully!")
} }
// Get and configure the internal database ConnPool // Get and configure the internal database ConnPool
sqlDb, err := db.DB() sqlDb, err := db.DB()
if err != nil { if err != nil {
return database(&DatabaseImpl{}), errors.Errorf("Unable to configure database connection pool: %+v", err) return nil, errors.Errorf("Unable to configure database connection pool: %+v", err)
} }
// SetMaxIdleConns sets the maximum number of connections in the idle connection pool. // SetMaxIdleConns sets the maximum number of connections in the idle connection pool.
sqlDb.SetMaxIdleConns(10) sqlDb.SetMaxIdleConns(10)
...@@ -181,15 +148,15 @@ func newDatabase(username, password, dbName, address, ...@@ -181,15 +148,15 @@ func newDatabase(username, password, dbName, address,
for _, model := range models { for _, model := range models {
err = db.AutoMigrate(model) err = db.AutoMigrate(model)
if err != nil { if err != nil {
return database(&DatabaseImpl{}), err return nil, err
} }
} }
// Build the interface // Build the interface
di := &DatabaseImpl{ di := &databaseImpl{
db: db, db: db,
} }
jww.INFO.Println("Database backend initialized successfully!") jww.INFO.Println("Database backend initialized successfully!")
return database(di), nil return di, nil
} }
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
package storage package storage
import ( import (
"bytes"
"github.com/pkg/errors" "github.com/pkg/errors"
jww "github.com/spf13/jwalterweatherman" jww "github.com/spf13/jwalterweatherman"
"gorm.io/gorm" "gorm.io/gorm"
...@@ -19,7 +18,7 @@ import ( ...@@ -19,7 +18,7 @@ import (
// Inserts the given State into Storage if it does not exist // Inserts the given State into Storage if it does not exist
// Or updates the Database State if its value does not match the given State // Or updates the Database State if its value does not match the given State
func (d *DatabaseImpl) UpsertState(state *State) error { func (d *databaseImpl) UpsertState(state *State) error {
jww.TRACE.Printf("Attempting to insert State into DB: %+v", state) jww.TRACE.Printf("Attempting to insert State into DB: %+v", state)
// Build a transaction to prevent race conditions // Build a transaction to prevent race conditions
...@@ -33,36 +32,24 @@ func (d *DatabaseImpl) UpsertState(state *State) error { ...@@ -33,36 +32,24 @@ func (d *DatabaseImpl) UpsertState(state *State) error {
// Returns a State's value from Storage with the given key // Returns a State's value from Storage with the given key
// Or an error if a matching State does not exist // Or an error if a matching State does not exist
func (d *DatabaseImpl) GetStateValue(key string) (string, error) { func (d *databaseImpl) GetStateValue(key string) (string, error) {
result := &State{Key: key} result := &State{Key: key}
err := d.db.Take(result).Error err := d.db.Take(result).Error
jww.TRACE.Printf("Obtained State from DB: %+v", result) jww.TRACE.Printf("Obtained State from DB: %+v", result)
return result.Value, err return result.Value, err
} }
func (d *DatabaseImpl) DeleteToken(token string) error { func (d *databaseImpl) DeleteToken(token string) error {
return d.db.Where("token = ?", token).Delete(&Token{Token: token}).Error return d.db.Where("token = ?", token).Delete(&Token{Token: token}).Error
} }
// Insert or Update User into backend // Insert or Update User into backend
func (d *DatabaseImpl) insertUser(user *User) error { func (d *databaseImpl) insertUser(user *User) error {
return d.db.Transaction(func(tx *gorm.DB) error { return d.db.Clauses(clause.OnConflict{DoNothing: true}).Create(user).Error
ret := &User{}
err := tx.FirstOrCreate(ret, user).Error
if err != nil {
return err
}
if bytes.Equal(ret.TransmissionRSAHash, user.TransmissionRSAHash) {
return tx.Save(&user).Error
}
return nil
})
} }
// Obtain User from backend by primary key // Obtain User from backend by primary key
func (d *DatabaseImpl) GetUser(transmissionRsaHash []byte) (*User, error) { func (d *databaseImpl) GetUser(transmissionRsaHash []byte) (*User, error) {
u := &User{} u := &User{}
err := d.db.Preload("Identities").Preload("Tokens").Take(u, "transmission_rsa_hash = ?", transmissionRsaHash).Error err := d.db.Preload("Identities").Preload("Tokens").Take(u, "transmission_rsa_hash = ?", transmissionRsaHash).Error
if err != nil { if err != nil {
...@@ -72,7 +59,7 @@ func (d *DatabaseImpl) GetUser(transmissionRsaHash []byte) (*User, error) { ...@@ -72,7 +59,7 @@ func (d *DatabaseImpl) GetUser(transmissionRsaHash []byte) (*User, error) {
} }
// Delete User from backend by primary key // Delete User from backend by primary key
func (d *DatabaseImpl) deleteUser(transmissionRsaHash []byte) error { func (d *databaseImpl) deleteUser(transmissionRsaHash []byte) error {
err := d.db.Delete(&User{ err := d.db.Delete(&User{
TransmissionRSAHash: transmissionRsaHash, TransmissionRSAHash: transmissionRsaHash,
}).Error }).Error
...@@ -82,13 +69,13 @@ func (d *DatabaseImpl) deleteUser(transmissionRsaHash []byte) error { ...@@ -82,13 +69,13 @@ func (d *DatabaseImpl) deleteUser(transmissionRsaHash []byte) error {
return nil return nil
} }
func (d *DatabaseImpl) GetAllUsers() ([]*User, error) { func (d *databaseImpl) GetAllUsers() ([]*User, error) {
var dest []*User var dest []*User
return dest, d.db.Find(&dest).Error return dest, d.db.Find(&dest).Error
} }
// Obtain Identity from backend by primary key // Obtain Identity from backend by primary key
func (d *DatabaseImpl) getIdentity(iid []byte) (*Identity, error) { func (d *databaseImpl) getIdentity(iid []byte) (*Identity, error) {
i := &Identity{} i := &Identity{}
err := d.db.Take(i, "intermediary_id = ?", iid).Error err := d.db.Take(i, "intermediary_id = ?", iid).Error
if err != nil { if err != nil {
...@@ -97,28 +84,28 @@ func (d *DatabaseImpl) getIdentity(iid []byte) (*Identity, error) { ...@@ -97,28 +84,28 @@ func (d *DatabaseImpl) getIdentity(iid []byte) (*Identity, error) {
return i, nil return i, nil
} }
func (d *DatabaseImpl) insertIdentity(identity *Identity) error { func (d *databaseImpl) insertIdentity(identity *Identity) error {
return d.db.Clauses(clause.OnConflict{ return d.db.Clauses(clause.OnConflict{
DoNothing: true, DoNothing: true,
}).Create(identity).Error }).Create(identity).Error
} }
func (d *DatabaseImpl) getIdentitiesByOffset(offset int64) ([]*Identity, error) { func (d *databaseImpl) getIdentitiesByOffset(offset int64) ([]*Identity, error) {
var result []*Identity var result []*Identity
err := d.db.Where(&Identity{OffsetNum: offset}).Find(&result).Error err := d.db.Where(&Identity{OffsetNum: offset}).Find(&result).Error
return result, err return result, err
} }
func (d *DatabaseImpl) GetOrphanedIdentities() ([]*Identity, error) { func (d *databaseImpl) GetOrphanedIdentities() ([]*Identity, error) {
var dest []*Identity var dest []*Identity
return dest, d.db.Find(&dest, "NOT EXISTS (select * from ephemerals where ephemerals.intermediary_id = identities.intermediary_id)").Error return dest, d.db.Find(&dest, "NOT EXISTS (select * from ephemerals where ephemerals.intermediary_id = identities.intermediary_id)").Error
} }
func (d *DatabaseImpl) insertEphemeral(ephemeral *Ephemeral) error { func (d *databaseImpl) insertEphemeral(ephemeral *Ephemeral) error {
return d.db.Create(&ephemeral).Error return d.db.Create(&ephemeral).Error
} }
func (d *DatabaseImpl) GetEphemeral(ephemeralId int64) ([]*Ephemeral, error) { func (d *databaseImpl) GetEphemeral(ephemeralId int64) ([]*Ephemeral, error) {
var result []*Ephemeral var result []*Ephemeral
err := d.db.Where("ephemeral_id = ?", ephemeralId).Find(&result).Error err := d.db.Where("ephemeral_id = ?", ephemeralId).Find(&result).Error
if err != nil { if err != nil {
...@@ -136,7 +123,7 @@ type GTNResult struct { ...@@ -136,7 +123,7 @@ type GTNResult struct {
EphemeralId int64 EphemeralId int64
} }
func (d *DatabaseImpl) GetToNotify(ephemeralIds []int64) ([]GTNResult, error) { func (d *databaseImpl) GetToNotify(ephemeralIds []int64) ([]GTNResult, error) {
var result []GTNResult var result []GTNResult
err := d.db.Transaction(func(tx *gorm.DB) error { err := d.db.Transaction(func(tx *gorm.DB) error {
t1 := tx.Table("identities").Select("ephemerals.ephemeral_id, identities.intermediary_id").Joins("inner join ephemerals on ephemerals.intermediary_id = identities.intermediary_id").Where("ephemerals.ephemeral_id in ?", ephemeralIds) t1 := tx.Table("identities").Select("ephemerals.ephemeral_id, identities.intermediary_id").Joins("inner join ephemerals on ephemerals.intermediary_id = identities.intermediary_id").Where("ephemerals.ephemeral_id in ?", ephemeralIds)
...@@ -147,12 +134,12 @@ func (d *DatabaseImpl) GetToNotify(ephemeralIds []int64) ([]GTNResult, error) { ...@@ -147,12 +134,12 @@ func (d *DatabaseImpl) GetToNotify(ephemeralIds []int64) ([]GTNResult, error) {
return result, err return result, err
} }
func (d *DatabaseImpl) DeleteOldEphemerals(currentEpoch int32) error { func (d *databaseImpl) DeleteOldEphemerals(currentEpoch int32) error {
res := d.db.Where("epoch < ?", currentEpoch).Delete(&Ephemeral{}) res := d.db.Where("epoch < ?", currentEpoch).Delete(&Ephemeral{})
return res.Error return res.Error
} }
func (d *DatabaseImpl) GetLatestEphemeral() (*Ephemeral, error) { func (d *databaseImpl) GetLatestEphemeral() (*Ephemeral, error) {
var result []*Ephemeral var result []*Ephemeral
err := d.db.Order("epoch desc").Limit(1).Find(&result).Error err := d.db.Order("epoch desc").Limit(1).Find(&result).Error
if err != nil { if err != nil {
...@@ -164,7 +151,7 @@ func (d *DatabaseImpl) GetLatestEphemeral() (*Ephemeral, error) { ...@@ -164,7 +151,7 @@ func (d *DatabaseImpl) GetLatestEphemeral() (*Ephemeral, error) {
return result[0], nil return result[0], nil
} }
func (d *DatabaseImpl) registerForNotifications(u *User, identity Identity, token string) error { func (d *databaseImpl) registerForNotifications(u *User, identity Identity, token string) error {
return d.db.Transaction(func(tx *gorm.DB) error { return d.db.Transaction(func(tx *gorm.DB) error {
err := tx.Model(u).Association("Identities").Append(&identity) err := tx.Model(u).Association("Identities").Append(&identity)
if err != nil { if err != nil {
...@@ -182,7 +169,7 @@ func (d *DatabaseImpl) registerForNotifications(u *User, identity Identity, toke ...@@ -182,7 +169,7 @@ func (d *DatabaseImpl) registerForNotifications(u *User, identity Identity, toke
}) })
} }
func (d *DatabaseImpl) unregisterIdentities(u *User, iids []Identity) error { func (d *databaseImpl) unregisterIdentities(u *User, iids []Identity) error {
return d.db.Transaction(func(tx *gorm.DB) error { return d.db.Transaction(func(tx *gorm.DB) error {
err := tx.Model(&u).Association("Identities").Delete(iids) err := tx.Model(&u).Association("Identities").Delete(iids)
if err != nil { if err != nil {
...@@ -216,7 +203,7 @@ func (d *DatabaseImpl) unregisterIdentities(u *User, iids []Identity) error { ...@@ -216,7 +203,7 @@ func (d *DatabaseImpl) unregisterIdentities(u *User, iids []Identity) error {
}) })
} }
func (d *DatabaseImpl) unregisterTokens(u *User, tokens []Token) error { func (d *databaseImpl) unregisterTokens(u *User, tokens []Token) error {
return d.db.Transaction(func(tx *gorm.DB) error { return d.db.Transaction(func(tx *gorm.DB) error {
for _, t := range tokens { for _, t := range tokens {
err := tx.Delete(t).Error err := tx.Delete(t).Error
...@@ -241,7 +228,7 @@ func (d *DatabaseImpl) unregisterTokens(u *User, tokens []Token) error { ...@@ -241,7 +228,7 @@ func (d *DatabaseImpl) unregisterTokens(u *User, tokens []Token) error {
}) })
} }
func (d *DatabaseImpl) LegacyUnregister(iid []byte) error { func (d *databaseImpl) LegacyUnregister(iid []byte) error {
return d.db.Transaction(func(tx *gorm.DB) error { return d.db.Transaction(func(tx *gorm.DB) error {
var res Identity var res Identity
err := tx.Preload("Users").Find(&res, "intermediary_id = ?", iid).Error err := tx.Preload("Users").Find(&res, "intermediary_id = ?", iid).Error
......
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
) )
type Storage struct { type Storage struct {
database *databaseImpl
notificationBuffer *NotificationBuffer notificationBuffer *NotificationBuffer
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment