Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OData pagination handling #7

Merged
merged 2 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 72 additions & 3 deletions base/get.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
package base

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"strings"
)

// GetHttpRequestInput configures a GET request.
type GetHttpRequestInput struct {
ValidStatusCodes []int
ValidStatusFunc ValidStatusFunc
Uri Uri
rawUri string
}

// GetValidStatusCodes returns a []int of status codes considered valid for a GET request.
Expand All @@ -26,17 +31,81 @@ func (i GetHttpRequestInput) GetValidStatusFunc() ValidStatusFunc {
// Get performs a GET request.
func (c Client) Get(ctx context.Context, input GetHttpRequestInput) (*http.Response, int, error) {
var status int
url, err := c.buildUri(input.Uri)
if err != nil {
return nil, status, fmt.Errorf("unable to make request: %v", err)

// Check for a raw uri, else build one from the Uri field
url := input.rawUri
if url == "" {
var err error
url, err = c.buildUri(input.Uri)
if err != nil {
return nil, status, fmt.Errorf("unable to make request: %v", err)
}
}

// Build a new request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody)
if err != nil {
return nil, status, err
}

// Perform the request
resp, status, err := c.performRequest(req, input)
if err != nil {
return nil, status, err
}

// Check for json content before handling pagination
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
if strings.HasPrefix(contentType, "application/json") {
// Read the response body and close it
respBody, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()

// Unmarshall odata
var odata OData
if err := json.Unmarshal(respBody, &odata); err != nil {
return nil, status, err
}

if odata.NextLink == nil || odata.Value == nil {
// No more pages, reassign response body and return
resp.Body = ioutil.NopCloser(bytes.NewBuffer(respBody))
return resp, status, nil
}

// Get the next page, recursively
nextInput := input
nextInput.rawUri = *odata.NextLink
nextResp, status, err := c.Get(ctx, nextInput)
if err != nil {
return resp, status, err
}

// Read the next page response body and close it
nextRespBody, _ := ioutil.ReadAll(nextResp.Body)
nextResp.Body.Close()

// Unmarshall odata from the next page
var nextOdata OData
if err := json.Unmarshal(nextRespBody, &nextOdata); err != nil {
return resp, status, err
}

if nextOdata.Value != nil {
// Next page has results, append to current page
value := append(*odata.Value, *nextOdata.Value...)
nextOdata.Value = &value
}

// Marshal the entire result, along with fields from the final page
newJson, err := json.Marshal(nextOdata)
if err != nil {
return resp, status, err
}

// Reassign the response body
resp.Body = ioutil.NopCloser(bytes.NewBuffer(newJson))
}

return resp, status, nil
}
16 changes: 16 additions & 0 deletions base/odata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package base

import "encoding/json"

type OData struct {
Context *string `json:"@odata.context"`
MetadataEtag *string `json:"@odata.metadataEtag"`
Type *string `json:"@odata.type"`
Count *string `json:"@odata.count"`
NextLink *string `json:"@odata.nextLink"`
Delta *string `json:"@odata.delta"`
DeltaLink *string `json:"@odata.deltaLink"`
Id *string `json:"@odata.id"`
Etag *string `json:"@odata.etag"`
Value *[]json.RawMessage `json:"value"`
}
4 changes: 2 additions & 2 deletions clients/applications_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ func testApplicationsClient_List(t *testing.T, c ApplicationsClientTest) (applic
func testApplicationsClient_Get(t *testing.T, c ApplicationsClientTest, id string) (application *models.Application) {
application, status, err := c.client.Get(c.connection.Context, id)
if err != nil {
t.Fatalf("ApplicationsClient.Delete(): %v", err)
t.Fatalf("ApplicationsClient.Get(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("ApplicationsClient.Delete(): invalid status: %d", status)
t.Fatalf("ApplicationsClient.Get(): invalid status: %d", status)
}
if application == nil {
t.Fatal("ApplicationsClient.Get(): application was nil")
Expand Down
4 changes: 2 additions & 2 deletions clients/groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ func testGroupsClient_List(t *testing.T, c GroupsClientTest) (groups *[]models.G
func testGroupsClient_Get(t *testing.T, c GroupsClientTest, id string) (group *models.Group) {
group, status, err := c.client.Get(c.connection.Context, id)
if err != nil {
t.Fatalf("GroupsClient.Delete(): %v", err)
t.Fatalf("GroupsClient.Get(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("GroupsClient.Delete(): invalid status: %d", status)
t.Fatalf("GroupsClient.Get(): invalid status: %d", status)
}
if group == nil {
t.Fatal("GroupsClient.Get(): group was nil")
Expand Down
4 changes: 2 additions & 2 deletions clients/serviceprincipals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ func testServicePrincipalsClient_List(t *testing.T, c ServicePrincipalsClientTes
func testServicePrincipalsClient_Get(t *testing.T, c ServicePrincipalsClientTest, id string) (servicePrincipal *models.ServicePrincipal) {
servicePrincipal, status, err := c.client.Get(c.connection.Context, id)
if err != nil {
t.Fatalf("ServicePrincipalsClient.Delete(): %v", err)
t.Fatalf("ServicePrincipalsClient.Get(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("ServicePrincipalsClient.Delete(): invalid status: %d", status)
t.Fatalf("ServicePrincipalsClient.Get(): invalid status: %d", status)
}
if servicePrincipal == nil {
t.Fatal("ServicePrincipalsClient.Get(): servicePrincipal was nil")
Expand Down
4 changes: 2 additions & 2 deletions clients/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ func testUsersClient_List(t *testing.T, c UsersClientTest) (users *[]models.User
func testUsersClient_Get(t *testing.T, c UsersClientTest, id string) (user *models.User) {
user, status, err := c.client.Get(c.connection.Context, id)
if err != nil {
t.Fatalf("UsersClient.Delete(): %v", err)
t.Fatalf("UsersClient.Get(): %v", err)
}
if status < 200 || status >= 300 {
t.Fatalf("UsersClient.Delete(): invalid status: %d", status)
t.Fatalf("UsersClient.Get(): invalid status: %d", status)
}
if user == nil {
t.Fatal("UsersClient.Get(): user was nil")
Expand Down