Skip to content

Commit

Permalink
Cosmos DB: Add AAD authentication (Azure#17742)
Browse files Browse the repository at this point in the history
Adding azcosmos.NewClient with support for TokenCredential
  • Loading branch information
ealsur authored Apr 27, 2022
1 parent ffb48f0 commit 3f7acd2
Show file tree
Hide file tree
Showing 12 changed files with 413 additions and 5 deletions.
1 change: 1 addition & 0 deletions sdk/data/azcosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Features Added
* Added single partition query support.
* Added Azure AD authentication support through `azcosmos.NewClient`

### Breaking Changes
* This module now requires Go 1.18
Expand Down
18 changes: 17 additions & 1 deletion sdk/data/azcosmos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,29 @@ The following section provides several code snippets covering some of the most c

### Create Cosmos Client

The clients support different forms of authentication. The azcosmos library supports authorization via Azure Active Directory or an account key.

**Using Azure Active Directory**

```go
import "github.com/Azure/azure-sdk-for-go/sdk/azidentity"

cred, err := azidentity.NewDefaultAzureCredential(nil)
handle(err)
client, err := azcosmos.NewClient("myAccountEndpointURL", cred, nil)
handle(err)
```

**Using account keys**

```go
const (
cosmosDbEndpoint = "someEndpoint"
cosmosDbKey = "someKey"
)

cred, _ := azcosmos.NewKeyCredential(cosmosDbKey)
cred, err := azcosmos.NewKeyCredential(cosmosDbKey)
handle(err)
client, err := azcosmos.NewClientWithKey(cosmosDbEndpoint, cred, nil)
handle(err)
```
Expand Down
32 changes: 28 additions & 4 deletions sdk/data/azcosmos/cosmos_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@ import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
Expand All @@ -26,31 +29,52 @@ func (c *Client) Endpoint() string {
return c.endpoint
}

// NewClientWithKey creates a new instance of Cosmos client with the specified values. It uses the default pipeline configuration.
// NewClientWithKey creates a new instance of Cosmos client with shared key authentication. It uses the default pipeline configuration.
// endpoint - The cosmos service endpoint to use.
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClientWithKey(endpoint string, cred KeyCredential, o *ClientOptions) (*Client, error) {
return &Client{endpoint: endpoint, pipeline: newPipeline(cred, o)}, nil
return &Client{endpoint: endpoint, pipeline: newPipeline([]policy.Policy{newSharedKeyCredPolicy(cred)}, o)}, nil
}

func newPipeline(cred KeyCredential, options *ClientOptions) azruntime.Pipeline {
// NewClient creates a new instance of Cosmos client with Azure AD access token authentication. It uses the default pipeline configuration.
// endpoint - The cosmos service endpoint to use.
// cred - The credential used to authenticate with the cosmos service.
// options - Optional Cosmos client options. Pass nil to accept default values.
func NewClient(endpoint string, cred azcore.TokenCredential, o *ClientOptions) (*Client, error) {
scope, err := createScopeFromEndpoint(endpoint)
if err != nil {
return nil, err
}
return &Client{endpoint: endpoint, pipeline: newPipeline([]policy.Policy{azruntime.NewBearerTokenPolicy(cred, scope, nil), &cosmosBearerTokenPolicy{}}, o)}, nil
}

func newPipeline(authPolicy []policy.Policy, options *ClientOptions) azruntime.Pipeline {
if options == nil {
options = &ClientOptions{}
}

return azruntime.NewPipeline("azcosmos", serviceLibVersion,
azruntime.PipelineOptions{
PerCall: []policy.Policy{
newSharedKeyCredPolicy(cred),
&headerPolicies{
enableContentResponseOnWrite: options.EnableContentResponseOnWrite,
},
},
PerRetry: authPolicy,
},
&options.ClientOptions)
}

func createScopeFromEndpoint(endpoint string) ([]string, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}

return []string{fmt.Sprintf("%s://%s/.default", u.Scheme, u.Hostname())}, nil
}

// NewDatabase returns a struct that represents a database and allows database level operations.
// id - The id of the database.
func (c *Client) NewDatabase(id string) (*DatabaseClient, error) {
Expand Down
16 changes: 16 additions & 0 deletions sdk/data/azcosmos/cosmos_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,22 @@ func TestSendQuery(t *testing.T) {
}
}

func TestCreateScopeFromEndpoint(t *testing.T) {
url := "https://foo.documents.azure.com:443/"
scope, err := createScopeFromEndpoint(url)
if err != nil {
t.Fatal(err)
}

if scope[0] != "https://foo.documents.azure.com/.default" {
t.Errorf("Expected %v, but got %v", "https://foo.documents.azure.com/.default", scope[0])
}

if len(scope) != 1 {
t.Errorf("Expected %v, but got %v", 1, len(scope))
}
}

type pipelineVerifier struct {
requests []pipelineVerifierRequest
}
Expand Down
28 changes: 28 additions & 0 deletions sdk/data/azcosmos/cosmos_policy_bearer_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"errors"
"fmt"
"net/http"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
)

const lenBearerTokenPrefix = len("Bearer ")

type cosmosBearerTokenPolicy struct {
}

func (b *cosmosBearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
currentAuthorization := req.Raw().Header.Get(headerAuthorization)
if currentAuthorization == "" {
return nil, errors.New("authorization header is missing")
}

token := currentAuthorization[lenBearerTokenPrefix:]
req.Raw().Header.Set(headerAuthorization, fmt.Sprintf("type=aad&ver=1.0&sig=%v", token))
return req.Next()
}
58 changes: 58 additions & 0 deletions sdk/data/azcosmos/cosmos_policy_bearer_token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"context"
"net/http"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
)

func TestConvertBearerToken(t *testing.T) {
srv, close := mock.NewTLSServer()
defer close()
srv.SetResponse(mock.WithStatusCode(http.StatusOK))

verifier := bearerTokenVerify{}
pl := azruntime.NewPipeline("azcosmostest", "v1.0.0", azruntime.PipelineOptions{PerCall: []policy.Policy{&mockAuthPolicy{}, &cosmosBearerTokenPolicy{}, &verifier}}, &policy.ClientOptions{Transport: srv})
req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL())
req.SetOperationValue(pipelineRequestOptions{
isWriteOperation: true,
})

if err != nil {
t.Fatalf("unexpected error: %v", err)
}

_, err = pl.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if verifier.authHeaderContent != "type=aad&ver=1.0&sig=this is a test token" {
t.Fatalf("Expected auth header content to be 'type=aad&ver=1.0&sig=this is a test token', got %s", verifier.authHeaderContent)
}
}

type bearerTokenVerify struct {
authHeaderContent string
}

func (p *bearerTokenVerify) Do(req *policy.Request) (*http.Response, error) {
p.authHeaderContent = req.Raw().Header.Get(headerAuthorization)

return req.Next()
}

type mockAuthPolicy struct{}

func (p *mockAuthPolicy) Do(req *policy.Request) (*http.Response, error) {
req.Raw().Header.Set(headerAuthorization, "Bearer this is a test token")

return req.Next()
}
13 changes: 13 additions & 0 deletions sdk/data/azcosmos/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ The azcosmos package is capable of:
Creating the Client
Types of Credentials
The clients support different forms of authentication. The azcosmos library supports
authorization via Azure Active Directory or an account key.
Using Azure Active Directory
To create a client, you can use any of the TokenCredential implementations provided by `azidentity`.
cred, err := azidentity.NewClientSecretCredential("tenantId", "clientId", "clientSecret")
handle(err)
client, err := azcosmos.NewClient("myAccountEndpointURL", cred, nil)
handle(err)
Using account keys
To create a client, you will need the account's endpoint URL and a key credential.
cred, err := azcosmos.NewKeyCredential("myAccountKey")
Expand Down
140 changes: 140 additions & 0 deletions sdk/data/azcosmos/emulator_cosmos_aad_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package azcosmos

import (
"context"
"encoding/json"
"testing"
)

func TestAAD(t *testing.T) {
emulatorTests := newEmulatorTests(t)
client := emulatorTests.getClient(t)

database := emulatorTests.createDatabase(t, context.TODO(), client, "aadTest")
defer emulatorTests.deleteDatabase(t, context.TODO(), database)
properties := ContainerProperties{
ID: "aContainer",
PartitionKeyDefinition: PartitionKeyDefinition{
Paths: []string{"/id"},
},
}

_, err := database.CreateContainer(context.TODO(), properties, nil)
if err != nil {
t.Fatalf("Failed to create container: %v", err)
}

aadClient := emulatorTests.getAadClient(t)

item := map[string]string{
"id": "1",
"value": "2",
}

container, _ := aadClient.NewContainer("aadTest", "aContainer")
pk := NewPartitionKeyString("1")

marshalled, err := json.Marshal(item)
if err != nil {
t.Fatal(err)
}

itemResponse, err := container.CreateItem(context.TODO(), pk, marshalled, nil)
if err != nil {
t.Fatalf("Failed to create item: %v", err)
}

if itemResponse.SessionToken == "" {
t.Fatalf("Session token is empty")
}

// No content on write by default
if len(itemResponse.Value) != 0 {
t.Fatalf("Expected empty response, got %v", itemResponse.Value)
}

itemResponse, err = container.ReadItem(context.TODO(), pk, "1", nil)
if err != nil {
t.Fatalf("Failed to read item: %v", err)
}

if len(itemResponse.Value) == 0 {
t.Fatalf("Expected non-empty response, got %v", itemResponse.Value)
}

var itemResponseBody map[string]interface{}
err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
if err != nil {
t.Fatalf("Failed to unmarshal item response: %v", err)
}
if itemResponseBody["id"] != "1" {
t.Fatalf("Expected id to be 1, got %v", itemResponseBody["id"])
}
if itemResponseBody["value"] != "2" {
t.Fatalf("Expected value to be 2, got %v", itemResponseBody["value"])
}

item["value"] = "3"
marshalled, err = json.Marshal(item)
if err != nil {
t.Fatal(err)
}
itemResponse, err = container.ReplaceItem(context.TODO(), pk, "1", marshalled, &ItemOptions{EnableContentResponseOnWrite: true})
if err != nil {
t.Fatalf("Failed to replace item: %v", err)
}

// Explicitly requesting body on write
if len(itemResponse.Value) == 0 {
t.Fatalf("Expected non-empty response, got %v", itemResponse.Value)
}

err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
if err != nil {
t.Fatalf("Failed to unmarshal item response: %v", err)
}
if itemResponseBody["id"] != "1" {
t.Fatalf("Expected id to be 1, got %v", itemResponseBody["id"])
}
if itemResponseBody["value"] != "3" {
t.Fatalf("Expected value to be 3, got %v", itemResponseBody["value"])
}

item["value"] = "4"
marshalled, err = json.Marshal(item)
if err != nil {
t.Fatal(err)
}
itemResponse, err = container.UpsertItem(context.TODO(), pk, marshalled, &ItemOptions{EnableContentResponseOnWrite: true})
if err != nil {
t.Fatalf("Failed to upsert item: %v", err)
}

// Explicitly requesting body on write
if len(itemResponse.Value) == 0 {
t.Fatalf("Expected non-empty response, got %v", itemResponse.Value)
}

err = json.Unmarshal(itemResponse.Value, &itemResponseBody)
if err != nil {
t.Fatalf("Failed to unmarshal item response: %v", err)
}
if itemResponseBody["id"] != "1" {
t.Fatalf("Expected id to be 1, got %v", itemResponseBody["id"])
}
if itemResponseBody["value"] != "4" {
t.Fatalf("Expected value to be 4, got %v", itemResponseBody["value"])
}

itemResponse, err = container.DeleteItem(context.TODO(), pk, "1", nil)
if err != nil {
t.Fatalf("Failed to replace item: %v", err)
}

if len(itemResponse.Value) != 0 {
t.Fatalf("Expected empty response, got %v", itemResponse.Value)
}
}
Loading

0 comments on commit 3f7acd2

Please sign in to comment.