diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index de9fd6de967c..c5c1daa74b55 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -11,6 +11,9 @@ * Added field `SpanFromContext` to `tracing.TracerOptions`. * Added methods `Enabled()`, `SetAttributes()`, and `SpanFromContext()` to `tracing.Tracer`. * Added supporting pipeline policies to include HTTP spans when creating clients. +* Added package `fake` to support generated fakes packages in SDKs. + * The package contains public surface area exposed by fake servers and supporting APIs intended only for use by the fake server implementations. + * Added an internal fake poller implementation. ### Breaking Changes diff --git a/sdk/azcore/fake/example_test.go b/sdk/azcore/fake/example_test.go new file mode 100644 index 000000000000..1fbd12314ea3 --- /dev/null +++ b/sdk/azcore/fake/example_test.go @@ -0,0 +1,146 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package fake_test + +import ( + "errors" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/fake" +) + +// Widget is a hypothetical type used in the following examples. +type Widget struct { + ID int + Shape string +} + +// WidgetResponse is a hypothetical type used in the following examples. +type WidgetResponse struct { + Widget +} + +// WidgetListResponse is a hypothetical type used in the following examples. +type WidgetListResponse struct { + Widgets []Widget +} + +func ExampleNewTokenCredential() { + // create a fake azcore.TokenCredential + // the fake is used as the client credential during testing with fakes. + var _ azcore.TokenCredential = fake.NewTokenCredential() +} + +func ExampleTokenCredential_SetError() { + cred := fake.NewTokenCredential() + + // set an error to be returned during authentication + cred.SetError(errors.New("failed to authenticate")) +} + +func ExampleResponder() { + // for a hypothetical API GetNextWidget(context.Context) (WidgetResponse, error) + + // a Responder is used to build a scalar response + resp := fake.Responder[WidgetResponse]{} + + // here we set the instance of Widget the Responder is to return + resp.Set(WidgetResponse{ + Widget{ID: 123, Shape: "triangle"}, + }) + + // optional HTTP headers can also be included in the raw response + resp.SetHeader("custom-header1", "value1") + resp.SetHeader("custom-header2", "value2") +} + +func ExampleErrorResponder() { + // an ErrorResponder is used to build an error response + errResp := fake.ErrorResponder{} + + // use SetError to return a generic error + errResp.SetError(errors.New("the system is down")) + + // to return an *azcore.ResponseError, use SetResponseError + errResp.SetResponseError("ErrorCodeConflict", http.StatusConflict) + + // ErrorResponder returns a singular error, so calling Set* APIs overwrites any previous value +} + +func ExamplePagerResponder() { + // for a hypothetical API NewListWidgetsPager() *runtime.Pager[WidgetListResponse] + + // a PagerResponder is used to build a sequence of responses for a paged operation + pagerResp := fake.PagerResponder[WidgetListResponse]{} + + // use AddPage to add one or more pages to the response. + // responses are returned in the order in which they were added. + pagerResp.AddPage(WidgetListResponse{ + Widgets: []Widget{ + {ID: 1, Shape: "circle"}, + {ID: 2, Shape: "square"}, + {ID: 3, Shape: "triangle"}, + }, + }, nil) + pagerResp.AddPage(WidgetListResponse{ + Widgets: []Widget{ + {ID: 4, Shape: "rectangle"}, + {ID: 5, Shape: "rhombus"}, + }, + }, nil) + + // errors can also be included in the sequence of responses. + // this can be used to simulate an error during paging. + pagerResp.AddError(errors.New("network too slow")) + + pagerResp.AddPage(WidgetListResponse{ + Widgets: []Widget{ + {ID: 6, Shape: "trapezoid"}, + }, + }, nil) +} + +func ExamplePollerResponder() { + // for a hypothetical API BeginCreateWidget(context.Context) (*runtime.Poller[WidgetResponse], error) + + // a PollerResponder is used to build a sequence of responses for a long-running operation + pollerResp := fake.PollerResponder[WidgetResponse]{} + + // use AddNonTerminalResponse to add one or more non-terminal responses + // to the sequence of responses. this is to simulate polling on a LRO. + // non-terminal responses are optional. exclude them to simulate a LRO + // that synchronously completes. + pollerResp.AddNonTerminalResponse(nil) + + // non-terminal errors can also be included in the sequence of responses. + // use this to simulate an error during polling. + pollerResp.AddNonTerminalError(errors.New("flaky network")) + + // use SetTerminalResponse to successfully terminate the long-running operation. + // the provided value will be returned as the terminal response. + pollerResp.SetTerminalResponse(WidgetResponse{ + Widget: Widget{ + ID: 987, + Shape: "dodecahedron", + }, + }) +} + +func ExamplePollerResponder_SetTerminalError() { + // for a hypothetical API BeginCreateWidget(context.Context) (*runtime.Poller[WidgetResponse], error) + + // a PollerResponder is used to build a sequence of responses for a long-running operation + pollerResp := fake.PollerResponder[WidgetResponse]{} + + // use SetTerminalError to terminate the long-running operation with an error. + // this returns an *azcore.ResponseError as the terminal response. + pollerResp.SetTerminalError("NoMoreWidgets", http.StatusBadRequest) + + // note that SetTerminalResponse and SetTerminalError are meant to be mutually exclusive. + // in the event that both are called, the result from SetTerminalError will be used. +} diff --git a/sdk/azcore/fake/fake.go b/sdk/azcore/fake/fake.go new file mode 100644 index 000000000000..1e5091248620 --- /dev/null +++ b/sdk/azcore/fake/fake.go @@ -0,0 +1,378 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Package fake provides the building blocks for fake servers. +// This includes fakes for authentication, API responses, and more. +// +// Most of the content in this package is intended to be used by +// SDK authors in construction of their fakes. +package fake + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" +) + +// NewTokenCredential creates an instance of the TokenCredential type. +func NewTokenCredential() *TokenCredential { + return &TokenCredential{} +} + +// TokenCredential is a fake credential that implements the azcore.TokenCredential interface. +type TokenCredential struct { + err error +} + +// SetError sets the specified error to be returned from GetToken(). +// Use this to simulate an error during authentication. +func (t *TokenCredential) SetError(err error) { + t.err = &nonRetriableError{err} +} + +// GetToken implements the azcore.TokenCredential for the TokenCredential type. +func (t *TokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) { + if t.err != nil { + return azcore.AccessToken{}, &nonRetriableError{t.err} + } + return azcore.AccessToken{Token: "fake_token", ExpiresOn: time.Now().Add(24 * time.Hour)}, nil +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Responder represents a scalar response. +type Responder[T any] struct { + h http.Header + resp T +} + +// Set sets the specified value to be returned. +func (r *Responder[T]) Set(b T) { + r.resp = b +} + +// SetHeader sets the specified header key/value pairs to be returned. +// Call multiple times to set multiple headers. +func (r *Responder[T]) SetHeader(key, value string) { + if r.h == nil { + r.h = http.Header{} + } + r.h.Set(key, value) +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// ErrorResponder represents a scalar error response. +type ErrorResponder struct { + err error +} + +// SetError sets the specified error to be returned. +// Use SetResponseError for returning an *azcore.ResponseError. +func (e *ErrorResponder) SetError(err error) { + e.err = &nonRetriableError{err: err} +} + +// SetResponseError sets an *azcore.ResponseError with the specified values to be returned. +func (e *ErrorResponder) SetResponseError(errorCode string, httpStatus int) { + e.err = &nonRetriableError{err: &azcore.ResponseError{ErrorCode: errorCode, StatusCode: httpStatus}} +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// PagerResponder represents a sequence of paged responses. +// Responses are replayed in the order in which they were added. +type PagerResponder[T any] struct { + pages []any +} + +// AddPage adds a page to the sequence of respones. +func (p *PagerResponder[T]) AddPage(page T, o *AddPageOptions) { + p.pages = append(p.pages, page) +} + +// AddError adds an error to the sequence of responses. +// The error is returned from the call to runtime.Pager[T].NextPage(). +func (p *PagerResponder[T]) AddError(err error) { + p.pages = append(p.pages, &nonRetriableError{err: err}) +} + +// AddResponseError adds an *azcore.ResponseError to the sequence of responses. +// The error is returned from the call to runtime.Pager[T].NextPage(). +func (p *PagerResponder[T]) AddResponseError(errorCode string, httpStatus int) { + p.pages = append(p.pages, &nonRetriableError{err: &azcore.ResponseError{ErrorCode: errorCode, StatusCode: httpStatus}}) +} + +// AddPageOptions contains the optional values for PagerResponder[T].AddPage. +type AddPageOptions struct { + // placeholder for future options +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// PollerResponder represents a sequence of responses for a long-running operation. +// Any non-terminal responses are replayed in the order in which they were added. +// The terminal response, success or error, is always the final response. +type PollerResponder[T any] struct { + nonTermResps []nonTermResp + res *T + err *exported.ResponseError +} + +// AddNonTerminalResponse adds a non-terminal response to the sequence of responses. +func (p *PollerResponder[T]) AddNonTerminalResponse(o *AddNonTerminalResponseOptions) { + p.nonTermResps = append(p.nonTermResps, nonTermResp{status: "InProgress"}) +} + +// AddNonTerminalError adds a non-terminal error to the sequence of responses. +// Use this to simulate an error durring polling. +func (p *PollerResponder[T]) AddNonTerminalError(err error) { + p.nonTermResps = append(p.nonTermResps, nonTermResp{err: err}) +} + +// SetTerminalResponse sets the provided value as the successful, terminal response. +func (p *PollerResponder[T]) SetTerminalResponse(result T) { + p.res = &result +} + +// SetTerminalError sets an *azcore.ResponseError with the specified values as the failed terminal response. +func (p *PollerResponder[T]) SetTerminalError(errorCode string, httpStatus int) { + p.err = &exported.ResponseError{ErrorCode: errorCode, StatusCode: httpStatus} +} + +// AddNonTerminalResponseOptions contains the optional values for PollerResponder[T].AddNonTerminalResponse. +type AddNonTerminalResponseOptions struct { + // place holder for future optional values +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +// the following APIs are intended for use by fake servers +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// MarshalResponseAsJSON converts the body into JSON and returns it in a *http.Response. +// This method is typically called by the fake server internals. +func MarshalResponseAsJSON[T any](r Responder[T], req *http.Request) (*http.Response, error) { + body, err := json.Marshal(r.resp) + if err != nil { + return nil, &nonRetriableError{err} + } + resp := newResponse(http.StatusOK, "OK", req, string(body)) + for key := range r.h { + resp.Header.Set(key, r.h.Get(key)) + } + return resp, nil +} + +// UnmarshalRequestAsJSON unmarshalls the request body into an instance of T. +// This method is typically called by the fake server internals. +func UnmarshalRequestAsJSON[T any](req *http.Request) (T, error) { + tt := *new(T) + body, err := io.ReadAll(req.Body) + if err != nil { + return tt, &nonRetriableError{err} + } + req.Body.Close() + if err = json.Unmarshal(body, &tt); err != nil { + err = &nonRetriableError{err} + } + return tt, err +} + +// GetError returns the error for this responder. +// This method is typically called by the fake server internals. +func GetError(e ErrorResponder, req *http.Request) error { + if e.err == nil { + return nil + } + + var respErr *azcore.ResponseError + if errors.As(e.err, &respErr) { + // fix up the raw response + respErr.RawResponse = newErrorResponse(respErr.ErrorCode, respErr.StatusCode, req) + } + return &nonRetriableError{e.err} +} + +// PagerResponderNext returns the next response in the sequence (a T or an error). +// This method is typically called by the fake server internals. +func PagerResponderNext[T any](p *PagerResponder[T], req *http.Request) (*http.Response, error) { + if len(p.pages) == 0 { + return nil, &nonRetriableError{errors.New("paged response has no pages")} + } + + page := p.pages[0] + p.pages = p.pages[1:] + + pageT, ok := page.(T) + if ok { + body, err := json.Marshal(pageT) + if err != nil { + return nil, &nonRetriableError{err} + } + return newResponse(http.StatusOK, "OK", req, string(body)), nil + } + + err := page.(error) + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + // fix up the raw response + respErr.RawResponse = newErrorResponse(respErr.ErrorCode, respErr.StatusCode, req) + } + return nil, &nonRetriableError{err} +} + +// PagerResponderMore returns true if there are more responses for consumption. +// This method is typically called by the fake server internals. +func PagerResponderMore[T any](p *PagerResponder[T]) bool { + return len(p.pages) > 0 +} + +type pageindex[T any] struct { + i int + page T +} + +// PagerResponderInjectNextLinks is used to populate the nextLink field. +// The inject callback is executed for every T in the sequence except for the last one. +// This method is typically called by the fake server internals. +func PagerResponderInjectNextLinks[T any](p *PagerResponder[T], req *http.Request, inject func(page *T, createLink func() string)) { + // first find all the actual pages in the list + pages := make([]pageindex[T], 0, len(p.pages)) + for i := range p.pages { + if pageT, ok := p.pages[i].(T); ok { + pages = append(pages, pageindex[T]{ + i: i, + page: pageT, + }) + } + } + + // now populate the next links + for i := range pages { + if i+1 == len(pages) { + // no nextLink for last page + break + } + + inject(&pages[i].page, func() string { + return fmt.Sprintf("%s://%s%s/page_%d", req.URL.Scheme, req.URL.Host, req.URL.Path, i+1) + }) + + // update the original slice with the modified page + p.pages[pages[i].i] = pages[i].page + } +} + +// PollerResponderMore returns true if there are more responses for consumption. +// This method is typically called by the fake server internals. +func PollerResponderMore[T any](p *PollerResponder[T]) bool { + return len(p.nonTermResps) > 0 || p.err != nil || p.res != nil +} + +// PollerResponderNext returns the next response in the sequence (a *http.Response or an error). +// This method is typically called by the fake server internals. +func PollerResponderNext[T any](p *PollerResponder[T], req *http.Request) (*http.Response, error) { + if len(p.nonTermResps) > 0 { + resp := p.nonTermResps[0] + p.nonTermResps = p.nonTermResps[1:] + + if resp.err != nil { + return nil, &nonRetriableError{resp.err} + } + + httpResp := newResponse(http.StatusOK, "OK", req, "") + httpResp.Header.Set(shared.HeaderFakePollerStatus, resp.status) + + if resp.retryAfter > 0 { + httpResp.Header.Add(shared.HeaderRetryAfter, strconv.Itoa(resp.retryAfter)) + } + + return httpResp, nil + } + + if p.err != nil { + err := p.err + err.RawResponse = newErrorResponse(p.err.ErrorCode, p.err.StatusCode, req) + p.err = nil + return nil, &nonRetriableError{err} + } else if p.res != nil { + body, err := json.Marshal(*p.res) + if err != nil { + return nil, &nonRetriableError{err} + } + p.res = nil + httpResp := newResponse(http.StatusOK, "OK", req, string(body)) + httpResp.Header.Set(shared.HeaderFakePollerStatus, "Succeeded") + return httpResp, nil + } else { + return nil, &nonRetriableError{fmt.Errorf("%T has no terminal response", p)} + } +} + +type nonTermResp struct { + status string + retryAfter int + err error +} + +func newResponse(statusCode int, status string, req *http.Request, body string) *http.Response { + resp := &http.Response{ + Body: http.NoBody, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: req, + Status: status, + StatusCode: statusCode, + } + + if l := int64(len(body)); l > 0 { + resp.Header.Set(shared.HeaderContentType, shared.ContentTypeAppJSON) + resp.ContentLength = l + resp.Body = io.NopCloser(strings.NewReader(body)) + } + + return resp +} + +func newErrorResponse(errorCode string, statusCode int, req *http.Request) *http.Response { + resp := newResponse(statusCode, "Operation Failed", req, "") + resp.Header.Set(shared.HeaderXMSErrorCode, errorCode) + return resp +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type nonRetriableError struct { + err error +} + +func (p *nonRetriableError) Error() string { + return p.err.Error() +} + +func (*nonRetriableError) NonRetriable() { + // marker method +} + +func (p *nonRetriableError) Unwrap() error { + return p.err +} + +var _ errorinfo.NonRetriable = (*nonRetriableError)(nil) diff --git a/sdk/azcore/fake/fake_test.go b/sdk/azcore/fake/fake_test.go new file mode 100644 index 000000000000..c37bd0feb88f --- /dev/null +++ b/sdk/azcore/fake/fake_test.go @@ -0,0 +1,331 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package fake + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/internal/errorinfo" + "github.com/stretchr/testify/require" +) + +type widget struct { + Name string +} + +type widgets struct { + NextPage *string + Widgets []widget +} + +func TestNewTokenCredential(t *testing.T) { + cred := NewTokenCredential() + require.NotNil(t, cred) + + tk, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{}) + require.NoError(t, err) + require.NotZero(t, tk) + + myErr := errors.New("failed") + cred.SetError(myErr) + tk, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{}) + require.ErrorIs(t, err, myErr) + require.Zero(t, tk) +} + +func TestResponder(t *testing.T) { + respr := Responder[widget]{} + respr.Set(widget{Name: "foo"}) + respr.SetHeader("one", "1") + respr.SetHeader("two", "2") + + req := &http.Request{} + resp, err := MarshalResponseAsJSON(respr, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, req, resp.Request) + require.Equal(t, "1", resp.Header.Get("one")) + require.Equal(t, "2", resp.Header.Get("two")) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + w := widget{} + require.NoError(t, json.Unmarshal(body, &w)) + require.Equal(t, "foo", w.Name) +} + +type badWidget struct { + Count int +} + +func (badWidget) MarshalJSON() ([]byte, error) { + return nil, errors.New("failed") +} + +func (*badWidget) UnmarshalJSON([]byte) error { + return errors.New("failed") +} + +func TestResponderMarshallingError(t *testing.T) { + respr := Responder[badWidget]{} + + req := &http.Request{} + resp, err := MarshalResponseAsJSON(respr, req) + require.Error(t, err) + var nre errorinfo.NonRetriable + require.ErrorAs(t, err, &nre) + require.Nil(t, resp) +} + +func TestErrorResponder(t *testing.T) { + req := &http.Request{} + + errResp := ErrorResponder{} + require.NoError(t, GetError(errResp, req)) + + myErr := errors.New("failed") + errResp.SetError(myErr) + require.ErrorIs(t, GetError(errResp, req), myErr) + + errResp.SetResponseError("ErrorInvalidWidget", http.StatusBadRequest) + var respErr *azcore.ResponseError + require.ErrorAs(t, GetError(errResp, req), &respErr) + require.Equal(t, "ErrorInvalidWidget", respErr.ErrorCode) + require.Equal(t, http.StatusBadRequest, respErr.StatusCode) + require.NotNil(t, respErr.RawResponse) + require.Equal(t, req, respErr.RawResponse.Request) +} + +func unmarshal[T any](resp *http.Response) (T, error) { + var t T + body, err := io.ReadAll(resp.Body) + if err != nil { + return t, err + } + resp.Body.Close() + + err = json.Unmarshal(body, &t) + return t, err +} + +func TestPagerResponder(t *testing.T) { + req := &http.Request{URL: &url.URL{}} + req.URL.Scheme = "http" + req.URL.Host = "fakehost.org" + req.URL.Path = "/lister" + + pagerResp := PagerResponder[widgets]{} + + require.False(t, PagerResponderMore(&pagerResp)) + resp, err := PagerResponderNext(&pagerResp, req) + var nre errorinfo.NonRetriable + require.ErrorAs(t, err, &nre) + require.Nil(t, resp) + + pagerResp.AddError(errors.New("one")) + pagerResp.AddPage(widgets{ + Widgets: []widget{ + {Name: "foo"}, + {Name: "bar"}, + }, + }, nil) + pagerResp.AddError(errors.New("two")) + pagerResp.AddPage(widgets{ + Widgets: []widget{ + {Name: "baz"}, + }, + }, nil) + pagerResp.AddResponseError("ErrorPagerBlewUp", http.StatusBadRequest) + + PagerResponderInjectNextLinks(&pagerResp, req, func(p *widgets, create func() string) { + p.NextPage = to.Ptr(create()) + }) + + iterations := 0 + for PagerResponderMore(&pagerResp) { + resp, err := PagerResponderNext(&pagerResp, req) + switch iterations { + case 0: + require.Error(t, err) + require.Equal(t, "one", err.Error()) + require.Nil(t, resp) + case 1: + require.NoError(t, err) + require.NotNil(t, resp) + page, err := unmarshal[widgets](resp) + require.NoError(t, err) + require.NotNil(t, page.NextPage) + require.Equal(t, []widget{{Name: "foo"}, {Name: "bar"}}, page.Widgets) + case 2: + require.Error(t, err) + require.Equal(t, "two", err.Error()) + require.Nil(t, resp) + case 3: + require.NoError(t, err) + require.NotNil(t, resp) + page, err := unmarshal[widgets](resp) + require.NoError(t, err) + require.Nil(t, page.NextPage) + require.Equal(t, []widget{{Name: "baz"}}, page.Widgets) + case 4: + require.Error(t, err) + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "ErrorPagerBlewUp", respErr.ErrorCode) + require.Equal(t, http.StatusBadRequest, respErr.StatusCode) + require.Nil(t, resp) + default: + t.Fatalf("unexpected case %d", iterations) + } + iterations++ + } + require.Equal(t, 5, iterations) +} + +func TestPollerResponder(t *testing.T) { + req := &http.Request{URL: &url.URL{}} + req.URL.Scheme = "http" + req.URL.Host = "fakehost.org" + req.URL.Path = "/lro" + + pollerResp := PollerResponder[widget]{} + + require.False(t, PollerResponderMore(&pollerResp)) + resp, err := PollerResponderNext(&pollerResp, req) + var nre errorinfo.NonRetriable + require.ErrorAs(t, err, &nre) + require.Nil(t, resp) + + pollerResp.AddNonTerminalResponse(nil) + pollerResp.AddNonTerminalError(errors.New("network glitch")) + pollerResp.AddNonTerminalResponse(nil) + pollerResp.SetTerminalResponse(widget{Name: "dodo"}) + + iterations := 0 + for PollerResponderMore(&pollerResp) { + resp, err := PollerResponderNext(&pollerResp, req) + switch iterations { + case 0: + require.NoError(t, err) + require.NotNil(t, resp) + case 1: + require.Error(t, err) + require.Nil(t, resp) + case 2: + require.NoError(t, err) + require.NotNil(t, resp) + case 3: + require.NoError(t, err) + require.NotNil(t, resp) + w, err := unmarshal[widget](resp) + require.NoError(t, err) + require.Equal(t, "dodo", w.Name) + default: + t.Fatalf("unexpected case %d", iterations) + } + iterations++ + } + require.Equal(t, 4, iterations) +} + +func TestPollerResponderTerminalFailure(t *testing.T) { + req := &http.Request{URL: &url.URL{}} + req.URL.Scheme = "http" + req.URL.Host = "fakehost.org" + req.URL.Path = "/lro" + + pollerResp := PollerResponder[widget]{} + + require.False(t, PollerResponderMore(&pollerResp)) + resp, err := PollerResponderNext(&pollerResp, req) + var nre errorinfo.NonRetriable + require.ErrorAs(t, err, &nre) + require.Nil(t, resp) + + pollerResp.AddNonTerminalError(errors.New("network glitch")) + pollerResp.AddNonTerminalResponse(nil) + pollerResp.SetTerminalError("ErrorConflictingOperation", http.StatusConflict) + + iterations := 0 + for PollerResponderMore(&pollerResp) { + resp, err := PollerResponderNext(&pollerResp, req) + switch iterations { + case 0: + require.Error(t, err) + require.Nil(t, resp) + case 1: + require.NoError(t, err) + require.NotNil(t, resp) + case 2: + require.Error(t, err) + require.Nil(t, resp) + var respErr *azcore.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, "ErrorConflictingOperation", respErr.ErrorCode) + require.Equal(t, http.StatusConflict, respErr.StatusCode) + require.Equal(t, req, respErr.RawResponse.Request) + default: + t.Fatalf("unexpected case %d", iterations) + } + iterations++ + } + require.Equal(t, 3, iterations) +} + +func TestUnmarshalRequestAsJSON(t *testing.T) { + req, err := http.NewRequest(http.MethodPut, "https://foo.bar/baz", strings.NewReader(`{"Name": "foo"}`)) + require.NoError(t, err) + require.NotNil(t, req) + + w, err := UnmarshalRequestAsJSON[widget](req) + require.NoError(t, err) + require.Equal(t, "foo", w.Name) +} + +func TestUnmarshalRequestAsJSONReadFailure(t *testing.T) { + req, err := http.NewRequest(http.MethodPut, "https://foo.bar/baz", &readFailer{}) + require.NoError(t, err) + require.NotNil(t, req) + + w, err := UnmarshalRequestAsJSON[widget](req) + require.Error(t, err) + require.Zero(t, w) +} + +func TestUnmarshalRequestAsJSONUnmarshalFailure(t *testing.T) { + req, err := http.NewRequest(http.MethodPut, "https://foo.bar/baz", strings.NewReader(`{"Name": "foo"}`)) + require.NoError(t, err) + require.NotNil(t, req) + + w, err := UnmarshalRequestAsJSON[badWidget](req) + require.Error(t, err) + require.Zero(t, w) +} + +type readFailer struct { + wrapped io.ReadCloser +} + +func (r *readFailer) Close() error { + return r.wrapped.Close() +} + +func (r *readFailer) Read(p []byte) (int, error) { + return 0, errors.New("mock read failure") +} diff --git a/sdk/azcore/internal/exported/response_error.go b/sdk/azcore/internal/exported/response_error.go index 7df2f88c1c1a..76a8c068d143 100644 --- a/sdk/azcore/internal/exported/response_error.go +++ b/sdk/azcore/internal/exported/response_error.go @@ -13,6 +13,7 @@ import ( "net/http" "regexp" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/internal/exported" ) @@ -25,7 +26,7 @@ func NewResponseError(resp *http.Response) error { } // prefer the error code in the response header - if ec := resp.Header.Get("x-ms-error-code"); ec != "" { + if ec := resp.Header.Get(shared.HeaderXMSErrorCode); ec != "" { respErr.ErrorCode = ec return respErr } diff --git a/sdk/azcore/internal/exported/response_error_test.go b/sdk/azcore/internal/exported/response_error_test.go index 7b4a44150ef1..97c8bc4d6a4c 100644 --- a/sdk/azcore/internal/exported/response_error_test.go +++ b/sdk/azcore/internal/exported/response_error_test.go @@ -13,6 +13,8 @@ import ( "net/url" "strings" "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" ) func TestNewResponseErrorNoBodyNoErrorCode(t *testing.T) { @@ -59,7 +61,7 @@ func TestNewResponseErrorNoBody(t *testing.T) { } respHeader := http.Header{} const errorCode = "ErrorTooManyCheats" - respHeader.Set("x-ms-error-code", errorCode) + respHeader.Set(shared.HeaderXMSErrorCode, errorCode) err = NewResponseError(&http.Response{ Status: "the system is down", StatusCode: http.StatusInternalServerError, @@ -136,7 +138,7 @@ func TestNewResponseErrorPreferErrorCodeHeader(t *testing.T) { t.Fatal(err) } respHeader := http.Header{} - respHeader.Set("x-ms-error-code", "ErrorTooManyCheats") + respHeader.Set(shared.HeaderXMSErrorCode, "ErrorTooManyCheats") err = NewResponseError(&http.Response{ Status: "the system is down", StatusCode: http.StatusInternalServerError, @@ -317,7 +319,7 @@ func TestNewResponseErrorErrorCodeHeaderXML(t *testing.T) { t.Fatal(err) } respHeader := http.Header{} - respHeader.Set("x-ms-error-code", "ContainerAlreadyExists") + respHeader.Set(shared.HeaderXMSErrorCode, "ContainerAlreadyExists") err = NewResponseError(&http.Response{ Status: "the system is down", StatusCode: http.StatusInternalServerError, @@ -354,7 +356,7 @@ func TestNewResponseErrorErrorCodeHeaderXMLWithNamespace(t *testing.T) { t.Fatal(err) } respHeader := http.Header{} - respHeader.Set("x-ms-error-code", "ContainerAlreadyExists") + respHeader.Set(shared.HeaderXMSErrorCode, "ContainerAlreadyExists") err = NewResponseError(&http.Response{ Status: "the system is down", StatusCode: http.StatusInternalServerError, diff --git a/sdk/azcore/internal/pollers/fake/fake.go b/sdk/azcore/internal/pollers/fake/fake.go new file mode 100644 index 000000000000..15adbee29f09 --- /dev/null +++ b/sdk/azcore/internal/pollers/fake/fake.go @@ -0,0 +1,118 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package fake + +import ( + "context" + "errors" + "fmt" + "net/http" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" +) + +// Applicable returns true if the LRO is a fake. +func Applicable(resp *http.Response) bool { + return resp.Header.Get(shared.HeaderFakePollerStatus) != "" +} + +// CanResume returns true if the token can rehydrate this poller type. +func CanResume(token map[string]interface{}) bool { + _, ok := token["fakeURL"] + return ok +} + +// Poller is an LRO poller that uses the Core-Fake-Poller pattern. +type Poller[T any] struct { + pl exported.Pipeline + + resp *http.Response + + // The API name from CtxAPINameKey + APIName string `json:"apiName"` + + // The URL from Core-Fake-Poller header. + FakeURL string `json:"fakeURL"` + + // The LRO's current state. + FakeStatus string `json:"status"` +} + +// New creates a new Poller from the provided initial response. +// Pass nil for response to create an empty Poller for rehydration. +func New[T any](pl exported.Pipeline, resp *http.Response) (*Poller[T], error) { + if resp == nil { + log.Write(log.EventLRO, "Resuming Core-Fake-Poller poller.") + return &Poller[T]{pl: pl}, nil + } + + log.Write(log.EventLRO, "Using Core-Fake-Poller poller.") + fakeStatus := resp.Header.Get(shared.HeaderFakePollerStatus) + if fakeStatus == "" { + return nil, errors.New("response is missing Fake-Poller-Status header") + } + + ctxVal := resp.Request.Context().Value(shared.CtxAPINameKey{}) + if ctxVal == nil { + return nil, errors.New("missing value for CtxAPINameKey") + } + + apiName, ok := ctxVal.(string) + if !ok { + return nil, fmt.Errorf("expected string for CtxAPINameKey, the type was %T", ctxVal) + } + + p := &Poller[T]{ + pl: pl, + resp: resp, + APIName: apiName, + FakeURL: fmt.Sprintf("%s://%s%s/get/fake/status", resp.Request.URL.Scheme, resp.Request.URL.Host, resp.Request.URL.Path), + FakeStatus: fakeStatus, + } + return p, nil +} + +// Done returns true if the LRO is in a terminal state. +func (p *Poller[T]) Done() bool { + return poller.IsTerminalState(p.FakeStatus) +} + +// Poll retrieves the current state of the LRO. +func (p *Poller[T]) Poll(ctx context.Context) (*http.Response, error) { + ctx = context.WithValue(ctx, shared.CtxAPINameKey{}, p.APIName) + err := pollers.PollHelper(ctx, p.FakeURL, p.pl, func(resp *http.Response) (string, error) { + if !poller.StatusCodeValid(resp) { + p.resp = resp + return "", exported.NewResponseError(resp) + } + fakeStatus := resp.Header.Get(shared.HeaderFakePollerStatus) + if fakeStatus == "" { + return "", errors.New("response is missing Fake-Poller-Status header") + } + p.resp = resp + p.FakeStatus = fakeStatus + return p.FakeStatus, nil + }) + if err != nil { + return nil, err + } + return p.resp, nil +} + +func (p *Poller[T]) Result(ctx context.Context, out *T) error { + if p.resp.StatusCode == http.StatusNoContent { + return nil + } else if poller.Failed(p.FakeStatus) { + return exported.NewResponseError(p.resp) + } + + return pollers.ResultHelper(p.resp, poller.Failed(p.FakeStatus), out) +} diff --git a/sdk/azcore/internal/pollers/fake/fake_test.go b/sdk/azcore/internal/pollers/fake/fake_test.go new file mode 100644 index 000000000000..0a32d6dd3a86 --- /dev/null +++ b/sdk/azcore/internal/pollers/fake/fake_test.go @@ -0,0 +1,185 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package fake + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" + "github.com/stretchr/testify/require" +) + +const ( + fakePollingURL = "https://foo.bar.baz/status" + fakeResourceURL = "https://foo.bar.baz/resource" +) + +func initialResponse(ctx context.Context, method string, resp io.Reader) *http.Response { + req, err := http.NewRequestWithContext(ctx, method, fakeResourceURL, nil) + if err != nil { + panic(err) + } + return &http.Response{ + Body: io.NopCloser(resp), + Header: http.Header{}, + Request: req, + } +} + +func TestApplicable(t *testing.T) { + resp := &http.Response{ + Header: http.Header{}, + } + require.False(t, Applicable(resp), "missing Fake-Poller-Status should not be applicable") + resp.Header.Set(shared.HeaderFakePollerStatus, fakePollingURL) + require.True(t, Applicable(resp), "having Fake-Poller-Status should be applicable") +} + +func TestCanResume(t *testing.T) { + token := map[string]interface{}{} + require.False(t, CanResume(token)) + token["fakeURL"] = fakePollingURL + require.True(t, CanResume(token)) +} + +func TestNew(t *testing.T) { + fp, err := New[struct{}](exported.Pipeline{}, nil) + require.NoError(t, err) + require.Empty(t, fp.FakeStatus) + + fp, err = New[struct{}](exported.Pipeline{}, &http.Response{Header: http.Header{}}) + require.Error(t, err) + require.Nil(t, fp) + + resp := initialResponse(context.Background(), http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, "faking") + fp, err = New[struct{}](exported.Pipeline{}, resp) + require.Error(t, err) + require.Nil(t, fp) + + resp = initialResponse(context.WithValue(context.Background(), shared.CtxAPINameKey{}, 123), http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, "faking") + fp, err = New[struct{}](exported.Pipeline{}, resp) + require.Error(t, err) + require.Nil(t, fp) + + resp = initialResponse(context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI"), http.MethodPut, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, "faking") + fp, err = New[struct{}](exported.Pipeline{}, resp) + require.NoError(t, err) + require.NotNil(t, fp) + require.False(t, fp.Done()) +} + +func TestSynchronousCompletion(t *testing.T) { + resp := initialResponse(context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI"), http.MethodPut, http.NoBody) + resp.StatusCode = http.StatusNoContent + resp.Header.Set(shared.HeaderFakePollerStatus, poller.StatusSucceeded) + fp, err := New[struct{}](exported.Pipeline{}, resp) + require.NoError(t, err) + require.Equal(t, poller.StatusSucceeded, fp.FakeStatus) + require.True(t, fp.Done()) + require.NoError(t, fp.Result(context.Background(), nil)) +} + +type widget struct { + Shape string `json:"shape"` +} + +func TestPollSucceeded(t *testing.T) { + pollCtx := context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI") + resp := initialResponse(pollCtx, http.MethodPatch, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, poller.StatusInProgress) + poller, err := New[widget](exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{shared.HeaderFakePollerStatus: []string{"Succeeded"}}, + Body: io.NopCloser(strings.NewReader(`{ "shape": "triangle" }`)), + }, nil + })), resp) + require.NoError(t, err) + require.False(t, poller.Done()) + resp, err = poller.Poll(pollCtx) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.True(t, poller.Done()) + var result widget + require.NoError(t, poller.Result(context.Background(), &result)) + require.EqualValues(t, "triangle", result.Shape) +} + +func TestPollError(t *testing.T) { + pollCtx := context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI") + resp := initialResponse(pollCtx, http.MethodPatch, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, poller.StatusInProgress) + poller, err := New[widget](exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Header: http.Header{shared.HeaderFakePollerStatus: []string{poller.StatusFailed}}, + Body: io.NopCloser(strings.NewReader(`{ "error": { "code": "NotFound", "message": "the item doesn't exist" } }`)), + }, nil + })), resp) + require.NoError(t, err) + require.False(t, poller.Done()) + resp, err = poller.Poll(pollCtx) + require.Error(t, err) + require.Nil(t, resp) + var respErr *exported.ResponseError + require.ErrorAs(t, err, &respErr) + require.Equal(t, http.StatusNotFound, respErr.StatusCode) + require.False(t, poller.Done()) + var result widget + require.Error(t, poller.Result(context.Background(), &result)) + require.ErrorAs(t, err, &respErr) +} + +func TestPollFailed(t *testing.T) { + pollCtx := context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI") + resp := initialResponse(pollCtx, http.MethodPatch, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, poller.StatusInProgress) + poller, err := New[widget](exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{shared.HeaderFakePollerStatus: []string{poller.StatusFailed}}, + Body: io.NopCloser(strings.NewReader(`{ "error": { "code": "FakeFailure", "message": "couldn't do the thing" } }`)), + }, nil + })), resp) + require.NoError(t, err) + require.False(t, poller.Done()) + resp, err = poller.Poll(pollCtx) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.True(t, poller.Done()) + var result widget + var respErr *exported.ResponseError + err = poller.Result(context.Background(), &result) + require.Error(t, err) + require.ErrorAs(t, err, &respErr) +} + +func TestPollErrorNoHeader(t *testing.T) { + pollCtx := context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI") + resp := initialResponse(pollCtx, http.MethodPatch, http.NoBody) + resp.Header.Set(shared.HeaderFakePollerStatus, poller.StatusInProgress) + poller, err := New[widget](exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: io.NopCloser(strings.NewReader(`{ "error": { "code": "NotFound", "message": "the item doesn't exist" } }`)), + }, nil + })), resp) + require.NoError(t, err) + require.False(t, poller.Done()) + resp, err = poller.Poll(pollCtx) + require.Error(t, err) + require.Nil(t, resp) +} diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index dcd3d098b339..01f802537e4c 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -17,6 +17,7 @@ const ( HeaderAzureAsync = "Azure-AsyncOperation" HeaderContentLength = "Content-Length" HeaderContentType = "Content-Type" + HeaderFakePollerStatus = "Fake-Poller-Status" HeaderLocation = "Location" HeaderOperationLocation = "Operation-Location" HeaderRetryAfter = "Retry-After" @@ -24,6 +25,7 @@ const ( HeaderWWWAuthenticate = "WWW-Authenticate" HeaderXMSClientRequestID = "x-ms-client-request-id" HeaderXMSRequestID = "x-ms-request-id" + HeaderXMSErrorCode = "x-ms-error-code" ) const BearerTokenPrefix = "Bearer " diff --git a/sdk/azcore/internal/shared/shared.go b/sdk/azcore/internal/shared/shared.go index 9bd054b3643e..69153854c77a 100644 --- a/sdk/azcore/internal/shared/shared.go +++ b/sdk/azcore/internal/shared/shared.go @@ -29,6 +29,9 @@ type CtxIncludeResponseKey struct{} // CtxWithTracingTracer is used as a context key for adding/retrieving tracing.Tracer. type CtxWithTracingTracer struct{} +// CtxAPINameKey is used as a context key for adding/retrieving the API name. +type CtxAPINameKey struct{} + // Delay waits for the duration to elapse or the context to be cancelled. func Delay(ctx context.Context, delay time.Duration) error { select { diff --git a/sdk/azcore/runtime/poller.go b/sdk/azcore/runtime/poller.go index e57ad240dc04..c373f68962e3 100644 --- a/sdk/azcore/runtime/poller.go +++ b/sdk/azcore/runtime/poller.go @@ -22,6 +22,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/loc" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/op" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" @@ -90,7 +91,9 @@ func NewPoller[T any](resp *http.Response, pl exported.Pipeline, options *NewPol // determine the polling method var opr PollingHandler[T] var err error - if async.Applicable(resp) { + if fake.Applicable(resp) { + opr, err = fake.New[T](pl, resp) + } else if async.Applicable(resp) { // async poller must be checked first as it can also have a location header opr, err = async.New[T](pl, resp, options.FinalStateVia) } else if op.Applicable(resp) { @@ -158,7 +161,9 @@ func NewPollerFromResumeToken[T any](token string, pl exported.Pipeline, options opr := options.Handler // now rehydrate the poller based on the encoded poller type - if opr != nil { + if fake.CanResume(asJSON) { + opr, _ = fake.New[T](pl, nil) + } else if opr != nil { log.Writef(log.EventLRO, "Resuming custom poller %T.", opr) } else if async.CanResume(asJSON) { opr, _ = async.New[T](pl, nil, "") diff --git a/sdk/azcore/runtime/poller_test.go b/sdk/azcore/runtime/poller_test.go index 7811f0fe51ab..a16f99f667df 100644 --- a/sdk/azcore/runtime/poller_test.go +++ b/sdk/azcore/runtime/poller_test.go @@ -23,9 +23,11 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/async" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/body" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/pollers/fake" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/Azure/azure-sdk-for-go/sdk/internal/poller" "github.com/stretchr/testify/require" ) @@ -771,8 +773,8 @@ func getPipeline(srv *mock.Server) Pipeline { ) } -func initialResponse(method, u string, resp io.Reader) (*http.Response, mock.TrackedClose) { - req, err := http.NewRequest(method, u, nil) +func initialResponse(ctx context.Context, method, u string, resp io.Reader) (*http.Response, mock.TrackedClose) { + req, err := http.NewRequestWithContext(ctx, method, u, nil) if err != nil { panic(err) } @@ -795,7 +797,7 @@ func TestNewPollerAsync(t *testing.T) { srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) srv.AppendResponse(mock.WithBody([]byte(successResp))) - resp, closed := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) @@ -838,7 +840,7 @@ func TestNewPollerBody(t *testing.T) { defer close() srv.AppendResponse(mock.WithBody([]byte(provStateUpdating)), mock.WithHeader("Retry-After", "1")) srv.AppendResponse(mock.WithBody([]byte(provStateSucceeded))) - resp, closed := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) poller, err := NewPoller[mockType](resp, pl, nil) @@ -874,7 +876,7 @@ func TestNewPollerInitialRetryAfter(t *testing.T) { srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) srv.AppendResponse(mock.WithBody([]byte(successResp))) - resp, closed := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) resp.Header.Set("Retry-After", "1") resp.StatusCode = http.StatusCreated @@ -903,7 +905,7 @@ func TestNewPollerCanceled(t *testing.T) { defer close() srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) srv.AppendResponse(mock.WithBody([]byte(statusCanceled)), mock.WithStatusCode(http.StatusOK)) - resp, closed := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) @@ -941,7 +943,7 @@ func TestNewPollerFailed(t *testing.T) { srv, close := mock.NewServer() defer close() srv.AppendResponse(mock.WithBody([]byte(provStateFailed))) - resp, closed := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) @@ -966,7 +968,7 @@ func TestNewPollerFailedWithError(t *testing.T) { defer close() srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest)) - resp, closed := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) @@ -991,7 +993,7 @@ func TestNewPollerSuccessNoContent(t *testing.T) { defer close() srv.AppendResponse(mock.WithBody([]byte(provStateUpdating))) srv.AppendResponse(mock.WithStatusCode(http.StatusNoContent)) - resp, closed := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) poller, err := NewPoller[mockType](resp, pl, nil) @@ -1024,7 +1026,7 @@ func TestNewPollerSuccessNoContent(t *testing.T) { func TestNewPollerFail202NoHeaders(t *testing.T) { srv, close := mock.NewServer() defer close() - resp, closed := initialResponse(http.MethodDelete, srv.URL(), http.NoBody) + resp, closed := initialResponse(context.Background(), http.MethodDelete, srv.URL(), http.NoBody) resp.StatusCode = http.StatusAccepted pl := getPipeline(srv) poller, err := NewPoller[mockType](resp, pl, nil) @@ -1049,7 +1051,7 @@ func TestNewPollerWithResponseType(t *testing.T) { defer close() srv.AppendResponse(mock.WithBody([]byte(provStateUpdating)), mock.WithHeader("Retry-After", "1")) srv.AppendResponse(mock.WithBody([]byte(provStateSucceeded))) - resp, closed := initialResponse(http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPatch, srv.URL(), strings.NewReader(provStateStarted)) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) poller, err := NewPoller[preconstructedMockType](resp, pl, nil) @@ -1146,7 +1148,7 @@ func TestNewPollerWithCustomHandler(t *testing.T) { srv.AppendResponse(mock.WithBody([]byte(statusInProgress))) srv.AppendResponse(mock.WithBody([]byte(statusSucceeded))) srv.AppendResponse(mock.WithBody([]byte(successResp))) - resp, closed := initialResponse(http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) + resp, closed := initialResponse(context.Background(), http.MethodPut, srv.URL(), strings.NewReader(provStateStarted)) resp.Header.Set(shared.HeaderAzureAsync, srv.URL()) resp.StatusCode = http.StatusCreated pl := getPipeline(srv) @@ -1190,3 +1192,28 @@ func TestShortenPollerTypeName(t *testing.T) { result = shortenTypeName("Poller.PollUntilDone") require.EqualValues(t, "Poller.PollUntilDone", result) } + +func TestNewFakePoller(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithHeader(shared.HeaderFakePollerStatus, "FakePollerInProgress")) + srv.AppendResponse(mock.WithHeader(shared.HeaderFakePollerStatus, poller.StatusSucceeded), mock.WithStatusCode(http.StatusNoContent)) + pollCtx := context.WithValue(context.Background(), shared.CtxAPINameKey{}, "FakeAPI") + resp, closed := initialResponse(pollCtx, http.MethodPatch, srv.URL(), http.NoBody) + resp.StatusCode = http.StatusCreated + resp.Header.Set(shared.HeaderFakePollerStatus, "FakePollerInProgress") + pl := getPipeline(srv) + poller, err := NewPoller[mockType](resp, pl, nil) + require.NoError(t, err) + require.True(t, closed()) + if pt := typeOfOpField(poller); pt != reflect.TypeOf((*fake.Poller[mockType])(nil)) { + t.Fatalf("unexpected poller type %s", pt.String()) + } + tk, err := poller.ResumeToken() + require.NoError(t, err) + poller, err = NewPollerFromResumeToken[mockType](tk, pl, nil) + require.NoError(t, err) + result, err := poller.PollUntilDone(context.Background(), &PollUntilDoneOptions{Frequency: time.Millisecond}) + require.NoError(t, err) + require.Nil(t, result.Field) +} diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index 98e00718488e..cdbf8bde60b0 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -169,6 +169,9 @@ func SkipBodyDownload(req *policy.Request) { req.SetOperationValue(bodyDownloadPolicyOpValues{Skip: true}) } +// CtxAPINameKey is used as a context key for adding/retrieving the API name. +type CtxAPINameKey = shared.CtxAPINameKey + // returns a clone of the object graph pointed to by v, omitting values of all read-only // fields. if there are no read-only fields in the object graph, no clone is created. func cloneWithoutReadOnlyFields(v interface{}) interface{} {