generated from grafana/xk6-sql-driver-ramsql
-
Notifications
You must be signed in to change notification settings - Fork 1
/
tls.go
126 lines (109 loc) · 3.7 KB
/
tls.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Adapted more or less unchanged from: https://github.com/grafana/xk6-sql/blob/v0.4.1/sql.go
// It will have to be refactored.
package mysql
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"os"
"strings"
"github.com/go-sql-driver/mysql"
"go.k6.io/k6/js/common"
"go.k6.io/k6/lib/netext"
)
// supportedTLSVersions is a map of TLS versions to their numeric values.
var supportedTLSVersions = map[string]uint16{ //nolint: gochecknoglobals
netext.TLS_1_0: tls.VersionTLS10,
netext.TLS_1_1: tls.VersionTLS11,
netext.TLS_1_2: tls.VersionTLS12,
netext.TLS_1_3: tls.VersionTLS13,
}
// tlsExports add TLS releated exports to name dexports.
func (mod *module) tlsExports() {
// TLS versions
mod.exports.Named["TLS_1_0"] = netext.TLS_1_0
mod.exports.Named["TLS_1_1"] = netext.TLS_1_1
mod.exports.Named["TLS_1_2"] = netext.TLS_1_2
mod.exports.Named["TLS_1_3"] = netext.TLS_1_3
// functions
mod.exports.Named["loadTLS"] = mod.LoadTLS
mod.exports.Named["addTLS"] = mod.AddTLS
}
const tlsConfigKey = "custom"
// TLSConfig contains all the TLS configuration options passed between the JS and Go code.
type TLSConfig struct {
EnableTLS bool `json:"enableTLS"`
InsecureSkipTLSverify bool `json:"insecureSkipTLSverify"`
MinVersion string `json:"minVersion"`
CAcertFile string `json:"caCertFile"`
ClientCertFile string `json:"clientCertFile"`
ClientKeyFile string `json:"clientKeyFile"`
}
// LoadTLS loads the TLS configuration for the SQL module.
func (mod *module) LoadTLS(params map[string]interface{}) error {
runtime := mod.vu.Runtime()
var tlsConfig *TLSConfig
if b, err := json.Marshal(params); err != nil {
common.Throw(runtime, err)
} else {
if err := json.Unmarshal(b, &tlsConfig); err != nil {
common.Throw(runtime, err)
}
}
if _, ok := supportedTLSVersions[tlsConfig.MinVersion]; !ok {
common.Throw(runtime, fmt.Errorf("unsupported TLS version: %s", tlsConfig.MinVersion))
}
mod.tlsConfig = *tlsConfig
if tlsConfig.EnableTLS {
if err := registerTLS(tlsConfigKey, mod.tlsConfig); err != nil {
return err
}
}
return nil
}
// AddTLS add the "tls" connection parameter if TLS is enabled.
func (mod *module) AddTLS(connectionString string) string {
if mod.tlsConfig.EnableTLS {
connectionString = prefixConnectionString(connectionString, tlsConfigKey)
}
return connectionString
}
// prefixConnectionString prefixes the connection string with the TLS configuration key.
func prefixConnectionString(connectionString string, tlsConfigKey string) string {
tlsParam := fmt.Sprintf("tls=%s", tlsConfigKey)
if strings.Contains(connectionString, tlsParam) {
return connectionString
}
var separator string
if strings.Contains(connectionString, "?") {
separator = "&"
} else {
separator = "?"
}
return fmt.Sprintf("%s%s%s", connectionString, separator, tlsParam)
}
// registerTLS loads the ca-cert and registers the TLS configuration with the MySQL driver.
func registerTLS(tlsConfigKey string, tlsConfig TLSConfig) error {
rootCAs := x509.NewCertPool()
pem, err := os.ReadFile(tlsConfig.CAcertFile) //nolint: forbidigo
if err != nil {
return err
}
if ok := rootCAs.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append PEM")
}
clientCerts := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(tlsConfig.ClientCertFile, tlsConfig.ClientKeyFile)
if err != nil {
return err
}
clientCerts = append(clientCerts, certs)
mysqlTLSConfig := &tls.Config{
RootCAs: rootCAs,
Certificates: clientCerts,
MinVersion: supportedTLSVersions[tlsConfig.MinVersion],
InsecureSkipVerify: tlsConfig.InsecureSkipTLSverify, // #nosec G402
}
return mysql.RegisterTLSConfig(tlsConfigKey, mysqlTLSConfig)
}