Skip to content

Commit

Permalink
Add auth token propagation for metrics reader (#3341)
Browse files Browse the repository at this point in the history
* Add auth token propagation for metrics reader

Signed-off-by: albertteoh <[email protected]>

* Consolidate token propagation code

Signed-off-by: albertteoh <[email protected]>

* Rename to use PropagationHandler for consistency

Signed-off-by: albertteoh <[email protected]>

* Split and move common transports

Signed-off-by: albertteoh <[email protected]>

* Give bearerToken value less meaning

Signed-off-by: albertteoh <[email protected]>

* Fix import grouping

Signed-off-by: albertteoh <[email protected]>

* Give token value less meaning

Signed-off-by: albertteoh <[email protected]>

* make fmt

Signed-off-by: albertteoh <[email protected]>

* Remove functional options

Signed-off-by: albertteoh <[email protected]>

* Fix build

Signed-off-by: albertteoh <[email protected]>

* Address review comments

Signed-off-by: albertteoh <[email protected]>

* Fix staticcheck err

Signed-off-by: albertteoh <[email protected]>
  • Loading branch information
albertteoh authored Nov 2, 2021
1 parent c4a3904 commit 4aa436b
Show file tree
Hide file tree
Showing 15 changed files with 247 additions and 76 deletions.
3 changes: 2 additions & 1 deletion cmd/query/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/jaegertracing/jaeger/cmd/query/app/apiv3"
"github.com/jaegertracing/jaeger/cmd/query/app/querysvc"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/healthcheck"
"github.com/jaegertracing/jaeger/pkg/netutils"
"github.com/jaegertracing/jaeger/pkg/recoveryhandler"
Expand Down Expand Up @@ -158,7 +159,7 @@ func createHTTPServer(querySvc *querysvc.QueryService, metricsQuerySvc querysvc.
var handler http.Handler = r
handler = additionalHeadersHandler(handler, queryOpts.AdditionalHeaders)
if queryOpts.BearerTokenPropagation {
handler = bearerTokenPropagationHandler(logger, handler)
handler = bearertoken.PropagationHandler(logger, handler)
}
handler = handlers.CompressHandler(handler)
recoveryHandler := recoveryhandler.NewRecoveryHandler(logger, true)
Expand Down
4 changes: 2 additions & 2 deletions cmd/query/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ import (
"github.com/jaegertracing/jaeger/cmd/query/app"
"github.com/jaegertracing/jaeger/cmd/query/app/querysvc"
"github.com/jaegertracing/jaeger/cmd/status"
"github.com/jaegertracing/jaeger/pkg/bearertoken"
"github.com/jaegertracing/jaeger/pkg/config"
"github.com/jaegertracing/jaeger/pkg/version"
metricsPlugin "github.com/jaegertracing/jaeger/plugin/metrics"
"github.com/jaegertracing/jaeger/plugin/storage"
"github.com/jaegertracing/jaeger/ports"
"github.com/jaegertracing/jaeger/storage/spanstore"
storageMetrics "github.com/jaegertracing/jaeger/storage/spanstore/metrics"
)

Expand Down Expand Up @@ -95,7 +95,7 @@ func main() {
opentracing.SetGlobalTracer(tracer)
queryOpts := new(app.QueryOptions).InitFromViper(v, logger)
// TODO: Need to figure out set enable/disable propagation on storage plugins.
v.Set(spanstore.StoragePropagationKey, queryOpts.BearerTokenPropagation)
v.Set(bearertoken.StoragePropagationKey, queryOpts.BearerTokenPropagation)
storageFactory.InitFromViper(v, logger)
if err := storageFactory.Initialize(baseFactory, logger); err != nil {
logger.Fatal("Failed to init storage factory", zap.Error(err))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019 The Jaeger Authors.
// Copyright (c) 2021 The Jaeger Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,30 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package spanstore
package bearertoken

import "context"

type contextKey string
type contextKeyType int

// BearerTokenKey is the string literal used internally in the implementation of this context.
const BearerTokenKey = "bearer.token"
const bearerToken = contextKey(BearerTokenKey)
const contextKey = contextKeyType(iota)

// StoragePropagationKey is a key for viper configuration to pass this option to storage plugins.
const StoragePropagationKey = "storage.propagate.token"

// ContextWithBearerToken set bearer token in context
// ContextWithBearerToken set bearer token in context.
func ContextWithBearerToken(ctx context.Context, token string) context.Context {
if token == "" {
return ctx
}
return context.WithValue(ctx, bearerToken, token)

return context.WithValue(ctx, contextKey, token)
}

// GetBearerToken from context, or empty string if there is no token
// GetBearerToken from context, or empty string if there is no token.
func GetBearerToken(ctx context.Context) (string, bool) {
val, ok := ctx.Value(bearerToken).(string)
val, ok := ctx.Value(contextKey).(string)
return val, ok
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2019 The Jaeger Authors.
// Copyright (c) 2021 The Jaeger Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package spanstore
package bearertoken

import (
"context"
Expand All @@ -22,7 +22,7 @@ import (
)

func Test_GetBearerToken(t *testing.T) {
token := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI"
const token = "blah"
ctx := context.Background()
ctx = ContextWithBearerToken(ctx, token)
contextToken, ok := GetBearerToken(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package app
package bearertoken

import (
"net/http"
"strings"

"go.uber.org/zap"

"github.com/jaegertracing/jaeger/storage/spanstore"
)

func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Handler {
// PropagationHandler returns a http.Handler containing the logic to extract
// the Bearer token from the Authorization header of the http.Request and insert it into request.Context
// for propagation. The token can be accessed via GetBearerToken.
func PropagationHandler(logger *zap.Logger, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
authHeaderValue := r.Header.Get("Authorization")
Expand All @@ -40,15 +41,14 @@ func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Hand
token = headerValue[1]
}
} else if len(headerValue) == 1 {
// Tread all value as a token
// Treat the entire value as a token.
token = authHeaderValue
} else {
logger.Warn("Invalid authorization header value, skipping token propagation")
}
h.ServeHTTP(w, r.WithContext(spanstore.ContextWithBearerToken(ctx, token)))
h.ServeHTTP(w, r.WithContext(ContextWithBearerToken(ctx, token)))
} else {
h.ServeHTTP(w, r.WithContext(ctx))
}
})

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,44 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package app
package bearertoken

import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"go.uber.org/zap"

"github.com/jaegertracing/jaeger/storage/spanstore"
)

func Test_bearTokenPropagationHandler(t *testing.T) {
func Test_PropagationHandler(t *testing.T) {
httpClient := &http.Client{
Timeout: 2 * time.Second,
}

logger := zap.NewNop()
bearerToken := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJhZG1pbiIsIm5hbWUiOiJKb2huIERvZSIsImlhdCI"
const bearerToken = "blah"

validTokenHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, ok := spanstore.GetBearerToken(ctx)
token, ok := GetBearerToken(ctx)
assert.Equal(t, token, bearerToken)
assert.True(t, ok)
stop.Done()
})
}
}

emptyHandler := func(stop *sync.WaitGroup) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, _ := spanstore.GetBearerToken(ctx)
token, _ := GetBearerToken(ctx)
assert.Empty(t, token, bearerToken)
stop.Done()
})
}
}

testCases := []struct {
Expand All @@ -68,7 +71,7 @@ func Test_bearTokenPropagationHandler(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
stop := sync.WaitGroup{}
stop.Add(1)
r := bearerTokenPropagationHandler(logger, testCase.handler(&stop))
r := PropagationHandler(logger, testCase.handler(&stop))
server := httptest.NewServer(r)
defer server.Close()
req, err := http.NewRequest("GET", server.URL, nil)
Expand All @@ -81,5 +84,4 @@ func Test_bearTokenPropagationHandler(t *testing.T) {
stop.Wait()
})
}

}
52 changes: 52 additions & 0 deletions pkg/bearertoken/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2021 The Jaeger Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bearertoken

import (
"errors"
"net/http"
)

// RoundTripper wraps another http.RoundTripper and injects
// an authentication header with bearer token into requests.
type RoundTripper struct {
// Transport is the underlying http.RoundTripper being wrapped. Required.
Transport http.RoundTripper

// StaticToken is the pre-configured bearer token. Optional.
StaticToken string

// OverrideFromCtx enables reading bearer token from Context.
OverrideFromCtx bool
}

// RoundTrip injects the outbound Authorization header with the
// token provided in the inbound request.
func (tr RoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
if tr.Transport == nil {
return nil, errors.New("no http.RoundTripper provided")
}
token := tr.StaticToken
if tr.OverrideFromCtx {
headerToken, _ := GetBearerToken(r.Context())
if headerToken != "" {
token = headerToken
}
}
if token != "" {
r.Header.Set("Authorization", "Bearer "+token)
}
return tr.Transport.RoundTrip(r)
}
111 changes: 111 additions & 0 deletions pkg/bearertoken/transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2021 The Jaeger Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bearertoken

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type roundTripFunc func(r *http.Request) (*http.Response, error)

func (s roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return s(r)
}

func TestRoundTripper(t *testing.T) {
for _, tc := range []struct {
name string
staticToken string
overrideFromCtx bool
wrappedTransport http.RoundTripper
requestContext context.Context
wantError bool
}{
{
name: "Default RoundTripper and request context set should have empty Bearer token",
wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Empty(t, r.Header.Get("Authorization"))
return &http.Response{
StatusCode: http.StatusOK,
}, nil
}),
requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"),
},
{
name: "Override from context provided, and request context set should use request context token",
overrideFromCtx: true,
wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer tokenFromContext", r.Header.Get("Authorization"))
return &http.Response{
StatusCode: http.StatusOK,
}, nil
}),
requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"),
},
{
name: "Allow override from context and token provided, and request context unset should use defaultToken",
overrideFromCtx: true,
staticToken: "initToken",
wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer initToken", r.Header.Get("Authorization"))
return &http.Response{}, nil
}),
requestContext: context.Background(),
},
{
name: "Allow override from context and token provided, and request context set should use context token",
overrideFromCtx: true,
staticToken: "initToken",
wrappedTransport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Bearer tokenFromContext", r.Header.Get("Authorization"))
return &http.Response{}, nil
}),
requestContext: ContextWithBearerToken(context.Background(), "tokenFromContext"),
},
{
name: "Nil roundTripper provided should return an error",
requestContext: context.Background(),
wantError: true,
},
} {
t.Run(tc.name, func(t *testing.T) {
server := httptest.NewServer(nil)
defer server.Close()
req, err := http.NewRequestWithContext(tc.requestContext, "GET", server.URL, nil)
require.NoError(t, err)

tr := RoundTripper{
Transport: tc.wrappedTransport,
OverrideFromCtx: tc.overrideFromCtx,
StaticToken: tc.staticToken,
}
resp, err := tr.RoundTrip(req)

if tc.wantError {
assert.Nil(t, resp)
assert.Error(t, err)
} else {
assert.NotNil(t, resp)
assert.NoError(t, err)
}
})
}
}
Loading

0 comments on commit 4aa436b

Please sign in to comment.