Skip to content

Commit

Permalink
Merge pull request #4 from mohini-crl/custom-sasl
Browse files Browse the repository at this point in the history
changefeedccl: add support for nonstandard oauthbearer kafka authentication
mohini-crl authored Dec 16, 2024

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
2 parents b9d147b + 80f75fe commit 48ef9df
Showing 5 changed files with 205 additions and 8 deletions.
3 changes: 3 additions & 0 deletions pkg/ccl/changefeedccl/changefeed_test.go
Original file line number Diff line number Diff line change
@@ -5381,6 +5381,9 @@ func TestChangefeedErrors(t *testing.T) {
t, `sasl_client_id must be provided when SASL is enabled using mechanism OAUTHBEARER`,
`CREATE CHANGEFEED FOR foo INTO $1`, `kafka://nope/?sasl_enabled=true&sasl_mechanism=OAUTHBEARER`,
)

// TODO: add some test cases in here for option validation

sqlDB.ExpectErrWithTimeout(
t, `sasl_enabled must be enabled if sasl_user is provided`,
`CREATE CHANGEFEED FOR foo INTO $1`, `kafka://nope/?sasl_user=a`,
14 changes: 12 additions & 2 deletions pkg/ccl/changefeedccl/changefeedbase/options.go
Original file line number Diff line number Diff line change
@@ -221,6 +221,12 @@ const (
SinkParamSASLAwsIAMSessionName = `sasl_aws_iam_session_name`
SinkParamTableNameAttribute = `with_table_name_attribute`

// These are custom fields required for custom auth. They should not be
// documented.
SinkParamSASLCustomResource = `sasl_custom_resource`
SinkParamSASLCustomClientAssertionType = `sasl_custom_client_assertion_type`
SinkParamSASLCustomClientAssertion = `sasl_custom_client_assertion`

SinkSchemeConfluentKafka = `confluent-cloud`
SinkParamConfluentAPIKey = `api_key`
SinkParamConfluentAPISecret = `api_secret`
@@ -239,8 +245,12 @@ const (
Topics = `topics`
)

// Support additional mechanism on top of the default SASL mechanism.
const SASLTypeAWSMSKIAM = "AWS_MSK_IAM"
// Support additional mechanisms on top of the default SASL mechanisms.
const (
SASLTypeAWSMSKIAM = "AWS_MSK_IAM"
// TODO: better but generic name
SASLTypeCustom = "CUSTOM"
)

func makeStringSet(opts ...string) map[string]struct{} {
res := make(map[string]struct{}, len(opts))
47 changes: 41 additions & 6 deletions pkg/ccl/changefeedccl/sink_kafka.go
Original file line number Diff line number Diff line change
@@ -858,6 +858,14 @@ func newTokenProvider(
}, nil
}

func newCustomTokenProvider(
ctx context.Context, dialConfig kafkaDialConfig,
) (sarama.AccessTokenProvider, error) {
return &tokenProvider{
tokenSource: oauth2.ReuseTokenSource(nil, customTokenSource{dialConfig: dialConfig, ctx: ctx}),
}, nil
}

// Apply configures provided kafka configuration struct based on this config.
func (c *saramaConfig) Apply(kafka *sarama.Config) error {
// Sarama limits the size of each message to be MaxMessageSize (1MB) bytes.
@@ -936,6 +944,15 @@ type kafkaDialConfig struct {
saslAwsIAMRoleArn string
saslAwsRegion string
saslAwsIAMSessionName string

// These are specific to the custom SASL mechanism. It also uses saslTokenURL, and saslClientID.
saslCustomConfig saslCustomConfig
}

type saslCustomConfig struct {
resource string
clientAssertionType string
clientAssertion string
}

func buildDialConfig(u sinkURL) (kafkaDialConfig, error) {
@@ -999,7 +1016,7 @@ func buildDefaultKafkaConfig(u sinkURL) (kafkaDialConfig, error) {
dialConfig.saslMechanism = sarama.SASLTypePlaintext
}
switch dialConfig.saslMechanism {
case sarama.SASLTypeSCRAMSHA256, sarama.SASLTypeSCRAMSHA512, sarama.SASLTypeOAuth, sarama.SASLTypePlaintext, changefeedbase.SASLTypeAWSMSKIAM:
case sarama.SASLTypeSCRAMSHA256, sarama.SASLTypeSCRAMSHA512, sarama.SASLTypeOAuth, sarama.SASLTypePlaintext, changefeedbase.SASLTypeAWSMSKIAM, changefeedbase.SASLTypeCustom:
default:
return kafkaDialConfig{}, errors.Errorf(`param %s must be one of %s, %s, %s, %s or %s`,
changefeedbase.SinkParamSASLMechanism,
@@ -1014,6 +1031,14 @@ func buildDefaultKafkaConfig(u sinkURL) (kafkaDialConfig, error) {
case changefeedbase.SASLTypeAWSMSKIAM:
requiredSASLParams = []string{changefeedbase.SinkParamSASLAwsRegion, changefeedbase.SinkParamSASLAwsIAMRoleArn,
changefeedbase.SinkParamSASLAwsIAMSessionName}
case changefeedbase.SASLTypeCustom:
requiredSASLParams = []string{
changefeedbase.SinkParamSASLClientID,
changefeedbase.SinkParamSASLTokenURL,
changefeedbase.SinkParamSASLCustomResource,
changefeedbase.SinkParamSASLCustomClientAssertion,
changefeedbase.SinkParamSASLCustomClientAssertionType,
}
default:
requiredSASLParams = []string{changefeedbase.SinkParamSASLUser, changefeedbase.SinkParamSASLPassword}
}
@@ -1034,7 +1059,8 @@ func buildDefaultKafkaConfig(u sinkURL) (kafkaDialConfig, error) {
}
}

if dialConfig.saslMechanism != sarama.SASLTypeOAuth {
// TODO: this can be done a little more cleanly.
if dialConfig.saslMechanism != sarama.SASLTypeOAuth && dialConfig.saslMechanism != changefeedbase.SASLTypeCustom {
oauthParams := []string{changefeedbase.SinkParamSASLClientID, changefeedbase.SinkParamSASLClientSecret,
changefeedbase.SinkParamSASLTokenURL, changefeedbase.SinkParamSASLGrantType,
changefeedbase.SinkParamSASLScopes}
@@ -1058,6 +1084,10 @@ func buildDefaultKafkaConfig(u sinkURL) (kafkaDialConfig, error) {
dialConfig.saslAwsIAMSessionName = u.consumeParam(changefeedbase.SinkParamSASLAwsIAMSessionName)
dialConfig.saslAwsIAMRoleArn = u.consumeParam(changefeedbase.SinkParamSASLAwsIAMRoleArn)

dialConfig.saslCustomConfig.resource = u.consumeParam(changefeedbase.SinkParamSASLCustomResource)
dialConfig.saslCustomConfig.clientAssertion = u.consumeParam(changefeedbase.SinkParamSASLCustomClientAssertion)
dialConfig.saslCustomConfig.clientAssertionType = u.consumeParam(changefeedbase.SinkParamSASLCustomClientAssertionType)

var decodedClientSecret []byte
if err := u.decodeBase64(changefeedbase.SinkParamSASLClientSecret, &decodedClientSecret); err != nil {
return kafkaDialConfig{}, err
@@ -1289,10 +1319,12 @@ func buildKafkaConfig(
config.Net.SASL.User = dialConfig.saslUser
config.Net.SASL.Password = dialConfig.saslPassword

// Substitute our fake SASL mechanism with the backing implementation (OAuth).
var mechanism sarama.SASLMechanism
if dialConfig.saslMechanism == changefeedbase.SASLTypeAWSMSKIAM {
switch dialConfig.saslMechanism {
case changefeedbase.SASLTypeAWSMSKIAM, changefeedbase.SASLTypeCustom:
mechanism = sarama.SASLTypeOAuth
} else {
default:
mechanism = sarama.SASLMechanism(dialConfig.saslMechanism)
}

@@ -1305,9 +1337,12 @@ func buildKafkaConfig(
config.Net.SASL.SCRAMClientGeneratorFunc = sha256ClientGenerator
case sarama.SASLTypeOAuth:
var err error
if dialConfig.saslMechanism == changefeedbase.SASLTypeAWSMSKIAM {
switch dialConfig.saslMechanism {
case changefeedbase.SASLTypeAWSMSKIAM:
config.Net.SASL.TokenProvider, err = newAwsIAMRoleSASLTokenProvider(ctx, dialConfig)
} else {
case changefeedbase.SASLTypeCustom:
config.Net.SASL.TokenProvider, err = newCustomTokenProvider(ctx, dialConfig)
default:
config.Net.SASL.TokenProvider, err = newTokenProvider(ctx, dialConfig)
}
if err != nil {
97 changes: 97 additions & 0 deletions pkg/ccl/changefeedccl/sink_kafka_v2.go
Original file line number Diff line number Diff line change
@@ -9,9 +9,12 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
@@ -39,6 +42,7 @@ import (
sasloauth "github.com/twmb/franz-go/pkg/sasl/oauth"
saslplain "github.com/twmb/franz-go/pkg/sasl/plain"
saslscram "github.com/twmb/franz-go/pkg/sasl/scram"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)

@@ -452,13 +456,27 @@ func buildKgoConfig(
return nil, err
}
s = sasloauth.Oauth(tp)
// This is another "fake" mechanism we use to signify that we're using
// the custom identity provider, which is not OAUTH-compliant and requires special handling.
case changefeedbase.SASLTypeCustom:
tp, err := newCustomOauthTokenProvider(ctx, dialConfig)
if err != nil {
return nil, err
}
s = sasloauth.Oauth(tp)
// DBG
tokres, tokerr := tp(ctx)
fmt.Printf("DBG: tokres: %+#v, tokerr: %v\n", tokres, tokerr)
// TODO(#126991): Remove this sarama dependency.
case sarama.SASLTypeOAuth:
tp, err := newKgoOauthTokenProvider(ctx, dialConfig)
if err != nil {
return nil, err
}
s = sasloauth.Oauth(tp)
// DBG
tokres, tokerr := tp(ctx)
fmt.Printf("DBG: tokres: %+#v, tokerr: %v\n", tokres, tokerr)
case sarama.SASLTypePlaintext, "":
s = saslplain.Plain(func(ctc context.Context) (saslplain.Auth, error) {
return saslplain.Auth{
@@ -704,6 +722,85 @@ func newKgoAWSIAMRoleOauthTokenProvider(
}, nil
}

type customTokenSource struct {
dialConfig kafkaDialConfig
// The oauth2.TokenSource API seems to require us to keep a context in here.
ctx context.Context
client *http.Client
}

// Token implements the oauth2.TokenSource interface.
func (iats customTokenSource) Token() (*oauth2.Token, error) {
tokenURL, err := url.Parse(iats.dialConfig.saslTokenURL)
if err != nil {
return nil, errors.Wrap(err, "malformed token url")
}

bodyParams := url.Values{
"grant_type": {"client_credentials"},
"client_id": {iats.dialConfig.saslClientID},
"client_assertion_type": {iats.dialConfig.saslCustomConfig.clientAssertionType},
"client_assertion": {iats.dialConfig.saslCustomConfig.clientAssertion},
"resource": {iats.dialConfig.saslCustomConfig.resource},
}

req, err := http.NewRequestWithContext(iats.ctx, "POST", tokenURL.String(), strings.NewReader(bodyParams.Encode()))
if err != nil {
return nil, errors.Wrap(err, "failed to create oauth token request")
}
req.Header.Set("Content-Type", "application/www-url-encoded")
res, err := iats.client.Do(req)
if err != nil {
return nil, errors.Wrapf(err, "failed to make oauth token request")
}
body, err := io.ReadAll(io.LimitReader(res.Body, 1<<20))
if err != nil {
return nil, errors.Wrap(err, "failed to read oauth response body")
}
if err := res.Body.Close(); err != nil {
return nil, errors.Wrap(err, "failed to close oauth response body")
}

// The endpoint returns JSON; this format was given to us by the customer.
var resp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err := json.Unmarshal(body, &resp); err != nil {
return nil, errors.Wrapf(err, "failed to parse oauth response")
}
if resp.AccessToken == "" {
return nil, errors.Errorf("no access token in oauth response")
}

tok := &oauth2.Token{AccessToken: resp.AccessToken, TokenType: resp.TokenType}

if resp.ExpiresIn > 0 {
tok.Expiry = time.Now().Add(time.Duration(resp.ExpiresIn) * time.Second)
}

return tok, nil
}

var _ oauth2.TokenSource = customTokenSource{}

func newCustomOauthTokenProvider(
ctx context.Context, dialConfig kafkaDialConfig,
) (func(ctx context.Context) (sasloauth.Auth, error), error) {
// TODO: not sure if the reuse stuff is necessary. the normal oauth token source wraps itself with that so like maybe?
// im not even sure if the custom stuff returns expiry info either...
// -> probably not necessary; see https://docs.confluent.io/legacy/platform/5.2.4/kafka/authentication_sasl/authentication_sasl_oauth.html#token-refresh-for-sasl-oauthbearer
ts := oauth2.ReuseTokenSource(nil, customTokenSource{dialConfig: dialConfig, ctx: ctx, client: &http.Client{}})
return func(ctx context.Context) (sasloauth.Auth, error) {
tok, err := ts.Token()
if err != nil {
return sasloauth.Auth{}, err
}
return sasloauth.Auth{Token: tok.AccessToken}, nil
}, nil
}

type kgoMetricsAdapter struct {
throttling metrics.Histogram
}
52 changes: 52 additions & 0 deletions pkg/cmd/roachtest/tests/cdc.go
Original file line number Diff line number Diff line change
@@ -2100,6 +2100,57 @@ func registerCDC(r registry.Registry) {
feed.waitForCompletion()
},
})

r.Add(registry.TestSpec{
Name: "cdc/kafka-custom-auth",
Owner: `cdc`,
// Only Kafka 3 supports Arm64, but the broker setup for Oauth used only works with Kafka 2 (?)
Cluster: r.MakeClusterSpec(4, spec.WorkloadNode(), spec.Arch(vm.ArchAMD64)),
Leases: registry.MetamorphicLeases,
CompatibleClouds: registry.AllExceptAWS,
Suites: registry.Suites(registry.Nightly),
RequiresLicense: true,
Run: func(ctx context.Context, t test.Test, c cluster.Cluster) {
if c.Cloud() == spec.Local && runtime.GOARCH == "arm64" {
// N.B. We have to skip locally since amd64 emulation may not be available everywhere.
t.L().PrintfCtx(ctx, "Skipping test under ARM64")
return
}

ct := newCDCTester(ctx, t, c)
defer ct.Close()

// Run tpcc workload for tiny bit. Roachtest monitor does not
// like when there are no tasks that were started with the monitor
// (This can be removed once #108530 resolved).
ct.runTPCCWorkload(tpccArgs{warehouses: 1, duration: "30s"})

kafkaNode := ct.sinkNodes
kafka := kafkaManager{
t: ct.t,
c: ct.cluster,
kafkaSinkNodes: kafkaNode,
mon: ct.mon,
useKafka2: true, // The broker-side oauth configuration used only works with Kafka 2
}
kafka.install(ct.ctx)

// TODO: how?
creds, kafkaEnv := kafka.configureOauth(ct.ctx)

kafka.start(ctx, "kafka", kafkaEnv)

feed := ct.newChangefeed(feedArgs{
sinkType: kafkaSink,
sinkURIOverride: kafka.sinkURLOAuth(ct.ctx, creds),
targets: allTpccTargets,
opts: map[string]string{"initial_scan": "'only'"},
})

feed.waitForCompletion()
},
})

r.Add(registry.TestSpec{
Name: "cdc/kafka-topics",
Owner: `cdc`,
@@ -2932,6 +2983,7 @@ func (k kafkaManager) configureOauth(ctx context.Context) (clientcredentials.Con
// CLASSPATH allows Kafka to load in the custom implementation
kafkaEnv := "CLASSPATH='/home/ubuntu/kafka-oauth/target/*'"

// TODO: and here?
// Hydra is used as an open source OAuth server
clientID, clientSecret := k.configureHydraOauth(ctx)
tokenURL := fmt.Sprintf("http://%s:4444/oauth2/token", nodeIP)

0 comments on commit 48ef9df

Please sign in to comment.