diff --git a/users.go b/users.go index eff2b060a..8bbd7d77d 100644 --- a/users.go +++ b/users.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/url" "strconv" + "time" ) const ( @@ -346,12 +347,19 @@ func (api *Client) GetUsers() ([]User, error) { // GetUsersContext returns the list of users (with their detailed information) with a custom context func (api *Client) GetUsersContext(ctx context.Context) (results []User, err error) { - var ( - p UserPagination - ) - - for p = api.GetUsersPaginated(); !p.Done(err); p, err = p.Next(ctx) { - results = append(results, p.Users...) + p := api.GetUsersPaginated() + for err == nil { + p, err = p.Next(ctx) + if err == nil { + results = append(results, p.Users...) + } else if rateLimitedError, ok := err.(*RateLimitedError); ok { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(rateLimitedError.RetryAfter): + err = nil + } + } } return results, p.Failure(err) diff --git a/users_test.go b/users_test.go index a30f86e6c..725fa42e6 100644 --- a/users_test.go +++ b/users_test.go @@ -50,9 +50,9 @@ func getTestUserProfile() UserProfile { } } -func getTestUser() User { +func getTestUserWithId(id string) User { return User{ - ID: "UXXXXXXXX", + ID: id, Name: "Test User", Deleted: false, Color: "9f69e7", @@ -72,6 +72,10 @@ func getTestUser() User { } } +func getTestUser() User { + return getTestUserWithId("UXXXXXXXX") +} + func getUserIdentity(rw http.ResponseWriter, r *http.Request) { rw.Header().Set("Content-Type", "application/json") response := []byte(`{ @@ -303,8 +307,8 @@ func testUnsetUserCustomStatus(api *Client, up *UserProfile, t *testing.T) { } func TestGetUsers(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) http.HandleFunc("/users.list", getUserPage(4)) - expectedUser := getTestUser() once.Do(startServer) api := New("testing-token", OptionAPIURL("http://"+serverAddr+"/")) @@ -315,7 +319,12 @@ func TestGetUsers(t *testing.T) { return } - if !reflect.DeepEqual([]User{expectedUser, expectedUser, expectedUser, expectedUser}, users) { + if !reflect.DeepEqual([]User{ + getTestUserWithId("U000"), + getTestUserWithId("U001"), + getTestUserWithId("U002"), + getTestUserWithId("U003"), + }, users) { t.Fatal(ErrIncorrectResponse) } } @@ -329,7 +338,45 @@ func getUserPage(max int64) func(rw http.ResponseWriter, r *http.Request) { Ok: true, } members := []User{ - getTestUser(), + getTestUserWithId(fmt.Sprintf("U%03d", n)), + } + rw.Header().Set("Content-Type", "application/json") + if cpage = atomic.AddInt64(&n, 1); cpage == max { + response, _ := json.Marshal(userResponseFull{ + SlackResponse: sresp, + Members: members, + }) + rw.Write(response) + return + } + response, _ := json.Marshal(userResponseFull{ + SlackResponse: sresp, + Members: members, + Metadata: ResponseMetadata{Cursor: strconv.Itoa(int(cpage))}, + }) + rw.Write(response) + } +} + +// returns n pages of users and sends rate limited errors in between successful pages. +func getUserPagesWithRateLimitErrors(max int64) func(rw http.ResponseWriter, r *http.Request) { + var n int64 + doRateLimit := false + return func(rw http.ResponseWriter, r *http.Request) { + defer func() { + doRateLimit = !doRateLimit + }() + if doRateLimit { + rw.Header().Set("Retry-After", "1") + rw.WriteHeader(http.StatusTooManyRequests) + return + } + var cpage int64 + sresp := SlackResponse{ + Ok: true, + } + members := []User{ + getTestUserWithId(fmt.Sprintf("U%03d", n)), } rw.Header().Set("Content-Type", "application/json") if cpage = atomic.AddInt64(&n, 1); cpage == max { @@ -553,3 +600,48 @@ func TestUserProfileCustomFieldsSetMap(t *testing.T) { t.Fatalf(`fields.fields = %v, wanted %v`, fields.fields, m) } } + +func TestGetUsersHandlesRateLimit(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + http.HandleFunc("/users.list", getUserPagesWithRateLimitErrors(4)) + + once.Do(startServer) + api := New("testing-token", OptionAPIURL("http://"+serverAddr+"/")) + + users, err := api.GetUsers() + if err != nil { + t.Errorf("Unexpected error: %s", err) + return + } + + if !reflect.DeepEqual([]User{ + getTestUserWithId("U000"), + getTestUserWithId("U001"), + getTestUserWithId("U002"), + getTestUserWithId("U003"), + }, users) { + t.Fatal(ErrIncorrectResponse) + } +} + +func TestGetUsersReturnsServerError(t *testing.T) { + http.DefaultServeMux = new(http.ServeMux) + http.HandleFunc("/users.list", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + + once.Do(startServer) + api := New("testing-token", OptionAPIURL("http://"+serverAddr+"/")) + + _, err := api.GetUsers() + + if err == nil { + t.Errorf("Expected error but got nil") + return + } + + expectedErr := "slack server error: 500 Internal Server Error" + if err.Error() != expectedErr { + t.Errorf("Expected: %s. Got: %s", expectedErr, err.Error()) + } +}