diff --git a/bindings/channels.go b/bindings/channels.go index d3563bdd1a7ffdcece5ecbf7a508c2236fc115ad..2119f816bcbc323207586bb10e07b39c6b76aa88 100644 --- a/bindings/channels.go +++ b/bindings/channels.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "gitlab.com/elixxir/client/channels" "gitlab.com/elixxir/client/cmix/rounds" + "gitlab.com/elixxir/client/storage/utility" "gitlab.com/elixxir/client/storage/versioned" "gitlab.com/elixxir/client/xxdk" cryptoBroadcast "gitlab.com/elixxir/crypto/broadcast" @@ -1422,3 +1423,139 @@ func (tem *toEventModel) UpdateSentStatus(uuid uint64, tem.em.UpdateSentStatus(int64(uuid), messageID[:], timestamp.UnixNano(), int64(round.ID), int64(status)) } + +//////////////////////////////////////////////////////////////////////////////// +// Channel ChannelDbCipher // +//////////////////////////////////////////////////////////////////////////////// + +// ChannelDbCipher is the bindings layer representation of the [channel.Cipher]. +type ChannelDbCipher struct { + api cryptoChannel.Cipher + salt []byte + id int +} + +// channelDbCipherTrackerSingleton is used to track ChannelDbCipher objects +// so that they can be referenced by ID back over the bindings. +var channelDbCipherTrackerSingleton = &channelDbCipherTracker{ + tracked: make(map[int]*ChannelDbCipher), + count: 0, +} + +// channelDbCipherTracker is a singleton used to keep track of extant +// ChannelDbCipher objects, preventing race conditions created by passing it +// over the bindings. +type channelDbCipherTracker struct { + tracked map[int]*ChannelDbCipher + count int + mux sync.RWMutex +} + +// create creates a ChannelDbCipher from a [channel.Cipher], assigns it a unique +// ID, and adds it to the channelDbCipherTracker. +func (ct *channelDbCipherTracker) create(c cryptoChannel.Cipher) *ChannelDbCipher { + ct.mux.Lock() + defer ct.mux.Unlock() + + chID := ct.count + ct.count++ + + ct.tracked[chID] = &ChannelDbCipher{ + api: c, + id: chID, + } + + return ct.tracked[chID] +} + +// get an ChannelDbCipher from the channelDbCipherTracker given its ID. +func (ct *channelDbCipherTracker) get(id int) (*ChannelDbCipher, error) { + ct.mux.RLock() + defer ct.mux.RUnlock() + + c, exist := ct.tracked[id] + if !exist { + return nil, errors.Errorf( + "Cannot get ChannelDbCipher for ID %d, does not exist", id) + } + + return c, nil +} + +// delete removes a ChannelDbCipher from the channelDbCipherTracker. +func (ct *channelDbCipherTracker) delete(id int) { + ct.mux.Lock() + defer ct.mux.Unlock() + + delete(ct.tracked, id) +} + +// GetChannelDbCipherTrackerFromID returns the ChannelDbCipher with the +// corresponding ID in the tracker. +func GetChannelDbCipherTrackerFromID(id int) (*ChannelDbCipher, error) { + return channelDbCipherTrackerSingleton.get(id) +} + +// NewChannelsDatabaseCipher constructs a ChannelDbCipher object. +// +// Parameters: +// - cmixID - The tracked [Cmix] object ID. +// - password - The password for storage. This should be the same password +// passed into [NewCmix]. +// - plaintTextBlockSize - The maximum size of a payload to be encrypted. +// A payload passed into [ChannelDbCipher.Encrypt] that is larger than +// plaintTextBlockSize will result in an error. +func NewChannelsDatabaseCipher(cmixID int, password []byte, + plaintTextBlockSize int) (*ChannelDbCipher, error) { + // Get user from singleton + user, err := cmixTrackerSingleton.get(cmixID) + if err != nil { + return nil, err + } + + // Generate RNG + stream := user.api.GetRng().GetStream() + + // Load or generate a salt + salt, err := utility.NewOrLoadSalt( + user.api.GetStorage().GetKV(), stream) + if err != nil { + return nil, err + } + + // Construct a cipher + c, err := cryptoChannel.NewCipher(password, salt, + plaintTextBlockSize, stream) + if err != nil { + return nil, err + } + + // Return a cipher + return channelDbCipherTrackerSingleton.create(c), nil +} + +// GetID returns the ID for this ChannelDbCipher in the channelDbCipherTracker. +func (c *ChannelDbCipher) GetID() int { + return c.id +} + +// Encrypt will encrypt the raw data. It will return a ciphertext. Padding is +// done on the plaintext so all encrypted data looks uniform at rest. +// +// Parameters: +// - plaintext - The data to be encrypted. This must be smaller than the block +// size passed into [NewChannelsDatabaseCipher]. If it is larger, this will +// return an error. +func (c *ChannelDbCipher) Encrypt(plaintext []byte) ([]byte, error) { + return c.api.Encrypt(plaintext) +} + +// Decrypt will decrypt the passed in encrypted value. The plaintext will +// be returned by this function. Any padding will be discarded within +// this function. +// +// Parameters: +// - ciphertext - the encrypted data returned by [ChannelDbCipher.Encrypt]. +func (c *ChannelDbCipher) Decrypt(ciphertext []byte) ([]byte, error) { + return c.api.Decrypt(ciphertext) +} diff --git a/storage/utility/encryptionSalt.go b/storage/utility/encryptionSalt.go new file mode 100644 index 0000000000000000000000000000000000000000..18b60c15d4d9f3d39aba0d34fb7efb4877fd92e0 --- /dev/null +++ b/storage/utility/encryptionSalt.go @@ -0,0 +1,69 @@ +package utility + +import ( + jww "github.com/spf13/jwalterweatherman" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/xx_network/primitives/netTime" + "io" +) + +// Storage constats +const ( + saltKey = "encryptionSalt" + saltVersion = 0 + saltPrefix = "encryptionSaltPrefix" +) + +// saltSize is the defined size in bytes of the salt generated in +// newSalt. +const saltSize = 32 + +// NewOrLoadSalt will attempt to find a stored salt if one exists. +// If one does not exist in storage, a new one will be generated. The newly +// generated salt will be stored. +func NewOrLoadSalt(kv *versioned.KV, stream io.Reader) ([]byte, error) { + kv = kv.Prefix(saltPrefix) + salt, err := loadSalt(kv) + if err != nil { + jww.WARN.Printf("Failed to load salt, generating new one...") + salt, err = newSalt(kv, stream) + } + + return salt, err +} + +// loadSalt is a helper function which attempts to load a stored salt from +// memory. +func loadSalt(kv *versioned.KV) ([]byte, error) { + obj, err := kv.Get(saltKey, saltVersion) + if err != nil { + return nil, err + } + + return obj.Data, nil +} + +// newSalt generates a new random salt. This salt is stored and returned +// to the caller. +func newSalt(kv *versioned.KV, stream io.Reader) ([]byte, error) { + // Generate a new salt + salt := make([]byte, saltSize) + _, err := stream.Read(salt) + if err != nil { + return nil, err + } + + // Store salt in storage + obj := &versioned.Object{ + Version: saltVersion, + Timestamp: netTime.Now(), + Data: salt, + } + + err = kv.Set(saltKey, obj) + if err != nil { + return nil, err + } + + return salt, nil +} \ No newline at end of file diff --git a/storage/utility/encryptionSalt_test.go b/storage/utility/encryptionSalt_test.go new file mode 100644 index 0000000000000000000000000000000000000000..f041b9ded2673277022ae55ce3c356e9744ce712 --- /dev/null +++ b/storage/utility/encryptionSalt_test.go @@ -0,0 +1,49 @@ +package utility + +import ( + "bytes" + "gitlab.com/elixxir/client/storage/versioned" + "gitlab.com/elixxir/ekv" + "gitlab.com/xx_network/crypto/csprng" + "testing" +) + +// Smoke test +func TestNewOrLoadSalt(t *testing.T) { + kv := ekv.MakeMemstore() + vkv := versioned.NewKV(kv) + + rng := csprng.NewSystemRNG() + + _, err := NewOrLoadSalt(vkv, rng) + if err != nil { + t.Fatalf("NewOrLoadSalt error: %+v", err) + } + +} + +// Test that calling NewOrLoadSalt twice returns the same +// salt that exists in storage. +func TestLoadSalt(t *testing.T) { + + kv := ekv.MakeMemstore() + vkv := versioned.NewKV(kv) + + rng := csprng.NewSystemRNG() + + original, err := NewOrLoadSalt(vkv, rng) + if err != nil { + t.Fatalf("NewOrLoadSalt error: %+v", err) + } + + loaded, err := NewOrLoadSalt(vkv, rng) + if err != nil { + t.Fatalf("NewOrLoadSalt error: %+v", err) + } + + // Test that loaded matches the original (ie a new one was not generated) + if !bytes.Equal(original, loaded) { + t.Fatalf("Failed to load salt.") + } + +}