diff --git a/connect/comms.go b/connect/comms.go index 54843bb19631581703f03b3423ebc0c7681cca98..e08865d432cc761e32e0e0984585d7f8d63d36ec 100644 --- a/connect/comms.go +++ b/connect/comms.go @@ -157,10 +157,11 @@ listen: // Build the comms object pc := &ProtoComms{ - networkId: id, - netListener: lis, - tokens: token.NewMap(), - Manager: newManager(), + networkId: id, + netListener: lis, + tokens: token.NewMap(), + Manager: newManager(), + httpsCredChan: make(chan tls.Certificate, 1), } for _, h := range preloadedHosts { @@ -259,10 +260,12 @@ func (c *ProtoComms) ServeWithWeb() { tlsConf.NextProtos = append(tlsConf.NextProtos, "h2", "http/1.1") // Wait for creds to be provided for https - c.httpsCredChan = make(chan tls.Certificate, 1) wg.Done() creds := <-c.httpsCredChan + // We use the GetCertificate field and a function which returns + // c.httpsCertificate wrapped by a ReadLock to allow for future + // changes to the certificate without downtime c.httpsCertificate = &creds tlsConf.GetCertificate = c.GetCertificateFunc() @@ -274,6 +277,7 @@ func (c *ProtoComms) ServeWithWeb() { serverName = cert.DNSNames[0] tlsConf.ServerName = serverName + // Start thread which will update the certificate going forward go c.startUpdateCertificate() tlsLis := tls.NewListener(l, tlsConf) @@ -307,9 +311,10 @@ func (c *ProtoComms) ServeWithWeb() { } // ProvisionHttps provides a tls cert and key to the thread which serves the -// grpcweb endpoints, allowing it to continue and serve with https. -// This function should be called exactly once after ServeWithWeb, any further -// calls will have undefined behavior +// grpcweb endpoints, allowing it to serve with https. Note that https will +// not be usable until this has been called at least once, unblocking the +// listenHTTP func in ServeWithWeb. Future calls will be handled by the +// startUpdateCertificate thread. func (c *ProtoComms) ProvisionHttps(cert, key []byte) error { if c.httpsCredChan == nil { return errors.New("https cred chan does not exist; is https enabled?") @@ -381,6 +386,9 @@ func (c *ProtoComms) Stream(host *Host, f func(conn Connection) ( return c.transmit(host, f) } +// GetCertificateFunc returns a function which serves as the value of +// GetCertificate in the https TLS config, enabling certificate changing +// without downtime func (c *ProtoComms) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { c.httpsCertificateLock.RLock() @@ -389,6 +397,9 @@ func (c *ProtoComms) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certi } } +// startUpdateCertificate starts a thread which listens for further updates to +// the TLS certificate. When received, it takes the write lock & updates +// the stored certificate func (c *ProtoComms) startUpdateCertificate() { for { select {