diff --git a/examples/instrumentation/go-synthetic/README.md b/examples/instrumentation/go-synthetic/README.md index c26a728573..995dd1b846 100644 --- a/examples/instrumentation/go-synthetic/README.md +++ b/examples/instrumentation/go-synthetic/README.md @@ -50,6 +50,13 @@ curl "localhost:8080/token?grant_type=client_credentials&client_id=abc&client_se curl -H "Authorization: Bearer DZ~9UYwD" localhost:8080/metrics ``` +#### mTLS + +```bash +go run ./examples/instrumentation/go-synthetic/ --tls-create-self-signed=true +curl -k https://localhost:8080/metrics +``` + ## Running on Kubernetes If running managed-collection on a Kubernetes cluster, the `go-synthetic` can be diff --git a/examples/instrumentation/go-synthetic/auth.go b/examples/instrumentation/go-synthetic/auth.go index e2291b95a7..5266125ab5 100644 --- a/examples/instrumentation/go-synthetic/auth.go +++ b/examples/instrumentation/go-synthetic/auth.go @@ -1,17 +1,36 @@ package main import ( + "crypto/ed25519" + cryptorand "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "errors" "flag" "fmt" + "log" + "math/big" "math/rand" + "net" "net/http" + "os" "sort" "strings" + "time" "github.com/google/go-cmp/cmp" ) +const ( + defaultRSABits = 4096 + keyAlgorithmRSA = "rsa" + keyAlgorithmEd25519 = "ed25519" +) + +// isFlagSet returns true if the flag was explicitly set in the command line by the user. func isFlagSet(name string) bool { found := false flag.Visit(func(f *flag.Flag) { @@ -22,6 +41,170 @@ func isFlagSet(name string) bool { return found } +type tlsConfig struct { + // Provide a custom certificate. + certPath string + keyPath string + + // Create a new self-signed certificate. + createSelfSigned bool + keyAlgorithm string + serverIP string + serverName string + + // General mTLS flags. + insecureSkipVerify bool + minVersion uint + maxVersion uint +} + +func newTLSConfigFromFlags() *tlsConfig { + c := &tlsConfig{} + flag.StringVar(&c.certPath, "tls-cert", "", "Path to the server TLS certificate") + flag.StringVar(&c.keyPath, "tls-key", "", "Path to the server TLS key") + + flag.BoolVar(&c.createSelfSigned, "tls-create-self-signed", false, "If true, a self-signed certificate will be created and used as the TLS server certificate.") + flag.StringVar(&c.keyAlgorithm, "tls-key-algorithm", keyAlgorithmRSA, fmt.Sprintf("Which algorithm to use when creating a self-signed certificate. Supports %q or %q", keyAlgorithmRSA, keyAlgorithmEd25519)) + flag.StringVar(&c.serverName, "tls-server-name", "Example", "Name of the server, used to verify the TLS certificate") + flag.StringVar(&c.serverIP, "tls-server-ip", "", "IP of the server. If unset, this will look for the POD_IP environment variable") + + flag.BoolVar(&c.insecureSkipVerify, "tls-insecure-skip-verify", false, "Whether to skip verifying the certificate") + flag.UintVar(&c.minVersion, "tls-min-version", tls.VersionTLS12, "Minimum TLS version") + flag.UintVar(&c.maxVersion, "tls-max-version", tls.VersionTLS13, "Maximum TLS version") + return c +} + +func (c *tlsConfig) isUserProvidedCertificate() bool { + return c.certPath != "" || c.keyPath != "" +} + +func (c *tlsConfig) isSelfSignedCertificate() bool { + return c.createSelfSigned || isFlagSet("tls-key-algorithm") || isFlagSet("tls-server-name") || isFlagSet("tls-server-ip") +} + +func (c *tlsConfig) hasCertificate() bool { + return c.isUserProvidedCertificate() || c.isSelfSignedCertificate() +} + +func (c *tlsConfig) isEnabled() bool { + return c.hasCertificate() || isFlagSet("tls-insecure-skip-verify") || isFlagSet("tls-min-version") || isFlagSet("tls-max-version") +} + +func (c *tlsConfig) validate() error { + errs := []error{} + if c.createSelfSigned { + if c.isUserProvidedCertificate() { + errs = append(errs, errors.New("--tls-create-self-signed and cannot be used together with use-provided certificate flags --tls-cert or --tls-key")) + } + } else { + for _, flagName := range []string{"tls-key-algorithm", "tls-server-name", "tls-server-ip"} { + if isFlagSet(flagName) { + errs = append(errs, fmt.Errorf("--%s can only be specified with --tls-create-self-signed", flagName)) + } + } + } + if c.isUserProvidedCertificate() && (c.certPath == "" || c.keyPath == "") { + errs = append(errs, errors.New("--tls-cert and --tls-key must both be set")) + } + if c.isEnabled() && !c.hasCertificate() { + for _, flagName := range []string{"tls-insecure-skip-verify", "tls-min-version", "tls-max-version"} { + if isFlagSet(flagName) { + errs = append(errs, fmt.Errorf("--%s can only be specified with --tls-cert or --tls-create-self-signed", flagName)) + } + } + } + + if c.keyAlgorithm != keyAlgorithmRSA && c.keyAlgorithm != keyAlgorithmEd25519 { + errs = append(errs, fmt.Errorf("key algorithm %q is invalid", c.keyAlgorithm)) + } + if c.serverIP == "" { + c.serverIP = os.Getenv("POD_IP") + } + + return errors.Join(errs...) +} + +func (c *tlsConfig) getTLSConfig() (*tls.Config, error) { + if !c.isEnabled() { + return nil, nil + } + config := &tls.Config{ + ServerName: c.serverName, + InsecureSkipVerify: c.insecureSkipVerify, + MinVersion: uint16(c.minVersion), + MaxVersion: uint16(c.maxVersion), + } + if c.createSelfSigned { + var privateKey, publicKey any + if c.keyAlgorithm == keyAlgorithmRSA { + rsaPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, defaultRSABits) + if err != nil { + return nil, fmt.Errorf("unable to generate RSA key: %w", err) + } + privateKey = rsaPrivateKey + publicKey = &rsaPrivateKey.PublicKey + } else { + var err error + publicKey, privateKey, err = ed25519.GenerateKey(cryptorand.Reader) + if err != nil { + return nil, fmt.Errorf("unable to generate ed25519 key: %w", err) + } + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{c.serverName}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour * 24 * 30), + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + if c.serverIP != "" { + template.IPAddresses = append(template.IPAddresses, net.ParseIP(c.serverIP)) + } + + certBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, publicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("unable to create self-signed certificate: %w", err) + } + certPem := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + if err != nil { + log.Println("Unable to marshal private key", err) + os.Exit(1) + } + privateKeyPem := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + cert, err := tls.X509KeyPair(certPem, privateKeyPem) + if err != nil { + log.Println("Unable to encode self-signed certificate", err) + os.Exit(1) + } + + config.Certificates = []tls.Certificate{cert} + } else if c.certPath != "" && c.keyPath != "" { + cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath) + if err != nil { + log.Println("Unable to load server cert and key", err) + os.Exit(1) + } + config.Certificates = []tls.Certificate{cert} + } + + return config, nil +} + type basicAuthConfig struct { username string password string @@ -177,6 +360,7 @@ func (c *oauth2Config) handle(handler http.Handler) http.Handler { } type httpClientConfig struct { + tls *tlsConfig basicAuth *basicAuthConfig auth *authorizationConfig oauth2 *oauth2Config @@ -184,6 +368,7 @@ type httpClientConfig struct { func newHttpClientConfigFromFlags() *httpClientConfig { return &httpClientConfig{ + tls: newTLSConfigFromFlags(), basicAuth: newBasicAuthConfigFromFlags(), auth: newAuthorizationConfigFromFlags(), oauth2: newOAuth2ConfigFromFlags(), @@ -192,6 +377,9 @@ func newHttpClientConfigFromFlags() *httpClientConfig { func (c *httpClientConfig) validate() error { var errs []error + if err := c.tls.validate(); err != nil { + errs = append(errs, err) + } if c.basicAuth.isEnabled() { if c.auth.isEnabled() { errs = append(errs, errors.New("cannot specify both --basic-auth and --auth flags")) @@ -230,3 +418,7 @@ func (c *httpClientConfig) handle(handler http.Handler) http.Handler { } return handler } + +func (c *httpClientConfig) getTLSConfig() (*tls.Config, error) { + return c.tls.getTLSConfig() +} diff --git a/examples/instrumentation/go-synthetic/main.go b/examples/instrumentation/go-synthetic/main.go index 372e9f526f..6d59a00098 100644 --- a/examples/instrumentation/go-synthetic/main.go +++ b/examples/instrumentation/go-synthetic/main.go @@ -230,12 +230,24 @@ func main() { }))) httpClientConfig.register(mux) + tlsConfig, err := httpClientConfig.getTLSConfig() + if err != nil { + log.Println("Unable to create TLS config", err) + os.Exit(1) + } + server := &http.Server{ - Addr: *addr, - Handler: mux, + Addr: *addr, + Handler: mux, + TLSConfig: tlsConfig, } g.Add(func() error { + if tlsConfig != nil { + fmt.Printf("Starting server on %q with TLS\n", *addr) + return server.ListenAndServeTLS("", "") + } + fmt.Printf("Starting server on %q\n", *addr) return server.ListenAndServe() }, func(err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) diff --git a/go.mod b/go.mod index 065fd4994c..6511678cd1 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/GoogleCloudPlatform/prometheus-engine -go 1.18 +go 1.20 require ( cloud.google.com/go/compute/metadata v0.2.2