Skip to content

Commit

Permalink
Add mTLS support into the example app
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSpiritXIII committed Sep 15, 2023
1 parent a240c01 commit f5b8a57
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 3 deletions.
192 changes: 192 additions & 0 deletions examples/instrumentation/go-synthetic/auth.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
package main

import (
"crypto/ed25519"
cryptorand "crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"flag"
"fmt"
"log"
"math/big"
"math/rand"
"net"
"net/http"
"os"
"sort"
"strings"
"time"

"github.com/google/go-cmp/cmp"
)

const (
defaultRSABits = 4096
keyAlgorithmRSA = "rsa"
keyAlgorithmEd25519 = "ed25519"
)

// isFlagSet returns true if the flag was explicitly set in the command line by the user.
func isFlagSet(name string) bool {
found := false
flag.Visit(func(f *flag.Flag) {
Expand All @@ -22,6 +41,170 @@ func isFlagSet(name string) bool {
return found
}

type tlsConfig struct {
// Provide a custom certificate.
certPath string
keyPath string

// Create a new self-signed certificate.
createSelfSigned bool
keyAlgorithm string
serverIP string
serverName string

// General mTLS flags.
insecureSkipVerify bool
minVersion uint
maxVersion uint
}

func newTLSConfigFromFlags() *tlsConfig {
c := &tlsConfig{}
flag.StringVar(&c.certPath, "tls-cert", "", "Path to the server TLS certificate")
flag.StringVar(&c.keyPath, "tls-key", "", "Path to the server TLS key")

flag.BoolVar(&c.createSelfSigned, "tls-create-self-signed", false, "If true, a self-signed certificate will be created and used as the TLS server certificate.")
flag.StringVar(&c.keyAlgorithm, "tls-key-algorithm", keyAlgorithmRSA, fmt.Sprintf("Which algorithm to use when creating a self-signed certificate. Supports %q or %q", keyAlgorithmRSA, keyAlgorithmEd25519))
flag.StringVar(&c.serverName, "tls-server-name", "Example", "Name of the server, used to verify the TLS certificate")
flag.StringVar(&c.serverIP, "tls-server-ip", "", "IP of the server. If unset, this will look for the POD_IP environment variable")

flag.BoolVar(&c.insecureSkipVerify, "tls-insecure-skip-verify", false, "Whether to skip verifying the certificate")
flag.UintVar(&c.minVersion, "tls-min-version", tls.VersionTLS12, "Minimum TLS version")
flag.UintVar(&c.maxVersion, "tls-max-version", tls.VersionTLS13, "Maximum TLS version")
return c
}

func (c *tlsConfig) isUserProvidedCertificate() bool {
return c.certPath != "" || c.keyPath != ""
}

func (c *tlsConfig) isSelfSignedCertificate() bool {
return c.createSelfSigned || isFlagSet("tls-key-algorithm") || isFlagSet("tls-server-name") || isFlagSet("tls-server-ip")
}

func (c *tlsConfig) hasCertificate() bool {
return c.isUserProvidedCertificate() || c.isSelfSignedCertificate()
}

func (c *tlsConfig) isEnabled() bool {
return c.hasCertificate() || isFlagSet("tls-insecure-skip-verify") || isFlagSet("tls-min-version") || isFlagSet("tls-max-version")
}

func (c *tlsConfig) validate() error {
errs := []error{}
if c.createSelfSigned {
if c.isUserProvidedCertificate() {
errs = append(errs, errors.New("--tls-create-self-signed and cannot be used together with use-provided certificate flags --tls-cert or --tls-key"))
}
} else {
for _, flagName := range []string{"tls-key-algorithm", "tls-server-name", "tls-server-ip"} {
if isFlagSet(flagName) {
errs = append(errs, fmt.Errorf("--%s can only be specified with --tls-create-self-signed", flagName))
}
}
}
if c.isUserProvidedCertificate() && (c.certPath == "" || c.keyPath == "") {
errs = append(errs, errors.New("--tls-cert and --tls-key must both be set"))
}
if c.isEnabled() && !c.hasCertificate() {
for _, flagName := range []string{"tls-insecure-skip-verify", "tls-min-version", "tls-max-version"} {
if isFlagSet(flagName) {
errs = append(errs, fmt.Errorf("--%s can only be specified with --tls-cert or --tls-create-self-signed", flagName))
}
}
}

if c.keyAlgorithm != keyAlgorithmRSA && c.keyAlgorithm != keyAlgorithmEd25519 {
errs = append(errs, fmt.Errorf("key algorithm %q is invalid", c.keyAlgorithm))
}
if c.serverIP == "" {
c.serverIP = os.Getenv("POD_IP")
}

return errors.Join(errs...)
}

func (c *tlsConfig) getTLSConfig() (*tls.Config, error) {
if !c.isEnabled() {
return nil, nil
}
config := &tls.Config{
ServerName: c.serverName,
InsecureSkipVerify: c.insecureSkipVerify,
MinVersion: uint16(c.minVersion),
MaxVersion: uint16(c.maxVersion),
}
if c.createSelfSigned {
var privateKey, publicKey any
if c.keyAlgorithm == keyAlgorithmRSA {
rsaPrivateKey, err := rsa.GenerateKey(cryptorand.Reader, defaultRSABits)
if err != nil {
return nil, fmt.Errorf("unable to generate RSA key: %w", err)
}
privateKey = rsaPrivateKey
publicKey = &rsaPrivateKey.PublicKey
} else {
var err error
publicKey, privateKey, err = ed25519.GenerateKey(cryptorand.Reader)
if err != nil {
return nil, fmt.Errorf("unable to generate ed25519 key: %w", err)
}
}

template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{c.serverName},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour * 24 * 30),

KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
if c.serverIP != "" {
template.IPAddresses = append(template.IPAddresses, net.ParseIP(c.serverIP))
}

certBytes, err := x509.CreateCertificate(cryptorand.Reader, &template, &template, publicKey, privateKey)
if err != nil {
return nil, fmt.Errorf("unable to create self-signed certificate: %w", err)
}
certPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
})

privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
log.Println("Unable to marshal private key", err)
os.Exit(1)
}
privateKeyPem := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: privateKeyBytes,
})

cert, err := tls.X509KeyPair(certPem, privateKeyPem)
if err != nil {
log.Println("Unable to encode self-signed certificate", err)
os.Exit(1)
}

config.Certificates = []tls.Certificate{cert}
} else if c.certPath != "" && c.keyPath != "" {
cert, err := tls.LoadX509KeyPair(c.certPath, c.keyPath)
if err != nil {
log.Println("Unable to load server cert and key", err)
os.Exit(1)
}
config.Certificates = []tls.Certificate{cert}
}

return config, nil
}

type basicAuthConfig struct {
username string
password string
Expand Down Expand Up @@ -177,13 +360,15 @@ func (c *oauth2Config) handle(handler http.Handler) http.Handler {
}

type httpClientConfig struct {
tls *tlsConfig
basicAuth *basicAuthConfig
auth *authorizationConfig
oauth2 *oauth2Config
}

func newHttpClientConfigFromFlags() *httpClientConfig {
return &httpClientConfig{
tls: newTLSConfigFromFlags(),
basicAuth: newBasicAuthConfigFromFlags(),
auth: newAuthorizationConfigFromFlags(),
oauth2: newOAuth2ConfigFromFlags(),
Expand All @@ -192,6 +377,9 @@ func newHttpClientConfigFromFlags() *httpClientConfig {

func (c *httpClientConfig) validate() error {
var errs []error
if err := c.tls.validate(); err != nil {
errs = append(errs, err)
}
if c.basicAuth.isEnabled() {
if c.auth.isEnabled() {
errs = append(errs, errors.New("cannot specify both --basic-auth and --auth flags"))
Expand Down Expand Up @@ -230,3 +418,7 @@ func (c *httpClientConfig) handle(handler http.Handler) http.Handler {
}
return handler
}

func (c *httpClientConfig) getTLSConfig() (*tls.Config, error) {
return c.tls.getTLSConfig()
}
16 changes: 14 additions & 2 deletions examples/instrumentation/go-synthetic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,24 @@ func main() {
mux.Handle("/metrics", httpClientConfig.handle(promhttp.HandlerFor(metrics, promhttp.HandlerOpts{Registry: metrics, EnableOpenMetrics: true})))
httpClientConfig.register(mux)

tlsConfig, err := httpClientConfig.getTLSConfig()
if err != nil {
log.Println("Unable to create TLS config", err)
os.Exit(1)
}

server := &http.Server{
Addr: *addr,
Handler: mux,
Addr: *addr,
Handler: mux,
TLSConfig: tlsConfig,
}

g.Add(func() error {
if tlsConfig != nil {
fmt.Printf("Starting server on %q with TLS\n", *addr)
return server.ListenAndServeTLS("", "")
}
fmt.Printf("Starting server on %q\n", *addr)
return server.ListenAndServe()
}, func(err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/GoogleCloudPlatform/prometheus-engine

go 1.18
go 1.20

require (
cloud.google.com/go/compute/metadata v0.2.2
Expand Down

0 comments on commit f5b8a57

Please sign in to comment.