diff --git a/warden/group/handler.go b/warden/group/handler.go index d11ff6b329e..107d67011c8 100644 --- a/warden/group/handler.go +++ b/warden/group/handler.go @@ -68,6 +68,7 @@ func (h *Handler) SetRoutes(r *httprouter.Router) { r.DELETE(GroupsHandlerPath+"/:id", h.DeleteGroup) r.POST(GroupsHandlerPath+"/:id/members", h.AddGroupMembers) r.DELETE(GroupsHandlerPath+"/:id/members", h.RemoveGroupMembers) + r.PUT(GroupsHandlerPath+"/:id/members", h.UpdateGroupMembers) } // swagger:route GET /warden/groups warden listGroups @@ -101,12 +102,7 @@ func (h *Handler) SetRoutes(r *httprouter.Router) { // 403: genericError // 500: genericError func (h *Handler) ListGroupsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - var ctx = r.Context() - - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ - Resource: h.PrefixResource(GroupsResource), - Action: "list", - }, Scope); err != nil { + if err := h.checkRequest(r, h.PrefixResource(GroupsResource), "list"); err != nil { h.H.WriteError(w, r, err) return } @@ -172,19 +168,14 @@ func (h *Handler) FindGroupNames(w http.ResponseWriter, r *http.Request, member // 403: genericError // 500: genericError func (h *Handler) CreateGroup(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - var g Group - var ctx = r.Context() - - if err := json.NewDecoder(r.Body).Decode(&g); err != nil { - h.H.WriteError(w, r, errors.WithStack(err)) + if err := h.checkRequest(r, h.PrefixResource(GroupsResource), "create"); err != nil { + h.H.WriteError(w, r, err) return } - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ - Resource: h.PrefixResource(GroupsResource), - Action: "create", - }, Scope); err != nil { - h.H.WriteError(w, r, err) + var g Group + if err := json.NewDecoder(r.Body).Decode(&g); err != nil { + h.H.WriteError(w, r, errors.WithStack(err)) return } @@ -227,19 +218,15 @@ func (h *Handler) CreateGroup(w http.ResponseWriter, r *http.Request, _ httprout // 403: genericError // 500: genericError func (h *Handler) GetGroup(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - var ctx = r.Context() var id = ps.ByName("id") - g, err := h.Manager.GetGroup(id) - if err != nil { + if err := h.checkRequest(r, fmt.Sprintf(h.PrefixResource(GroupResource), id), "get"); err != nil { h.H.WriteError(w, r, err) return } - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ - Resource: fmt.Sprintf(h.PrefixResource(GroupResource), id), - Action: "get", - }, Scope); err != nil { + g, err := h.Manager.GetGroup(id) + if err != nil { h.H.WriteError(w, r, err) return } @@ -278,13 +265,9 @@ func (h *Handler) GetGroup(w http.ResponseWriter, r *http.Request, ps httprouter // 403: genericError // 500: genericError func (h *Handler) DeleteGroup(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - var ctx = r.Context() var id = ps.ByName("id") - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ - Resource: fmt.Sprintf(h.PrefixResource(GroupResource), id), - Action: "delete", - }, Scope); err != nil { + if err := h.checkRequest(r, fmt.Sprintf(h.PrefixResource(GroupResource), id), "delete"); err != nil { h.H.WriteError(w, r, err) return } @@ -328,20 +311,16 @@ func (h *Handler) DeleteGroup(w http.ResponseWriter, r *http.Request, ps httprou // 403: genericError // 500: genericError func (h *Handler) AddGroupMembers(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - var ctx = r.Context() var id = ps.ByName("id") - var m membersRequest - if err := json.NewDecoder(r.Body).Decode(&m); err != nil { - h.H.WriteError(w, r, errors.WithStack(err)) + if err := h.checkRequest(r, fmt.Sprintf(h.PrefixResource(GroupResource), id), "members.add"); err != nil { + h.H.WriteError(w, r, err) return } - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ - Resource: fmt.Sprintf(h.PrefixResource(GroupResource), id), - Action: "members.add", - }, Scope); err != nil { - h.H.WriteError(w, r, err) + var m membersRequest + if err := json.NewDecoder(r.Body).Decode(&m); err != nil { + h.H.WriteError(w, r, errors.WithStack(err)) return } @@ -384,27 +363,88 @@ func (h *Handler) AddGroupMembers(w http.ResponseWriter, r *http.Request, ps htt // 403: genericError // 500: genericError func (h *Handler) RemoveGroupMembers(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - var ctx = r.Context() var id = ps.ByName("id") + if err := h.checkRequest(r, fmt.Sprintf(h.PrefixResource(GroupResource), id), "members.remove"); err != nil { + h.H.WriteError(w, r, err) + return + } + var m membersRequest if err := json.NewDecoder(r.Body).Decode(&m); err != nil { h.H.WriteError(w, r, errors.WithStack(err)) return } - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ - Resource: fmt.Sprintf(h.PrefixResource(GroupResource), id), - Action: "members.remove", - }, Scope); err != nil { + if err := h.Manager.RemoveGroupMembers(id, m.Members); err != nil { h.H.WriteError(w, r, err) return } - if err := h.Manager.RemoveGroupMembers(id, m.Members); err != nil { + w.WriteHeader(http.StatusNoContent) +} + +// swagger:route PUT /warden/groups/{id}/members warden replaceMembersInGroup +// +// Replace the members of a group +// +// The subject making the request needs to be assigned to a policy containing: +// +// ``` +// { +// "resources": ["rn:hydra:warden:groups:"], +// "actions": ["members.update"], +// "effect": "allow" +// } +// ``` +// +// Consumes: +// - application/json +// +// Produces: +// - application/json +// +// Schemes: http, https +// +// Security: +// oauth2: hydra.warden.groups +// +// Responses: +// 204: emptyResponse +// 401: genericError +// 403: genericError +// 500: genericError +func (h *Handler) UpdateGroupMembers(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + var id = ps.ByName("id") + + if err := h.checkRequest(r, fmt.Sprintf(h.PrefixResource(GroupResource), id), "members.update"); err != nil { + h.H.WriteError(w, r, err) + return + } + + var m membersRequest + if err := json.NewDecoder(r.Body).Decode(&m); err != nil { + h.H.WriteError(w, r, errors.WithStack(err)) + return + } + + if err := h.Manager.UpdateGroupMembers(id, m.Members); err != nil { h.H.WriteError(w, r, err) return } w.WriteHeader(http.StatusNoContent) } + +func (h *Handler) checkRequest(r *http.Request, resource string, action string) error { + var ctx = r.Context() + + if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ + Resource: resource, + Action: action, + }, Scope); err != nil { + return errors.WithStack(err) + } + + return nil +} diff --git a/warden/group/manager.go b/warden/group/manager.go index 039854c3143..311850191dd 100644 --- a/warden/group/manager.go +++ b/warden/group/manager.go @@ -32,6 +32,7 @@ type Manager interface { AddGroupMembers(group string, members []string) error RemoveGroupMembers(group string, members []string) error + UpdateGroupMembers(group string, members []string) error FindGroupsByMember(subject string, limit, offset int) ([]Group, error) ListGroups(limit, offset int) ([]Group, error) diff --git a/warden/group/manager_memory.go b/warden/group/manager_memory.go index 64871928ff8..45378316093 100644 --- a/warden/group/manager_memory.go +++ b/warden/group/manager_memory.go @@ -126,3 +126,20 @@ func (m *MemoryManager) ListGroups(limit, offset int) ([]Group, error) { start, end := pagination.Index(limit, offset, len(res)) return res[start:end], nil } + +func (m *MemoryManager) UpdateGroupMembers(group string, members []string) error { + id := group + + if err := m.DeleteGroup(id); err != nil { + return err + } + + if err := m.CreateGroup(&Group{ + ID: id, + Members: members, + }); err != nil { + return err + } + + return nil +} diff --git a/warden/group/manager_sql.go b/warden/group/manager_sql.go index 871218bcdd4..3a0d77eb35e 100644 --- a/warden/group/manager_sql.go +++ b/warden/group/manager_sql.go @@ -57,16 +57,20 @@ func (m *SQLManager) CreateSchemas() (int, error) { return n, nil } +func (m *SQLManager) createGroup(group string) func(tx *sqlx.Tx) error { + return func(tx *sqlx.Tx) error { + _, err := tx.Exec(m.DB.Rebind("INSERT INTO hydra_warden_group (id) VALUES (?)"), group) + + return errors.WithStack(err) + } +} + func (m *SQLManager) CreateGroup(g *Group) error { if g.ID == "" { g.ID = uuid.New() } - if _, err := m.DB.Exec(m.DB.Rebind("INSERT INTO hydra_warden_group (id) VALUES (?)"), g.ID); err != nil { - return errors.WithStack(err) - } - - return m.AddGroupMembers(g.ID, g.Members) + return m.applyInTransaction(m.createGroup(g.ID), m.addGroupMembers(g.ID, g.Members)) } func (m *SQLManager) GetGroup(id string) (*Group, error) { @@ -88,58 +92,48 @@ func (m *SQLManager) GetGroup(id string) (*Group, error) { }, nil } -func (m *SQLManager) DeleteGroup(id string) error { - if _, err := m.DB.Exec(m.DB.Rebind("DELETE FROM hydra_warden_group WHERE id=?"), id); err != nil { +func (m *SQLManager) deleteGroup(id string) func(tx *sqlx.Tx) error { + return func(tx *sqlx.Tx) error { + _, err := tx.Exec(m.DB.Rebind("DELETE FROM hydra_warden_group WHERE id=?"), id) + return errors.WithStack(err) } - return nil } -func (m *SQLManager) AddGroupMembers(group string, subjects []string) error { - tx, err := m.DB.Beginx() - if err != nil { - return errors.Wrap(err, "Could not begin transaction") - } +func (m *SQLManager) DeleteGroup(id string) error { + return m.applyInTransaction(m.deleteGroup(id)) +} - for _, subject := range subjects { - if _, err := tx.Exec(m.DB.Rebind("INSERT INTO hydra_warden_group_member (group_id, member) VALUES (?, ?)"), group, subject); err != nil { - if err := tx.Rollback(); err != nil { +func (m *SQLManager) addGroupMembers(group string, subjects []string) func(tx *sqlx.Tx) error { + return func(tx *sqlx.Tx) error { + for _, subject := range subjects { + if _, err := tx.Exec(m.DB.Rebind("INSERT INTO hydra_warden_group_member (group_id, member) VALUES (?, ?)"), group, subject); err != nil { return errors.WithStack(err) } - return errors.WithStack(err) } - } - if err := tx.Commit(); err != nil { - if err := tx.Rollback(); err != nil { - return errors.WithStack(err) - } - return errors.Wrap(err, "Could not commit transaction") + return nil } - return nil } -func (m *SQLManager) RemoveGroupMembers(group string, subjects []string) error { - tx, err := m.DB.Beginx() - if err != nil { - return errors.Wrap(err, "Could not begin transaction") - } - for _, subject := range subjects { - if _, err := m.DB.Exec(m.DB.Rebind("DELETE FROM hydra_warden_group_member WHERE member=? AND group_id=?"), subject, group); err != nil { - if err := tx.Rollback(); err != nil { +func (m *SQLManager) AddGroupMembers(group string, subjects []string) error { + return m.applyInTransaction(m.addGroupMembers(group, subjects)) +} + +func (m *SQLManager) removeGroupMembers(group string, subjects []string) func(tx *sqlx.Tx) error { + return func(tx *sqlx.Tx) error { + for _, subject := range subjects { + if _, err := tx.Exec(m.DB.Rebind("DELETE FROM hydra_warden_group_member WHERE member=? AND group_id=?"), subject, group); err != nil { return errors.WithStack(err) } - return errors.WithStack(err) } - } - if err := tx.Commit(); err != nil { - if err := tx.Rollback(); err != nil { - return errors.WithStack(err) - } - return errors.Wrap(err, "Could not commit transaction") + return nil } - return nil +} + +func (m *SQLManager) RemoveGroupMembers(group string, subjects []string) error { + return m.applyInTransaction(m.removeGroupMembers(group, subjects)) } func (m *SQLManager) FindGroupsByMember(subject string, limit, offset int) ([]Group, error) { @@ -154,7 +148,7 @@ func (m *SQLManager) FindGroupsByMember(subject string, limit, offset int) ([]Gr for k, id := range ids { group, err := m.GetGroup(id) if err != nil { - return nil, errors.WithStack(err) + return nil, err } groups[k] = *group @@ -175,7 +169,7 @@ func (m *SQLManager) ListGroups(limit, offset int) ([]Group, error) { for k, id := range ids { group, err := m.GetGroup(id) if err != nil { - return nil, errors.WithStack(err) + return nil, err } groups[k] = *group @@ -183,3 +177,33 @@ func (m *SQLManager) ListGroups(limit, offset int) ([]Group, error) { return groups, nil } + +func (m *SQLManager) UpdateGroupMembers(group string, members []string) error { + return m.applyInTransaction(m.deleteGroup(group), m.createGroup(group), m.addGroupMembers(group, members)) +} + +func (m *SQLManager) applyInTransaction(executors ...func(tx *sqlx.Tx) error) error { + tx, err := m.DB.Beginx() + if err != nil { + return errors.Wrap(err, "Could not begin transaction") + } + + for _, exec := range executors { + if err := exec(tx); err != nil { + if err := tx.Rollback(); err != nil { + return errors.WithStack(err) + } + + return err + } + } + + if err := tx.Commit(); err != nil { + if err := tx.Rollback(); err != nil { + return errors.WithStack(err) + } + return errors.Wrap(err, "Could not commit transaction") + } + + return nil +}