Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(oidc): OIDC Prompt #4053

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/authmethods/oidc_auth_method_attributes.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions api/authmethods/option.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 31 additions & 1 deletion internal/auth/oidc/auth_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ func NewAuthMethod(ctx context.Context, scopeId string, clientId string, clientS
a.SigningAlgs = append(a.SigningAlgs, string(alg))
}
}
if len(opts.withPrompts) > 0 {
a.Prompts = make([]string, 0, len(opts.withPrompts))
for _, prompts := range opts.withPrompts {
a.Prompts = append(a.Prompts, string(prompts))
}
}
if len(opts.withAccountClaimMap) > 0 {
a.AccountClaimMaps = make([]string, 0, len(opts.withAccountClaimMap))
for k, v := range opts.withAccountClaimMap {
Expand Down Expand Up @@ -282,6 +288,7 @@ type convertedValues struct {
Certs []any
ClaimsScopes []any
AccountClaimMaps []any
Prompts []any
}

// convertValueObjects converts the embedded value objects. It will return an
Expand All @@ -292,7 +299,7 @@ func (am *AuthMethod) convertValueObjects(ctx context.Context) (*convertedValues
return nil, errors.New(ctx, errors.InvalidPublicId, op, "missing public id")
}
var err error
var addAlgs, addAuds, addCerts, addScopes, addAccountClaimMaps []any
var addAlgs, addAuds, addCerts, addScopes, addAccountClaimMaps, addPrompts []any
if addAlgs, err = am.convertSigningAlgs(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
Expand All @@ -308,12 +315,16 @@ func (am *AuthMethod) convertValueObjects(ctx context.Context) (*convertedValues
if addAccountClaimMaps, err = am.convertAccountClaimMaps(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
if addPrompts, err = am.convertPrompts(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return &convertedValues{
Algs: addAlgs,
Auds: addAuds,
Certs: addCerts,
ClaimsScopes: addScopes,
AccountClaimMaps: addAccountClaimMaps,
Prompts: addPrompts,
}, nil
}

Expand Down Expand Up @@ -458,3 +469,22 @@ func ParseAccountClaimMaps(ctx context.Context, m ...string) ([]ClaimMap, error)
}
return claimMap, nil
}

// convertPrompts converts the embedded prompts from []string
// to []interface{} where each slice element is a *Prompt. It will return an
// error if the AuthMethod's public id is not set.
func (am *AuthMethod) convertPrompts(ctx context.Context) ([]any, error) {
const op = "oidc.(AuthMethod).convertPrompts"
if am.PublicId == "" {
return nil, errors.New(ctx, errors.InvalidPublicId, op, "missing public id")
}
newInterfaces := make([]any, 0, len(am.Prompts))
for _, a := range am.Prompts {
obj, err := NewPrompt(ctx, am.PublicId, PromptParam(a))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
newInterfaces = append(newInterfaces, obj)
}
return newInterfaces, nil
}
20 changes: 20 additions & 0 deletions internal/auth/oidc/auth_method_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,14 @@ func Test_convertValueObjects(t *testing.T) {
testAccountClaimMaps = append(testAccountClaimMaps, obj)
}

testPrompts := []string{"consent", "select_account"}
testExpectedPrompts := make([]any, 0, len(testPrompts))
for _, a := range testPrompts {
obj, err := NewPrompt(ctx, testPublicId, PromptParam(a))
require.NoError(t, err)
testExpectedPrompts = append(testExpectedPrompts, obj)
}

tests := []struct {
name string
authMethodId string
Expand All @@ -599,6 +607,7 @@ func Test_convertValueObjects(t *testing.T) {
certs []string
scopes []string
maps []string
prompts []string
wantValues *convertedValues
wantErrMatch *errors.Template
wantErrContains string
Expand All @@ -611,12 +620,14 @@ func Test_convertValueObjects(t *testing.T) {
certs: testCerts,
scopes: testScopes,
maps: testClaimMaps,
prompts: testPrompts,
wantValues: &convertedValues{
Algs: testSigningAlgs,
Auds: testAudiences,
Certs: testCertificates,
ClaimsScopes: testClaimsScopes,
AccountClaimMaps: testAccountClaimMaps,
Prompts: testExpectedPrompts,
},
},
{
Expand All @@ -636,6 +647,7 @@ func Test_convertValueObjects(t *testing.T) {
Certificates: tt.certs,
ClaimsScopes: tt.scopes,
AccountClaimMaps: tt.maps,
Prompts: tt.prompts,
},
}

Expand Down Expand Up @@ -693,6 +705,14 @@ func Test_convertValueObjects(t *testing.T) {
assert.Equal(want, got)
}

convertedPrompts, err := am.convertPrompts(ctx)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
} else {
assert.Equal(tt.wantValues.Prompts, convertedPrompts)
}

values, err := am.convertValueObjects(ctx)
if tt.wantErrMatch != nil {
require.Error(err)
Expand Down
77 changes: 77 additions & 0 deletions internal/auth/oidc/immutable_fields_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,80 @@ func TestAccount_ImmutableFields(t *testing.T) {
})
}
}

func TestPrompt_ImmutableFields(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
rw := db.New(conn)
ctx := context.Background()
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)

ts := timestamp.Timestamp{Timestamp: &timestamppb.Timestamp{Seconds: 0, Nanos: 0}}

am := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, InactiveState, "alice_rp", "my-dogs-name",
WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]), WithPrompts(SelectAccount))

new := AllocPrompt()
require.NoError(t, rw.LookupWhere(ctx, &new, "oidc_method_id = ? and prompt = ?", []any{am.PublicId, SelectAccount}))

tests := []struct {
name string
update *Prompt
fieldMask []string
}{
{
name: "oidc_method_id",
update: func() *Prompt {
cp := new.Clone()
cp.OidcMethodId = "p_thisIsNotAValidId"
return cp
}(),
fieldMask: []string{"PublicId"},
},
{
name: "create time",
update: func() *Prompt {
cp := new.Clone()
cp.CreateTime = &ts
return cp
}(),
fieldMask: []string{"CreateTime"},
},
{
name: "prompt",
update: func() *Prompt {
cp := new.Clone()
cp.PromptParam = string(Consent)
return cp
}(),
fieldMask: []string{"PromptParam"},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)

orig := new.Clone()
orig.SetTableName(defaultAuthMethodTableName)
require.NoError(rw.LookupWhere(ctx, &new, "oidc_method_id = ? and prompt = ?", []any{orig.OidcMethodId, orig.PromptParam}))

require.NoError(err)

tt.update.SetTableName(defaultAuthMethodTableName)
rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true))
require.Error(err)
assert.Equal(0, rowsUpdated)

after := new.Clone()
after.SetTableName(defaultAuthMethodTableName)
require.NoError(rw.LookupWhere(ctx, &new, "oidc_method_id = ? and prompt = ?", []any{after.OidcMethodId, after.PromptParam}))

assert.True(proto.Equal(orig, after))
})
}
}
8 changes: 8 additions & 0 deletions internal/auth/oidc/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type options struct {
withAudClaims []string
withSigningAlgs []Alg
withClaimsScopes []string
withPrompts []PromptParam
withEmail string
withFullName string
withOrderByCreateTime bool
Expand Down Expand Up @@ -232,3 +233,10 @@ func WithReader(reader db.Reader) Option {
o.withReader = reader
}
}

// WithPrompts provides optional prompts
func WithPrompts(prompt ...PromptParam) Option {
return func(o *options) {
o.withPrompts = prompt
}
}
117 changes: 117 additions & 0 deletions internal/auth/oidc/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package oidc

import (
"context"
"fmt"

"github.com/hashicorp/boundary/internal/auth/oidc/store"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/cap/oidc"
"google.golang.org/protobuf/proto"
)

// Prompt represents OIDC authentication prompt
type PromptParam string

const (
// Prompt values defined by OpenID specs.
// See: https://openid.net/specs/openid-connect-basic-1_0.html#RequestParameters
None PromptParam = "none"
Login PromptParam = "login"
Consent PromptParam = "consent"
SelectAccount PromptParam = "select_account"
)

var supportedPrompts = map[PromptParam]bool{
None: true,
Login: true,
Consent: true,
SelectAccount: true,
}

// SupportedPrompt returns true if the provided prompt is supported
// by boundary.
func SupportedPrompt(p PromptParam) bool {
return supportedPrompts[p]
}

// defaultPromptTableName defines the default table name for a Prompt
const defaultPromptTableName = "auth_oidc_prompt"

// Prompt defines an prompt supported by an OIDC auth method.
// It is assigned to an OIDC AuthMethod and updates/deletes to that AuthMethod
// are cascaded to its Prompts. Prompts are value objects of an AuthMethod,
// therefore there's no need for oplog metadata, since only the AuthMethod will have
// metadata because it's the root aggregate.
type Prompt struct {
*store.Prompt
tableName string
}

// NewPrompt creates a new in memory prompt assigned to an OIDC
// AuthMethod. It supports no options.
func NewPrompt(ctx context.Context, authMethodId string, p PromptParam) (*Prompt, error) {
const op = "oidc.NewPrompt"
prompt := &Prompt{
Prompt: &store.Prompt{
OidcMethodId: authMethodId,
PromptParam: string(p),
},
}
if err := prompt.validate(ctx, op); err != nil {
return nil, err // intentionally not wrapped
}
return prompt, nil
}

// validate the Prompt. On success, it will return nil.
func (s *Prompt) validate(ctx context.Context, caller errors.Op) error {
if s.OidcMethodId == "" {
return errors.New(ctx, errors.InvalidParameter, caller, "missing oidc auth method id")
}
if _, ok := supportedPrompts[PromptParam(s.PromptParam)]; !ok {
return errors.New(ctx, errors.InvalidParameter, caller, fmt.Sprintf("unsupported prompt: %s", s.Prompt))
}
return nil
}

func convertToOIDCPrompts(ctx context.Context, p []string) []oidc.Prompt {
prompts := make([]oidc.Prompt, 0, len(p))
for _, a := range p {
prompt := oidc.Prompt(a)
prompts = append(prompts, prompt)
}

return prompts
}

// AllocPrompt makes an empty one in memory
func AllocPrompt() Prompt {
return Prompt{
Prompt: &store.Prompt{},
}
}

// Clone a Prompt
func (s *Prompt) Clone() *Prompt {
cp := proto.Clone(s.Prompt)
return &Prompt{
Prompt: cp.(*store.Prompt),
}
}

// TableName returns the table name.
func (s *Prompt) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultPromptTableName
}

// SetTableName sets the table name.
func (s *Prompt) SetTableName(n string) {
s.tableName = n
}
Loading
Loading