From 4aa436bc53fa8868a3c278ec6cbe244ac1e9f7f7 Mon Sep 17 00:00:00 2001 From: Albert <26584478+albertteoh@users.noreply.github.com> Date: Wed, 3 Nov 2021 03:28:35 +1100 Subject: [PATCH] Add auth token propagation for metrics reader (#3341) * Add auth token propagation for metrics reader Signed-off-by: albertteoh * Consolidate token propagation code Signed-off-by: albertteoh * Rename to use PropagationHandler for consistency Signed-off-by: albertteoh * Split and move common transports Signed-off-by: albertteoh * Give bearerToken value less meaning Signed-off-by: albertteoh * Fix import grouping Signed-off-by: albertteoh * Give token value less meaning Signed-off-by: albertteoh * make fmt Signed-off-by: albertteoh * Remove functional options Signed-off-by: albertteoh * Fix build Signed-off-by: albertteoh * Address review comments Signed-off-by: albertteoh * Fix staticcheck err Signed-off-by: albertteoh --- cmd/query/app/server.go | 3 +- cmd/query/main.go | 4 +- .../bearertoken/context.go | 19 ++- .../bearertoken/context_test.go | 6 +- .../bearertoken/http.go | 14 +-- .../bearertoken/http_test.go | 28 +++-- pkg/bearertoken/transport.go | 52 ++++++++ pkg/bearertoken/transport_test.go | 111 ++++++++++++++++++ pkg/es/config/config.go | 29 +---- .../metrics/prometheus/metricsstore/reader.go | 7 +- .../prometheus/metricsstore/reader_test.go | 28 ++++- plugin/storage/es/options.go | 4 +- plugin/storage/grpc/README.md | 5 +- plugin/storage/grpc/shared/grpc_client.go | 8 +- .../storage/grpc/shared/grpc_client_test.go | 5 +- 15 files changed, 247 insertions(+), 76 deletions(-) rename storage/spanstore/token_propagation.go => pkg/bearertoken/context.go (70%) rename storage/spanstore/token_propagation_test.go => pkg/bearertoken/context_test.go (83%) rename cmd/query/app/token_propagation_handler.go => pkg/bearertoken/http.go (76%) rename cmd/query/app/token_propagation_hander_test.go => pkg/bearertoken/http_test.go (82%) create mode 100644 pkg/bearertoken/transport.go create mode 100644 pkg/bearertoken/transport_test.go diff --git a/cmd/query/app/server.go b/cmd/query/app/server.go index 3c410294fba..d1f1b676ebe 100644 --- a/cmd/query/app/server.go +++ b/cmd/query/app/server.go @@ -32,6 +32,7 @@ import ( "github.com/jaegertracing/jaeger/cmd/query/app/apiv3" "github.com/jaegertracing/jaeger/cmd/query/app/querysvc" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/pkg/healthcheck" "github.com/jaegertracing/jaeger/pkg/netutils" "github.com/jaegertracing/jaeger/pkg/recoveryhandler" @@ -158,7 +159,7 @@ func createHTTPServer(querySvc *querysvc.QueryService, metricsQuerySvc querysvc. var handler http.Handler = r handler = additionalHeadersHandler(handler, queryOpts.AdditionalHeaders) if queryOpts.BearerTokenPropagation { - handler = bearerTokenPropagationHandler(logger, handler) + handler = bearertoken.PropagationHandler(logger, handler) } handler = handlers.CompressHandler(handler) recoveryHandler := recoveryhandler.NewRecoveryHandler(logger, true) diff --git a/cmd/query/main.go b/cmd/query/main.go index 4659ad84bab..5ec3e708e40 100644 --- a/cmd/query/main.go +++ b/cmd/query/main.go @@ -35,12 +35,12 @@ import ( "github.com/jaegertracing/jaeger/cmd/query/app" "github.com/jaegertracing/jaeger/cmd/query/app/querysvc" "github.com/jaegertracing/jaeger/cmd/status" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/pkg/config" "github.com/jaegertracing/jaeger/pkg/version" metricsPlugin "github.com/jaegertracing/jaeger/plugin/metrics" "github.com/jaegertracing/jaeger/plugin/storage" "github.com/jaegertracing/jaeger/ports" - "github.com/jaegertracing/jaeger/storage/spanstore" storageMetrics "github.com/jaegertracing/jaeger/storage/spanstore/metrics" ) @@ -95,7 +95,7 @@ func main() { opentracing.SetGlobalTracer(tracer) queryOpts := new(app.QueryOptions).InitFromViper(v, logger) // TODO: Need to figure out set enable/disable propagation on storage plugins. - v.Set(spanstore.StoragePropagationKey, queryOpts.BearerTokenPropagation) + v.Set(bearertoken.StoragePropagationKey, queryOpts.BearerTokenPropagation) storageFactory.InitFromViper(v, logger) if err := storageFactory.Initialize(baseFactory, logger); err != nil { logger.Fatal("Failed to init storage factory", zap.Error(err)) diff --git a/storage/spanstore/token_propagation.go b/pkg/bearertoken/context.go similarity index 70% rename from storage/spanstore/token_propagation.go rename to pkg/bearertoken/context.go index ba023822b2e..4f9221fe631 100644 --- a/storage/spanstore/token_propagation.go +++ b/pkg/bearertoken/context.go @@ -1,4 +1,4 @@ -// Copyright (c) 2019 The Jaeger Authors. +// Copyright (c) 2021 The Jaeger Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,30 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package spanstore +package bearertoken import "context" -type contextKey string +type contextKeyType int -// BearerTokenKey is the string literal used internally in the implementation of this context. -const BearerTokenKey = "bearer.token" -const bearerToken = contextKey(BearerTokenKey) +const contextKey = contextKeyType(iota) // StoragePropagationKey is a key for viper configuration to pass this option to storage plugins. const StoragePropagationKey = "storage.propagate.token" -// ContextWithBearerToken set bearer token in context +// ContextWithBearerToken set bearer token in context. func ContextWithBearerToken(ctx context.Context, token string) context.Context { if token == "" { return ctx } - return context.WithValue(ctx, bearerToken, token) - + return context.WithValue(ctx, contextKey, token) } -// GetBearerToken from context, or empty string if there is no token +// GetBearerToken from context, or empty string if there is no token. func GetBearerToken(ctx context.Context) (string, bool) { - val, ok := ctx.Value(bearerToken).(string) + val, ok := ctx.Value(contextKey).(string) return val, ok } diff --git a/storage/spanstore/token_propagation_test.go b/pkg/bearertoken/context_test.go similarity index 83% rename from storage/spanstore/token_propagation_test.go rename to pkg/bearertoken/context_test.go index a5fe03b5381..7a3f5184f20 100644 --- a/storage/spanstore/token_propagation_test.go +++ b/pkg/bearertoken/context_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2019 The Jaeger Authors. +// Copyright (c) 2021 The Jaeger Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package spanstore +package bearertoken import ( "context" @@ -22,7 +22,7 @@ import ( ) func Test_GetBearerToken(t *testing.T) { - token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI" + const token = "blah" ctx := context.Background() ctx = ContextWithBearerToken(ctx, token) contextToken, ok := GetBearerToken(ctx) diff --git a/cmd/query/app/token_propagation_handler.go b/pkg/bearertoken/http.go similarity index 76% rename from cmd/query/app/token_propagation_handler.go rename to pkg/bearertoken/http.go index b44a5087e8e..76cb0014bec 100644 --- a/cmd/query/app/token_propagation_handler.go +++ b/pkg/bearertoken/http.go @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -package app +package bearertoken import ( "net/http" "strings" "go.uber.org/zap" - - "github.com/jaegertracing/jaeger/storage/spanstore" ) -func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Handler { +// PropagationHandler returns a http.Handler containing the logic to extract +// the Bearer token from the Authorization header of the http.Request and insert it into request.Context +// for propagation. The token can be accessed via GetBearerToken. +func PropagationHandler(logger *zap.Logger, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() authHeaderValue := r.Header.Get("Authorization") @@ -40,15 +41,14 @@ func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Hand token = headerValue[1] } } else if len(headerValue) == 1 { - // Tread all value as a token + // Treat the entire value as a token. token = authHeaderValue } else { logger.Warn("Invalid authorization header value, skipping token propagation") } - h.ServeHTTP(w, r.WithContext(spanstore.ContextWithBearerToken(ctx, token))) + h.ServeHTTP(w, r.WithContext(ContextWithBearerToken(ctx, token))) } else { h.ServeHTTP(w, r.WithContext(ctx)) } }) - } diff --git a/cmd/query/app/token_propagation_hander_test.go b/pkg/bearertoken/http_test.go similarity index 82% rename from cmd/query/app/token_propagation_hander_test.go rename to pkg/bearertoken/http_test.go index 9d238da4924..88d2d4618b7 100644 --- a/cmd/query/app/token_propagation_hander_test.go +++ b/pkg/bearertoken/http_test.go @@ -12,41 +12,44 @@ // See the License for the specific language governing permissions and // limitations under the License. -package app +package bearertoken import ( "net/http" "net/http/httptest" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "go.uber.org/zap" - - "github.com/jaegertracing/jaeger/storage/spanstore" ) -func Test_bearTokenPropagationHandler(t *testing.T) { +func Test_PropagationHandler(t *testing.T) { + httpClient := &http.Client{ + Timeout: 2 * time.Second, + } + logger := zap.NewNop() - bearerToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI" + const bearerToken = "blah" validTokenHandler := func(stop *sync.WaitGroup) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - token, ok := spanstore.GetBearerToken(ctx) + token, ok := GetBearerToken(ctx) assert.Equal(t, token, bearerToken) assert.True(t, ok) stop.Done() - }) + } } emptyHandler := func(stop *sync.WaitGroup) http.HandlerFunc { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - token, _ := spanstore.GetBearerToken(ctx) + token, _ := GetBearerToken(ctx) assert.Empty(t, token, bearerToken) stop.Done() - }) + } } testCases := []struct { @@ -68,7 +71,7 @@ func Test_bearTokenPropagationHandler(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { stop := sync.WaitGroup{} stop.Add(1) - r := bearerTokenPropagationHandler(logger, testCase.handler(&stop)) + r := PropagationHandler(logger, testCase.handler(&stop)) server := httptest.NewServer(r) defer server.Close() req, err := http.NewRequest("GET", server.URL, nil) @@ -81,5 +84,4 @@ func Test_bearTokenPropagationHandler(t *testing.T) { stop.Wait() }) } - } diff --git a/pkg/bearertoken/transport.go b/pkg/bearertoken/transport.go new file mode 100644 index 00000000000..7e13e6300b5 --- /dev/null +++ b/pkg/bearertoken/transport.go @@ -0,0 +1,52 @@ +// Copyright (c) 2021 The Jaeger Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bearertoken + +import ( + "errors" + "net/http" +) + +// RoundTripper wraps another http.RoundTripper and injects +// an authentication header with bearer token into requests. +type RoundTripper struct { + // Transport is the underlying http.RoundTripper being wrapped. Required. + Transport http.RoundTripper + + // StaticToken is the pre-configured bearer token. Optional. + StaticToken string + + // OverrideFromCtx enables reading bearer token from Context. + OverrideFromCtx bool +} + +// RoundTrip injects the outbound Authorization header with the +// token provided in the inbound request. +func (tr RoundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + if tr.Transport == nil { + return nil, errors.New("no http.RoundTripper provided") + } + token := tr.StaticToken + if tr.OverrideFromCtx { + headerToken, _ := GetBearerToken(r.Context()) + if headerToken != "" { + token = headerToken + } + } + if token != "" { + r.Header.Set("Authorization", "Bearer "+token) + } + return tr.Transport.RoundTrip(r) +} diff --git a/pkg/bearertoken/transport_test.go b/pkg/bearertoken/transport_test.go new file mode 100644 index 00000000000..05992c5b97b --- /dev/null +++ b/pkg/bearertoken/transport_test.go @@ -0,0 +1,111 @@ +// Copyright (c) 2021 The Jaeger Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bearertoken + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(r *http.Request) (*http.Response, error) + +func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return s(r) +} + +func TestRoundTripper(t *testing.T) { + for _, tc := range []struct { + name string + staticToken string + overrideFromCtx bool + wrappedTransport http.RoundTripper + requestContext context.Context + wantError bool + }{ + { + name: "Default RoundTripper and request context set should have empty Bearer token", + wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Empty(t, r.Header.Get("Authorization")) + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }), + requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"), + }, + { + name: "Override from context provided, and request context set should use request context token", + overrideFromCtx: true, + wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "Bearer tokenFromContext", r.Header.Get("Authorization")) + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }), + requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"), + }, + { + name: "Allow override from context and token provided, and request context unset should use defaultToken", + overrideFromCtx: true, + staticToken: "initToken", + wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "Bearer initToken", r.Header.Get("Authorization")) + return &http.Response{}, nil + }), + requestContext: context.Background(), + }, + { + name: "Allow override from context and token provided, and request context set should use context token", + overrideFromCtx: true, + staticToken: "initToken", + wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + assert.Equal(t, "Bearer tokenFromContext", r.Header.Get("Authorization")) + return &http.Response{}, nil + }), + requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"), + }, + { + name: "Nil roundTripper provided should return an error", + requestContext: context.Background(), + wantError: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(nil) + defer server.Close() + req, err := http.NewRequestWithContext(tc.requestContext, "GET", server.URL, nil) + require.NoError(t, err) + + tr := RoundTripper{ + Transport: tc.wrappedTransport, + OverrideFromCtx: tc.overrideFromCtx, + StaticToken: tc.staticToken, + } + resp, err := tr.RoundTrip(req) + + if tc.wantError { + assert.Nil(t, resp) + assert.Error(t, err) + } else { + assert.NotNil(t, resp) + assert.NoError(t, err) + } + }) + } +} diff --git a/pkg/es/config/config.go b/pkg/es/config/config.go index f9c810e6c6e..7fe513aa776 100644 --- a/pkg/es/config/config.go +++ b/pkg/es/config/config.go @@ -34,10 +34,10 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapgrpc" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/pkg/config/tlscfg" "github.com/jaegertracing/jaeger/pkg/es" eswrapper "github.com/jaegertracing/jaeger/pkg/es/wrapper" - "github.com/jaegertracing/jaeger/storage/spanstore" storageMetrics "github.com/jaegertracing/jaeger/storage/spanstore/metrics" ) @@ -519,34 +519,15 @@ func GetHTTPRoundTripper(c *Configuration, logger *zap.Logger) (http.RoundTrippe token = tokenFromFile } if token != "" || c.AllowTokenFromContext { - transport = &tokenAuthTransport{ - token: token, - allowOverrideFromCtx: c.AllowTokenFromContext, - wrapped: httpTransport, + transport = bearertoken.RoundTripper{ + Transport: httpTransport, + OverrideFromCtx: c.AllowTokenFromContext, + StaticToken: token, } } return transport, nil } -// TokenAuthTransport -type tokenAuthTransport struct { - token string - allowOverrideFromCtx bool - wrapped *http.Transport -} - -func (tr *tokenAuthTransport) RoundTrip(r *http.Request) (*http.Response, error) { - token := tr.token - if tr.allowOverrideFromCtx { - headerToken, _ := spanstore.GetBearerToken(r.Context()) - if headerToken != "" { - token = headerToken - } - } - r.Header.Set("Authorization", "Bearer "+token) - return tr.wrapped.RoundTrip(r) -} - func loadToken(path string) (string, error) { b, err := os.ReadFile(filepath.Clean(path)) if err != nil { diff --git a/plugin/metrics/prometheus/metricsstore/reader.go b/plugin/metrics/prometheus/metricsstore/reader.go index 257d1d1616a..a5d05e131d3 100644 --- a/plugin/metrics/prometheus/metricsstore/reader.go +++ b/plugin/metrics/prometheus/metricsstore/reader.go @@ -31,6 +31,7 @@ import ( promapi "github.com/prometheus/client_golang/api/prometheus/v1" "go.uber.org/zap" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/pkg/prometheus/config" "github.com/jaegertracing/jaeger/plugin/metrics/prometheus/metricsstore/dbmodel" "github.com/jaegertracing/jaeger/proto-gen/api_v2/metrics" @@ -253,7 +254,7 @@ func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (rt http.R // KeepAlive and TLSHandshake timeouts are kept to existing Prometheus client's // DefaultRoundTripper to simplify user configuration and may be made configurable when required. - return &http.Transport{ + httpTransport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: c.ConnectTimeout, @@ -261,5 +262,9 @@ func getHTTPRoundTripper(c *config.Configuration, logger *zap.Logger) (rt http.R }).DialContext, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: ctlsConfig, + } + return bearertoken.RoundTripper{ + Transport: httpTransport, + OverrideFromCtx: true, }, nil } diff --git a/plugin/metrics/prometheus/metricsstore/reader_test.go b/plugin/metrics/prometheus/metricsstore/reader_test.go index 4857c938cf5..3c25bb52516 100644 --- a/plugin/metrics/prometheus/metricsstore/reader_test.go +++ b/plugin/metrics/prometheus/metricsstore/reader_test.go @@ -31,6 +31,7 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/pkg/config/tlscfg" "github.com/jaegertracing/jaeger/pkg/prometheus/config" "github.com/jaegertracing/jaeger/proto-gen/api_v2/metrics" @@ -331,12 +332,27 @@ func TestGetRoundTripper(t *testing.T) { }, }, logger) require.NoError(t, err) - assert.IsType(t, &http.Transport{}, rt) - if tc.tlsEnabled { - assert.NotNil(t, rt.(*http.Transport).TLSClientConfig) - } else { - assert.Nil(t, rt.(*http.Transport).TLSClientConfig) - } + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer foo", r.Header.Get("Authorization")) + }, + ), + ) + defer server.Close() + + req, err := http.NewRequestWithContext( + bearertoken.ContextWithBearerToken(context.Background(), "foo"), + http.MethodGet, + server.URL, + nil, + ) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) }) } } diff --git a/plugin/storage/es/options.go b/plugin/storage/es/options.go index 93e3a78d418..58f2aff0149 100644 --- a/plugin/storage/es/options.go +++ b/plugin/storage/es/options.go @@ -22,9 +22,9 @@ import ( "github.com/spf13/viper" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/pkg/config/tlscfg" "github.com/jaegertracing/jaeger/pkg/es/config" - "github.com/jaegertracing/jaeger/storage/spanstore" ) const ( @@ -328,7 +328,7 @@ func initFromViper(cfg *namespaceConfig, v *viper.Viper) { cfg.UseILM = v.GetBool(cfg.namespace + suffixUseILM) // TODO: Need to figure out a better way for do this. - cfg.AllowTokenFromContext = v.GetBool(spanstore.StoragePropagationKey) + cfg.AllowTokenFromContext = v.GetBool(bearertoken.StoragePropagationKey) cfg.TLS = cfg.getTLSFlagsConfig().InitFromViper(v) remoteReadClusters := stripWhiteSpace(v.GetString(cfg.namespace + suffixRemoteReadClusters)) diff --git a/plugin/storage/grpc/README.md b/plugin/storage/grpc/README.md index fd64d06c79a..0efca2d4dd0 100644 --- a/plugin/storage/grpc/README.md +++ b/plugin/storage/grpc/README.md @@ -180,15 +180,16 @@ When using `--query.bearer-token-propagation=true`, the bearer token will be pro import ( // ... other imports "fmt" - "github.com/jaegertracing/jaeger/storage/spanstore" "google.golang.org/grpc/metadata" + + "github.com/jaegertracing/jaeger/plugin/storage/grpc" ) // ... spanReader type declared here func (r *spanReader) extractBearerToken(ctx context.Context) (string, bool) { if md, ok := metadata.FromIncomingContext(ctx); ok { - values := md.Get(spanstore.BearerTokenKey) + values := md.Get(grpc.BearerTokenKey) if len(values) > 0 { return values[0], true } diff --git a/plugin/storage/grpc/shared/grpc_client.go b/plugin/storage/grpc/shared/grpc_client.go index 4d27c0eb269..33d166de717 100644 --- a/plugin/storage/grpc/shared/grpc_client.go +++ b/plugin/storage/grpc/shared/grpc_client.go @@ -25,11 +25,15 @@ import ( "google.golang.org/grpc/status" "github.com/jaegertracing/jaeger/model" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/proto-gen/storage_v1" "github.com/jaegertracing/jaeger/storage/dependencystore" "github.com/jaegertracing/jaeger/storage/spanstore" ) +// BearerTokenKey is the key name for the bearer token context value. +const BearerTokenKey = "bearer.token" + var ( _ StoragePlugin = (*grpcClient)(nil) _ ArchiveStoragePlugin = (*grpcClient)(nil) @@ -67,13 +71,13 @@ func composeContextUpgradeFuncs(funcs ...ContextUpgradeFunc) ContextUpgradeFunc // in the request metadata, if the original context has bearer token attached. // Otherwise returns original context. func upgradeContextWithBearerToken(ctx context.Context) context.Context { - bearerToken, hasToken := spanstore.GetBearerToken(ctx) + bearerToken, hasToken := bearertoken.GetBearerToken(ctx) if hasToken { md, ok := metadata.FromOutgoingContext(ctx) if !ok { md = metadata.New(nil) } - md.Set(spanstore.BearerTokenKey, bearerToken) + md.Set(BearerTokenKey, bearerToken) return metadata.NewOutgoingContext(ctx, md) } return ctx diff --git a/plugin/storage/grpc/shared/grpc_client_test.go b/plugin/storage/grpc/shared/grpc_client_test.go index b77536542ed..51e1252a5d9 100644 --- a/plugin/storage/grpc/shared/grpc_client_test.go +++ b/plugin/storage/grpc/shared/grpc_client_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/status" "github.com/jaegertracing/jaeger/model" + "github.com/jaegertracing/jaeger/pkg/bearertoken" "github.com/jaegertracing/jaeger/proto-gen/storage_v1" grpcMocks "github.com/jaegertracing/jaeger/proto-gen/storage_v1/mocks" "github.com/jaegertracing/jaeger/storage/spanstore" @@ -108,11 +109,11 @@ func withGRPCClient(fn func(r *grpcClientTest)) { func TestContextUpgradeWithToken(t *testing.T) { testBearerToken := "test-bearer-token" - ctx := spanstore.ContextWithBearerToken(context.Background(), testBearerToken) + ctx := bearertoken.ContextWithBearerToken(context.Background(), testBearerToken) upgradedToken := upgradeContextWithBearerToken(ctx) md, ok := metadata.FromOutgoingContext(upgradedToken) assert.Truef(t, ok, "Expected metadata in context") - bearerTokenFromMetadata := md.Get(spanstore.BearerTokenKey) + bearerTokenFromMetadata := md.Get(BearerTokenKey) assert.Equal(t, []string{testBearerToken}, bearerTokenFromMetadata) }