Skip to content

Commit

Permalink
feat(auth): add idtoken package (#8580)
Browse files Browse the repository at this point in the history
  • Loading branch information
codyoss authored Oct 3, 2023
1 parent 5feb3ea commit a79e693
Show file tree
Hide file tree
Showing 12 changed files with 1,543 additions and 58 deletions.
133 changes: 133 additions & 0 deletions auth/idtoken/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
)

type cachingClient struct {
client *http.Client

// clock optionally specifies a func to return the current time.
// If nil, time.Now is used.
clock func() time.Time

mu sync.Mutex
certs map[string]*cachedResponse
}

func newCachingClient(client *http.Client) *cachingClient {
return &cachingClient{
client: client,
certs: make(map[string]*cachedResponse, 2),
}
}

type cachedResponse struct {
resp *certResponse
exp time.Time
}

func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
if response, ok := c.get(url); ok {
return response, nil
}
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
}

certResp := &certResponse{}
if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
return nil, err

}
c.set(url, certResp, resp.Header)
return certResp, nil
}

func (c *cachingClient) now() time.Time {
if c.clock != nil {
return c.clock()
}
return time.Now()
}

func (c *cachingClient) get(url string) (*certResponse, bool) {
c.mu.Lock()
defer c.mu.Unlock()
cachedResp, ok := c.certs[url]
if !ok {
return nil, false
}
if c.now().After(cachedResp.exp) {
return nil, false
}
return cachedResp.resp, true
}

func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
exp := c.calculateExpireTime(headers)
c.mu.Lock()
c.certs[url] = &cachedResponse{resp: resp, exp: exp}
c.mu.Unlock()
}

// calculateExpireTime will determine the expire time for the cache based on
// HTTP headers. If there is any difficulty reading the headers the fallback is
// to set the cache to expire now.
func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
var maxAge int
cc := strings.Split(headers.Get("cache-control"), ",")
for _, v := range cc {
if strings.Contains(v, "max-age") {
ss := strings.Split(v, "=")
if len(ss) < 2 {
return c.now()
}
ma, err := strconv.Atoi(ss[1])
if err != nil {
return c.now()
}
maxAge = ma
}
}
a := headers.Get("age")
if a == "" {
return c.now().Add(time.Duration(maxAge) * time.Second)
}
age, err := strconv.Atoi(a)
if err != nil {
return c.now()
}
return c.now().Add(time.Duration(maxAge-age) * time.Second)
}
82 changes: 82 additions & 0 deletions auth/idtoken/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"net/http"
"sync"
"testing"
"time"
)

type fakeClock struct {
mu sync.Mutex
t time.Time
}

func (c *fakeClock) Now() time.Time {
c.mu.Lock()
defer c.mu.Unlock()
return c.t
}

func (c *fakeClock) Sleep(d time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.t = c.t.Add(d)
}

func TestCacheHit(t *testing.T) {
clock := &fakeClock{t: time.Now()}
fakeResp := &certResponse{
Keys: []jwk{
{
Kid: "123",
},
},
}
cache := newCachingClient(nil)
cache.clock = clock.Now

// Cache should be empty
cert, ok := cache.get(googleSACertsURL)
if ok || cert != nil {
t.Fatal("cache for SA certs should be empty")
}

// Add an item, but make it expire now
cache.set(googleSACertsURL, fakeResp, make(http.Header))
clock.Sleep(time.Nanosecond) // it expires when current time is > expiration, not >=
cert, ok = cache.get(googleSACertsURL)
if ok || cert != nil {
t.Fatal("cache for SA certs should be expired")
}

// Add an item that expires in 1 seconds
h := make(http.Header)
h.Set("age", "0")
h.Set("cache-control", "public, max-age=1, must-revalidate, no-transform")
cache.set(googleSACertsURL, fakeResp, h)
cert, ok = cache.get(googleSACertsURL)
if !ok || cert == nil || cert.Keys[0].Kid != "123" {
t.Fatal("cache for SA certs have a resp")
}
// Wait
clock.Sleep(2 * time.Second)
cert, ok = cache.get(googleSACertsURL)
if ok || cert != nil {
t.Fatal("cache for SA certs should be expired")
}
}
77 changes: 77 additions & 0 deletions auth/idtoken/compute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"context"
"fmt"
"net/url"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/auth/internal"
"cloud.google.com/go/compute/metadata"
)

const identitySuffix = "instance/service-accounts/default/identity"

// computeTokenProvider checks if this code is being run on GCE. If it is, it
// will use the metadata service to build a TokenProvider that fetches ID
// tokens.
func computeTokenProvider(opts *Options) (auth.TokenProvider, error) {
if opts.CustomClaims != nil {
return nil, fmt.Errorf("idtoken: Options.CustomClaims can't be used with the metadata service, please provide a service account if you would like to use this feature")
}
tp := computeIDTokenProvider{
audience: opts.Audience,
format: opts.ComputeTokenFormat,
client: *metadata.NewClient(opts.client()),
}
return auth.NewCachedTokenProvider(tp, &auth.CachedTokenProviderOptions{
ExpireEarly: 5 * time.Minute,
}), nil
}

type computeIDTokenProvider struct {
audience string
format ComputeTokenFormat
client metadata.Client
}

func (c computeIDTokenProvider) Token(ctx context.Context) (*auth.Token, error) {
v := url.Values{}
v.Set("audience", c.audience)
if c.format != ComputeTokenFormatStandard {
v.Set("format", "full")
}
if c.format == ComputeTokenFormatFullWithLicense {
v.Set("licenses", "TRUE")
}
urlSuffix := identitySuffix + "?" + v.Encode()
res, err := c.client.Get(urlSuffix)
if err != nil {
return nil, err
}
if res == "" {
return nil, fmt.Errorf("idtoken: invalid empty response from metadata service")
}
return &auth.Token{
Value: res,
Type: internal.TokenTypeBearer,
// Compute tokens are valid for one hour:
// https://cloud.google.com/iam/docs/create-short-lived-credentials-direct#create-id
Expiry: time.Now().Add(1 * time.Hour),
}, nil
}
102 changes: 102 additions & 0 deletions auth/idtoken/compute_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package idtoken

import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

const metadataHostEnv = "GCE_METADATA_HOST"

func TestComputeTokenSource(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, identitySuffix) {
t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix)
}
if got, want := r.URL.Query().Get("audience"), "aud"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("format"), "full"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("licenses"), "TRUE"; got != want {
t.Errorf("got %q, want %q", got, want)
}
w.Write([]byte(`fake_token`))
}))
defer ts.Close()
t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://"))
tp, err := computeTokenProvider(&Options{
Audience: "aud",
ComputeTokenFormat: ComputeTokenFormatFullWithLicense,
})
if err != nil {
t.Fatalf("computeTokenProvider() = %v", err)
}
tok, err := tp.Token(context.Background())
if err != nil {
t.Fatalf("tp.Token() = %v", err)
}
if want := "fake_token"; tok.Value != want {
t.Errorf("got %q, want %q", tok.Value, want)
}
}

func TestComputeTokenSource_Standard(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.URL.Path, identitySuffix) {
t.Errorf("got %q, want contains %q", r.URL.Path, identitySuffix)
}
if got, want := r.URL.Query().Get("audience"), "aud"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("format"), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
if got, want := r.URL.Query().Get("licenses"), ""; got != want {
t.Errorf("got %q, want %q", got, want)
}
w.Write([]byte(`fake_token`))
}))
defer ts.Close()
t.Setenv(metadataHostEnv, strings.TrimPrefix(ts.URL, "http://"))
tp, err := computeTokenProvider(&Options{
Audience: "aud",
ComputeTokenFormat: ComputeTokenFormatStandard,
})
if err != nil {
t.Fatalf("computeTokenProvider() = %v", err)
}
tok, err := tp.Token(context.Background())
if err != nil {
t.Fatalf("tp.Token() = %v", err)
}
if want := "fake_token"; tok.Value != want {
t.Errorf("got %q, want %q", tok.Value, want)
}
}

func TestComputeTokenSource_Invalid(t *testing.T) {
if _, err := computeTokenProvider(&Options{
Audience: "aud",
CustomClaims: map[string]interface{}{"foo": "bar"},
}); err == nil {
t.Fatal("computeTokenProvider() = nil, expected non-nil error", err)
}
}
Loading

0 comments on commit a79e693

Please sign in to comment.