diff --git a/cmd/clusterctl/client/cluster/template.go b/cmd/clusterctl/client/cluster/template.go index 213addd9fad3..e52fa4720322 100644 --- a/cmd/clusterctl/client/cluster/template.go +++ b/cmd/clusterctl/client/cluster/template.go @@ -51,6 +51,7 @@ type templateClient struct { configClient config.Client gitHubClientFactory func(configVariablesClient config.VariablesClient) (*github.Client, error) processor yaml.Processor + httpClient *http.Client } // ensure templateClient implements TemplateClient. @@ -70,6 +71,7 @@ func newTemplateClient(input TemplateClientInput) *templateClient { configClient: input.configClient, gitHubClientFactory: getGitHubClient, processor: input.processor, + httpClient: http.DefaultClient, } } @@ -143,8 +145,11 @@ func (t *templateClient) getURLContent(templateURL string) ([]byte, error) { return nil, errors.Wrapf(err, "failed to parse %q", templateURL) } - if rURL.Scheme == "https" && rURL.Host == "github.com" { - return t.getGitHubFileContent(rURL) + if rURL.Scheme == "https" { + if rURL.Host == "github.com" { + return t.getGitHubFileContent(rURL) + } + return t.getRawURLFileContent(templateURL) } if rURL.Scheme == "file" || rURL.Scheme == "" { @@ -210,6 +215,20 @@ func (t *templateClient) getGitHubFileContent(rURL *url.URL) ([]byte, error) { return content, nil } +func (t *templateClient) getRawURLFileContent(rURL string) ([]byte, error) { + res, err := t.httpClient.Get(rURL) + if err != nil { + return nil, err + } + + content, err := io.ReadAll(res.Body) + if err != nil { + return nil, err + } + + return content, nil +} + func getGitHubClient(configVariablesClient config.VariablesClient) (*github.Client, error) { var authenticatingHTTPClient *http.Client if token, err := configVariablesClient.Get(config.GitHubTokenVariable); err == nil { diff --git a/cmd/clusterctl/client/cluster/template_test.go b/cmd/clusterctl/client/cluster/template_test.go index af9ba015eead..bca138fc2f9c 100644 --- a/cmd/clusterctl/client/cluster/template_test.go +++ b/cmd/clusterctl/client/cluster/template_test.go @@ -20,6 +20,7 @@ import ( "encoding/base64" "fmt" "net/http" + "net/http/httptest" "net/url" "os" "path/filepath" @@ -225,6 +226,51 @@ func Test_templateClient_getGitHubFileContent(t *testing.T) { } } +func Test_templateClient_getRawUrlFileContent(t *testing.T) { + fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, template) + })) + + defer fakeServer.Close() + + type args struct { + rURL string + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "Return custom template", + args: args{ + rURL: fakeServer.URL, + }, + want: []byte(template), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := NewWithT(t) + + c := &templateClient{ + httpClient: &http.Client{}, + } + got, err := c.getRawURLFileContent(tt.args.rURL) + if tt.wantErr { + g.Expect(err).To(HaveOccurred()) + return + } + + g.Expect(err).NotTo(HaveOccurred()) + + g.Expect(got).To(Equal(tt.want)) + }) + } +} + func Test_templateClient_getLocalFileContent(t *testing.T) { g := NewWithT(t)