diff --git a/github/auth.go b/github/auth.go index ff413efe..3205a9de 100644 --- a/github/auth.go +++ b/github/auth.go @@ -17,16 +17,14 @@ limitations under the License. package github import ( - "context" - "errors" "fmt" "net/http" "github.com/google/go-github/v32/github" - "github.com/gregjones/httpcache" "golang.org/x/oauth2" "github.com/fluxcd/go-git-providers/gitprovider" + "github.com/fluxcd/go-git-providers/gitprovider/cache" ) const ( @@ -39,183 +37,194 @@ const ( patUsername = "git" ) -var ( - // ErrInvalidClientOptions is the error returned when calling NewClient() with - // invalid options (e.g. specifying mutually exclusive options). - ErrInvalidClientOptions = errors.New("invalid options given to NewClient()") - // ErrDestructiveCallDisallowed happens when the client isn't set up with WithDestructiveAPICalls() - // but a destructive action is called. - ErrDestructiveCallDisallowed = errors.New("a destructive call was blocked because it wasn't allowed by the client") -) +// ClientOption is the interface to implement for passing options to NewClient. +// The clientOptions struct is private to force usage of the With... functions. +type ClientOption interface { + // ApplyToGithubClientOptions applies set fields of this object into target. + ApplyToGithubClientOptions(target *clientOptions) error +} -// clientOptions is the struct that tracks data about what options have been set -// It is private so that the user must use the With... functions. +// clientOptions is the struct that tracks data about what options have been set. type clientOptions struct { - // Domain specifies the backing domain, which can be arbitrary if the user uses - // GitHub Enterprise. If unset, defaultDomain will be used. - Domain *string - - // authTransportFactory is a way to acquire a http.RoundTripper with auth credentials configured. - authTransportFactory authTransportFactory - - // RoundTripperFactory is a factory to get a http.RoundTripper that is sitting between the *github.Client's - // internal *http.Client, and the *httpcache.Transport RoundTripper. It can be set - // for doing arbitrary modifications to http requests. - RoundTripperFactory RoundTripperFactory + // clientOptions shares all the common options + gitprovider.CommonClientOptions - // EnableDestructiveAPICalls is a flag to tell whether destructive API calls like - // deleting a repository and such is allowed. Default: false - EnableDestructiveAPICalls *bool + // AuthTransport is a ChainableRoundTripperFunc adding authentication credentials to the transport chain. + AuthTransport gitprovider.ChainableRoundTripperFunc // EnableConditionalRequests will be set if conditional requests should be used. + // TODO: Move this to gitprovider.CommonClientOptions if other providers support this too. // See: https://developer.github.com/v3/#conditional-requests for more info. // Default: false EnableConditionalRequests *bool } -// ClientOption is a function that is mutating a pointer to a clientOptions object -// which holds information of how the Client should be initialized. -type ClientOption func(*clientOptions) error +// ApplyToGithubClientOptions implements ClientOption, and applies the set fields of opts +// into target. If both opts and target has the same specific field set, ErrInvalidClientOptions is returned. +func (opts *clientOptions) ApplyToGithubClientOptions(target *clientOptions) error { + // Apply common values, if any + if err := opts.CommonClientOptions.ApplyToCommonClientOptions(&target.CommonClientOptions); err != nil { + return err + } -// WithOAuth2Token initializes a Client which authenticates with GitHub through an OAuth2 token. -// oauth2Token must not be an empty string. -// WithOAuth2Token is mutually exclusive with WithPersonalAccessToken. -func WithOAuth2Token(oauth2Token string) ClientOption { - return func(opts *clientOptions) error { - // Don't allow an empty value - if len(oauth2Token) == 0 { - return fmt.Errorf("oauth2Token cannot be empty: %w", ErrInvalidClientOptions) - } - // Make sure the user didn't specify auth twice - if opts.authTransportFactory != nil { - return fmt.Errorf("authentication http.Client already configured: %w", ErrInvalidClientOptions) + if opts.AuthTransport != nil { + // Make sure the user didn't specify the AuthTransport twice + if target.AuthTransport != nil { + return fmt.Errorf("option AuthTransport already configured: %w", gitprovider.ErrInvalidClientOptions) } - - opts.authTransportFactory = &oauth2Auth{oauth2Token} - return nil + target.AuthTransport = opts.AuthTransport } -} -// WithPersonalAccessToken initializes a Client which authenticates with GitHub through a personal access token. -// patToken must not be an empty string. -// WithPersonalAccessToken is mutually exclusive with WithOAuth2Token. -func WithPersonalAccessToken(patToken string) ClientOption { - return func(opts *clientOptions) error { - // Don't allow an empty value - if len(patToken) == 0 { - return fmt.Errorf("patToken cannot be empty: %w", ErrInvalidClientOptions) + if opts.EnableConditionalRequests != nil { + // Make sure the user didn't specify the EnableConditionalRequests twice + if target.EnableConditionalRequests != nil { + return fmt.Errorf("option EnableConditionalRequests already configured: %w", gitprovider.ErrInvalidClientOptions) } - // Make sure the user didn't specify auth twice - if opts.authTransportFactory != nil { - return fmt.Errorf("authentication http.Client already configured: %w", ErrInvalidClientOptions) - } - opts.authTransportFactory = &patAuth{patToken} - return nil + target.EnableConditionalRequests = opts.EnableConditionalRequests } + return nil } -// WithRoundTripper initializes a Client with a given authTransportFactory, used for acquiring the *http.Client later. -// authTransportFactory must not be nil. -func WithRoundTripper(roundTripper RoundTripperFactory) ClientOption { - return func(opts *clientOptions) error { - // Don't allow an empty value - if roundTripper == nil { - return fmt.Errorf("roundTripper cannot be nil: %w", ErrInvalidClientOptions) - } - // Make sure the user didn't specify the RoundTripperFactory twice - if opts.RoundTripperFactory != nil { - return fmt.Errorf("roundTripper already configured: %w", ErrInvalidClientOptions) - } - opts.RoundTripperFactory = roundTripper - return nil +// getTransportChain builds the full chain of transports (from left to right, +// as per gitprovider.BuildClientFromTransportChain) of the form described in NewClient. +func (opts *clientOptions) getTransportChain() (chain []gitprovider.ChainableRoundTripperFunc) { + if opts.PostChainTransportHook != nil { + chain = append(chain, opts.PostChainTransportHook) + } + if opts.AuthTransport != nil { + chain = append(chain, opts.AuthTransport) + } + if opts.EnableConditionalRequests != nil && *opts.EnableConditionalRequests { + // TODO: Provide some kind of debug logging if/when the httpcache is used + // One can see if the request hit the cache using: resp.Header[httpcache.XFromCache] + chain = append(chain, cache.NewHTTPCacheTransport) + } + if opts.PreChainTransportHook != nil { + chain = append(chain, opts.PreChainTransportHook) } + return } +// buildCommonOption is a helper for returning a ClientOption out of a common option field. +func buildCommonOption(opt gitprovider.CommonClientOptions) *clientOptions { + return &clientOptions{CommonClientOptions: opt} +} + +// errorOption implements ClientOption, and just wraps an error which is immediately returned. +// This struct can be used through the optionError function, in order to make makeOptions fail +// if there are invalid options given to the With... functions. +type errorOption struct { + err error +} + +// ApplyToGithubClientOptions implements ClientOption, but just returns the internal error. +func (e *errorOption) ApplyToGithubClientOptions(*clientOptions) error { return e.err } + +// optionError is a constructor for errorOption. +func optionError(err error) ClientOption { + return &errorOption{err} +} + +// +// Common options +// + // WithDomain initializes a Client for a custom GitHub Enterprise instance of the given domain. // Only host and port information should be present in domain. domain must not be an empty string. func WithDomain(domain string) ClientOption { - return func(opts *clientOptions) error { - // Don't set an empty value - if len(domain) == 0 { - return fmt.Errorf("domain cannot be empty: %w", ErrInvalidClientOptions) - } - // Make sure the user didn't specify the domain twice - if opts.Domain != nil { - return fmt.Errorf("domain already configured: %w", ErrInvalidClientOptions) - } - opts.Domain = gitprovider.StringVar(domain) - return nil - } + return buildCommonOption(gitprovider.CommonClientOptions{Domain: &domain}) } // WithDestructiveAPICalls tells the client whether it's allowed to do dangerous and possibly destructive // actions, like e.g. deleting a repository. func WithDestructiveAPICalls(destructiveActions bool) ClientOption { - return func(opts *clientOptions) error { - // Make sure the user didn't specify the flag twice - if opts.EnableDestructiveAPICalls != nil { - return fmt.Errorf("destructive actions flag already configured: %w", ErrInvalidClientOptions) - } - opts.EnableDestructiveAPICalls = gitprovider.BoolVar(destructiveActions) - return nil - } + return buildCommonOption(gitprovider.CommonClientOptions{EnableDestructiveAPICalls: &destructiveActions}) } -// WithConditionalRequests instructs the client to use Conditional Requests to GitHub, asking GitHub -// whether a resource has changed (without burning your quota), and using an in-memory cached "database" -// if so. See: https://developer.github.com/v3/#conditional-requests for more information. -func WithConditionalRequests(conditionalRequests bool) ClientOption { - return func(opts *clientOptions) error { - // Make sure the user didn't specify the flag twice - if opts.EnableConditionalRequests != nil { - return fmt.Errorf("conditional requests flag already configured: %w", ErrInvalidClientOptions) - } - opts.EnableConditionalRequests = gitprovider.BoolVar(conditionalRequests) - return nil +// WithPreChainTransportHook registers a ChainableRoundTripperFunc "before" the cache and authentication +// transports in the chain. For more information, see NewClient, and gitprovider.CommonClientOptions.PreChainTransportHook. +func WithPreChainTransportHook(preRoundTripperFunc gitprovider.ChainableRoundTripperFunc) ClientOption { + // Don't allow an empty value + if preRoundTripperFunc == nil { + return optionError(fmt.Errorf("preRoundTripperFunc cannot be nil: %w", gitprovider.ErrInvalidClientOptions)) } -} -type RoundTripperFactory interface { - Transport(rt http.RoundTripper) http.RoundTripper + return buildCommonOption(gitprovider.CommonClientOptions{PreChainTransportHook: preRoundTripperFunc}) } -// authTransportFactory is a way to acquire a http.RoundTripper with auth credentials configured. -type authTransportFactory interface { - // Transport returns a http.RoundTripper with auth credentials configured. - Transport(ctx context.Context) http.RoundTripper +// WithPostChainTransportHook registers a ChainableRoundTripperFunc "after" the cache and authentication +// transports in the chain. For more information, see NewClient, and gitprovider.CommonClientOptions.WithPostChainTransportHook. +func WithPostChainTransportHook(postRoundTripperFunc gitprovider.ChainableRoundTripperFunc) ClientOption { + // Don't allow an empty value + if postRoundTripperFunc == nil { + return optionError(fmt.Errorf("postRoundTripperFunc cannot be nil: %w", gitprovider.ErrInvalidClientOptions)) + } + + return buildCommonOption(gitprovider.CommonClientOptions{PostChainTransportHook: postRoundTripperFunc}) } -// oauth2Auth is an implementation of authTransportFactory. -type oauth2Auth struct { - token string +// +// GitHub-specific options +// + +// WithOAuth2Token initializes a Client which authenticates with GitHub through an OAuth2 token. +// oauth2Token must not be an empty string. +// WithOAuth2Token is mutually exclusive with WithPersonalAccessToken. +func WithOAuth2Token(oauth2Token string) ClientOption { + // Don't allow an empty value + if len(oauth2Token) == 0 { + return optionError(fmt.Errorf("oauth2Token cannot be empty: %w", gitprovider.ErrInvalidClientOptions)) + } + + return &clientOptions{AuthTransport: oauth2Transport(oauth2Token)} } -// Transport returns a http.RoundTripper with auth credentials configured. -func (a *oauth2Auth) Transport(ctx context.Context) http.RoundTripper { - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: a.token}, - ) - return oauth2.NewClient(ctx, ts).Transport +func oauth2Transport(oauth2Token string) gitprovider.ChainableRoundTripperFunc { + return func(in http.RoundTripper) http.RoundTripper { + // Create a TokenSource of the given access token + ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: oauth2Token}) + // Create a Transport, with "in" as the underlying transport, and the given TokenSource + return &oauth2.Transport{ + Base: in, + Source: oauth2.ReuseTokenSource(nil, ts), + } + } } -// patAuth is an implementation of authTransportFactory. -type patAuth struct { - token string +// WithPersonalAccessToken initializes a Client which authenticates with GitHub through a personal access token. +// patToken must not be an empty string. +// WithPersonalAccessToken is mutually exclusive with WithOAuth2Token. +func WithPersonalAccessToken(patToken string) ClientOption { + // Don't allow an empty value + if len(patToken) == 0 { + return optionError(fmt.Errorf("patToken cannot be empty: %w", gitprovider.ErrInvalidClientOptions)) + } + + return &clientOptions{AuthTransport: patTransport(patToken)} } -// Transport returns a http.RoundTripper with auth credentials configured. -func (a *patAuth) Transport(ctx context.Context) http.RoundTripper { - return &github.BasicAuthTransport{ - Username: patUsername, - Password: a.token, +func patTransport(patToken string) gitprovider.ChainableRoundTripperFunc { + return func(in http.RoundTripper) http.RoundTripper { + return &github.BasicAuthTransport{ + Username: patUsername, + Password: patToken, + Transport: in, + } } } +// WithConditionalRequests instructs the client to use Conditional Requests to GitHub, asking GitHub +// whether a resource has changed (without burning your quota), and using an in-memory cached "database" +// if so. See: https://developer.github.com/v3/#conditional-requests for more information. +func WithConditionalRequests(conditionalRequests bool) ClientOption { + return &clientOptions{EnableConditionalRequests: &conditionalRequests} +} + // makeOptions assembles a clientOptions struct from ClientOption mutator functions. func makeOptions(opts ...ClientOption) (*clientOptions, error) { o := &clientOptions{} for _, opt := range opts { - if err := opt(o); err != nil { + if err := opt.ApplyToGithubClientOptions(o); err != nil { return nil, err } } @@ -230,28 +239,26 @@ func makeOptions(opts ...ClientOption) (*clientOptions, error) { // Basic Auth is not supported because it is deprecated by GitHub, see // https://developer.github.com/changes/2020-02-14-deprecating-password-auth/ // -// GitHub Enterprise can be used if you specify the domain using the WithDomain option. +// GitHub Enterprise can be used if you specify the domain using WithDomain. // -// You can customize low-level HTTP Transport functionality by using WithRoundTripper. +// You can customize low-level HTTP Transport functionality by using the With{Pre,Post}ChainTransportHook options. // You can also use conditional requests (and an in-memory cache) using WithConditionalRequests. // // The chain of transports looks like this: -// github.com API <-> Authentication <-> Cache <-> Custom Roundtripper <-> This Client. -func NewClient(ctx context.Context, optFns ...ClientOption) (gitprovider.Client, error) { +// github.com API <-> "Post Chain" <-> Authentication <-> Cache <-> "Pre Chain" <-> *github.Client. +func NewClient(optFns ...ClientOption) (gitprovider.Client, error) { // Complete the options struct opts, err := makeOptions(optFns...) if err != nil { return nil, err } - transport, err := buildTransportChain(ctx, opts) + // Create a *http.Client using the transport chain + httpClient, err := gitprovider.BuildClientFromTransportChain(opts.getTransportChain()) if err != nil { return nil, err } - // Create a *http.Client using the transport chain - httpClient := &http.Client{Transport: transport} - // Create the GitHub client either for the default github.com domain, or // a custom enterprise domain if opts.Domain is set to something other than // the default. @@ -280,77 +287,3 @@ func NewClient(ctx context.Context, optFns ...ClientOption) (gitprovider.Client, return newClient(gh, domain, destructiveActions), nil } - -// buildTransportChain builds a chain of http.RoundTrippers calling each other as per the -// description in NewClient. -func buildTransportChain(ctx context.Context, opts *clientOptions) (http.RoundTripper, error) { - // transport will be the http.RoundTripper for the *http.Client given to the Github client. - var transport http.RoundTripper - // Get an authenticated http.RoundTripper, if set - if opts.authTransportFactory != nil { - transport = opts.authTransportFactory.Transport(ctx) - } - - // Conditionally enable conditional requests - if opts.EnableConditionalRequests != nil && *opts.EnableConditionalRequests { - // Create a new httpcache high-level Transport - t := httpcache.NewMemoryCacheTransport() - // Make the httpcache high-level transport use the auth transport "underneath" - if transport != nil { - t.Transport = transport - } - // Override the transport with our embedded underlying auth transport - transport = &cacheRoundtripper{t} - } - - // If a custom roundtripper was set, pipe it through the transport too - if opts.RoundTripperFactory != nil { - // TODO: Document usage of a nil transport here - // TODO: Provide some kind of debug logging if/when the httpcache is used - // One can see if the request hit the cache using: resp.Header[httpcache.XFromCache] - customTransport := opts.RoundTripperFactory.Transport(transport) - if customTransport == nil { - // The lint failure here is a false positive, for some (unknown) reason - //nolint:goerr113 - return nil, fmt.Errorf("the RoundTripper returned from the RoundTripperFactory must not be nil: %w", ErrInvalidClientOptions) - } - transport = customTransport - } - - return transport, nil -} - -type cacheRoundtripper struct { - t *httpcache.Transport -} - -// This function follows the same logic as in github.com/gregjones/httpcache to be able -// to implement our custom roundtripper logic below. -func cacheKey(req *http.Request) string { - if req.Method == http.MethodGet { - return req.URL.String() - } - return req.Method + " " + req.URL.String() -} - -// RoundTrip calls the underlying RoundTrip (using the cache), but invalidates the cache on -// non GET/HEAD requests and non-"200 OK" responses. -func (r *cacheRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { - // These two statements are the same as in github.com/gregjones/httpcache Transport.RoundTrip - // to be able to implement our custom roundtripper below - cacheKey := cacheKey(req) - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" - - // If the object isn't a GET or HEAD request, also invalidate the cache of the GET URL - // as this action will modify the underlying resource (e.g. DELETE/POST/PATCH) - if !cacheable { - r.t.Cache.Delete(req.URL.String()) - } - // Call the underlying roundtrip - resp, err := r.t.RoundTrip(req) - // Don't cache anything but "200 OK" requests - if resp.StatusCode != http.StatusOK { - r.t.Cache.Delete(cacheKey) - } - return resp, err -} diff --git a/github/auth_test.go b/github/auth_test.go new file mode 100644 index 00000000..42728fc7 --- /dev/null +++ b/github/auth_test.go @@ -0,0 +1,214 @@ +/* +Copyright 2020 The Flux CD contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package github + +import ( + "net/http" + "reflect" + "testing" + + "github.com/fluxcd/go-git-providers/gitprovider" + "github.com/fluxcd/go-git-providers/gitprovider/cache" + "github.com/fluxcd/go-git-providers/validation" +) + +func dummyRoundTripper1(http.RoundTripper) http.RoundTripper { return nil } +func dummyRoundTripper2(http.RoundTripper) http.RoundTripper { return nil } +func dummyRoundTripper3(http.RoundTripper) http.RoundTripper { return nil } + +func roundTrippersEqual(a, b gitprovider.ChainableRoundTripperFunc) bool { + if a == nil && b == nil { + return true + } else if (a != nil && b == nil) || (a == nil && b != nil) { + return false + } + // Note that this comparison relies on "undefined behavior" in the Go language spec, see: + // https://stackoverflow.com/questions/9643205/how-do-i-compare-two-functions-for-pointer-equality-in-the-latest-go-weekly + return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer() +} + +func Test_clientOptions_getTransportChain(t *testing.T) { + tests := []struct { + name string + preChain gitprovider.ChainableRoundTripperFunc + postChain gitprovider.ChainableRoundTripperFunc + auth gitprovider.ChainableRoundTripperFunc + cache bool + wantChain []gitprovider.ChainableRoundTripperFunc + }{ + { + name: "all roundtrippers", + preChain: dummyRoundTripper1, + postChain: dummyRoundTripper2, + auth: dummyRoundTripper3, + cache: true, + // expect: "post chain" <-> "auth" <-> "cache" <-> "pre chain" + wantChain: []gitprovider.ChainableRoundTripperFunc{ + dummyRoundTripper2, + dummyRoundTripper3, + cache.NewHTTPCacheTransport, + dummyRoundTripper1, + }, + }, + { + name: "only pre + auth", + preChain: dummyRoundTripper1, + auth: dummyRoundTripper2, + // expect: "auth" <-> "pre chain" + wantChain: []gitprovider.ChainableRoundTripperFunc{ + dummyRoundTripper2, + dummyRoundTripper1, + }, + }, + { + name: "only cache + auth", + cache: true, + auth: dummyRoundTripper1, + // expect: "auth" <-> "cache" + wantChain: []gitprovider.ChainableRoundTripperFunc{ + dummyRoundTripper1, + cache.NewHTTPCacheTransport, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := &clientOptions{ + CommonClientOptions: gitprovider.CommonClientOptions{ + PreChainTransportHook: tt.preChain, + PostChainTransportHook: tt.postChain, + }, + AuthTransport: tt.auth, + EnableConditionalRequests: &tt.cache, + } + gotChain := opts.getTransportChain() + for i := range tt.wantChain { + if !roundTrippersEqual(tt.wantChain[i], gotChain[i]) { + t.Errorf("clientOptions.getTransportChain() = %v, want %v", gotChain, tt.wantChain) + } + break + } + }) + } +} + +func Test_makeOptions(t *testing.T) { + tests := []struct { + name string + opts []ClientOption + want *clientOptions + expectedErrs []error + }{ + { + name: "no options", + want: &clientOptions{}, + }, + { + name: "WithDomain", + opts: []ClientOption{WithDomain("foo")}, + want: buildCommonOption(gitprovider.CommonClientOptions{Domain: gitprovider.StringVar("foo")}), + }, + { + name: "WithDomain, empty", + opts: []ClientOption{WithDomain("")}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + { + name: "WithDestructiveAPICalls", + opts: []ClientOption{WithDestructiveAPICalls(true)}, + want: buildCommonOption(gitprovider.CommonClientOptions{EnableDestructiveAPICalls: gitprovider.BoolVar(true)}), + }, + { + name: "WithPreChainTransportHook", + opts: []ClientOption{WithPreChainTransportHook(dummyRoundTripper1)}, + want: buildCommonOption(gitprovider.CommonClientOptions{PreChainTransportHook: dummyRoundTripper1}), + }, + { + name: "WithPreChainTransportHook, nil", + opts: []ClientOption{WithPreChainTransportHook(nil)}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + { + name: "WithPostChainTransportHook", + opts: []ClientOption{WithPostChainTransportHook(dummyRoundTripper2)}, + want: buildCommonOption(gitprovider.CommonClientOptions{PostChainTransportHook: dummyRoundTripper2}), + }, + { + name: "WithPostChainTransportHook, nil", + opts: []ClientOption{WithPostChainTransportHook(nil)}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + { + name: "WithOAuth2Token", + opts: []ClientOption{WithOAuth2Token("foo")}, + want: &clientOptions{AuthTransport: oauth2Transport("foo")}, + }, + { + name: "WithOAuth2Token, empty", + opts: []ClientOption{WithOAuth2Token("")}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + { + name: "WithPersonalAccessToken", + opts: []ClientOption{WithPersonalAccessToken("foo")}, + want: &clientOptions{AuthTransport: patTransport("foo")}, + }, + { + name: "WithPersonalAccessToken, empty", + opts: []ClientOption{WithPersonalAccessToken("")}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + { + name: "WithPersonalAccessToken and WithOAuth2Token, exclusive", + opts: []ClientOption{WithPersonalAccessToken("foo"), WithOAuth2Token("foo")}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + { + name: "WithConditionalRequests", + opts: []ClientOption{WithConditionalRequests(true)}, + want: &clientOptions{EnableConditionalRequests: gitprovider.BoolVar(true)}, + }, + { + name: "WithConditionalRequests, exclusive", + opts: []ClientOption{WithConditionalRequests(true), WithConditionalRequests(false)}, + expectedErrs: []error{gitprovider.ErrInvalidClientOptions}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := makeOptions(tt.opts...) + validation.TestExpectErrors(t, "makeOptions", err, tt.expectedErrs...) + if tt.want == nil { + return + } + if !roundTrippersEqual(got.AuthTransport, tt.want.AuthTransport) || + !roundTrippersEqual(got.PostChainTransportHook, tt.want.PostChainTransportHook) || + !roundTrippersEqual(got.PreChainTransportHook, tt.want.PreChainTransportHook) { + t.Errorf("makeOptions() = %v, want %v", got, tt.want) + } + got.AuthTransport = nil + got.PostChainTransportHook = nil + got.PreChainTransportHook = nil + tt.want.AuthTransport = nil + tt.want.PostChainTransportHook = nil + tt.want.PreChainTransportHook = nil + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("makeOptions() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/github/example_organization_test.go b/github/example_organization_test.go index dba10b70..cb0c43cb 100644 --- a/github/example_organization_test.go +++ b/github/example_organization_test.go @@ -20,7 +20,7 @@ func checkErr(err error) { func ExampleOrganizationsClient_Get() { // Create a new client ctx := context.Background() - c, err := github.NewClient(ctx) + c, err := github.NewClient() checkErr(err) // Get public information about the fluxcd organization diff --git a/github/example_repository_test.go b/github/example_repository_test.go index 254ffb0f..baa339c8 100644 --- a/github/example_repository_test.go +++ b/github/example_repository_test.go @@ -12,7 +12,7 @@ import ( func ExampleOrgRepositoriesClient_Get() { // Create a new client ctx := context.Background() - c, err := github.NewClient(ctx) + c, err := github.NewClient() checkErr(err) // Parse the URL into an OrgRepositoryRef diff --git a/github/githubclient.go b/github/githubclient.go index ad5fb224..cd769a3a 100644 --- a/github/githubclient.go +++ b/github/githubclient.go @@ -262,7 +262,7 @@ func (c *githubClientImpl) UpdateRepo(ctx context.Context, owner, repo string, r func (c *githubClientImpl) DeleteRepo(ctx context.Context, owner, repo string) error { // Don't allow deleting repositories if the user didn't explicitly allow dangerous API calls. if !c.destructiveActions { - return fmt.Errorf("cannot delete repository: %w", ErrDestructiveCallDisallowed) + return fmt.Errorf("cannot delete repository: %w", gitprovider.ErrDestructiveCallDisallowed) } // DELETE /repos/{owner}/{repo} _, err := c.c.Repositories.Delete(ctx, owner, repo) diff --git a/github/integration_test.go b/github/integration_test.go index b5a24a00..2c4901d5 100644 --- a/github/integration_test.go +++ b/github/integration_test.go @@ -45,6 +45,11 @@ const ( defaultBranch = "master" ) +var ( + // customTransportImpl is a shared instance of a customTransport, allowing counting of cache hits. + customTransportImpl *customTransport +) + func init() { // Call testing.Init() prior to tests.NewParams(), as otherwise -test.* will not be recognised. See also: https://golang.org/doc/go1.13#testing testing.Init() @@ -56,21 +61,17 @@ func TestProvider(t *testing.T) { RunSpecs(t, "GitHub Provider Suite") } -type customTransportFactory struct { - customTransport *customTransport -} - -func (f *customTransportFactory) Transport(transport http.RoundTripper) http.RoundTripper { - if f.customTransport != nil { +func customTransportFactory(transport http.RoundTripper) http.RoundTripper { + if customTransportImpl != nil { panic("didn't expect this function to be called twice") } - f.customTransport = &customTransport{ + customTransportImpl = &customTransport{ transport: transport, countCacheHits: false, cacheHits: 0, mux: &sync.Mutex{}, } - return f.customTransport + return customTransportImpl } type customTransport struct { @@ -125,9 +126,8 @@ func (t *customTransport) countCacheHitsForFunc(fn func()) int { var _ = Describe("GitHub Provider", func() { var ( - ctx context.Context - c gitprovider.Client - transportFactory = &customTransportFactory{} + ctx context.Context = context.Background() + c gitprovider.Client testRepoName string testOrgName string = "fluxcd-testing" @@ -148,13 +148,12 @@ var _ = Describe("GitHub Provider", func() { testOrgName = orgName } - ctx = context.Background() var err error - c, err = NewClient(ctx, + c, err = NewClient( WithPersonalAccessToken(githubToken), WithDestructiveAPICalls(true), WithConditionalRequests(true), - WithRoundTripper(transportFactory), + WithPreChainTransportHook(customTransportFactory), ) Expect(err).ToNot(HaveOccurred()) }) @@ -174,7 +173,7 @@ var _ = Describe("GitHub Provider", func() { } Expect(listedOrg).ToNot(BeNil()) - hits := transportFactory.customTransport.countCacheHitsForFunc(func() { + hits := customTransportImpl.countCacheHitsForFunc(func() { // Do a GET call for that organization getOrg, err = c.Organizations().Get(ctx, listedOrg.Organization()) Expect(err).ToNot(HaveOccurred()) @@ -200,7 +199,7 @@ var _ = Describe("GitHub Provider", func() { Expect(getOrg.Get().Description).To(Equal(internal.Description)) // Expect that when we do the same request a second time, it will hit the cache - hits = transportFactory.customTransport.countCacheHitsForFunc(func() { + hits = customTransportImpl.countCacheHitsForFunc(func() { getOrg2, err := c.Organizations().Get(ctx, listedOrg.Organization()) Expect(err).ToNot(HaveOccurred()) Expect(getOrg2).ToNot(BeNil()) diff --git a/gitprovider/cache/httpcache.go b/gitprovider/cache/httpcache.go new file mode 100644 index 00000000..abba10c4 --- /dev/null +++ b/gitprovider/cache/httpcache.go @@ -0,0 +1,75 @@ +/* +Copyright 2020 The Flux CD contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cache + +import ( + "net/http" + + "github.com/gregjones/httpcache" +) + +// TODO: Implement an unit test for this package. + +// NewHTTPCacheTransport is a gitprovider.ChainableRoundTripperFunc which adds +// HTTP Conditional Requests caching for the backend, if the server supports it. +func NewHTTPCacheTransport(in http.RoundTripper) http.RoundTripper { + // Create a new httpcache high-level Transport + t := httpcache.NewMemoryCacheTransport() + // Configure the httpcache Transport to use in as its underlying Transport. + // If in is nil, http.DefaultTransport will be used. + t.Transport = in + // Set "out" to use a slightly custom variant of the httpcache Transport + // (with more aggressive cache invalidation) + return &cacheRoundtripper{Transport: t} +} + +// cacheRoundtripper is a slight wrapper around *httpcache.Transport that automatically +// invalidates the cache on non-GET/HEAD requests, and non-"200 OK" responses. +type cacheRoundtripper struct { + Transport *httpcache.Transport +} + +// This function follows the same logic as in github.com/gregjones/httpcache to be able +// to implement our custom roundtripper logic below. +func cacheKey(req *http.Request) string { + if req.Method == http.MethodGet { + return req.URL.String() + } + return req.Method + " " + req.URL.String() +} + +// RoundTrip calls the underlying RoundTrip (using the cache), but invalidates the cache on +// non GET/HEAD requests and non-"200 OK" responses. +func (r *cacheRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { + // These two statements are the same as in github.com/gregjones/httpcache Transport.RoundTrip + // to be able to implement our custom roundtripper below + cacheKey := cacheKey(req) + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + + // If the object isn't a GET or HEAD request, also invalidate the cache of the GET URL + // as this action will modify the underlying resource (e.g. DELETE/POST/PATCH) + if !cacheable { + r.Transport.Cache.Delete(req.URL.String()) + } + // Call the underlying roundtrip + resp, err := r.Transport.RoundTrip(req) + // Don't cache anything but "200 OK" requests + if resp.StatusCode != http.StatusOK { + r.Transport.Cache.Delete(cacheKey) + } + return resp, err +} diff --git a/gitprovider/client_options.go b/gitprovider/client_options.go new file mode 100644 index 00000000..83bfa879 --- /dev/null +++ b/gitprovider/client_options.go @@ -0,0 +1,111 @@ +/* +Copyright 2020 The Flux CD contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gitprovider + +import ( + "fmt" + "net/http" +) + +// ChainableRoundTripperFunc is a function that returns a higher-level "out" RoundTripper, +// chained to call the "in" RoundTripper internally, with extra logic. This function must be able +// to handle "in" being nil, and use the http.DefaultTransport default RoundTripper in that case. +// "out" must never be nil. +type ChainableRoundTripperFunc func(in http.RoundTripper) (out http.RoundTripper) + +// CommonClientOptions is a struct containing options that are generic to all clients. +type CommonClientOptions struct { + // Domain specifies the target domain for the client. If unset, the default domain for the + // given provider will be used (often exposed as DefaultDomain in respective package). + // The behaviour when setting this flag might vary between providers, read the documentation on + // NewClient for more information. + Domain *string + + // EnableDestructiveAPICalls is a flag specifying whether destructive API calls (like + // deleting a repository) are allowed in the Client. Default: false + EnableDestructiveAPICalls *bool + + // PreChainRoundTripper is a function to get a custom RoundTripper that is given as the Transport + // to the *http.Client given to the provider-specific Client. It can be set for doing arbitrary + // modifications to HTTP requests. "in" might be nil, if so http.DefaultTransport is recommended. + // The "chain" looks like follows: + // Git provider API <-> "Post Chain" <-> Provider Specific (e.g. auth, caching) (in) <-> "Pre Chain" (out) <-> *http.Client + PreChainTransportHook ChainableRoundTripperFunc + + // PostChainTransportHook is a function to get a custom RoundTripper that is the "final" Transport + // in the chain before talking to the backing API. It can be set for doing arbitrary + // modifications to HTTP requests. "in" is always nil. It's recommended to internally use http.DefaultTransport. + // The "chain" looks like follows: + // Git provider API (in==nil) <-> "Post Chain" (out) <-> Provider Specific (e.g. auth, caching) <-> "Pre Chain" <-> *http.Client + PostChainTransportHook ChainableRoundTripperFunc +} + +// ApplyToCommonClientOptions applies the currently set fields in opts to target. If both opts and +// target has the same specific field set, ErrInvalidClientOptions is returned. +func (opts *CommonClientOptions) ApplyToCommonClientOptions(target *CommonClientOptions) error { + if opts.Domain != nil { + // Make sure the user didn't specify the Domain twice + if target.Domain != nil { + return fmt.Errorf("option Domain already configured: %w", ErrInvalidClientOptions) + } + // Don't allow an empty string + if len(*opts.Domain) == 0 { + return fmt.Errorf("option Domain cannot be an empty string: %w", ErrInvalidClientOptions) + } + target.Domain = opts.Domain + } + + if opts.EnableDestructiveAPICalls != nil { + // Make sure the user didn't specify the EnableDestructiveAPICalls twice + if target.EnableDestructiveAPICalls != nil { + return fmt.Errorf("option EnableDestructiveAPICalls already configured: %w", ErrInvalidClientOptions) + } + target.EnableDestructiveAPICalls = opts.EnableDestructiveAPICalls + } + + if opts.PreChainTransportHook != nil { + // Make sure the user didn't specify the PreChainTransportHook twice + if target.PreChainTransportHook != nil { + return fmt.Errorf("option PreChainTransportHook already configured: %w", ErrInvalidClientOptions) + } + target.PreChainTransportHook = opts.PreChainTransportHook + } + + if opts.PostChainTransportHook != nil { + // Make sure the user didn't specify the PostChainTransportHook twice + if target.PostChainTransportHook != nil { + return fmt.Errorf("option PostChainTransportHook already configured: %w", ErrInvalidClientOptions) + } + target.PostChainTransportHook = opts.PostChainTransportHook + } + return nil +} + +// BuildClientFromTransportChain builds a *http.Client from a chain of ChainableRoundTripperFuncs. +// The first function in the chain is called with "in" == nil. "out" of the first function in the chain, +// is passed as "in" to the second function, and so on. "out" of the last function in the chain is used +// as net/http Client.Transport. +func BuildClientFromTransportChain(chain []ChainableRoundTripperFunc) (*http.Client, error) { + var transport http.RoundTripper + for _, rtFunc := range chain { + transport = rtFunc(transport) + if transport == nil { + return nil, ErrInvalidTransportChainReturn + } + } + return &http.Client{Transport: transport}, nil +} diff --git a/gitprovider/client_options_test.go b/gitprovider/client_options_test.go new file mode 100644 index 00000000..be13dc5a --- /dev/null +++ b/gitprovider/client_options_test.go @@ -0,0 +1,147 @@ +/* +Copyright 2020 The Flux CD contributors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gitprovider + +import ( + "net/http" + "reflect" + "testing" + + "github.com/fluxcd/go-git-providers/validation" +) + +func roundTrippersEqual(a, b ChainableRoundTripperFunc) bool { + if a == nil && b == nil { + return true + } else if (a != nil && b == nil) || (a == nil && b != nil) { + return false + } + // Note that this comparison relies on "undefined behavior" in the Go language spec, see: + // https://stackoverflow.com/questions/9643205/how-do-i-compare-two-functions-for-pointer-equality-in-the-latest-go-weekly + return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer() +} + +type commonClientOption interface { + ApplyToCommonClientOptions(*CommonClientOptions) error +} + +func makeOptions(opts ...commonClientOption) (*CommonClientOptions, error) { + o := &CommonClientOptions{} + for _, opt := range opts { + if err := opt.ApplyToCommonClientOptions(o); err != nil { + return nil, err + } + } + return o, nil +} + +func withDomain(domain string) commonClientOption { + return &CommonClientOptions{Domain: &domain} +} + +func withDestructiveAPICalls(destructiveActions bool) commonClientOption { + return &CommonClientOptions{EnableDestructiveAPICalls: &destructiveActions} +} + +func withPreChainTransportHook(preRoundTripperFunc ChainableRoundTripperFunc) commonClientOption { + return &CommonClientOptions{PreChainTransportHook: preRoundTripperFunc} +} + +func withPostChainTransportHook(postRoundTripperFunc ChainableRoundTripperFunc) commonClientOption { + return &CommonClientOptions{PostChainTransportHook: postRoundTripperFunc} +} + +func dummyRoundTripper1(http.RoundTripper) http.RoundTripper { return nil } + +func Test_makeOptions(t *testing.T) { + tests := []struct { + name string + opts []commonClientOption + want *CommonClientOptions + expectedErrs []error + }{ + { + name: "no options", + want: &CommonClientOptions{}, + }, + { + name: "withDomain", + opts: []commonClientOption{withDomain("foo")}, + want: &CommonClientOptions{Domain: StringVar("foo")}, + }, + { + name: "withDomain, empty", + opts: []commonClientOption{withDomain("")}, + expectedErrs: []error{ErrInvalidClientOptions}, + }, + { + name: "withDomain, duplicate", + opts: []commonClientOption{withDomain("foo"), withDomain("bar")}, + expectedErrs: []error{ErrInvalidClientOptions}, + }, + { + name: "withDestructiveAPICalls", + opts: []commonClientOption{withDestructiveAPICalls(true)}, + want: &CommonClientOptions{EnableDestructiveAPICalls: BoolVar(true)}, + }, + { + name: "withDestructiveAPICalls, duplicate", + opts: []commonClientOption{withDestructiveAPICalls(true), withDestructiveAPICalls(false)}, + expectedErrs: []error{ErrInvalidClientOptions}, + }, + { + name: "withPreChainTransportHook", + opts: []commonClientOption{withPreChainTransportHook(dummyRoundTripper1)}, + want: &CommonClientOptions{PreChainTransportHook: dummyRoundTripper1}, + }, + { + name: "withPreChainTransportHook, duplicate", + opts: []commonClientOption{withPreChainTransportHook(dummyRoundTripper1), withPreChainTransportHook(dummyRoundTripper1)}, + expectedErrs: []error{ErrInvalidClientOptions}, + }, + { + name: "withPostChainTransportHook", + opts: []commonClientOption{withPostChainTransportHook(dummyRoundTripper1)}, + want: &CommonClientOptions{PostChainTransportHook: dummyRoundTripper1}, + }, + { + name: "withPostChainTransportHook, duplicate", + opts: []commonClientOption{withPostChainTransportHook(dummyRoundTripper1), withPostChainTransportHook(dummyRoundTripper1)}, + expectedErrs: []error{ErrInvalidClientOptions}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := makeOptions(tt.opts...) + validation.TestExpectErrors(t, "makeOptions", err, tt.expectedErrs...) + if tt.want == nil { + return + } + if !roundTrippersEqual(got.PostChainTransportHook, tt.want.PostChainTransportHook) || + !roundTrippersEqual(got.PreChainTransportHook, tt.want.PreChainTransportHook) { + t.Errorf("makeOptions() = %v, want %v", got, tt.want) + } + got.PostChainTransportHook = nil + got.PreChainTransportHook = nil + tt.want.PostChainTransportHook = nil + tt.want.PreChainTransportHook = nil + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("makeOptions() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/gitprovider/errors.go b/gitprovider/errors.go index 3fb55a47..b4332276 100644 --- a/gitprovider/errors.go +++ b/gitprovider/errors.go @@ -53,6 +53,15 @@ var ( ErrURLInvalid = errors.New("invalid organization, user or repository URL") // ErrURLMissingRepoName is returned if there is no repository name in the URL. ErrURLMissingRepoName = errors.New("missing repository name") + + // ErrInvalidClientOptions is the error returned when calling NewClient() with + // invalid options (e.g. specifying mutually exclusive options). + ErrInvalidClientOptions = errors.New("invalid options given to NewClient()") + // ErrDestructiveCallDisallowed happens when the client isn't set up with WithDestructiveAPICalls() + // but a destructive action is called. + ErrDestructiveCallDisallowed = errors.New("destructive call was blocked, disallowed by client") + // ErrInvalidTransportChainReturn is returned if a ChainableRoundTripperFunc returns nil, which is invalid. + ErrInvalidTransportChainReturn = errors.New("the return value of a ChainableRoundTripperFunc must not be nil") ) // HTTPError is an error that contains context about the HTTP request/response that failed.