diff --git a/github/auth.go b/github/auth.go index cdab6e6e..76f093a1 100644 --- a/github/auth.go +++ b/github/auth.go @@ -23,6 +23,7 @@ import ( "net/http" "github.com/google/go-github/v32/github" + "github.com/gregjones/httpcache" "golang.org/x/oauth2" "github.com/fluxcd/go-git-providers/gitprovider" @@ -54,12 +55,22 @@ type clientOptions struct { // GitHub Enterprise. If unset, defaultDomain will be used. Domain *string - // ClientFactory is a way to acquire a *http.Client, possibly with auth credentials - ClientFactory ClientFactory + // authTransportFactory is a way to acquire a http.RoundTripper with auth credentials configured. + authTransportFactory authTransportFactory + + // RoundTripperFactory is a factory to get a http.RoundTripper that is sitting between the *github.Client's + // internal *http.Client, and the *httpcache.Transport RoundTripper. It can be set + // for doing arbitrary modifications to http requests. + RoundTripperFactory RoundTripperFactory // EnableDestructiveAPICalls is a flag to tell whether destructive API calls like // deleting a repository and such is allowed. Default: false EnableDestructiveAPICalls *bool + + // EnableConditionalRequests will be set if conditional requests should be used. + // See: https://developer.github.com/v3/#conditional-requests for more info. + // Default: false + EnableConditionalRequests *bool } // ClientOption is a function that is mutating a pointer to a clientOptions object @@ -68,7 +79,7 @@ type ClientOption func(*clientOptions) error // WithOAuth2Token initializes a Client which authenticates with GitHub through an OAuth2 token. // oauth2Token must not be an empty string. -// WithOAuth2Token is mutually exclusive with WithPersonalAccessToken and WithClientFactory. +// WithOAuth2Token is mutually exclusive with WithPersonalAccessToken. func WithOAuth2Token(oauth2Token string) ClientOption { return func(opts *clientOptions) error { // Don't allow an empty value @@ -76,18 +87,18 @@ func WithOAuth2Token(oauth2Token string) ClientOption { return fmt.Errorf("oauth2Token cannot be empty: %w", ErrInvalidClientOptions) } // Make sure the user didn't specify auth twice - if opts.ClientFactory != nil { + if opts.authTransportFactory != nil { return fmt.Errorf("authentication http.Client already configured: %w", ErrInvalidClientOptions) } - opts.ClientFactory = &oauth2Auth{oauth2Token} + opts.authTransportFactory = &oauth2Auth{oauth2Token} return nil } } // WithPersonalAccessToken initializes a Client which authenticates with GitHub through a personal access token. // patToken must not be an empty string. -// WithPersonalAccessToken is mutually exclusive with WithOAuth2Token and WithClientFactory. +// WithPersonalAccessToken is mutually exclusive with WithOAuth2Token. func WithPersonalAccessToken(patToken string) ClientOption { return func(opts *clientOptions) error { // Don't allow an empty value @@ -95,28 +106,27 @@ func WithPersonalAccessToken(patToken string) ClientOption { return fmt.Errorf("patToken cannot be empty: %w", ErrInvalidClientOptions) } // Make sure the user didn't specify auth twice - if opts.ClientFactory != nil { + if opts.authTransportFactory != nil { return fmt.Errorf("authentication http.Client already configured: %w", ErrInvalidClientOptions) } - opts.ClientFactory = &patAuth{patToken} + opts.authTransportFactory = &patAuth{patToken} return nil } } -// WithClientFactory initializes a Client with a given ClientFactory, used for acquiring the *http.Client later. -// clientFactory must not be nil. -// WithClientFactory is mutually exclusive with WithOAuth2Token and WithPersonalAccessToken. -func WithClientFactory(clientFactory ClientFactory) ClientOption { +// WithRoundTripper initializes a Client with a given authTransportFactory, used for acquiring the *http.Client later. +// authTransportFactory must not be nil. +func WithRoundTripper(roundTripper RoundTripperFactory) ClientOption { return func(opts *clientOptions) error { // Don't allow an empty value - if clientFactory == nil { - return fmt.Errorf("clientFactory cannot be nil: %w", ErrInvalidClientOptions) + if roundTripper == nil { + return fmt.Errorf("roundTripper cannot be nil: %w", ErrInvalidClientOptions) } - // Make sure the user didn't specify auth twice - if opts.ClientFactory != nil { - return fmt.Errorf("authentication http.Client already configured: %w", ErrInvalidClientOptions) + // Make sure the user didn't specify the RoundTripperFactory twice + if opts.RoundTripperFactory != nil { + return fmt.Errorf("roundTripper already configured: %w", ErrInvalidClientOptions) } - opts.ClientFactory = clientFactory + opts.RoundTripperFactory = roundTripper return nil } } @@ -142,7 +152,7 @@ func WithDomain(domain string) ClientOption { // actions, like e.g. deleting a repository. func WithDestructiveAPICalls(destructiveActions bool) ClientOption { return func(opts *clientOptions) error { - // Make sure the user didn't specify the domain twice + // Make sure the user didn't specify the flag twice if opts.EnableDestructiveAPICalls != nil { return fmt.Errorf("destructive actions flag already configured: %w", ErrInvalidClientOptions) } @@ -151,37 +161,54 @@ func WithDestructiveAPICalls(destructiveActions bool) ClientOption { } } -// ClientFactory is a way to acquire a *http.Client, possibly with auth credentials. -type ClientFactory interface { - // Client returns a *http.Client, possibly with auth credentials. - Client(ctx context.Context) *http.Client +// WithConditionalRequests instructs the client to use Conditional Requests to GitHub, asking GitHub +// whether a resource has changed (without burning your quota), and using an in-memory cached "database" +// if so. See: https://developer.github.com/v3/#conditional-requests for more information. +func WithConditionalRequests(conditionalRequests bool) ClientOption { + return func(opts *clientOptions) error { + // Make sure the user didn't specify the flag twice + if opts.EnableConditionalRequests != nil { + return fmt.Errorf("conditional requests flag already configured: %w", ErrInvalidClientOptions) + } + opts.EnableConditionalRequests = gitprovider.BoolVar(conditionalRequests) + return nil + } } -// oauth2Auth is an implementation of ClientFactory. +type RoundTripperFactory interface { + Transport(rt http.RoundTripper) http.RoundTripper +} + +// authTransportFactory is a way to acquire a http.RoundTripper with auth credentials configured. +type authTransportFactory interface { + // Transport returns a http.RoundTripper with auth credentials configured. + Transport(ctx context.Context) http.RoundTripper +} + +// oauth2Auth is an implementation of authTransportFactory. type oauth2Auth struct { token string } -// Client returns a *http.Client, possibly with auth credentials. -func (a *oauth2Auth) Client(ctx context.Context) *http.Client { +// Transport returns a http.RoundTripper with auth credentials configured. +func (a *oauth2Auth) Transport(ctx context.Context) http.RoundTripper { ts := oauth2.StaticTokenSource( &oauth2.Token{AccessToken: a.token}, ) - return oauth2.NewClient(ctx, ts) + return oauth2.NewClient(ctx, ts).Transport } -// patAuth is an implementation of ClientFactory. +// patAuth is an implementation of authTransportFactory. type patAuth struct { token string } -// Client returns a *http.Client, possibly with auth credentials. -func (a *patAuth) Client(ctx context.Context) *http.Client { - auth := github.BasicAuthTransport{ +// Transport returns a http.RoundTripper with auth credentials configured. +func (a *patAuth) Transport(ctx context.Context) http.RoundTripper { + return &github.BasicAuthTransport{ Username: patUsername, Password: a.token, } - return auth.Client() } // makeOptions assembles a clientOptions struct from ClientOption mutator functions. @@ -204,6 +231,12 @@ func makeOptions(opts ...ClientOption) (*clientOptions, error) { // https://developer.github.com/changes/2020-02-14-deprecating-password-auth/ // // GitHub Enterprise can be used if you specify the domain using the WithDomain option. +// +// You can customize low-level HTTP Transport functionality by using WithRoundTripper. +// You can also use conditional requests (and an in-memory cache) using WithConditionalRequests. +// +// The chain of transports looks like this: +// github.com API <-> Authentication <-> Cache <-> Custom Roundtripper <-> This Client. func NewClient(ctx context.Context, optFns ...ClientOption) (gitprovider.Client, error) { // Complete the options struct opts, err := makeOptions(optFns...) @@ -211,14 +244,14 @@ func NewClient(ctx context.Context, optFns ...ClientOption) (gitprovider.Client, return nil, err } - // Get the *http.Client to use for the transport, possibly with authentication. - // If opts.ClientFactory is nil, just leave the httpClient as nil, it will be - // automatically set by the github package. - var httpClient *http.Client - if opts.ClientFactory != nil { - httpClient = opts.ClientFactory.Client(ctx) + transport, err := buildTransportChain(ctx, opts) + if err != nil { + return nil, err } + // Create a *http.Client using the transport chain + httpClient := &http.Client{Transport: transport} + // Create the GitHub client either for the default github.com domain, or // a custom enterprise domain if opts.Domain is set to something other than // the default. @@ -247,3 +280,74 @@ func NewClient(ctx context.Context, optFns ...ClientOption) (gitprovider.Client, return newClient(gh, domain, destructiveActions), nil } + +// buildTransportChain builds a chain of http.RoundTrippers calling each other as per the +// description in NewClient. +func buildTransportChain(ctx context.Context, opts *clientOptions) (http.RoundTripper, error) { + // transport will be the http.RoundTripper for the *http.Client given to the Github client. + var transport http.RoundTripper + // Get an authenticated http.RoundTripper, if set + if opts.authTransportFactory != nil { + transport = opts.authTransportFactory.Transport(ctx) + } + + // Conditionally enable conditional requests + if opts.EnableConditionalRequests != nil && *opts.EnableConditionalRequests { + // Create a new httpcache high-level Transport + t := httpcache.NewMemoryCacheTransport() + // Make the httpcache high-level transport use the auth transport "underneath" + if transport != nil { + t.Transport = transport + } + // Override the transport with our embedded underlying auth transport + transport = &cacheRoundtripper{t} + } + + // If a custom roundtripper was set, pipe it through the transport too + if opts.RoundTripperFactory != nil { + customTransport := opts.RoundTripperFactory.Transport(transport) + if customTransport == nil { + // The lint failure here is a false positive, for some (unknown) reason + //nolint:goerr113 + return nil, fmt.Errorf("the RoundTripper returned from the RoundTripperFactory must not be nil: %w", ErrInvalidClientOptions) + } + transport = customTransport + } + + return transport, nil +} + +type cacheRoundtripper struct { + t *httpcache.Transport +} + +// This function follows the same logic as in github.com/gregjones/httpcache to be able +// to implement our custom roundtripper logic below. +func cacheKey(req *http.Request) string { + if req.Method == http.MethodGet { + return req.URL.String() + } + return req.Method + " " + req.URL.String() +} + +// RoundTrip calls the underlying RoundTrip (using the cache), but invalidates the cache on +// non GET/HEAD requests and non-"200 OK" responses. +func (r *cacheRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) { + // These two statements are the same as in github.com/gregjones/httpcache Transport.RoundTrip + // to be able to implement our custom roundtripper below + cacheKey := cacheKey(req) + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + + // If the object isn't a GET or HEAD request, also invalidate the cache of the GET URL + // as this action will modify the underlying resource (e.g. DELETE/POST/PATCH) + if !cacheable { + r.t.Cache.Delete(req.URL.String()) + } + // Call the underlying roundtrip + resp, err := r.t.RoundTrip(req) + // Don't cache anything but "200 OK" requests + if resp.StatusCode != http.StatusOK { + r.t.Cache.Delete(cacheKey) + } + return resp, err +} diff --git a/github/client_organizations.go b/github/client_organizations.go index 85472714..0ef5634a 100644 --- a/github/client_organizations.go +++ b/github/client_organizations.go @@ -45,6 +45,8 @@ func (c *OrganizationsClient) Get(ctx context.Context, ref gitprovider.Organizat // GET /orgs/{org} apiObj, _, err := c.c.Organizations.Get(ctx, ref.Organization) + // TODO: Provide some kind of debug logging if/when the httpcache is used + // One can see if the request hit the cache using: resp.Header[httpcache.XFromCache] if err != nil { return nil, handleHTTPError(err) } diff --git a/github/integration_test.go b/github/integration_test.go index 11d66fba..b5a24a00 100644 --- a/github/integration_test.go +++ b/github/integration_test.go @@ -22,11 +22,14 @@ import ( "fmt" "io/ioutil" "math/rand" + "net/http" "os" + "sync" "testing" "time" "github.com/google/go-github/v32/github" + "github.com/gregjones/httpcache" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -53,10 +56,78 @@ func TestProvider(t *testing.T) { RunSpecs(t, "GitHub Provider Suite") } +type customTransportFactory struct { + customTransport *customTransport +} + +func (f *customTransportFactory) Transport(transport http.RoundTripper) http.RoundTripper { + if f.customTransport != nil { + panic("didn't expect this function to be called twice") + } + f.customTransport = &customTransport{ + transport: transport, + countCacheHits: false, + cacheHits: 0, + mux: &sync.Mutex{}, + } + return f.customTransport +} + +type customTransport struct { + transport http.RoundTripper + countCacheHits bool + cacheHits int + mux *sync.Mutex +} + +func (t *customTransport) RoundTrip(req *http.Request) (*http.Response, error) { + t.mux.Lock() + defer t.mux.Unlock() + + resp, err := t.transport.RoundTrip(req) + // If we should count, count all cache hits whenever found + if t.countCacheHits { + if _, ok := resp.Header[httpcache.XFromCache]; ok { + t.cacheHits++ + } + } + return resp, err +} + +func (t *customTransport) resetCounter() { + t.mux.Lock() + defer t.mux.Unlock() + + t.cacheHits = 0 +} + +func (t *customTransport) setCounter(state bool) { + t.mux.Lock() + defer t.mux.Unlock() + + t.countCacheHits = state +} + +func (t *customTransport) getCacheHits() int { + t.mux.Lock() + defer t.mux.Unlock() + + return t.cacheHits +} + +func (t *customTransport) countCacheHitsForFunc(fn func()) int { + t.setCounter(true) + t.resetCounter() + fn() + t.setCounter(false) + return t.getCacheHits() +} + var _ = Describe("GitHub Provider", func() { var ( - ctx context.Context - c gitprovider.Client + ctx context.Context + c gitprovider.Client + transportFactory = &customTransportFactory{} testRepoName string testOrgName string = "fluxcd-testing" @@ -79,7 +150,12 @@ var _ = Describe("GitHub Provider", func() { ctx = context.Background() var err error - c, err = NewClient(ctx, WithPersonalAccessToken(githubToken), WithDestructiveAPICalls(true)) + c, err = NewClient(ctx, + WithPersonalAccessToken(githubToken), + WithDestructiveAPICalls(true), + WithConditionalRequests(true), + WithRoundTripper(transportFactory), + ) Expect(err).ToNot(HaveOccurred()) }) @@ -89,7 +165,7 @@ var _ = Describe("GitHub Provider", func() { Expect(err).ToNot(HaveOccurred()) // Make sure we find the expected one given as testOrgName - var listedOrg gitprovider.Organization + var listedOrg, getOrg gitprovider.Organization for _, org := range orgs { if org.Organization().Organization == testOrgName { listedOrg = org @@ -98,9 +174,14 @@ var _ = Describe("GitHub Provider", func() { } Expect(listedOrg).ToNot(BeNil()) - // Do a GET call for that organization - getOrg, err := c.Organizations().Get(ctx, listedOrg.Organization()) - Expect(err).ToNot(HaveOccurred()) + hits := transportFactory.customTransport.countCacheHitsForFunc(func() { + // Do a GET call for that organization + getOrg, err = c.Organizations().Get(ctx, listedOrg.Organization()) + Expect(err).ToNot(HaveOccurred()) + }) + // don't expect any cache hit, as we didn't request this before + Expect(hits).To(Equal(0)) + // Expect that the organization's info is the same regardless of method Expect(getOrg.Organization()).To(Equal(listedOrg.Organization())) // We don't expect the name from LIST calls, but we do expect @@ -117,6 +198,14 @@ var _ = Describe("GitHub Provider", func() { internal := getOrg.APIObject().(*github.Organization) Expect(getOrg.Get().Name).To(Equal(internal.Name)) Expect(getOrg.Get().Description).To(Equal(internal.Description)) + + // Expect that when we do the same request a second time, it will hit the cache + hits = transportFactory.customTransport.countCacheHitsForFunc(func() { + getOrg2, err := c.Organizations().Get(ctx, listedOrg.Organization()) + Expect(err).ToNot(HaveOccurred()) + Expect(getOrg2).ToNot(BeNil()) + }) + Expect(hits).To(Equal(1)) }) It("should fail when .Children is called", func() { diff --git a/go.mod b/go.mod index eb4ddc75..774dafd2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/google/go-github/v32 v32.1.0 + github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 github.com/ktrysmt/go-bitbucket v0.6.2 github.com/onsi/ginkgo v1.14.0 github.com/onsi/gomega v1.10.1 diff --git a/go.sum b/go.sum index 76a6a9b3..fd4b5bf3 100644 --- a/go.sum +++ b/go.sum @@ -23,6 +23,8 @@ github.com/google/go-github/v32 v32.1.0 h1:GWkQOdXqviCPx7Q7Fj+KyPoGm4SwHRh8rheoP github.com/google/go-github/v32 v32.1.0/go.mod h1:rIEpZD9CTDQwDK9GDrtMTycQNA4JU3qBsCizh3q2WCI= github.com/google/go-querystring v1.0.0 h1:Xkwi/a1rcvNg1PPYe5vI8GbeBY/jrVuDX5ASuANWTrk= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= +github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA= +github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI=