Skip to content

Commit

Permalink
Merge pull request #28 from fluxcd/code_cleanups
Browse files Browse the repository at this point in the history
Restructure Options and Transport functionality to become generic
  • Loading branch information
luxas authored Aug 18, 2020
2 parents ee897f0 + ffaa4a4 commit 92426c1
Show file tree
Hide file tree
Showing 10 changed files with 719 additions and 231 deletions.
357 changes: 145 additions & 212 deletions github/auth.go

Large diffs are not rendered by default.

214 changes: 214 additions & 0 deletions github/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/*
Copyright 2020 The Flux CD contributors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package github

import (
"net/http"
"reflect"
"testing"

"github.com/fluxcd/go-git-providers/gitprovider"
"github.com/fluxcd/go-git-providers/gitprovider/cache"
"github.com/fluxcd/go-git-providers/validation"
)

func dummyRoundTripper1(http.RoundTripper) http.RoundTripper { return nil }
func dummyRoundTripper2(http.RoundTripper) http.RoundTripper { return nil }
func dummyRoundTripper3(http.RoundTripper) http.RoundTripper { return nil }

func roundTrippersEqual(a, b gitprovider.ChainableRoundTripperFunc) bool {
if a == nil && b == nil {
return true
} else if (a != nil && b == nil) || (a == nil && b != nil) {
return false
}
// Note that this comparison relies on "undefined behavior" in the Go language spec, see:
// https://stackoverflow.com/questions/9643205/how-do-i-compare-two-functions-for-pointer-equality-in-the-latest-go-weekly
return reflect.ValueOf(a).Pointer() == reflect.ValueOf(b).Pointer()
}

func Test_clientOptions_getTransportChain(t *testing.T) {
tests := []struct {
name string
preChain gitprovider.ChainableRoundTripperFunc
postChain gitprovider.ChainableRoundTripperFunc
auth gitprovider.ChainableRoundTripperFunc
cache bool
wantChain []gitprovider.ChainableRoundTripperFunc
}{
{
name: "all roundtrippers",
preChain: dummyRoundTripper1,
postChain: dummyRoundTripper2,
auth: dummyRoundTripper3,
cache: true,
// expect: "post chain" <-> "auth" <-> "cache" <-> "pre chain"
wantChain: []gitprovider.ChainableRoundTripperFunc{
dummyRoundTripper2,
dummyRoundTripper3,
cache.NewHTTPCacheTransport,
dummyRoundTripper1,
},
},
{
name: "only pre + auth",
preChain: dummyRoundTripper1,
auth: dummyRoundTripper2,
// expect: "auth" <-> "pre chain"
wantChain: []gitprovider.ChainableRoundTripperFunc{
dummyRoundTripper2,
dummyRoundTripper1,
},
},
{
name: "only cache + auth",
cache: true,
auth: dummyRoundTripper1,
// expect: "auth" <-> "cache"
wantChain: []gitprovider.ChainableRoundTripperFunc{
dummyRoundTripper1,
cache.NewHTTPCacheTransport,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &clientOptions{
CommonClientOptions: gitprovider.CommonClientOptions{
PreChainTransportHook: tt.preChain,
PostChainTransportHook: tt.postChain,
},
AuthTransport: tt.auth,
EnableConditionalRequests: &tt.cache,
}
gotChain := opts.getTransportChain()
for i := range tt.wantChain {
if !roundTrippersEqual(tt.wantChain[i], gotChain[i]) {
t.Errorf("clientOptions.getTransportChain() = %v, want %v", gotChain, tt.wantChain)
}
break
}
})
}
}

func Test_makeOptions(t *testing.T) {
tests := []struct {
name string
opts []ClientOption
want *clientOptions
expectedErrs []error
}{
{
name: "no options",
want: &clientOptions{},
},
{
name: "WithDomain",
opts: []ClientOption{WithDomain("foo")},
want: buildCommonOption(gitprovider.CommonClientOptions{Domain: gitprovider.StringVar("foo")}),
},
{
name: "WithDomain, empty",
opts: []ClientOption{WithDomain("")},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
{
name: "WithDestructiveAPICalls",
opts: []ClientOption{WithDestructiveAPICalls(true)},
want: buildCommonOption(gitprovider.CommonClientOptions{EnableDestructiveAPICalls: gitprovider.BoolVar(true)}),
},
{
name: "WithPreChainTransportHook",
opts: []ClientOption{WithPreChainTransportHook(dummyRoundTripper1)},
want: buildCommonOption(gitprovider.CommonClientOptions{PreChainTransportHook: dummyRoundTripper1}),
},
{
name: "WithPreChainTransportHook, nil",
opts: []ClientOption{WithPreChainTransportHook(nil)},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
{
name: "WithPostChainTransportHook",
opts: []ClientOption{WithPostChainTransportHook(dummyRoundTripper2)},
want: buildCommonOption(gitprovider.CommonClientOptions{PostChainTransportHook: dummyRoundTripper2}),
},
{
name: "WithPostChainTransportHook, nil",
opts: []ClientOption{WithPostChainTransportHook(nil)},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
{
name: "WithOAuth2Token",
opts: []ClientOption{WithOAuth2Token("foo")},
want: &clientOptions{AuthTransport: oauth2Transport("foo")},
},
{
name: "WithOAuth2Token, empty",
opts: []ClientOption{WithOAuth2Token("")},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
{
name: "WithPersonalAccessToken",
opts: []ClientOption{WithPersonalAccessToken("foo")},
want: &clientOptions{AuthTransport: patTransport("foo")},
},
{
name: "WithPersonalAccessToken, empty",
opts: []ClientOption{WithPersonalAccessToken("")},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
{
name: "WithPersonalAccessToken and WithOAuth2Token, exclusive",
opts: []ClientOption{WithPersonalAccessToken("foo"), WithOAuth2Token("foo")},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
{
name: "WithConditionalRequests",
opts: []ClientOption{WithConditionalRequests(true)},
want: &clientOptions{EnableConditionalRequests: gitprovider.BoolVar(true)},
},
{
name: "WithConditionalRequests, exclusive",
opts: []ClientOption{WithConditionalRequests(true), WithConditionalRequests(false)},
expectedErrs: []error{gitprovider.ErrInvalidClientOptions},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := makeOptions(tt.opts...)
validation.TestExpectErrors(t, "makeOptions", err, tt.expectedErrs...)
if tt.want == nil {
return
}
if !roundTrippersEqual(got.AuthTransport, tt.want.AuthTransport) ||
!roundTrippersEqual(got.PostChainTransportHook, tt.want.PostChainTransportHook) ||
!roundTrippersEqual(got.PreChainTransportHook, tt.want.PreChainTransportHook) {
t.Errorf("makeOptions() = %v, want %v", got, tt.want)
}
got.AuthTransport = nil
got.PostChainTransportHook = nil
got.PreChainTransportHook = nil
tt.want.AuthTransport = nil
tt.want.PostChainTransportHook = nil
tt.want.PreChainTransportHook = nil
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("makeOptions() = %v, want %v", got, tt.want)
}
})
}
}
2 changes: 1 addition & 1 deletion github/example_organization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func checkErr(err error) {
func ExampleOrganizationsClient_Get() {
// Create a new client
ctx := context.Background()
c, err := github.NewClient(ctx)
c, err := github.NewClient()
checkErr(err)

// Get public information about the fluxcd organization
Expand Down
2 changes: 1 addition & 1 deletion github/example_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func ExampleOrgRepositoriesClient_Get() {
// Create a new client
ctx := context.Background()
c, err := github.NewClient(ctx)
c, err := github.NewClient()
checkErr(err)

// Parse the URL into an OrgRepositoryRef
Expand Down
2 changes: 1 addition & 1 deletion github/githubclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ func (c *githubClientImpl) UpdateRepo(ctx context.Context, owner, repo string, r
func (c *githubClientImpl) DeleteRepo(ctx context.Context, owner, repo string) error {
// Don't allow deleting repositories if the user didn't explicitly allow dangerous API calls.
if !c.destructiveActions {
return fmt.Errorf("cannot delete repository: %w", ErrDestructiveCallDisallowed)
return fmt.Errorf("cannot delete repository: %w", gitprovider.ErrDestructiveCallDisallowed)
}
// DELETE /repos/{owner}/{repo}
_, err := c.c.Repositories.Delete(ctx, owner, repo)
Expand Down
31 changes: 15 additions & 16 deletions github/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ const (
defaultBranch = "master"
)

var (
// customTransportImpl is a shared instance of a customTransport, allowing counting of cache hits.
customTransportImpl *customTransport
)

func init() {
// Call testing.Init() prior to tests.NewParams(), as otherwise -test.* will not be recognised. See also: https://golang.org/doc/go1.13#testing
testing.Init()
Expand All @@ -56,21 +61,17 @@ func TestProvider(t *testing.T) {
RunSpecs(t, "GitHub Provider Suite")
}

type customTransportFactory struct {
customTransport *customTransport
}

func (f *customTransportFactory) Transport(transport http.RoundTripper) http.RoundTripper {
if f.customTransport != nil {
func customTransportFactory(transport http.RoundTripper) http.RoundTripper {
if customTransportImpl != nil {
panic("didn't expect this function to be called twice")
}
f.customTransport = &customTransport{
customTransportImpl = &customTransport{
transport: transport,
countCacheHits: false,
cacheHits: 0,
mux: &sync.Mutex{},
}
return f.customTransport
return customTransportImpl
}

type customTransport struct {
Expand Down Expand Up @@ -125,9 +126,8 @@ func (t *customTransport) countCacheHitsForFunc(fn func()) int {

var _ = Describe("GitHub Provider", func() {
var (
ctx context.Context
c gitprovider.Client
transportFactory = &customTransportFactory{}
ctx context.Context = context.Background()
c gitprovider.Client

testRepoName string
testOrgName string = "fluxcd-testing"
Expand All @@ -148,13 +148,12 @@ var _ = Describe("GitHub Provider", func() {
testOrgName = orgName
}

ctx = context.Background()
var err error
c, err = NewClient(ctx,
c, err = NewClient(
WithPersonalAccessToken(githubToken),
WithDestructiveAPICalls(true),
WithConditionalRequests(true),
WithRoundTripper(transportFactory),
WithPreChainTransportHook(customTransportFactory),
)
Expect(err).ToNot(HaveOccurred())
})
Expand All @@ -174,7 +173,7 @@ var _ = Describe("GitHub Provider", func() {
}
Expect(listedOrg).ToNot(BeNil())

hits := transportFactory.customTransport.countCacheHitsForFunc(func() {
hits := customTransportImpl.countCacheHitsForFunc(func() {
// Do a GET call for that organization
getOrg, err = c.Organizations().Get(ctx, listedOrg.Organization())
Expect(err).ToNot(HaveOccurred())
Expand All @@ -200,7 +199,7 @@ var _ = Describe("GitHub Provider", func() {
Expect(getOrg.Get().Description).To(Equal(internal.Description))

// Expect that when we do the same request a second time, it will hit the cache
hits = transportFactory.customTransport.countCacheHitsForFunc(func() {
hits = customTransportImpl.countCacheHitsForFunc(func() {
getOrg2, err := c.Organizations().Get(ctx, listedOrg.Organization())
Expect(err).ToNot(HaveOccurred())
Expect(getOrg2).ToNot(BeNil())
Expand Down
Loading

0 comments on commit 92426c1

Please sign in to comment.