diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index e66baa8e1300..944b9a466ab5 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -3,6 +3,7 @@ ## 0.21.1 (Unreleased) ### Features Added +* Added `runtime.Pager[T any]` and supporting types for a central, generic, pager implementation. ### Breaking Changes diff --git a/sdk/azcore/runtime/pager.go b/sdk/azcore/runtime/pager.go new file mode 100644 index 000000000000..3df5c950c904 --- /dev/null +++ b/sdk/azcore/runtime/pager.go @@ -0,0 +1,77 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "encoding/json" + "errors" +) + +// PageProcessor contains the required data for constructing a Pager. +type PageProcessor[T any] struct { + // More returns a boolean indicating if there are more pages to fetch. + // It uses the provided page to make the determination. + More func(T) bool + + // Fetcher fetches the first and subsequent pages. + Fetcher func(context.Context, *T) (T, error) +} + +// Pager provides operations for iterating over paged responses. +type Pager[T any] struct { + current *T + processor PageProcessor[T] + firstPage bool +} + +// NewPager creates an instance of Pager using the specified PageProcessor. +// Pass a non-nil T for firstPage if the first page has already been retrieved. +func NewPager[T any](processor PageProcessor[T]) *Pager[T] { + return &Pager[T]{ + processor: processor, + firstPage: true, + } +} + +// More returns true if there are more pages to retrieve. +func (p *Pager[T]) More() bool { + if p.current != nil { + return p.processor.More(*p.current) + } + return true +} + +// NextPage advances the pager to the next page. +func (p *Pager[T]) NextPage(ctx context.Context) (T, error) { + var resp T + var err error + if p.current != nil { + if p.firstPage { + // we get here if it's an LRO-pager, we already have the first page + p.firstPage = false + return *p.current, nil + } else if !p.processor.More(*p.current) { + return *new(T), errors.New("no more pages") + } + resp, err = p.processor.Fetcher(ctx, p.current) + } else { + // non-LRO case, first page + p.firstPage = false + resp, err = p.processor.Fetcher(ctx, nil) + } + if err != nil { + return *new(T), err + } + p.current = &resp + return *p.current, nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface for Pager[T]. +func (p *Pager[T]) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &p.current) +} diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go new file mode 100644 index 000000000000..81b7d89151b5 --- /dev/null +++ b/sdk/azcore/runtime/pager_test.go @@ -0,0 +1,259 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package runtime + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pipeline" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/stretchr/testify/require" +) + +type PageResponse struct { + Values []int `json:"values"` + NextPage bool `json:"next"` +} + +func pageResponseFetcher(ctx context.Context, pl Pipeline, endpoint string) (PageResponse, error) { + req, err := NewRequest(ctx, http.MethodGet, endpoint) + if err != nil { + return PageResponse{}, err + } + resp, err := pl.Do(req) + if err != nil { + return PageResponse{}, err + } + if !HasStatusCode(resp, http.StatusOK) { + return PageResponse{}, shared.NewResponseError(resp) + } + pr := PageResponse{} + if err := UnmarshalAsJSON(resp, &pr); err != nil { + return PageResponse{}, err + } + return pr, nil +} + +func TestPagerSinglePage(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [1, 2, 3, 4, 5]}`))) + pl := pipeline.NewPipeline(srv) + + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + return pageResponseFetcher(ctx, pl, srv.URL()) + }, + }) + require.True(t, pager.firstPage) + + pageCount := 0 + for pager.More() { + page, err := pager.NextPage(context.Background()) + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3, 4, 5}, page.Values) + require.Empty(t, page.NextPage) + pageCount++ + } + require.Equal(t, 1, pageCount) + page, err := pager.NextPage(context.Background()) + require.Error(t, err) + require.Empty(t, page) +} + +func TestPagerMultiplePages(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [1, 2, 3, 4, 5], "next": true}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [6, 7, 8], "next": true}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [9, 0, 1, 2]}`))) + pl := pipeline.NewPipeline(srv) + + pageCount := 0 + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + if pageCount == 1 { + require.Nil(t, current) + } else { + require.NotNil(t, current) + } + return pageResponseFetcher(ctx, pl, srv.URL()) + }, + }) + require.True(t, pager.firstPage) + + for pager.More() { + pageCount++ + page, err := pager.NextPage(context.Background()) + require.NoError(t, err) + switch pageCount { + case 1: + require.Equal(t, []int{1, 2, 3, 4, 5}, page.Values) + require.True(t, page.NextPage) + case 2: + require.Equal(t, []int{6, 7, 8}, page.Values) + require.True(t, page.NextPage) + case 3: + require.Equal(t, []int{9, 0, 1, 2}, page.Values) + require.False(t, page.NextPage) + } + } + require.Equal(t, 3, pageCount) + page, err := pager.NextPage(context.Background()) + require.Error(t, err) + require.Empty(t, page) +} + +func TestPagerLROMultiplePages(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [6, 7, 8]}`))) + pl := pipeline.NewPipeline(srv) + + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + return pageResponseFetcher(ctx, pl, srv.URL()) + }, + }) + require.True(t, pager.firstPage) + + require.NoError(t, json.Unmarshal([]byte(`{"values": [1, 2, 3, 4, 5], "next": true}`), pager)) + + pageCount := 0 + for pager.More() { + pageCount++ + page, err := pager.NextPage(context.Background()) + require.NoError(t, err) + switch pageCount { + case 1: + require.Equal(t, []int{1, 2, 3, 4, 5}, page.Values) + require.True(t, page.NextPage) + case 2: + require.Equal(t, []int{6, 7, 8}, page.Values) + require.False(t, page.NextPage) + } + } + require.Equal(t, 2, pageCount) + page, err := pager.NextPage(context.Background()) + require.Error(t, err) + require.Empty(t, page) +} + +func TestPagerFetcherError(t *testing.T) { + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + return PageResponse{}, errors.New("fetcher failed") + }, + }) + require.True(t, pager.firstPage) + + page, err := pager.NextPage(context.Background()) + require.Error(t, err) + require.Empty(t, page) +} + +func TestPagerPipelineError(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.SetError(errors.New("pipeline failed")) + pl := pipeline.NewPipeline(srv) + + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + return pageResponseFetcher(ctx, pl, srv.URL()) + }, + }) + require.True(t, pager.firstPage) + + page, err := pager.NextPage(context.Background()) + require.Error(t, err) + require.Empty(t, page) +} + +func TestPagerSecondPageError(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`{"values": [1, 2, 3, 4, 5], "next": true}`))) + srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{"message": "didn't work", "code": "PageError"}`))) + pl := pipeline.NewPipeline(srv) + + pageCount := 0 + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + if pageCount == 1 { + require.Nil(t, current) + } else { + require.NotNil(t, current) + } + return pageResponseFetcher(ctx, pl, srv.URL()) + }, + }) + require.True(t, pager.firstPage) + + for pager.More() { + pageCount++ + page, err := pager.NextPage(context.Background()) + switch pageCount { + case 1: + require.NoError(t, err) + require.Equal(t, []int{1, 2, 3, 4, 5}, page.Values) + require.True(t, page.NextPage) + case 2: + require.Error(t, err) + var respErr *shared.ResponseError + require.True(t, errors.As(err, &respErr)) + require.Equal(t, "PageError", respErr.ErrorCode) + goto ExitLoop + } + } +ExitLoop: + require.Equal(t, 2, pageCount) +} + +func TestPagerResponderError(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(`incorrect JSON response`))) + pl := pipeline.NewPipeline(srv) + + pager := NewPager(PageProcessor[PageResponse]{ + More: func(current PageResponse) bool { + return current.NextPage + }, + Fetcher: func(ctx context.Context, current *PageResponse) (PageResponse, error) { + return pageResponseFetcher(ctx, pl, srv.URL()) + }, + }) + require.True(t, pager.firstPage) + + page, err := pager.NextPage(context.Background()) + require.Error(t, err) + require.Empty(t, page) +}