From 8b140d33acb056bdbf50d81f3963b85341412f67 Mon Sep 17 00:00:00 2001
From: Jono <jono@privategrity.com>
Date: Thu, 23 Apr 2020 16:36:33 -0700
Subject: [PATCH] Add whitelist and fix tests

---
 cmd/impl.go              | 10 +++++++++-
 cmd/poll.go              |  9 ++++++---
 cmd/poll_test.go         | 14 ++++++++++++--
 cmd/registration_test.go | 34 ++++++++++++++++++++++++++++++++++
 4 files changed, 61 insertions(+), 6 deletions(-)

diff --git a/cmd/impl.go b/cmd/impl.go
index a76374a8..253b633d 100644
--- a/cmd/impl.go
+++ b/cmd/impl.go
@@ -41,6 +41,7 @@ type RegistrationImpl struct {
 	registrationsRemaining  *uint64
 	maxRegistrationAttempts uint64
 	ipNdfBucket             *rateLimiting.BucketMap
+	ipNdfWhitelist          *rateLimiting.Whitelist
 }
 
 //function used to schedule nodes
@@ -104,6 +105,12 @@ func StartRegistration(params Params) (*RegistrationImpl, error) {
 		return nil, err
 	}
 
+	// Initialize IP bucket whitelist file
+	ipWhitelist, err := rateLimiting.InitWhitelist(params.ipBucket.WhitelistFile, nil)
+	if err != nil {
+		return nil, err
+	}
+
 	// Build default parameters
 	regImpl := &RegistrationImpl{
 		State:                   state,
@@ -114,7 +121,8 @@ func StartRegistration(params Params) (*RegistrationImpl, error) {
 		nodeCompleted:           make(chan string, nodeCompletionChanLen),
 		NdfReady:                &ndfReady,
 
-		ipNdfBucket: rateLimiting.CreateBucketMapFromParams(params.ipBucket),
+		ipNdfBucket:    rateLimiting.CreateBucketMapFromParams(params.ipBucket),
+		ipNdfWhitelist: ipWhitelist,
 	}
 
 	// Create timer and channel to be used by routine that clears the number of
diff --git a/cmd/poll.go b/cmd/poll.go
index 63cd9569..036a5552 100644
--- a/cmd/poll.go
+++ b/cmd/poll.go
@@ -16,6 +16,7 @@ import (
 	"gitlab.com/elixxir/primitives/current"
 	"gitlab.com/elixxir/primitives/id"
 	"gitlab.com/elixxir/primitives/ndf"
+	"strings"
 	"sync/atomic"
 )
 
@@ -23,8 +24,6 @@ import (
 func (m *RegistrationImpl) Poll(msg *pb.PermissioningPoll,
 	auth *connect.Auth) (response *pb.PermissionPollResponse, err error) {
 
-	auth.Sender.GetId()
-
 	// Initialize the response
 	response = &pb.PermissionPollResponse{}
 
@@ -89,8 +88,12 @@ func (m *RegistrationImpl) Poll(msg *pb.PermissioningPoll,
 // PollNdf handles the client polling for an updated NDF
 func (m *RegistrationImpl) PollNdf(theirNdfHash []byte, auth *connect.Auth) ([]byte, error) {
 
+	// Remove port from IP address
+	ipAddress := strings.Split(auth.Sender.GetAddress(), ":")
+
 	// Check if the IP bucket is full; if it is, then return an error
-	if m.ipNdfBucket.LookupBucket(auth.Sender.GetAddress()).Add(1) {
+	if m.ipNdfBucket.LookupBucket(ipAddress[0]).Add(1) && !m.ipNdfWhitelist.Exists(ipAddress[0]) {
+		jww.FATAL.Printf("Client has exceeded communications rate limit")
 		return nil, errors.Errorf("Client has exceeded communications rate limit")
 	}
 
diff --git a/cmd/poll_test.go b/cmd/poll_test.go
index 115cf4f5..374dbd4f 100644
--- a/cmd/poll_test.go
+++ b/cmd/poll_test.go
@@ -229,7 +229,12 @@ func TestRegistrationImpl_PollNdf(t *testing.T) {
 	case <-beginScheduling:
 	}
 
-	observedNDFBytes, err := impl.PollNdf(nil, &connect.Auth{})
+	testHost, err := connect.NewHost("", "0.0.0.0", nil, false, false)
+	if err != nil {
+		t.Errorf("failed to create new hosy: %v", err)
+	}
+
+	observedNDFBytes, err := impl.PollNdf(nil, &connect.Auth{Sender: testHost})
 	if err != nil {
 		t.Errorf("failed to update ndf: %v", err)
 	}
@@ -295,10 +300,15 @@ func TestRegistrationImpl_PollNdf_NoNDF(t *testing.T) {
 		t.Errorf("Expected happy path, recieved error: %+v", err)
 	}
 
+	testHost, err := connect.NewHost("", "0.0.0.0", nil, false, false)
+	if err != nil {
+		t.Errorf("failed to create new hosy: %v", err)
+	}
+
 	//Make a client ndf hash that is not up to date
 	clientNdfHash := []byte("test")
 
-	_, err = impl.PollNdf(clientNdfHash, &connect.Auth{})
+	_, err = impl.PollNdf(clientNdfHash, &connect.Auth{Sender: testHost})
 	if err == nil {
 		t.Error("Expected error path, should not have an ndf ready")
 	}
diff --git a/cmd/registration_test.go b/cmd/registration_test.go
index 5cdbc98a..aebfd613 100644
--- a/cmd/registration_test.go
+++ b/cmd/registration_test.go
@@ -9,6 +9,7 @@ import (
 	"fmt"
 	jww "github.com/spf13/jwalterweatherman"
 	nodeComms "gitlab.com/elixxir/comms/node"
+	"gitlab.com/elixxir/primitives/rateLimiting"
 	"gitlab.com/elixxir/primitives/utils"
 	"gitlab.com/elixxir/registration/storage"
 	"gitlab.com/elixxir/registration/storage/node"
@@ -53,6 +54,13 @@ func TestMain(m *testing.M) {
 		publicAddress:             permAddr,
 		maxRegistrationAttempts:   5,
 		registrationCountDuration: time.Hour,
+		ipBucket: rateLimiting.Params{
+			Capacity:      10,
+			LeakRate:      .0000000003,
+			CleanPeriod:   time.Hour,
+			MaxDuration:   time.Hour,
+			WhitelistFile: "test_whitelist.txt",
+		},
 	}
 	nodeComm = nodeComms.StartNode("tmp", nodeAddr, nodeComms.NewImplementation(), nodeCert, nodeKey)
 
@@ -151,6 +159,19 @@ func TestRegCodeExists_RegUser(t *testing.T) {
 
 //Attempt to register a node after the
 func TestCompleteRegistration_HappyPath(t *testing.T) {
+	// Create temporary whitelist file
+	whitelistFilePath := "test_whitelist.txt"
+	err := utils.WriteFile(whitelistFilePath, nil, utils.FilePerms, utils.DirPerms)
+	if err != nil {
+		t.Errorf("Error creating temporary whitelist file:\n%v", err)
+	}
+	defer func() {
+		err = os.RemoveAll(whitelistFilePath)
+		if err != nil {
+			t.Errorf("Error deleting temporary whitelist file:\n%v", err)
+		}
+	}()
+
 	//Crate database
 	storage.PermissioningDb = storage.NewDatabase("test", "password", "regCodes", "0.0.0.0:6969")
 	//Insert a sample regCode
@@ -196,6 +217,19 @@ func TestCompleteRegistration_HappyPath(t *testing.T) {
 
 //Error path: test that trying to register with the same reg code fails
 func TestDoubleRegistration(t *testing.T) {
+	// Create temporary whitelist file
+	whitelistFilePath := "test_whitelist.txt"
+	err := utils.WriteFile(whitelistFilePath, nil, utils.FilePerms, utils.DirPerms)
+	if err != nil {
+		t.Errorf("Error creating temporary whitelist file:\n%v", err)
+	}
+	defer func() {
+		err = os.RemoveAll(whitelistFilePath)
+		if err != nil {
+			t.Errorf("Error deleting temporary whitelist file:\n%v", err)
+		}
+	}()
+
 	//Create database
 	storage.PermissioningDb = storage.NewDatabase("test", "password", "regCodes", "0.0.0.0:6969")
 
-- 
GitLab