Skip to content

Commit

Permalink
*: adjust TLS behaviour for dumpling and lightning (#37479)
Browse files Browse the repository at this point in the history
close #37480
  • Loading branch information
lance6716 authored Sep 7, 2022
1 parent 4cb0d1f commit 796fb1f
Show file tree
Hide file tree
Showing 19 changed files with 498 additions and 404 deletions.
117 changes: 43 additions & 74 deletions br/pkg/lightning/common/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,97 +17,62 @@ package common
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"net/http/httptest"
"os"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/httputil"
"github.com/pingcap/tidb/util"
"github.com/tikv/client-go/v2/config"
pd "github.com/tikv/pd/client"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

type TLS struct {
caPath string
certPath string
keyPath string
inner *tls.Config
client *http.Client
url string
caPath string
certPath string
keyPath string
caBytes []byte
certBytes []byte
keyBytes []byte
inner *tls.Config
client *http.Client
url string
}

// ToTLSConfig constructs a `*tls.Config` from the CA, certification and key
// paths.
//
// If the CA path is empty, returns nil.
func ToTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
if len(caPath) == 0 {
return nil, nil
}

// Create a certificate pool from CA
certPool := x509.NewCertPool()
ca, err := os.ReadFile(caPath)
// NewTLS constructs a new HTTP client with TLS configured with the CA,
// certificate and key paths.
func NewTLS(caPath, certPath, keyPath, host string, caBytes, certBytes, keyBytes []byte) (*TLS, error) {
inner, err := util.NewTLSConfig(
util.WithCAPath(caPath),
util.WithCertAndKeyPath(certPath, keyPath),
util.WithCAContent(caBytes),
util.WithCertAndKeyContent(certBytes, keyBytes),
)
if err != nil {
return nil, errors.Annotate(err, "could not read ca certificate")
}

// Append the certificates from the CA
if !certPool.AppendCertsFromPEM(ca) {
return nil, errors.New("failed to append ca certs")
}

tlsConfig := &tls.Config{
RootCAs: certPool,
NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2.
MinVersion: tls.VersionTLS12,
}

if len(certPath) != 0 && len(keyPath) != 0 {
loadCert := func() (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, errors.Annotate(err, "could not load client key pair")
}
return &cert, nil
}
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return loadCert()
}
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return loadCert()
}
return nil, errors.Trace(err)
}
return tlsConfig, nil
}

// NewTLS constructs a new HTTP client with TLS configured with the CA,
// certificate and key paths.
//
// If the CA path is empty, returns an instance where TLS is disabled.
func NewTLS(caPath, certPath, keyPath, host string) (*TLS, error) {
if len(caPath) == 0 {
if inner == nil {
return &TLS{
inner: nil,
client: &http.Client{},
url: "http://" + host,
}, nil
}
inner, err := ToTLSConfig(caPath, certPath, keyPath)
if err != nil {
return nil, errors.Trace(err)
}

return &TLS{
caPath: caPath,
certPath: certPath,
keyPath: keyPath,
inner: inner,
client: httputil.NewClient(inner),
url: "https://" + host,
caPath: caPath,
certPath: certPath,
keyPath: keyPath,
caBytes: caBytes,
certBytes: certBytes,
keyBytes: keyBytes,
inner: inner,
client: httputil.NewClient(inner),
url: "https://" + host,
}, nil
}

Expand All @@ -129,11 +94,9 @@ func (tc *TLS) WithHost(host string) *TLS {
} else {
url = "http://" + host
}
return &TLS{
inner: tc.inner,
client: tc.client,
url: url,
}
shallowClone := *tc
shallowClone.url = url
return &shallowClone
}

// ToGRPCDialOption constructs a gRPC dial option.
Expand All @@ -156,14 +119,20 @@ func (tc *TLS) GetJSON(ctx context.Context, path string, v interface{}) error {
return GetJSON(ctx, tc.client, tc.url+path, v)
}

// ToPDSecurityOption converts the TLS configuration to a PD security option.
func (tc *TLS) ToPDSecurityOption() pd.SecurityOption {
return pd.SecurityOption{
CAPath: tc.caPath,
CertPath: tc.certPath,
KeyPath: tc.keyPath,
CAPath: tc.caPath,
CertPath: tc.certPath,
KeyPath: tc.keyPath,
SSLCABytes: tc.caBytes,
SSLCertBytes: tc.certBytes,
SSLKEYBytes: tc.keyBytes,
}
}

// ToTiKVSecurityConfig converts the TLS configuration to a TiKV security config.
// TODO: TiKV does not support pass in content.
func (tc *TLS) ToTiKVSecurityConfig() config.Security {
return config.Security{
ClusterSSLCA: tc.caPath,
Expand Down
28 changes: 10 additions & 18 deletions br/pkg/lightning/common/security_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestGetJSONInsecure(t *testing.T) {
u, err := url.Parse(mockServer.URL)
require.NoError(t, err)

tls, err := common.NewTLS("", "", "", u.Host)
tls, err := common.NewTLS("", "", "", u.Host, nil, nil, nil)
require.NoError(t, err)

var result struct{ Path string }
Expand Down Expand Up @@ -73,15 +73,8 @@ func TestGetJSONSecure(t *testing.T) {
func TestInvalidTLS(t *testing.T) {
tempDir := t.TempDir()
caPath := filepath.Join(tempDir, "ca.pem")
_, err := common.NewTLS(caPath, "", "", "localhost")
require.Regexp(t, "could not read ca certificate:.*", err.Error())

err = os.WriteFile(caPath, []byte("invalid ca content"), 0o644)
require.NoError(t, err)
_, err = common.NewTLS(caPath, "", "", "localhost")
require.Regexp(t, "failed to append ca certs", err.Error())

err = os.WriteFile(caPath, []byte(`-----BEGIN CERTIFICATE-----
caContent := []byte(`-----BEGIN CERTIFICATE-----
MIIBITCBxwIUf04/Hucshr7AynmgF8JeuFUEf9EwCgYIKoZIzj0EAwIwEzERMA8G
A1UEAwwIYnJfdGVzdHMwHhcNMjIwNDEzMDcyNDQxWhcNMjIwNDE1MDcyNDQxWjAT
MREwDwYDVQQDDAhicl90ZXN0czBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABL+X
Expand All @@ -90,20 +83,19 @@ wczUg0AbaFFaCI+FAk3K9vbB9JeIORgGKS+F1TKip5tvm96g7S5lq8SgY38SXVc3
ze4ZnCkwJdP2VdpI3WZsoI7zAiEAjP8X1c0iFwYxdAbQAveX+9msVrzyUpZOohi4
RtgQTNI=
-----END CERTIFICATE-----
`), 0o644)
`)
err := os.WriteFile(caPath, caContent, 0o644)
require.NoError(t, err)

certPath := filepath.Join(tempDir, "test.pem")
keyPath := filepath.Join(tempDir, "test.key")
tls, err := common.NewTLS(caPath, certPath, keyPath, "localhost")
_, err = tls.TLSConfig().GetCertificate(nil)
require.Regexp(t, "could not load client key pair: open.*", err.Error())

err = os.WriteFile(certPath, []byte("invalid cert content"), 0o644)
certContent := []byte("invalid cert content")
err = os.WriteFile(certPath, certContent, 0o644)
require.NoError(t, err)
err = os.WriteFile(keyPath, []byte("invalid key content"), 0o600)
keyContent := []byte("invalid key content")
err = os.WriteFile(keyPath, keyContent, 0o600)
require.NoError(t, err)
tls, err = common.NewTLS(caPath, certPath, keyPath, "localhost")
_, err = tls.TLSConfig().GetCertificate(nil)
require.Regexp(t, "could not load client key pair: tls.*", err.Error())
_, err = common.NewTLS(caPath, "", "", "localhost", caContent, certContent, keyContent)
require.ErrorContains(t, err, "tls: failed to find any PEM data in certificate input")
}
31 changes: 27 additions & 4 deletions br/pkg/lightning/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ import (
"github.com/BurntSushi/toml"
"github.com/docker/go-units"
gomysql "github.com/go-sql-driver/mysql"
"github.com/google/uuid"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/log"
tidbcfg "github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/util"
filter "github.com/pingcap/tidb/util/table-filter"
router "github.com/pingcap/tidb/util/table-router"
"go.uber.org/atomic"
Expand Down Expand Up @@ -155,7 +157,15 @@ func (cfg *Config) String() string {

func (cfg *Config) ToTLS() (*common.TLS, error) {
hostPort := net.JoinHostPort(cfg.TiDB.Host, strconv.Itoa(cfg.TiDB.StatusPort))
return common.NewTLS(cfg.Security.CAPath, cfg.Security.CertPath, cfg.Security.KeyPath, hostPort)
return common.NewTLS(
cfg.Security.CAPath,
cfg.Security.CertPath,
cfg.Security.KeyPath,
hostPort,
cfg.Security.CABytes,
cfg.Security.CertBytes,
cfg.Security.KeyBytes,
)
}

type Lightning struct {
Expand Down Expand Up @@ -559,6 +569,11 @@ type Security struct {
// TLSConfigName is used to set tls config for lightning in DM, so we don't expose this field to user
// DM may running many lightning instances at same time, so we need to set different tls config name for each lightning
TLSConfigName string `toml:"-" json:"-"`

// When DM/engine uses lightning as a library, it can directly pass in the content
CABytes []byte `toml:"-" json:"-"`
CertBytes []byte `toml:"-" json:"-"`
KeyBytes []byte `toml:"-" json:"-"`
}

// RegisterMySQL registers the TLS config with name "cluster" or security.TLSConfigName
Expand All @@ -567,7 +582,13 @@ func (sec *Security) RegisterMySQL() error {
if sec == nil {
return nil
}
tlsConfig, err := common.ToTLSConfig(sec.CAPath, sec.CertPath, sec.KeyPath)

tlsConfig, err := util.NewTLSConfig(
util.WithCAPath(sec.CAPath),
util.WithCertAndKeyPath(sec.CertPath, sec.KeyPath),
util.WithCAContent(sec.CABytes),
util.WithCertAndKeyContent(sec.CertBytes, sec.KeyBytes),
)
if err != nil {
return errors.Trace(err)
}
Expand Down Expand Up @@ -1151,9 +1172,11 @@ func (cfg *Config) CheckAndAdjustSecurity() error {

switch cfg.TiDB.TLS {
case "":
if len(cfg.TiDB.Security.CAPath) > 0 {
if len(cfg.TiDB.Security.CAPath) > 0 || len(cfg.TiDB.Security.CABytes) > 0 ||
len(cfg.TiDB.Security.CertPath) > 0 || len(cfg.TiDB.Security.CertBytes) > 0 ||
len(cfg.TiDB.Security.KeyPath) > 0 || len(cfg.TiDB.Security.KeyBytes) > 0 {
if cfg.TiDB.Security.TLSConfigName == "" {
cfg.TiDB.Security.TLSConfigName = "cluster" // adjust this the default value
cfg.TiDB.Security.TLSConfigName = uuid.NewString() // adjust this the default value
}
cfg.TiDB.TLS = cfg.TiDB.Security.TLSConfigName
} else {
Expand Down
13 changes: 9 additions & 4 deletions br/pkg/lightning/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ func TestAdjustWillBatchImportRatioInvalid(t *testing.T) {
}

func TestAdjustSecuritySection(t *testing.T) {
uuidHolder := "<uuid>"
testCases := []struct {
input string
expectedCA string
Expand All @@ -302,7 +303,7 @@ func TestAdjustSecuritySection(t *testing.T) {
ca-path = "/path/to/ca.pem"
`,
expectedCA: "/path/to/ca.pem",
expectedTLS: "cluster",
expectedTLS: uuidHolder,
},
{
input: `
Expand All @@ -321,7 +322,7 @@ func TestAdjustSecuritySection(t *testing.T) {
ca-path = "/path/to/ca2.pem"
`,
expectedCA: "/path/to/ca2.pem",
expectedTLS: "cluster",
expectedTLS: uuidHolder,
},
{
input: `
Expand All @@ -330,7 +331,7 @@ func TestAdjustSecuritySection(t *testing.T) {
ca-path = "/path/to/ca2.pem"
`,
expectedCA: "/path/to/ca2.pem",
expectedTLS: "cluster",
expectedTLS: uuidHolder,
},
{
input: `
Expand All @@ -356,7 +357,11 @@ func TestAdjustSecuritySection(t *testing.T) {
err = cfg.Adjust(context.Background())
require.NoError(t, err, comment)
require.Equal(t, tc.expectedCA, cfg.TiDB.Security.CAPath, comment)
require.Equal(t, tc.expectedTLS, cfg.TiDB.TLS, comment)
if tc.expectedTLS == uuidHolder {
require.NotEmpty(t, cfg.TiDB.TLS, comment)
} else {
require.Equal(t, tc.expectedTLS, cfg.TiDB.TLS, comment)
}
}
// test different tls config name
cfg := config.NewConfig()
Expand Down
10 changes: 9 additions & 1 deletion br/pkg/lightning/lightning.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ func New(globalCfg *config.GlobalConfig) *Lightning {
os.Exit(1)
}

tls, err := common.NewTLS(globalCfg.Security.CAPath, globalCfg.Security.CertPath, globalCfg.Security.KeyPath, globalCfg.App.StatusAddr)
tls, err := common.NewTLS(
globalCfg.Security.CAPath,
globalCfg.Security.CertPath,
globalCfg.Security.KeyPath,
globalCfg.App.StatusAddr,
globalCfg.Security.CABytes,
globalCfg.Security.CertBytes,
globalCfg.Security.KeyBytes,
)
if err != nil {
log.L().Fatal("failed to load TLS certificates", zap.Error(err))
}
Expand Down
Loading

0 comments on commit 796fb1f

Please sign in to comment.