diff --git a/internal/auth/ldap/repository_auth_method_update.go b/internal/auth/ldap/repository_auth_method_update.go new file mode 100644 index 0000000000..0f39bbf0b2 --- /dev/null +++ b/internal/auth/ldap/repository_auth_method_update.go @@ -0,0 +1,571 @@ +package ldap + +import ( + "context" + "fmt" + "net/url" + "strings" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/go-dbw" + "github.com/hashicorp/go-secure-stdlib/strutil" +) + +const ( + OperationalStateField = "OperationalState" + VersionField = "Version" + IsPrimaryAuthMethodField = "IsPrimaryAuthMethod" + NameField = "Name" + DescriptionField = "Description" + StartTlsField = "StartTls" + InsecureTlsField = "InsecureTls" + DiscoverDnField = "DiscoverDn" + AnonGroupSearchField = "AnonGroupSearch" + UpnDomainField = "UpnDomain" + UrlsField = "Urls" + UserDnField = "UserDn" + UserAttrField = "UserAttr" + UserFilterField = "UserFilter" + GroupDnField = "GroupDn" + GroupAttrField = "GroupAttr" + GroupFilterField = "GroupFilter" + CertificatesField = "Certificates" + ClientCertificateField = "ClientCertificate" + ClientCertificateKeyField = "ClientCertificateKey" + BindDnField = "BindDn" + BindPasswordField = "BindPassword" +) + +// UpdateAuthMethod will retrieve the auth method from the repository, +// and update it based on the field masks provided. +// +// fieldMaskPaths provides field_mask.proto paths for fields that should +// be updated. Fields will be set to NULL if the field is a +// zero value and included in fieldMask. Name, Description, StartTLs, +// DiscoverDn, AnonGroupSearch, UpnDomain, UserDn, UserAttr, UserFilter, +// GroupDn, GroupAttr, GroupFilter, ClientCertificateKey, ClientCertificate, +// BindDn and BindPassword are all updatable fields. The AuthMethod's Value +// Objects of Urls and Certificates are also updatable. If no updatable fields +// are included in the fieldMaskPaths, then an error is returned. +// +// No Options are currently supported. +func (r *Repository) UpdateAuthMethod(ctx context.Context, am *AuthMethod, version uint32, fieldMaskPaths []string, _ ...Option) (*AuthMethod, int, error) { + const op = "ldap.(AuthMethod).Update" + switch { + case am == nil: + return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method") + case am.AuthMethod == nil: + return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing auth method store") + case am.PublicId == "": + return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing public id") + } + if err := validateFieldMask(ctx, fieldMaskPaths); err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + dbMask, nullFields := dbw.BuildUpdatePaths( + map[string]any{ + NameField: am.Name, + DescriptionField: am.Description, + StartTlsField: am.StartTls, + InsecureTlsField: am.InsecureTls, + DiscoverDnField: am.DiscoverDn, + AnonGroupSearchField: am.AnonGroupSearch, + UpnDomainField: am.UpnDomain, + UserDnField: am.UserDn, + UserAttrField: am.UserAttr, + UserFilterField: am.UserFilter, + GroupDnField: am.GroupDn, + GroupAttrField: am.GroupAttr, + GroupFilterField: am.GroupFilter, + CertificatesField: am.Certificates, + ClientCertificateField: am.ClientCertificate, + ClientCertificateKeyField: am.ClientCertificateKey, + BindDnField: am.BindDn, + BindPasswordField: am.BindPassword, + UrlsField: am.Urls, + }, + fieldMaskPaths, + nil, + ) + if len(dbMask) == 0 && len(nullFields) == 0 { + return nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "empty field mask") + } + if strutil.StrListContains(nullFields, UrlsField) { + return nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "missing urls (you cannot delete all of them; there must be at least one)") + } + + origAm, err := r.LookupAuthMethod(ctx, am.PublicId) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + if origAm == nil { + return nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("auth method %q", am.PublicId)) + } + // there's no reason to continue if another controller has already updated this auth method. + if origAm.Version != version { + return nil, db.NoRowsAffected, errors.New(ctx, errors.VersionMismatch, op, fmt.Sprintf("update version %d doesn't match db version %d", version, origAm.Version)) + } + + dbWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeDatabase) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) + } + addUrls, deleteUrls, err := valueObjectChanges(ctx, origAm.PublicId, UrlVO, am.Urls, origAm.Urls, dbMask, nullFields) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + addCerts, deleteCerts, err := valueObjectChanges(ctx, origAm.PublicId, CertificateVO, am.Certificates, origAm.Certificates, dbMask, nullFields) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + var addUserSearchConf, deleteUserSearchConf any + if strListContainsOneOf(dbMask, UserDnField, UserAttrField, UserAttrField) { + addUserSearchConf, err = NewUserEntrySearchConf(ctx, am.PublicId, WithUserDn(ctx, am.UserDn), WithUserAttr(ctx, am.UserAttr), WithUserFilter(ctx, am.UserFilter)) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + } + combinedMasks := append(dbMask, nullFields...) + if strListContainsOneOf(combinedMasks, UserDnField, UserAttrField, UserAttrField) { + deleteUserSearchConf, err = NewUserEntrySearchConf(ctx, am.PublicId, WithUserDn(ctx, origAm.UserDn), WithUserAttr(ctx, origAm.UserAttr), WithUserFilter(ctx, origAm.UserFilter)) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + } + var addGroupSearchConf, deleteGroupSearchConf any + if strListContainsOneOf(dbMask, GroupDnField, GroupAttrField, GroupAttrField) { + addGroupSearchConf, err = NewGroupEntrySearchConf(ctx, am.PublicId, WithGroupDn(ctx, am.GroupDn), WithGroupAttr(ctx, am.GroupAttr), WithGroupFilter(ctx, am.GroupFilter)) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + } + if strListContainsOneOf(combinedMasks, GroupDnField, GroupAttrField, GroupAttrField) { + deleteGroupSearchConf, err = NewGroupEntrySearchConf(ctx, am.PublicId, WithGroupDn(ctx, origAm.GroupDn), WithGroupAttr(ctx, origAm.GroupAttr), WithGroupFilter(ctx, origAm.GroupFilter)) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + } + var addClientCert, deleteClientCert any + if strListContainsOneOf(dbMask, ClientCertificateField, ClientCertificateKeyField) { + cc, err := NewClientCertificate(ctx, am.PublicId, am.ClientCertificateKey, am.ClientCertificate) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + if err := cc.encrypt(ctx, dbWrapper); err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + addClientCert = cc + } + if strListContainsOneOf(combinedMasks, ClientCertificateField, ClientCertificateKeyField) { + deleteClientCert, err = NewClientCertificate(ctx, am.PublicId, origAm.ClientCertificateKey, origAm.ClientCertificate) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + } + var addBindCred, deleteBindCred any + if strListContainsOneOf(dbMask, BindDnField, BindPasswordField) { + bc, err := NewBindCredential(ctx, am.PublicId, am.BindDn, []byte(am.BindPassword)) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + if err := bc.encrypt(ctx, dbWrapper); err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + addBindCred = bc + } + if strListContainsOneOf(combinedMasks, BindDnField, BindPasswordField) { + deleteBindCred, err = NewBindCredential(ctx, am.PublicId, origAm.BindDn, []byte(origAm.BindPassword)) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + } + + var filteredDbMask, filteredNullFields []string + for _, f := range dbMask { + switch f { + case + UrlsField, + CertificatesField, + UserDnField, UserAttrField, UserFilterField, + GroupDnField, GroupAttrField, GroupFilterField, + ClientCertificateField, ClientCertificateKeyField, + BindDnField, BindPasswordField: + continue + default: + filteredDbMask = append(filteredDbMask, f) + } + } + for _, f := range nullFields { + switch f { + case + StartTlsField, InsecureTlsField, DiscoverDnField, AnonGroupSearchField, + UrlsField, + CertificatesField, + UserDnField, UserAttrField, UserFilterField, + GroupDnField, GroupAttrField, GroupFilterField, + ClientCertificateField, ClientCertificateKeyField, + BindDnField, BindPasswordField: + continue + default: + filteredNullFields = append(filteredNullFields, f) + } + } + + // handle no changes... + if len(filteredDbMask) == 0 && + len(filteredNullFields) == 0 && + len(addUrls) == 0 && + len(deleteUrls) == 0 && + len(addCerts) == 0 && + len(deleteCerts) == 0 && + addUserSearchConf == nil && + deleteUserSearchConf == nil && + addGroupSearchConf == nil && + deleteGroupSearchConf == nil && + addClientCert == nil && + deleteClientCert == nil && + addBindCred == nil && + deleteBindCred == nil { + return origAm, db.NoRowsAffected, nil + } + + oplogWrapper, err := r.kms.GetWrapper(ctx, origAm.ScopeId, kms.KeyPurposeOplog) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper")) + } + var updatedAm *AuthMethod + var rowsUpdated int + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + msgs := make([]*oplog.Message, 0, 7) // AuthMethod, Algs*2, Certs*2, Audiences*2 + ticket, err := w.GetTicket(ctx, am) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket")) + } + var authMethodOplogMsg oplog.Message + switch { + case len(filteredDbMask) == 0 && len(filteredNullFields) == 0: + // the auth method's fields are not being updated, just it's value objects, so we need to just update the auth + // method's version. + updatedAm = am.clone() + updatedAm.Version = uint32(version) + 1 + rowsUpdated, err = w.Update(ctx, updatedAm, []string{VersionField}, nil, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method version")) + } + if rowsUpdated != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method version and %d rows updated", rowsUpdated)) + } + default: + updatedAm = am.clone() + rowsUpdated, err = w.Update(ctx, updatedAm, filteredDbMask, filteredNullFields, db.NewOplogMsg(&authMethodOplogMsg), db.WithVersion(&version)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update auth method")) + } + if rowsUpdated != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated auth method and %d rows updated", rowsUpdated)) + } + } + msgs = append(msgs, &authMethodOplogMsg) + + if len(deleteCerts) > 0 { + deleteCertOplogMsgs := make([]*oplog.Message, 0, len(deleteCerts)) + rowsDeleted, err := w.DeleteItems(ctx, deleteCerts, db.NewOplogMsgs(&deleteCertOplogMsgs)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete certificates")) + } + if rowsDeleted != len(deleteCerts) { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("certificates deleted %d did not match request for %d", rowsDeleted, len(deleteCerts))) + } + msgs = append(msgs, deleteCertOplogMsgs...) + } + if len(addCerts) > 0 { + addCertsOplogMsgs := make([]*oplog.Message, 0, len(addCerts)) + if err := w.CreateItems(ctx, addCerts, db.NewOplogMsgs(&addCertsOplogMsgs)); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add certificates")) + } + msgs = append(msgs, addCertsOplogMsgs...) + } + + if len(deleteUrls) > 0 { + deleteAudsOplogMsgs := make([]*oplog.Message, 0, len(deleteUrls)) + rowsDeleted, err := w.DeleteItems(ctx, deleteUrls, db.NewOplogMsgs(&deleteAudsOplogMsgs)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete URLs")) + } + if rowsDeleted != len(deleteUrls) { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("urls deleted %d did not match request for %d", rowsDeleted, len(deleteUrls))) + } + msgs = append(msgs, deleteAudsOplogMsgs...) + } + if len(addUrls) > 0 { + addUrlsOplogMsgs := make([]*oplog.Message, 0, len(addUrls)) + if err := w.CreateItems(ctx, addUrls, db.NewOplogMsgs(&addUrlsOplogMsgs)); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add urls")) + } + msgs = append(msgs, addUrlsOplogMsgs...) + } + + if deleteUserSearchConf != nil { + var deleteUserSearchConfMsg oplog.Message + rowsDeleted, err := w.Delete(ctx, deleteUserSearchConf, db.NewOplogMsg(&deleteUserSearchConfMsg)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete user search conf")) + } + if rowsDeleted != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("user search conf deleted %d did not match request for 1", rowsDeleted)) + } + msgs = append(msgs, &deleteUserSearchConfMsg) + } + if addUserSearchConf != nil { + var addUserSearchConfOplogMsg oplog.Message + if err := w.Create(ctx, addUserSearchConf, db.NewOplogMsg(&addUserSearchConfOplogMsg)); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add user search conf")) + } + msgs = append(msgs, &addUserSearchConfOplogMsg) + } + + if deleteGroupSearchConf != nil { + var deleteGroupSearchConfMsg oplog.Message + rowsDeleted, err := w.Delete(ctx, deleteGroupSearchConf, db.NewOplogMsg(&deleteGroupSearchConfMsg)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete group search conf")) + } + if rowsDeleted != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("group search conf deleted %d did not match request for 1", rowsDeleted)) + } + msgs = append(msgs, &deleteGroupSearchConfMsg) + } + if addGroupSearchConf != nil { + var addGroupSearchConfOplogMsg oplog.Message + if err := w.Create(ctx, addGroupSearchConf, db.NewOplogMsg(&addGroupSearchConfOplogMsg)); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add group search conf")) + } + msgs = append(msgs, &addGroupSearchConfOplogMsg) + } + + if deleteClientCert != nil { + var deleteClientCertMsg oplog.Message + rowsDeleted, err := w.Delete(ctx, deleteClientCert, db.NewOplogMsg(&deleteClientCertMsg)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete client cert")) + } + if rowsDeleted != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("client cert deleted %d did not match request for 1", rowsDeleted)) + } + msgs = append(msgs, &deleteClientCertMsg) + } + if addClientCert != nil { + var addClientCertOplogMsg oplog.Message + if err := w.Create(ctx, addClientCert, db.NewOplogMsg(&addClientCertOplogMsg)); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add client cert")) + } + msgs = append(msgs, &addClientCertOplogMsg) + } + + if deleteBindCred != nil { + var deleteBindCredMsg oplog.Message + rowsDeleted, err := w.Delete(ctx, deleteBindCred, db.NewOplogMsg(&deleteBindCredMsg)) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to delete bind credential conf")) + } + if rowsDeleted != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("bind credential deleted %d did not match request for 1", rowsDeleted)) + } + msgs = append(msgs, &deleteBindCredMsg) + } + if addBindCred != nil { + var addBindCredOplogMsg oplog.Message + if err := w.Create(ctx, addBindCred, db.NewOplogMsg(&addBindCredOplogMsg)); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add bind credential")) + } + msgs = append(msgs, &addBindCredOplogMsg) + } + + metadata := updatedAm.oplog(oplog.OpType_OP_TYPE_UPDATE) + if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog")) + } + // we need a new repo, that's using the same reader/writer as this TxHandler + txRepo := &Repository{ + reader: reader, + writer: w, + kms: r.kms, + // intentionally not setting the defaultLimit, so we'll get all + // the account ids without a limit + } + updatedAm, err = txRepo.lookupAuthMethod(ctx, updatedAm.PublicId) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to lookup auth method after update")) + } + if updatedAm == nil { + return errors.New(ctx, errors.RecordNotFound, op, "unable to lookup auth method after update") + } + return nil + }, + ) + if err != nil { + return nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + return updatedAm, rowsUpdated, nil +} + +// validateFieldMasks ensures that all the fields in the mask are updatable +func validateFieldMask(ctx context.Context, fieldMaskPaths []string) error { + const op = "ldap.validateFieldMasks" + for _, f := range fieldMaskPaths { + switch { + case strings.EqualFold(NameField, f): + case strings.EqualFold(DescriptionField, f): + case strings.EqualFold(StartTlsField, f): + case strings.EqualFold(InsecureTlsField, f): + case strings.EqualFold(DiscoverDnField, f): + case strings.EqualFold(AnonGroupSearchField, f): + case strings.EqualFold(UpnDomainField, f): + case strings.EqualFold(UserDnField, f): + case strings.EqualFold(UserAttrField, f): + case strings.EqualFold(UserFilterField, f): + case strings.EqualFold(GroupDnField, f): + case strings.EqualFold(GroupAttrField, f): + case strings.EqualFold(GroupFilterField, f): + case strings.EqualFold(CertificatesField, f): + case strings.EqualFold(ClientCertificateField, f): + case strings.EqualFold(ClientCertificateKeyField, f): + case strings.EqualFold(BindDnField, f): + case strings.EqualFold(BindPasswordField, f): + case strings.EqualFold(UrlsField, f): + default: + return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid field mask: %q", f)) + } + } + return nil +} + +// voName represents the names of auth method value objects +type voName string + +const ( + CertificateVO voName = "Certificates" + UrlVO voName = "Urls" +) + +// validVoName decides if the name is valid +func validVoName(name voName) bool { + switch name { + case CertificateVO, UrlVO: + return true + default: + return false + } +} + +// factoryFunc defines a func type for value object factories +type factoryFunc func(ctx context.Context, publicId string, idx int, i any) (any, error) + +// supportedFactories are the currently supported factoryFunc for value objects +var supportedFactories = map[voName]factoryFunc{ + CertificateVO: func(ctx context.Context, publicId string, idx int, i any) (any, error) { + str := fmt.Sprintf("%s", i) + return NewCertificate(ctx, publicId, str) + }, + UrlVO: func(ctx context.Context, publicId string, idx int, i any) (any, error) { + u, err := url.Parse(fmt.Sprintf("%s", i)) + if err != nil { + return nil, errors.Wrap(ctx, err, "ldap.urlFactory") + } + return NewUrl(ctx, publicId, idx+1, u) + }, +} + +// valueObjectChanges takes the new and old list of VOs (value objects) and +// using the dbMasks/nullFields it will return lists of VOs which need to be +// added and deleted in order to reconcile auth method's value objects. +func valueObjectChanges( + ctx context.Context, + publicId string, + valueObjectName voName, + newVOs, + oldVOs, + dbMask, + nullFields []string, +) (add []any, del []any, e error) { + const op = "ldap.valueObjectChanges" + switch { + case publicId == "": + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing public id") + case !validVoName(valueObjectName): + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("invalid value object name: %s", valueObjectName)) + case !strutil.StrListContains(dbMask, string(valueObjectName)) && !strutil.StrListContains(nullFields, string(valueObjectName)): + return nil, nil, nil + case len(strutil.RemoveDuplicates(newVOs, false)) != len(newVOs): + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate new %s", valueObjectName)) + case len(strutil.RemoveDuplicates(oldVOs, false)) != len(oldVOs): + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("duplicate old %s", valueObjectName)) + } + + factory, ok := supportedFactories[valueObjectName] + if !ok { + return nil, nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unsupported factory for value object: %s", valueObjectName)) + } + + foundVOs := map[string]int{} + for i, a := range oldVOs { + foundVOs[a] = i + } + var adds []any + var deletes []any + if strutil.StrListContains(nullFields, string(valueObjectName)) { + deletes = make([]any, 0, len(oldVOs)) + for i, v := range oldVOs { + deleteObj, err := factory(ctx, publicId, i, v) + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + deletes = append(deletes, deleteObj) + delete(foundVOs, v) + } + } + if strutil.StrListContains(dbMask, string(valueObjectName)) { + adds = make([]any, 0, len(newVOs)) + for i, v := range newVOs { + if _, ok := foundVOs[v]; ok { + delete(foundVOs, v) + continue + } + obj, err := factory(ctx, publicId, i, v) + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + adds = append(adds, obj) + delete(foundVOs, v) + } + } + if len(foundVOs) > 0 { + for v := range foundVOs { + obj, err := factory(ctx, publicId, foundVOs[v], v) + if err != nil { + return nil, nil, errors.Wrap(ctx, err, op) + } + deletes = append(deletes, obj) + delete(foundVOs, v) + } + } + return adds, deletes, nil +} + +func strListContainsOneOf(haystack []string, needles ...string) bool { + for _, item := range haystack { + for _, n := range needles { + if item == n { + return true + } + } + } + return false +} diff --git a/internal/auth/ldap/repository_auth_method_update_test.go b/internal/auth/ldap/repository_auth_method_update_test.go new file mode 100644 index 0000000000..2df7e0e45d --- /dev/null +++ b/internal/auth/ldap/repository_auth_method_update_test.go @@ -0,0 +1,894 @@ +package ldap + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/x509" + "fmt" + "sort" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/oplog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" +) + +func TestRepository_UpdateAuthMethod(t *testing.T) { + t.Parallel() + testCtx := context.Background() + testConn, _ := db.TestSetup(t, "postgres") + testWrapper := db.TestWrapper(t) + testKms := kms.TestKms(t, testConn, testWrapper) + testRw := db.New(testConn) + testRepo, err := NewRepository(testCtx, testRw, testRw, testKms) + require.NoError(t, err) + org, _ := iam.TestScopes(t, iam.TestRepo(t, testConn, testWrapper)) + databaseWrapper, err := testKms.GetWrapper(context.Background(), org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + testCert, testCertEncoded := testGenerateCA(t, "localhost") + _, testPrivKey, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + derPrivKey, err := x509.MarshalPKCS8PrivateKey(testPrivKey) + require.NoError(t, err) + + _, testCertEncoded2 := testGenerateCA(t, "localhost") + _, testPrivKey2, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + derPrivKey2, err := x509.MarshalPKCS8PrivateKey(testPrivKey2) + require.NoError(t, err) + + tests := []struct { + name string + ctx context.Context + repo *Repository + setup func() *AuthMethod + updateWith func(orig *AuthMethod) *AuthMethod + fieldMasks []string + version uint32 + opt []Option + want func(orig, updateWith *AuthMethod) *AuthMethod + wantErrMatch *errors.Template + wantErrContains string + wantNoRowsUpdated bool + }{ + { + name: "update-everything", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return TestAuthMethod(t, + testConn, databaseWrapper, + org.PublicId, + []string{"ldaps://ldap1", "ldap://ldap2"}, + WithName(testCtx, "update-everything-test-name"), + WithDescription(testCtx, "update-everything-test-description"), + WithUpnDomain(testCtx, "orig.alice.com"), + WithUserDn(testCtx, "orig-user-dn"), + WithUserAttr(testCtx, "orig-user-attr"), + WithUserFilter(testCtx, "orig-user-filter"), + WithGroupDn(testCtx, "orig-group-dn"), + WithGroupAttr(testCtx, "orig-group-attr"), + WithGroupFilter(testCtx, "orig-group-filter"), + WithBindCredential(testCtx, "orig-bind-dn", "orig-bind-password"), + WithCertificates(testCtx, testCert), + WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test. + ) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + am := AllocAuthMethod() + am.PublicId = orig.PublicId + am.Urls = []string{"ldaps://ldap1.alice.com", "ldaps://ldap2.alice.com"} + am.OperationalState = string(InactiveState) + am.Name = "update-everything-updated-name" + am.Description = "update-everything-updated-description" + am.StartTls = true + am.InsecureTls = true + am.DiscoverDn = true + am.AnonGroupSearch = true + am.UpnDomain = "alice.com" + am.UserDn = "user-dn" + am.UserAttr = "user-attr" + am.UserFilter = "user-filter" + am.GroupDn = "group-dn" + am.GroupAttr = "group-attr" + am.GroupFilter = "group-filter" + am.BindDn = "bind-dn" + am.BindPassword = "bind-password" + am.Certificates = []string{testCertEncoded2} + am.ClientCertificate = testCertEncoded2 + am.ClientCertificateKey = derPrivKey2 + return &am + }, + fieldMasks: []string{ + NameField, + DescriptionField, + UrlsField, + StartTlsField, + InsecureTlsField, + DiscoverDnField, + AnonGroupSearchField, + UpnDomainField, + UserDnField, + UserAttrField, + UserFilterField, + GroupDnField, + GroupAttrField, + GroupFilterField, + BindDnField, + BindPasswordField, + CertificatesField, + ClientCertificateField, + }, + version: 1, + want: func(orig, updateWith *AuthMethod) *AuthMethod { + am := orig.clone() + am.Name = updateWith.Name + am.Description = updateWith.Description + am.Urls = updateWith.Urls + am.StartTls = updateWith.StartTls + am.InsecureTls = updateWith.InsecureTls + am.DiscoverDn = updateWith.DiscoverDn + am.AnonGroupSearch = updateWith.AnonGroupSearch + am.UpnDomain = updateWith.UpnDomain + am.UserDn = updateWith.UserDn + am.UserAttr = updateWith.UserAttr + am.UserFilter = updateWith.UserFilter + am.GroupDn = updateWith.GroupDn + am.GroupAttr = updateWith.GroupAttr + am.GroupFilter = updateWith.GroupFilter + am.BindDn = updateWith.BindDn + am.BindPassword = updateWith.BindPassword + am.BindPasswordHmac = updateWith.BindPasswordHmac + am.Certificates = updateWith.Certificates + am.ClientCertificateKey = updateWith.ClientCertificateKey + am.ClientCertificate = updateWith.ClientCertificate + am.ClientCertificateKeyHmac = updateWith.ClientCertificateKeyHmac + return am + }, + }, + { + name: "update-nothing", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return TestAuthMethod(t, + testConn, databaseWrapper, + org.PublicId, + []string{"ldaps://ldap1", "ldap://ldap2"}, + WithName(testCtx, "update-nothing-test-name"), + WithDescription(testCtx, "update-nothing-test-description"), + WithUpnDomain(testCtx, "orig.alice.com"), + WithUserDn(testCtx, "orig-user-dn"), + WithUserAttr(testCtx, "orig-user-attr"), + WithUserFilter(testCtx, "orig-user-filter"), + WithGroupDn(testCtx, "orig-group-dn"), + WithGroupAttr(testCtx, "orig-group-attr"), + WithGroupFilter(testCtx, "orig-group-filter"), + WithBindCredential(testCtx, "orig-bind-dn", "orig-bind-password"), + WithCertificates(testCtx, testCert), + WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test. + ) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig.clone() + }, + fieldMasks: []string{ + NameField, + DescriptionField, + UrlsField, + StartTlsField, + InsecureTlsField, + DiscoverDnField, + AnonGroupSearchField, + UpnDomainField, + UserDnField, + UserAttrField, + UserFilterField, + GroupDnField, + GroupAttrField, + GroupFilterField, + BindDnField, + BindPasswordField, + CertificatesField, + ClientCertificateField, + }, + version: 1, + want: func(orig, updateWith *AuthMethod) *AuthMethod { + return orig.clone() + }, + }, + { + name: "only-update-attributes", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return TestAuthMethod(t, + testConn, databaseWrapper, + org.PublicId, + []string{"ldaps://ldap1", "ldap://ldap2"}, + WithName(testCtx, "only-update-attributes-test-name"), + WithDescription(testCtx, "only-update-attributes-test-description"), + WithUpnDomain(testCtx, "orig.alice.com"), + WithUserDn(testCtx, "orig-user-dn"), + WithUserAttr(testCtx, "orig-user-attr"), + WithUserFilter(testCtx, "orig-user-filter"), + WithGroupDn(testCtx, "orig-group-dn"), + WithGroupAttr(testCtx, "orig-group-attr"), + WithGroupFilter(testCtx, "orig-group-filter"), + WithBindCredential(testCtx, "orig-bind-dn", "orig-bind-password"), + WithCertificates(testCtx, testCert), + WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test. + ) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + am := AllocAuthMethod() + am.PublicId = orig.PublicId + am.OperationalState = string(ActivePublicState) + am.Name = "only-update-attributes-updated-name" + am.Description = "only-update-attributes-updated-description" + am.StartTls = true + am.InsecureTls = true + am.DiscoverDn = true + am.AnonGroupSearch = true + am.UpnDomain = "alice.com" + return &am + }, + fieldMasks: []string{ + NameField, + DescriptionField, + StartTlsField, + InsecureTlsField, + DiscoverDnField, + AnonGroupSearchField, + UpnDomainField, + }, + version: 1, + want: func(orig, updateWith *AuthMethod) *AuthMethod { + am := orig.clone() + am.Name = updateWith.Name + am.Description = updateWith.Description + am.StartTls = updateWith.StartTls + am.InsecureTls = updateWith.InsecureTls + am.DiscoverDn = updateWith.DiscoverDn + am.AnonGroupSearch = updateWith.AnonGroupSearch + am.UpnDomain = updateWith.UpnDomain + return am + }, + }, + { + name: "all-attributes-set-to-null-or-empty", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return TestAuthMethod(t, + testConn, databaseWrapper, + org.PublicId, + []string{"ldaps://ldap1", "ldap://ldap2"}, + WithName(testCtx, "all-attributes-set-to-null-or-empty-test-name"), + WithDescription(testCtx, "all-attributes-set-to-null-or-empty-description"), + WithUpnDomain(testCtx, "orig.alice.com"), + WithUserDn(testCtx, "orig-user-dn"), + WithUserAttr(testCtx, "orig-user-attr"), + WithUserFilter(testCtx, "orig-user-filter"), + WithGroupDn(testCtx, "orig-group-dn"), + WithGroupAttr(testCtx, "orig-group-attr"), + WithGroupFilter(testCtx, "orig-group-filter"), + WithBindCredential(testCtx, "orig-bind-dn", "orig-bind-password"), + WithCertificates(testCtx, testCert), + WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test. + ) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + am := AllocAuthMethod() + am.PublicId = orig.PublicId + return &am + }, + fieldMasks: []string{ + NameField, + DescriptionField, + StartTlsField, + InsecureTlsField, + DiscoverDnField, + AnonGroupSearchField, + UpnDomainField, + }, + version: 1, + want: func(orig, updateWith *AuthMethod) *AuthMethod { + am := orig.clone() + am.Name = updateWith.Name + am.Description = updateWith.Description + am.StartTls = updateWith.StartTls + am.InsecureTls = updateWith.InsecureTls + am.DiscoverDn = updateWith.DiscoverDn + am.AnonGroupSearch = updateWith.AnonGroupSearch + am.UpnDomain = updateWith.UpnDomain + return am + }, + }, + { + name: "only-update-value-objects", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return TestAuthMethod(t, + testConn, databaseWrapper, + org.PublicId, + []string{"ldaps://ldap1", "ldap://ldap2"}, + WithName(testCtx, "only-update-value-objects-test-name"), + WithDescription(testCtx, "orig-test-description"), + WithUpnDomain(testCtx, "orig.alice.com"), + WithUserDn(testCtx, "orig-user-dn"), + WithUserAttr(testCtx, "orig-user-attr"), + WithUserFilter(testCtx, "orig-user-filter"), + WithGroupDn(testCtx, "orig-group-dn"), + WithGroupAttr(testCtx, "orig-group-attr"), + WithGroupFilter(testCtx, "orig-group-filter"), + WithBindCredential(testCtx, "orig-bind-dn", "orig-bind-password"), + WithCertificates(testCtx, testCert), + WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test. + ) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + am := AllocAuthMethod() + am.PublicId = orig.PublicId + am.Urls = []string{"ldaps://ldap3", "ldaps://ldap4"} + am.UserDn = "user-dn" + am.UserAttr = "user-attr" + am.UserFilter = "user-filter" + am.GroupDn = "group-dn" + am.GroupAttr = "group-attr" + am.GroupFilter = "group-filter" + am.BindDn = "bind-dn" + am.BindPassword = "bind-password" + am.Certificates = []string{testCertEncoded} + am.ClientCertificate = testCertEncoded + am.ClientCertificateKey = derPrivKey + return &am + }, + fieldMasks: []string{ + UrlsField, + UserDnField, + UserAttrField, + UserFilterField, + GroupDnField, + GroupAttrField, + GroupFilterField, + BindDnField, + BindPasswordField, + CertificatesField, + ClientCertificateField, + }, + version: 1, + want: func(orig, updateWith *AuthMethod) *AuthMethod { + am := orig.clone() + am.Urls = updateWith.Urls + am.UserDn = updateWith.UserDn + am.UserAttr = updateWith.UserAttr + am.UserFilter = updateWith.UserFilter + am.GroupDn = updateWith.GroupDn + am.GroupAttr = updateWith.GroupAttr + am.GroupFilter = updateWith.GroupFilter + am.BindDn = updateWith.BindDn + am.BindPassword = updateWith.BindPassword + am.BindPasswordHmac = updateWith.BindPasswordHmac + am.ClientCertificateKey = updateWith.ClientCertificateKey + am.ClientCertificate = updateWith.ClientCertificate + am.ClientCertificateKeyHmac = updateWith.ClientCertificateKeyHmac + return am + }, + }, + { + name: "remove-value-objects", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return TestAuthMethod(t, + testConn, databaseWrapper, + org.PublicId, + []string{"ldaps://ldap1", "ldap://ldap2"}, + WithUserDn(testCtx, "orig-user-dn"), + WithUserAttr(testCtx, "orig-user-attr"), + WithUserFilter(testCtx, "orig-user-filter"), + WithGroupDn(testCtx, "orig-group-dn"), + WithGroupAttr(testCtx, "orig-group-attr"), + WithGroupFilter(testCtx, "orig-group-filter"), + WithBindCredential(testCtx, "orig-bind-dn", "orig-bind-password"), + WithCertificates(testCtx, testCert), + WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test. + ) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + am := AllocAuthMethod() + am.PublicId = orig.PublicId + return &am + }, + fieldMasks: []string{ + UserDnField, + UserAttrField, + UserFilterField, + GroupDnField, + GroupAttrField, + GroupFilterField, + BindDnField, + BindPasswordField, + CertificatesField, + ClientCertificateField, + }, + version: 1, + want: func(orig, updateWith *AuthMethod) *AuthMethod { + am := orig.clone() + am.Certificates = updateWith.Certificates + am.UserDn = updateWith.UserDn + am.UserAttr = updateWith.UserAttr + am.UserFilter = updateWith.UserFilter + am.GroupDn = updateWith.GroupDn + am.GroupAttr = updateWith.GroupAttr + am.GroupFilter = updateWith.GroupFilter + am.BindDn = updateWith.BindDn + am.BindPassword = updateWith.BindPassword + am.BindPasswordHmac = updateWith.BindPasswordHmac + am.ClientCertificateKey = updateWith.ClientCertificateKey + am.ClientCertificate = updateWith.ClientCertificate + am.ClientCertificateKeyHmac = updateWith.ClientCertificateKeyHmac + return am + }, + }, + { + name: "missing-auth-method", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return nil + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing auth method", + }, + { + name: "missing-auth-method-store", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + return &AuthMethod{} + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing auth method store", + }, + { + name: "missing-public-id", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + am := AllocAuthMethod() + return &am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing public id", + }, + { + name: "invalid-field-mask", + ctx: testCtx, + repo: testRepo, + fieldMasks: []string{"CreateTime"}, + setup: func() *AuthMethod { + am := AllocAuthMethod() + am.PublicId = "test-id" + return &am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "invalid field mask: \"CreateTime\"", + }, + { + name: "no-field-mask", + ctx: testCtx, + repo: testRepo, + setup: func() *AuthMethod { + am := AllocAuthMethod() + am.PublicId = "test-id" + return &am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.EmptyFieldMask), + wantErrContains: "empty field mask", + }, + { + name: "missing-urls", + ctx: testCtx, + repo: testRepo, + fieldMasks: []string{"Urls"}, + setup: func() *AuthMethod { + am := AllocAuthMethod() + am.PublicId = "test-id" + return &am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing urls (you cannot delete all of them; there must be at least one)", + }, + { + name: "lookup-err", + ctx: testCtx, + repo: func() *Repository { + conn, mock := db.TestSetupWithMock(t) + mock.ExpectQuery(`SELECT`).WillReturnError(fmt.Errorf("lookup-err")) + mockRw := db.New(conn) + testRepo, err := NewRepository(testCtx, mockRw, mockRw, testKms) + require.NoError(t, err) + return testRepo + }(), + fieldMasks: []string{"UserDn"}, + setup: func() *AuthMethod { + am := AllocAuthMethod() + am.PublicId = "test-id" + return &am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "lookup-err", + }, + { + name: "not-found", + ctx: testCtx, + repo: testRepo, + fieldMasks: []string{"UserDn"}, + setup: func() *AuthMethod { + am := AllocAuthMethod() + am.PublicId = "test-id" + return &am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.RecordNotFound), + wantErrContains: "auth method \"test-id\": search issue", + }, + { + name: "version-mismatch", + ctx: testCtx, + repo: testRepo, + fieldMasks: []string{"UserDn"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + am.Version += 1 + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Integrity), + wantErrContains: "update version 0 doesn't match db version 1", + }, + { + name: "getWrapper-err", + ctx: testCtx, + repo: func() *Repository { + testKms := &kms.MockGetWrapperer{ + GetErr: fmt.Errorf("getWrapper-err"), + } + testRepo, err := NewRepository(testCtx, testRw, testRw, testKms) + require.NoError(t, err) + return testRepo + }(), + version: 1, + fieldMasks: []string{"UserDn"}, + setup: func() *AuthMethod { + return TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "getWrapper-err", + }, + { + name: "urls-conversion-err", + ctx: testCtx, + repo: testRepo, + version: 1, + fieldMasks: []string{"Urls"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + am.Urls = []string{"https://not-valid.com"} + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "valueObjectChanges: ldap.NewUrl: scheme \"https\" is not ldap or ldaps", + }, + { + name: "certs-conversion-err", + ctx: testCtx, + repo: testRepo, + version: 1, + fieldMasks: []string{"Certificates"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + am.Certificates = []string{testInvalidPem} + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "valueObjectChanges: ldap.NewCertificate: failed to parse certificate", + }, + { + name: "user-entry-search-conversion-err", + ctx: testCtx, + repo: testRepo, + version: 1, + fieldMasks: []string{"UserDn"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "ldap.NewUserEntrySearchConf: you must supply either dn, attr, or filter", + }, + { + name: "group-entry-search-conversion-err", + ctx: testCtx, + repo: testRepo, + version: 1, + fieldMasks: []string{"GroupDn"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "ldap.NewGroupEntrySearchConf: you must supply either dn, attr, or filter", + }, + { + name: "client-search-conversion-err", + ctx: testCtx, + repo: testRepo, + version: 1, + fieldMasks: []string{"ClientCertificate"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "ldap.NewClientCertificate: missing key", + }, + { + name: "bind-credential-conversion-err", + ctx: testCtx, + repo: testRepo, + version: 1, + fieldMasks: []string{"BindDn"}, + setup: func() *AuthMethod { + am := TestAuthMethod(t, testConn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + return am + }, + updateWith: func(orig *AuthMethod) *AuthMethod { + return orig + }, + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "ldap.NewBindCredential: missing dn", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + orig := tc.setup() + updateWith := tc.updateWith(orig) + updated, rowsUpdated, err := tc.repo.UpdateAuthMethod(tc.ctx, updateWith, tc.version, tc.fieldMasks, tc.opt...) + if tc.wantErrMatch != nil { + require.Error(err) + assert.Empty(updated) + assert.Zero(rowsUpdated) + assert.Truef(errors.Match(tc.wantErrMatch, err), "want err code: %q got: %q", tc.wantErrMatch.Code, err) + if tc.wantErrContains != "" { + assert.Contains(err.Error(), tc.wantErrContains) + } + return + } + require.NoError(err) + require.NotNil(updated) + want := tc.want(orig, updateWith) + want.CreateTime = updated.CreateTime + want.UpdateTime = updated.UpdateTime + want.Version = updated.Version + want.BindPasswordHmac = updated.BindPasswordHmac + want.ClientCertificateKeyHmac = updated.ClientCertificateKeyHmac + TestSortAuthMethods(t, []*AuthMethod{want, updated}) + assert.Empty(cmp.Diff(updated.AuthMethod, want.AuthMethod, protocmp.Transform())) + if !tc.wantNoRowsUpdated { + assert.Equal(1, rowsUpdated) + err = db.TestVerifyOplog(t, testRw, updateWith.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + require.NoErrorf(err, "unexpected error verifying oplog entry: %s", err) + } + found, err := tc.repo.LookupAuthMethod(tc.ctx, want.PublicId) + require.NoError(err) + TestSortAuthMethods(t, []*AuthMethod{found}) + assert.Empty(cmp.Diff(found.AuthMethod, want.AuthMethod, protocmp.Transform())) + }) + } +} + +func Test_validateFieldMask(t *testing.T) { + t.Parallel() + tests := []struct { + name string + fieldMask []string + wantErr bool + }{ + { + name: "all-valid-fields", + fieldMask: []string{ + NameField, + DescriptionField, + StartTlsField, + InsecureTlsField, + DiscoverDnField, + AnonGroupSearchField, + UpnDomainField, + UrlsField, + UserDnField, + UserAttrField, + UserFilterField, + GroupDnField, + GroupAttrField, + GroupFilterField, + CertificatesField, + ClientCertificateField, + ClientCertificateKeyField, + BindDnField, + BindPasswordField, + }, + }, + { + name: "invalid", + fieldMask: []string{"Invalid", NameField}, + wantErr: true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + require := require.New(t) + err := validateFieldMask(context.TODO(), tc.fieldMask) + if tc.wantErr { + require.Error(err) + return + } + require.NoError(err) + }) + } +} + +// Test_valueObjectChanges is just being used to test failure conditions primarily +func Test_valueObjectChanges(t *testing.T) { + t.Parallel() + testCtx := context.Background() + _, pem1 := testGenerateCA(t, "localhost") + _, pem2 := testGenerateCA(t, "127.0.0.1") + _, pem3 := testGenerateCA(t, "www.example.com") + + tests := []struct { + name string + ctx context.Context + id string + voName voName + new []string + old []string + dbMask []string + nullFields []string + wantAdd []any + wantDel []any + wantErrMatch *errors.Template + wantErrContains string + }{ + { + name: "missing-public-id", + ctx: testCtx, + voName: CertificateVO, + new: nil, + old: []string{pem1, pem2, pem3}, + nullFields: []string{string(CertificateVO)}, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing public id", + }, + { + name: "invalid-vo-name", + ctx: testCtx, + voName: voName("invalid-name"), + id: "am-public-id", + new: nil, + old: []string{pem1, pem2}, + nullFields: []string{string(CertificateVO)}, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "invalid value object name", + }, + { + name: "dup-new", + ctx: testCtx, + id: "am-public-id", + voName: CertificateVO, + new: []string{pem1, pem1}, + old: []string{pem1}, + dbMask: []string{string(CertificateVO)}, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "duplicate new Certificates", + }, + { + name: "dup-old", + ctx: testCtx, + id: "am-public-id", + voName: CertificateVO, + new: []string{pem1}, + old: []string{pem2, pem2}, + dbMask: []string{string(CertificateVO)}, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "duplicate old Certificates", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + gotAdd, gotDel, err := valueObjectChanges(tc.ctx, tc.id, tc.voName, tc.new, tc.old, tc.dbMask, tc.nullFields) + if tc.wantErrMatch != nil { + require.Error(err) + assert.Truef(errors.Match(tc.wantErrMatch, err), "want err code: %q got: %q", tc.wantErrMatch.Code, err) + if tc.wantErrContains != "" { + assert.Contains(err.Error(), tc.wantErrContains) + } + return + } + require.NoError(err) + assert.Equal(tc.wantAdd, gotAdd) + + switch tc.voName { + case CertificateVO: + sort.Slice(gotDel, func(a, b int) bool { + aa := gotDel[a] + bb := gotDel[b] + return aa.(*Certificate).Cert < bb.(*Certificate).Cert + }) + case UrlVO: + sort.Slice(gotDel, func(a, b int) bool { + aa := gotDel[a] + bb := gotDel[b] + return aa.(*Url).ServerUrl < bb.(*Url).ServerUrl + }) + } + assert.Equalf(tc.wantDel, gotDel, "wantDel: %s\ngotDel: %s\n", tc.wantDel, gotDel) + }) + } +}