diff --git a/plugin/install.go b/plugin/install.go index 1a1944aee..c8ca7a00c 100644 --- a/plugin/install.go +++ b/plugin/install.go @@ -16,6 +16,8 @@ import ( "golang.org/x/oauth2" ) +const defaultSourceHost = "github.com" + // InstallConfig is a config for plugin installation. // This is a wrapper for PluginConfig and manages naming conventions // and directory names for installation. @@ -149,9 +151,12 @@ func (c *InstallConfig) fetchReleaseAssets() (map[string]*github.ReleaseAsset, e assets := map[string]*github.ReleaseAsset{} ctx := context.Background() - client := newGitHubClient(ctx) - log.Printf("[DEBUG] Request to https://api.github.com/repos/%s/%s/releases/tags/%s", c.SourceOwner, c.SourceRepo, c.TagName()) + client, err := newGitHubClient(ctx, c) + if err != nil { + return assets, err + } + release, _, err := client.Repositories.GetReleaseByTag(ctx, c.SourceOwner, c.SourceRepo, c.TagName()) if err != nil { return assets, err @@ -172,9 +177,12 @@ func (c *InstallConfig) downloadToTempFile(asset *github.ReleaseAsset) (*os.File } ctx := context.Background() - client := newGitHubClient(ctx) - log.Printf("[DEBUG] Request to https://api.github.com/repos/%s/%s/releases/assets/%d", c.SourceOwner, c.SourceRepo, asset.GetID()) + client, err := newGitHubClient(ctx, c) + if err != nil { + return nil, err + } + downloader, _, err := client.Repositories.DownloadReleaseAsset(ctx, c.SourceOwner, c.SourceRepo, asset.GetID(), http.DefaultClient) if err != nil { return nil, err @@ -237,18 +245,27 @@ func extractFileFromZipFile(zipFile *os.File, savePath string) error { return nil } -func newGitHubClient(ctx context.Context) *github.Client { - token := os.Getenv("GITHUB_TOKEN") - if token == "" { - return github.NewClient(nil) +func newGitHubClient(ctx context.Context, config *InstallConfig) (*github.Client, error) { + hc := &http.Client{ + Transport: http.DefaultTransport, } - log.Printf("[DEBUG] GITHUB_TOKEN set, plugin requests to the GitHub API will be authenticated") + if t := os.Getenv("GITHUB_TOKEN"); t != "" { + log.Printf("[DEBUG] GITHUB_TOKEN set, plugin requests to the GitHub API will be authenticated") - ts := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: token}, - ) - return github.NewClient(oauth2.NewClient(ctx, ts)) + hc = oauth2.NewClient(ctx, oauth2.StaticTokenSource(&oauth2.Token{ + AccessToken: t, + })) + } + + hc.Transport = &requestLoggingTransport{hc.Transport} + + if config.SourceHost == defaultSourceHost { + return github.NewClient(hc), nil + } + + baseURL := fmt.Sprintf("https://%s/", config.SourceHost) + return github.NewEnterpriseClient(baseURL, baseURL, hc) } func fileExt() string { @@ -257,3 +274,13 @@ func fileExt() string { } return "" } + +// requestLoggingTransport wraps an existing RoundTripper and prints DEBUG logs before each request +type requestLoggingTransport struct { + http.RoundTripper +} + +func (s *requestLoggingTransport) RoundTrip(r *http.Request) (*http.Response, error) { + log.Printf("[DEBUG] Request to %s", r.URL) + return s.RoundTripper.RoundTrip(r) +} diff --git a/plugin/install_test.go b/plugin/install_test.go index 1a1c38c37..5657e9af8 100644 --- a/plugin/install_test.go +++ b/plugin/install_test.go @@ -1,6 +1,7 @@ package plugin import ( + "context" "os" "testing" @@ -17,6 +18,7 @@ func Test_Install(t *testing.T) { Enabled: true, Version: "0.4.0", Source: "github.com/terraform-linters/tflint-ruleset-aws", + SourceHost: "github.com", SourceOwner: "terraform-linters", SourceRepo: "tflint-ruleset-aws", }) @@ -40,3 +42,45 @@ func Test_Install(t *testing.T) { t.Fatalf("Installed binary name is invalid: expected=%s, got=%s", expected, info.Name()) } } + +func TestNewGitHubClient(t *testing.T) { + cases := []struct { + name string + config *InstallConfig + expected string + }{ + { + name: "default", + config: &InstallConfig{ + PluginConfig: &tflint.PluginConfig{ + SourceHost: "github.com", + }, + }, + expected: "https://api.github.com/", + }, + { + name: "enterprise", + config: &InstallConfig{ + PluginConfig: &tflint.PluginConfig{ + SourceHost: "github.example.com", + }, + }, + expected: "https://github.example.com/api/v3/", + }, + } + + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + client, err := newGitHubClient(context.Background(), tc.config) + if err != nil { + t.Fatalf("Failed to create client: %s", err) + } + + if client.BaseURL.String() != tc.expected { + t.Fatalf("Unexpected API URL: want %s, got %s", tc.expected, client.BaseURL.String()) + } + }) + } +} diff --git a/tflint/config.go b/tflint/config.go index 34edf3821..6496a9139 100644 --- a/tflint/config.go +++ b/tflint/config.go @@ -101,6 +101,7 @@ type PluginConfig struct { Body hcl.Body `hcl:",remain"` // Parsed source attributes + SourceHost string SourceOwner string SourceRepo string } @@ -477,11 +478,10 @@ func (c *PluginConfig) validate() error { parts := strings.Split(c.Source, "/") // Expected `github.com/owner/repo` format if len(parts) != 3 { - return fmt.Errorf("plugin `%s`: `source` is invalid. Must be in the format `github.com/owner/repo`", c.Name) - } - if parts[0] != "github.com" { - return fmt.Errorf("plugin `%s`: `source` is invalid. Hostname must be `github.com`", c.Name) + return fmt.Errorf("plugin `%s`: `source` is invalid. Must be a GitHub reference in the format `${host}/${owner}/${repo}`", c.Name) } + + c.SourceHost = parts[0] c.SourceOwner = parts[1] c.SourceRepo = parts[2] } diff --git a/tflint/config_test.go b/tflint/config_test.go index 33372d561..5d1727bb7 100644 --- a/tflint/config_test.go +++ b/tflint/config_test.go @@ -107,6 +107,7 @@ plugin "baz" { Version: "0.1.0", Source: "github.com/foo/bar", SigningKey: "SIGNING_KEY", + SourceHost: "github.com", SourceOwner: "foo", SourceRepo: "bar", }, @@ -276,24 +277,46 @@ plugin "foo" { }`, }, errCheck: func(err error) bool { - return err == nil || err.Error() != "plugin `foo`: `source` is invalid. Must be in the format `github.com/owner/repo`" + return err == nil || err.Error() != "plugin `foo`: `source` is invalid. Must be a GitHub reference in the format `${host}/${owner}/${repo}`" }, }, { - name: "plugin with invalid source host", - file: "plugin_with_invalid_source_host.hcl", + name: "plugin with GHES source host", + file: "plugin_with_ghes_source_host.hcl", files: map[string]string{ - "plugin_with_invalid_source_host.hcl": ` + "plugin_with_ghes_source_host.hcl": ` plugin "foo" { enabled = true version = "0.1.0" - source = "gitlab.com/foo/bar" + source = "github.example.com/foo/bar" }`, }, - errCheck: func(err error) bool { - return err == nil || err.Error() != "plugin `foo`: `source` is invalid. Hostname must be `github.com`" + want: &Config{ + Module: false, + Force: false, + IgnoreModules: map[string]bool{}, + Varfiles: []string{}, + Variables: []string{}, + DisabledByDefault: false, + Rules: map[string]*RuleConfig{}, + Plugins: map[string]*PluginConfig{ + "foo": { + Name: "foo", + Enabled: true, + Version: "0.1.0", + Source: "github.example.com/foo/bar", + SourceHost: "github.example.com", + SourceOwner: "foo", + SourceRepo: "bar", + }, + "terraform": { + Name: "terraform", + Enabled: true, + }, + }, }, + errCheck: neverHappend, }, }