From d10f2bdf0240786a94eed67749fd7f5670a4dedb Mon Sep 17 00:00:00 2001 From: Oliver Calder Date: Wed, 29 Nov 2023 15:10:50 -0600 Subject: [PATCH] fix(daemon): small refactor of notices api Most importantly, wait to lock the state until after the optional duration has been parsed, in case an error occurs. Also, moves the types query parsing/sanitization into its own function, since this cleans up the GET notices handler and better matches future handling of user ID filtering. Signed-off-by: Oliver Calder --- internals/daemon/api_notices.go | 44 ++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/internals/daemon/api_notices.go b/internals/daemon/api_notices.go index 2882c599..982a78e4 100644 --- a/internals/daemon/api_notices.go +++ b/internals/daemon/api_notices.go @@ -42,18 +42,8 @@ type addedNotice struct { func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response { query := r.URL.Query() - typeStrs := strutil.MultiCommaSeparatedList(query["types"]) - types := make([]state.NoticeType, 0, len(typeStrs)) - for _, typeStr := range typeStrs { - noticeType := state.NoticeType(typeStr) - if !noticeType.Valid() { - // Ignore invalid notice types (so requests from newer clients - // with unknown types succeed). - continue - } - types = append(types, noticeType) - } - if len(types) == 0 && len(typeStrs) > 0 { + types, err := sanitizeTypesFilter(query["types"]) + if err != nil { // Caller did provide a types filter, but they're all invalid notice types. // Return no notices, rather than the default of all notices. return SyncResponse([]*state.Notice{}) @@ -71,16 +61,18 @@ func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response { Keys: keys, After: after, } - var notices []*state.Notice - - st := c.d.overlord.State() - st.Lock() - defer st.Unlock() timeout, err := parseOptionalDuration(query.Get("timeout")) if err != nil { return statusBadRequest("invalid timeout: %v", err) } + + st := c.d.overlord.State() + st.Lock() + defer st.Unlock() + + var notices []*state.Notice + if timeout != 0 { // Wait up to timeout for notices matching given filter to occur ctx, cancel := context.WithTimeout(r.Context(), timeout) @@ -106,6 +98,24 @@ func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response { return SyncResponse(notices) } +func sanitizeTypesFilter(queryTypes []string) ([]state.NoticeType, error) { + typeStrs := strutil.MultiCommaSeparatedList(queryTypes) + types := make([]state.NoticeType, 0, len(typeStrs)) + for _, typeStr := range typeStrs { + noticeType := state.NoticeType(typeStr) + if !noticeType.Valid() { + // Ignore invalid notice types (so requests from newer clients + // with unknown types succeed). + continue + } + types = append(types, noticeType) + } + if len(types) == 0 && len(typeStrs) > 0 { + return nil, errors.New("all requested notice types invalid") + } + return types, nil +} + func v1PostNotices(c *Command, r *http.Request, _ *UserState) Response { var payload struct { Action string `json:"action"`