Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions pkg/networking/http_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ const HttpsScheme = "https"
// HttpScheme is the HTTP scheme
const HttpScheme = "http"

// Dialer control function for validating addresses prior to connection
func protectedDialerControl(_, address string, _ syscall.RawConn) error {
// ProtectedDialerControl is a Dialer control function for validating addresses
// prior to connection. It returns an error if the resolved address points at a
// private, loopback, or link-local IP, providing an SSRF guard at dial time.
// Pass it to (&net.Dialer{Control: ...}).DialContext on outbound HTTP transports.
func ProtectedDialerControl(_, address string, _ syscall.RawConn) error {
err := AddressReferencesPrivateIp(address)
if err != nil {
return err
Expand Down Expand Up @@ -100,9 +103,12 @@ type HttpClientBuilder struct {
tlsHandshakeTimeout time.Duration
responseHeaderTimeout time.Duration
caCertPath string
clientCertPath string
clientKeyPath string
authTokenFile string
allowPrivate bool
insecureAllowHTTP bool
insecureSkipVerify bool
disableKeepAlives bool
}

Expand All @@ -121,6 +127,13 @@ func (b *HttpClientBuilder) WithCABundle(path string) *HttpClientBuilder {
return b
}

// WithClientCert sets a client certificate and key for mutual TLS (mTLS) authentication.
func (b *HttpClientBuilder) WithClientCert(certPath, keyPath string) *HttpClientBuilder {
b.clientCertPath = certPath
b.clientKeyPath = keyPath
return b
}

// WithTokenFromFile sets the auth token file path
func (b *HttpClientBuilder) WithTokenFromFile(path string) *HttpClientBuilder {
b.authTokenFile = path
Expand All @@ -140,6 +153,13 @@ func (b *HttpClientBuilder) WithInsecureAllowHTTP(allow bool) *HttpClientBuilder
return b
}

// WithInsecureSkipVerify disables TLS server certificate verification.
// WARNING: This is insecure and should NEVER be used in production
func (b *HttpClientBuilder) WithInsecureSkipVerify(skip bool) *HttpClientBuilder {
b.insecureSkipVerify = skip
return b
}

// WithDisableKeepAlives disables HTTP keep-alive on the transport. When true,
// each request uses a fresh connection, ensuring the per-dial SSRF check fires
// on every request rather than being bypassed by a reused connection.
Expand All @@ -164,7 +184,7 @@ func (b *HttpClientBuilder) Build() (*http.Client, error) {

if !b.allowPrivate {
transport.DialContext = (&net.Dialer{
Control: protectedDialerControl,
Control: ProtectedDialerControl,
}).DialContext
}

Expand All @@ -187,6 +207,34 @@ func (b *HttpClientBuilder) Build() (*http.Client, error) {
transport.TLSClientConfig.RootCAs = caCertPool
}

if (b.clientCertPath == "") != (b.clientKeyPath == "") {
return nil, fmt.Errorf("both client certificate and key paths must be set for mTLS")
}
if b.clientCertPath != "" {
cert, err := tls.LoadX509KeyPair(b.clientCertPath, b.clientKeyPath)
if err != nil {
return nil, fmt.Errorf("failed to load client certificate: %w", err)
}

if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
transport.TLSClientConfig.Certificates = []tls.Certificate{cert}
}

if b.insecureSkipVerify {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
//#nosec G402 -- InsecureSkipVerify is intentionally user-configurable via WithInsecureSkipVerify;
// callers must opt in explicitly.
transport.TLSClientConfig.InsecureSkipVerify = true
}

// Start with validation transport
var clientTransport http.RoundTripper = &ValidatingTransport{
Transport: transport,
Expand Down
177 changes: 177 additions & 0 deletions pkg/networking/http_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
package networking

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math/big"
"net/http"
"net/http/httptest"
"os"
Expand All @@ -19,6 +26,35 @@ import (
"golang.org/x/oauth2"
)

// generateTestClientCert creates a self-signed certificate/key pair in PEM
// format for testing mTLS client certificate loading.
func generateTestClientCert(t *testing.T) (certPEM, keyPEM []byte) {
t.Helper()

key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "test-client"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}

certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)

certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})

keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})

return certPEM, keyPEM
}

func TestNewHttpClientBuilder(t *testing.T) {
t.Parallel()

Expand All @@ -44,6 +80,43 @@ func TestHttpClientBuilder_WithCABundle(t *testing.T) {
assert.Equal(t, path, builder.caCertPath)
}

func TestHttpClientBuilder_WithClientCert(t *testing.T) {
t.Parallel()

builder := NewHttpClientBuilder()
certPath, keyPath := "/path/to/client.crt", "/path/to/client.key"

result := builder.WithClientCert(certPath, keyPath)

assert.Same(t, builder, result) // fluent interface
assert.Equal(t, certPath, builder.clientCertPath)
assert.Equal(t, keyPath, builder.clientKeyPath)
}

func TestHttpClientBuilder_WithInsecureSkipVerify(t *testing.T) {
t.Parallel()

tests := []struct {
name string
skip bool
}{
{name: "skip verification", skip: true},
{name: "verify normally", skip: false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

builder := NewHttpClientBuilder()
result := builder.WithInsecureSkipVerify(tt.skip)

assert.Same(t, builder, result) // fluent interface
assert.Equal(t, tt.skip, builder.insecureSkipVerify)
})
}
}

func TestHttpClientBuilder_WithTokenFromFile(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -248,6 +321,23 @@ lT/G27CBRUlDiDhthwY1dccTCFhICg6ENUGqh2I=
assert.NotNil(t, httpTransport.DialContext)
},
},
{
name: "insecure skip verify",
setupBuilder: func() *HttpClientBuilder {
return NewHttpClientBuilder().WithInsecureSkipVerify(true)
},
setupFiles: func(_ *testing.T) (string, string) {
return "", ""
},
expectError: false,
validateClient: func(t *testing.T, client *http.Client) {
t.Helper()
transport := client.Transport.(*ValidatingTransport)
httpTransport := transport.Transport.(*http.Transport)
require.NotNil(t, httpTransport.TLSClientConfig)
assert.True(t, httpTransport.TLSClientConfig.InsecureSkipVerify)
},
},
{
name: "invalid CA certificate file",
setupBuilder: func() *HttpClientBuilder {
Expand Down Expand Up @@ -347,6 +437,93 @@ lT/G27CBRUlDiDhthwY1dccTCFhICg6ENUGqh2I=
}
}

func TestHttpClientBuilder_Build_ClientCert(t *testing.T) {
t.Parallel()

validCertPEM, validKeyPEM := generateTestClientCert(t)

writeFile := func(t *testing.T, name string, data []byte) string {
t.Helper()
path := filepath.Join(t.TempDir(), name)
require.NoError(t, os.WriteFile(path, data, 0600))
return path
}

tests := []struct {
name string
setupBuilder func(t *testing.T) *HttpClientBuilder
expectError bool
errorContains string
validateCert bool
}{
{
name: "valid client certificate and key",
setupBuilder: func(t *testing.T) *HttpClientBuilder {
t.Helper()
certPath := writeFile(t, "client.crt", validCertPEM)
keyPath := writeFile(t, "client.key", validKeyPEM)
return NewHttpClientBuilder().WithClientCert(certPath, keyPath)
},
validateCert: true,
},
{
name: "certificate without matching key",
setupBuilder: func(t *testing.T) *HttpClientBuilder {
t.Helper()
certPath := writeFile(t, "client.crt", validCertPEM)
return NewHttpClientBuilder().WithClientCert(certPath, "")
},
expectError: true,
errorContains: "both client certificate and key paths must be set",
},
{
name: "key without matching certificate",
setupBuilder: func(t *testing.T) *HttpClientBuilder {
t.Helper()
keyPath := writeFile(t, "client.key", validKeyPEM)
return NewHttpClientBuilder().WithClientCert("", keyPath)
},
expectError: true,
errorContains: "both client certificate and key paths must be set",
},
{
name: "invalid certificate content",
setupBuilder: func(t *testing.T) *HttpClientBuilder {
t.Helper()
certPath := writeFile(t, "client.crt", []byte("invalid cert"))
keyPath := writeFile(t, "client.key", []byte("invalid key"))
return NewHttpClientBuilder().WithClientCert(certPath, keyPath)
},
expectError: true,
errorContains: "failed to load client certificate",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

client, err := tt.setupBuilder(t).Build()

if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.errorContains)
assert.Nil(t, client)
return
}

require.NoError(t, err)
require.NotNil(t, client)
if tt.validateCert {
transport := client.Transport.(*ValidatingTransport)
httpTransport := transport.Transport.(*http.Transport)
require.NotNil(t, httpTransport.TLSClientConfig)
assert.Len(t, httpTransport.TLSClientConfig.Certificates, 1)
}
})
}
}

func TestValidatingTransport_RoundTrip(t *testing.T) {
t.Parallel()

Expand Down
7 changes: 6 additions & 1 deletion pkg/runner/webhook_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ import (

// TestWebhookMiddlewareChainIntegration tests the full execution of the webhook middleware chain
// populated by PopulateMiddlewareConfigs in the runner.
//
//nolint:paralleltest // mutates package-level allowPrivateIPsForTesting flag via SetAllowPrivateIPsForTesting
func TestWebhookMiddlewareChainIntegration(t *testing.T) {
t.Parallel()
// The webhook clients built by the middleware factories use the production
// dialer guard, which rejects the 127.0.0.1 httptest servers below as part of
// the SSRF protection. Disable the guard for this test so the dial succeeds.
webhook.SetAllowPrivateIPsForTesting(t)

// 1. Set up a mutating webhook server that adds a new argument field
mutatingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading