Skip to content
Snippets Groups Projects
Commit 370b1ae1 authored by Jonah Husson's avatar Jonah Husson
Browse files

re-enable https for servewithweb, add function to provide certs when available

parent 8732b403
No related branches found
No related tags found
3 merge requests!47Project/https support,!45re-enable https for servewithweb, add function to provide certs when available,!39Merge release into master
...@@ -11,6 +11,7 @@ package connect ...@@ -11,6 +11,7 @@ package connect
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"github.com/golang/protobuf/ptypes/any" "github.com/golang/protobuf/ptypes/any"
"github.com/improbable-eng/grpc-web/go/grpcweb" "github.com/improbable-eng/grpc-web/go/grpcweb"
"github.com/pkg/errors" "github.com/pkg/errors"
...@@ -26,6 +27,8 @@ import ( ...@@ -26,6 +27,8 @@ import (
"net" "net"
"net/http" "net/http"
"strings" "strings"
"sync"
"sync/atomic"
"time" "time"
) )
...@@ -96,6 +99,9 @@ type ProtoComms struct { ...@@ -96,6 +99,9 @@ type ProtoComms struct {
// Used to store the salt used for generating Client Id // Used to store the salt used for generating Client Id
salt []byte salt []byte
needsHttpsCreds atomic.Bool
httpsCredChan chan tls.Certificate
// ------------------------------------------------------------------------- // -------------------------------------------------------------------------
} }
...@@ -228,54 +234,58 @@ func (c *ProtoComms) ServeWithWeb() { ...@@ -228,54 +234,58 @@ func (c *ProtoComms) ServeWithWeb() {
grpcMatcher = cmux.HTTP2() grpcMatcher = cmux.HTTP2()
} }
grpcL := mux.Match(grpcMatcher) grpcL := mux.Match(grpcMatcher)
httpL := mux.Match(cmux.HTTP1()) // TODO did we find a better way to do this than using any with https?
httpL := mux.Match(cmux.Any())
var wg sync.WaitGroup
wg.Add(1)
listenHTTP := func(l net.Listener) { listenHTTP := func(l net.Listener) {
jww.INFO.Printf("Starting HTTP listener on GRPC endpoints: %+v", jww.INFO.Printf("Starting HTTP listener on GRPC endpoints: %+v",
grpcweb.ListGRPCResources(grpcServer)) grpcweb.ListGRPCResources(grpcServer))
httpServer := grpcweb.WrapServer(grpcServer, httpServer := grpcweb.WrapServer(grpcServer,
grpcweb.WithOriginFunc(func(origin string) bool { return true })) grpcweb.WithOriginFunc(func(origin string) bool { return true }))
// This blocks for the lifetime of the listener.
jww.CRITICAL.Printf("Starting HTTP server to without TLS!") if TestingOnlyDisableTLS && c.privateKey == nil {
wg.Done()
jww.WARN.Printf("Starting HTTP server to without TLS!")
if err := http.Serve(l, httpServer); err != nil { if err := http.Serve(l, httpServer); err != nil {
// Cannot panic here due to shared net.Listener // Cannot panic here due to shared net.Listener
jww.ERROR.Printf("Failed to serve HTTP: %+v", err) jww.ERROR.Printf("Failed to serve HTTP: %+v", err)
} }
} else {
// Configure TLS for this listener, using the config from
// http.ServeTLS
tlsConf := &tls.Config{}
tlsConf.NextProtos = append(tlsConf.NextProtos, "h2", "http/1.1")
// Wait for creds to be provided for https
c.httpsCredChan = make(chan tls.Certificate, 1)
wg.Done()
creds := <-c.httpsCredChan
var err error
tlsConf.Certificates = make([]tls.Certificate, 1)
tlsConf.Certificates[0] = creds
var serverName string
if tlsConf.Certificates[0].Leaf == nil {
var cert *x509.Certificate
cert, err = x509.ParseCertificate(creds.Certificate[0])
if err != nil {
jww.FATAL.Panicf("Failed to load TLS certificate: %+v", err)
}
serverName = cert.DNSNames[0]
} else {
serverName = tlsConf.Certificates[0].Leaf.DNSNames[0]
}
tlsConf.ServerName = serverName
// FIXME: Currently only HTTP is used. This must be fixed to use HTTPS tlsLis := tls.NewListener(l, tlsConf)
// before production use. if err := http.Serve(tlsLis, httpServer); err != nil {
// if TestingOnlyDisableTLS && c.privateKey == nil { // Cannot panic here due to shared net.Listener
// jww.WARN.Printf("Starting HTTP server to without TLS!") jww.ERROR.Printf("Failed to serve HTTP: %+v", err)
// if err := http.Serve(l, httpServer); err != nil { }
// // Cannot panic here due to shared net.Listener }
// jww.ERROR.Printf("Failed to serve HTTP: %+v", err)
// }
// } else {
// // Configure TLS for this listener, using the config from
// // http.ServeTLS
// tlsConf := &tls.Config{}
// tlsConf.NextProtos = append(tlsConf.NextProtos, "h2", "http/1.1")
//
// var err error
// var cert *x509.Certificate
// cert, err = tlsCreds.LoadCertificate(string(c.pubKeyPem))
// if err != nil {
// jww.FATAL.Panicf("Failed to load TLS certificate: %+v", err)
// }
// tlsConf.ServerName = cert.DNSNames[0]
//
// tlsConf.Certificates = make([]tls.Certificate, 1)
// tlsConf.Certificates[0], err = tls.X509KeyPair(
// c.pubKeyPem, rsa.CreatePrivateKeyPem(c.privateKey))
// if err != nil {
// jww.FATAL.Panicf("Failed to load TLS key: %+v", err)
// }
// tlsLis := tls.NewListener(l, tlsConf)
// if err := http.Serve(tlsLis, httpServer); err != nil {
// // Cannot panic here due to shared net.Listener
// jww.ERROR.Printf("Failed to serve HTTP: %+v", err)
// }
// }
jww.INFO.Printf("Shutting down HTTP server listener") jww.INFO.Printf("Shutting down HTTP server listener")
} }
...@@ -297,6 +307,24 @@ func (c *ProtoComms) ServeWithWeb() { ...@@ -297,6 +307,24 @@ func (c *ProtoComms) ServeWithWeb() {
go listenHTTP(httpL) go listenHTTP(httpL)
go listenGRPC(grpcL) go listenGRPC(grpcL)
go listenPort() go listenPort()
wg.Wait() // Don't return until https is ready for creds
}
// ProvisionHttps provides a tls cert and key to the thread which serves the
// grpcweb endpoints, allowing it to continue and serve with https.
// This function should be called exactly once after ServeWithWeb, any further
// calls will have undefined behavior
func (c *ProtoComms) ProvisionHttps(cert, key []byte) error {
if c.httpsCredChan == nil {
return errors.New("https cred chan does not exist; is https enabled?")
}
keyPair, err := tls.X509KeyPair(
cert, key)
if err != nil {
return errors.WithMessage(err, "cert & key could not be parsed to valid tls certificate")
}
c.httpsCredChan <- keyPair
return nil
} }
// Shutdown performs a graceful shutdown of the local server. // Shutdown performs a graceful shutdown of the local server.
......
...@@ -105,7 +105,6 @@ func TestWebConnection_TLS(t *testing.T) { ...@@ -105,7 +105,6 @@ func TestWebConnection_TLS(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
hostParams := GetDefaultHostParams() hostParams := GetDefaultHostParams()
TestingOnlyDisableTLS = true
hostParams.ConnectionType = Web hostParams.ConnectionType = Web
h, err := newHost(hostId, addr, certBytes, hostParams) h, err := newHost(hostId, addr, certBytes, hostParams)
...@@ -114,7 +113,7 @@ func TestWebConnection_TLS(t *testing.T) { ...@@ -114,7 +113,7 @@ func TestWebConnection_TLS(t *testing.T) {
} }
s := grpc.NewServer() s := grpc.NewServer()
pb.RegisterGenericServer(s, &TestGenericServer{}) pb.RegisterGenericServer(s, &TestGenericServer{resp: "hello!"})
pk, err := rsa.LoadPrivateKeyFromPem(keyBytes) pk, err := rsa.LoadPrivateKeyFromPem(keyBytes)
if err != nil { if err != nil {
...@@ -125,6 +124,7 @@ func TestWebConnection_TLS(t *testing.T) { ...@@ -125,6 +124,7 @@ func TestWebConnection_TLS(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
errCh := make(chan error)
go func() { go func() {
pc := ProtoComms{ pc := ProtoComms{
networkId: id.NewIdFromString("zezima", id.User, t), networkId: id.NewIdFromString("zezima", id.User, t),
...@@ -137,9 +137,15 @@ func TestWebConnection_TLS(t *testing.T) { ...@@ -137,9 +137,15 @@ func TestWebConnection_TLS(t *testing.T) {
pubKeyPem: certBytes, pubKeyPem: certBytes,
salt: nil, salt: nil,
} }
pc.ServeWithWeb() pc.ServeWithWeb()
errCh <- pc.ProvisionHttps(certBytes, keyBytes)
}() }()
time.Sleep(time.Second * 5) time.Sleep(60)
err = <-errCh
if err != nil {
t.Fatal(err)
}
err = h.connect() err = h.connect()
if err != nil { if err != nil {
...@@ -152,6 +158,9 @@ func TestWebConnection_TLS(t *testing.T) { ...@@ -152,6 +158,9 @@ func TestWebConnection_TLS(t *testing.T) {
resp := &pb.Ack{} resp := &pb.Ack{}
err = h.connection.GetWebConn().Invoke(ctx, "/messages.Generic/AuthenticateToken", &pb.AuthenticatedMessage{}, resp) err = h.connection.GetWebConn().Invoke(ctx, "/messages.Generic/AuthenticateToken", &pb.AuthenticatedMessage{}, resp)
if err != nil { if err != nil {
t.Fatal(err) t.Fatalf("Failed to invoke authenticate: %+v", err)
}
if resp.Error != "hello!" {
t.Errorf("Did not receive expected payload")
} }
} }
...@@ -2,6 +2,7 @@ package connect ...@@ -2,6 +2,7 @@ package connect
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"regexp" "regexp"
...@@ -73,18 +74,15 @@ func (wc *webConn) connectWebHelper() (err error) { ...@@ -73,18 +74,15 @@ func (wc *webConn) connectWebHelper() (err error) {
// FIXME: Currently only HTTP is used. This must be fixed to use HTTPS // FIXME: Currently only HTTP is used. This must be fixed to use HTTPS
// before production use. // before production use.
// Configure TLS options // Configure TLS options
jww.WARN.Printf("grpcWeb connecting to %s without TLS! This is insecure "+ var securityDial []grpcweb.DialOption
"and should only be used for testing.", wc.h.GetAddress()) if wc.h.credentials != nil {
securityDial := []grpcweb.DialOption{grpcweb.WithInsecure()} securityDial = []grpcweb.DialOption{grpcweb.WithTlsCertificate(wc.h.certificate)}
// var securityDial []grpcweb.DialOption } else if TestingOnlyDisableTLS {
// if wc.h.credentials != nil { jww.WARN.Printf("Connecting to %s without TLS!", wc.h.GetAddress())
// securityDial = []grpcweb.DialOption{grpcweb.WithTlsCertificate(wc.h.certificate)} securityDial = []grpcweb.DialOption{grpcweb.WithInsecure()}
// } else if TestingOnlyDisableTLS { } else {
// jww.WARN.Printf("Connecting to %s without TLS!", wc.h.GetAddress()) jww.FATAL.Panicf(tlsError)
// securityDial = []grpcweb.DialOption{grpcweb.WithInsecure()} }
// } else {
// jww.FATAL.Panicf(tlsError)
// }
jww.DEBUG.Printf("Attempting to establish connection to %s using "+ jww.DEBUG.Printf("Attempting to establish connection to %s using "+
"credentials: %v", wc.h.GetAddress(), securityDial) "credentials: %v", wc.h.GetAddress(), securityDial)
...@@ -211,11 +209,12 @@ func (wc *webConn) IsOnline() (time.Duration, bool) { ...@@ -211,11 +209,12 @@ func (wc *webConn) IsOnline() (time.Duration, bool) {
} }
// IMPORTANT - enables better HTTP(S) discovery, because many browsers block CORS by default. // IMPORTANT - enables better HTTP(S) discovery, because many browsers block CORS by default.
req.Header.Add("js.fetch:mode", "no-cors") //req.Header.Add("js.fetch:mode", "no-cors")
jww.TRACE.Printf("(GO request): %+v", req) jww.TRACE.Printf("(GO request): %+v", req)
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
if _, err = client.Do(req); err != nil { var resp *http.Response
if resp, err = client.Do(req); err != nil {
jww.TRACE.Printf("(GO error): %s", err.Error()) jww.TRACE.Printf("(GO error): %s", err.Error())
if checkErrorExceptions(err) { if checkErrorExceptions(err) {
jww.DEBUG.Printf( jww.DEBUG.Printf(
...@@ -227,6 +226,7 @@ func (wc *webConn) IsOnline() (time.Duration, bool) { ...@@ -227,6 +226,7 @@ func (wc *webConn) IsOnline() (time.Duration, bool) {
return time.Since(start), false return time.Since(start), false
} }
} }
fmt.Println(resp)
client.CloseIdleConnections() client.CloseIdleConnections()
return time.Since(start), true return time.Since(start), true
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment