diff --git a/internals/daemon/api_notices.go b/internals/daemon/api_notices.go index 2882c599..982a78e4 100644 --- a/internals/daemon/api_notices.go +++ b/internals/daemon/api_notices.go @@ -42,18 +42,8 @@ type addedNotice struct { func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response { query := r.URL.Query() - typeStrs := strutil.MultiCommaSeparatedList(query["types"]) - types := make([]state.NoticeType, 0, len(typeStrs)) - for _, typeStr := range typeStrs { - noticeType := state.NoticeType(typeStr) - if !noticeType.Valid() { - // Ignore invalid notice types (so requests from newer clients - // with unknown types succeed). - continue - } - types = append(types, noticeType) - } - if len(types) == 0 && len(typeStrs) > 0 { + types, err := sanitizeTypesFilter(query["types"]) + if err != nil { // Caller did provide a types filter, but they're all invalid notice types. // Return no notices, rather than the default of all notices. return SyncResponse([]*state.Notice{}) @@ -71,16 +61,18 @@ func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response { Keys: keys, After: after, } - var notices []*state.Notice - - st := c.d.overlord.State() - st.Lock() - defer st.Unlock() timeout, err := parseOptionalDuration(query.Get("timeout")) if err != nil { return statusBadRequest("invalid timeout: %v", err) } + + st := c.d.overlord.State() + st.Lock() + defer st.Unlock() + + var notices []*state.Notice + if timeout != 0 { // Wait up to timeout for notices matching given filter to occur ctx, cancel := context.WithTimeout(r.Context(), timeout) @@ -106,6 +98,24 @@ func v1GetNotices(c *Command, r *http.Request, _ *UserState) Response { return SyncResponse(notices) } +func sanitizeTypesFilter(queryTypes []string) ([]state.NoticeType, error) { + typeStrs := strutil.MultiCommaSeparatedList(queryTypes) + types := make([]state.NoticeType, 0, len(typeStrs)) + for _, typeStr := range typeStrs { + noticeType := state.NoticeType(typeStr) + if !noticeType.Valid() { + // Ignore invalid notice types (so requests from newer clients + // with unknown types succeed). + continue + } + types = append(types, noticeType) + } + if len(types) == 0 && len(typeStrs) > 0 { + return nil, errors.New("all requested notice types invalid") + } + return types, nil +} + func v1PostNotices(c *Command, r *http.Request, _ *UserState) Response { var payload struct { Action string `json:"action"`