Skip to content
Snippets Groups Projects
Commit 2ed1fec7 authored by Jonah Husson's avatar Jonah Husson
Browse files

Merge branch 'hotfix/atomic-cert' into 'release'

Hotfix/atomic cert

See merge request !7
parents 527c45dd 436c7ab7
No related branches found
No related tags found
1 merge request!7Hotfix/atomic cert
...@@ -182,8 +182,8 @@ func (c *ClientConn) applyCallOptions(opts []CallOption) *callOptions { ...@@ -182,8 +182,8 @@ func (c *ClientConn) applyCallOptions(opts []CallOption) *callOptions {
return &callOptions return &callOptions
} }
func (c *ClientConn) GetReceivedCertificate() (*x509.Certificate, error) { func (c *ClientConn) GetRemoteCertificate() (*x509.Certificate, error) {
return c.transport.GetReceivedCertificate() return c.transport.GetRemoteCertificate()
} }
// copied from rpc_util.go#msgHeader // copied from rpc_util.go#msgHeader
......
...@@ -28,7 +28,7 @@ type unaryTransport struct { ...@@ -28,7 +28,7 @@ type unaryTransport struct {
err error err error
} }
func (t *unaryTransport) GetReceivedCertificate() (*x509.Certificate, error) { func (t *unaryTransport) GetRemoteCertificate() (*x509.Certificate, error) {
panic("UNIMPLEMENTED") panic("UNIMPLEMENTED")
return nil, nil return nil, nil
} }
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
...@@ -25,7 +26,7 @@ import ( ...@@ -25,7 +26,7 @@ import (
type UnaryTransport interface { type UnaryTransport interface {
Header() http.Header Header() http.Header
Send(ctx context.Context, endpoint, contentType string, body io.Reader) (http.Header, []byte, error) Send(ctx context.Context, endpoint, contentType string, body io.Reader) (http.Header, []byte, error)
GetReceivedCertificate() (*x509.Certificate, error) GetRemoteCertificate() (*x509.Certificate, error)
Close() error Close() error
} }
...@@ -34,8 +35,7 @@ type httpTransport struct { ...@@ -34,8 +35,7 @@ type httpTransport struct {
client *http.Client client *http.Client
clientLock *sync.RWMutex clientLock *sync.RWMutex
opts *ConnectOptions opts *ConnectOptions
receivedCertLock sync.RWMutex receivedCertAtomic atomic.Value
receivedCert *x509.Certificate
header http.Header header http.Header
} }
...@@ -102,34 +102,19 @@ func (t *httpTransport) Send(ctx context.Context, endpoint, contentType string, ...@@ -102,34 +102,19 @@ func (t *httpTransport) Send(ctx context.Context, endpoint, contentType string,
if res.TLS != nil { if res.TLS != nil {
if res.TLS.PeerCertificates != nil && len(res.TLS.PeerCertificates) > 0 { if res.TLS.PeerCertificates != nil && len(res.TLS.PeerCertificates) > 0 {
serverCert := res.TLS.PeerCertificates[0] serverCert := res.TLS.PeerCertificates[0]
t.receivedCertLock.RLock() t.receivedCertAtomic.Store(serverCert)
newCert := t.receivedCert == nil || !certsEqual(t.receivedCert, serverCert)
t.receivedCertLock.RUnlock()
if newCert {
t.receivedCertLock.Lock()
stillNewCert := t.receivedCert == nil || !certsEqual(t.receivedCert, serverCert)
if stillNewCert {
t.receivedCert = serverCert
}
t.receivedCertLock.Unlock()
}
} }
} }
return res.Header, respBody, nil return res.Header, respBody, nil
} }
func (t *httpTransport) GetReceivedCertificate() (*x509.Certificate, error) { func (t *httpTransport) GetRemoteCertificate() (*x509.Certificate, error) {
t.receivedCertLock.RLock() receivedCert := t.receivedCertAtomic.Load()
defer t.receivedCertLock.RUnlock() if receivedCert == nil {
if t.receivedCert == nil {
return nil, errors.New("http transport has not yet received a tls certificate") return nil, errors.New("http transport has not yet received a tls certificate")
} }
return t.receivedCert, nil return receivedCert.(*x509.Certificate), nil
}
func certsEqual(c1, c2 *x509.Certificate) bool {
return c1.Issuer.String() == c2.Issuer.String() && c1.SerialNumber == c2.SerialNumber
} }
// Close the httpTransport object // Close the httpTransport object
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment