Skip to content

Commit

Permalink
feat: extra_templates allows users to control templates (#1046)
Browse files Browse the repository at this point in the history
* feat: allow variables to specify values in main.tf

* filter out nonexistent variables

* loop

* feat: generate provider configs for kubectl & helm

* remove custom provider

* tidy

* conditionals

* index

* feat: extra_templates allows users to control templates

* commit from ci -- updated golden files

* no template header

* Update config/v2/config.go

Co-authored-by: Hayden Spitzley <[email protected]>

* maps

---------

Co-authored-by: jakeyheath <[email protected]>
Co-authored-by: Hayden Spitzley <[email protected]>
  • Loading branch information
3 people authored Apr 29, 2024
1 parent a43b364 commit f13687a
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 79 deletions.
69 changes: 54 additions & 15 deletions apply/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,41 +298,76 @@ func applyRepo(fs afero.Fs, p *plan.Plan, repoTemplates, commonTemplates fs.FS)
return applyTree(fs, repoTemplates, commonTemplates, "", p)
}

func applyExtraTemplates(fs afero.Fs, p plan.ComponentCommon, commonBox fs.FS, path string) error {
for filename, templateCfg := range p.ExtraTemplates {
target := getTargetPath(path, filename)
_, err := fs.Stat(target)
if err == nil && !templateCfg.Overwrite {
// file exists and we don't want to overwrite
continue
}

err = applyTemplate(strings.NewReader(templateCfg.Content), commonBox, fs, target, p)
if err != nil {
return errs.WrapUser(err, "applying extra templates")
}

if filepath.Ext(filename) == ".tf" {
err = fmtHcl(fs, target, true)
if err != nil {
return errs.WrapUser(err, "formating HCL of extra templates")
}
}
}

return nil
}

func applyGlobal(fs afero.Fs, p plan.Component, repoBox, commonBox fs.FS) error {
logrus.Debug("applying global")
path := fmt.Sprintf("%s/global", rootPath)
e := fs.MkdirAll(path, 0755)
if e != nil {
return errs.WrapUserf(e, "unable to make directory %s", path)
}
return applyTree(fs, repoBox, commonBox, path, p)
err := applyTree(fs, repoBox, commonBox, path, p)
if err != nil {
return err
}

return applyExtraTemplates(fs, p.ComponentCommon, commonBox, path)
}

func applyAccounts(fs afero.Fs, p *plan.Plan, accountBox, commonBox fs.FS) (e error) {
func applyAccounts(fs afero.Fs, p *plan.Plan, accountBox, commonBox fs.FS) error {
for account, accountPlan := range p.Accounts {
path := fmt.Sprintf("%s/accounts/%s", rootPath, account)
e = fs.MkdirAll(path, 0755)
if e != nil {
return errs.WrapUser(e, "unable to make directories for accounts")
err := fs.MkdirAll(path, 0755)
if err != nil {
return errs.WrapUser(err, "unable to make directories for accounts")
}
e = applyTree(fs, accountBox, commonBox, path, accountPlan)
if e != nil {
return errs.WrapUser(e, "unable to apply templates to account")
err = applyTree(fs, accountBox, commonBox, path, accountPlan)
if err != nil {
return errs.WrapUser(err, "unable to apply templates to account")
}

err = applyExtraTemplates(fs, accountPlan.ComponentCommon, commonBox, path)
if err != nil {
return errs.WrapUser(err, "apply extra templates")
}
}
return nil
}

func applyModules(fs afero.Fs, p map[string]plan.Module, moduleBox, commonBox fs.FS) (e error) {
func applyModules(fs afero.Fs, p map[string]plan.Module, moduleBox, commonBox fs.FS) error {
for module, modulePlan := range p {
path := fmt.Sprintf("%s/modules/%s", rootPath, module)
e = fs.MkdirAll(path, 0755)
if e != nil {
return errs.WrapUserf(e, "unable to make path %s", path)
err := fs.MkdirAll(path, 0755)
if err != nil {
return errs.WrapUserf(err, "unable to make path %s", path)
}
e = applyTree(fs, moduleBox, commonBox, path, modulePlan)
if e != nil {
return errs.WrapUser(e, "unable to apply tree")
err = applyTree(fs, moduleBox, commonBox, path, modulePlan)
if err != nil {
return errs.WrapUser(err, "unable to apply tree")
}
}
return nil
Expand Down Expand Up @@ -375,6 +410,10 @@ func applyEnvs(
if err != nil {
return errs.WrapUser(err, "unable to apply templates for component")
}
err = applyExtraTemplates(fs, componentPlan.ComponentCommon, commonBox, path)
if err != nil {
return errs.WrapUser(err, "apply extra templates")
}

if componentPlan.ModuleSource != nil {
downloader, err := util.MakeDownloader(*componentPlan.ModuleSource)
Expand Down
24 changes: 15 additions & 9 deletions config/v2/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,22 @@ type TFE struct {
AdditionalGithubRequiredChecks *[]string `yaml:"additional_gh_required_checks,omitempty"`
}

type ExtraTemplate struct {
Overwrite *bool
Content *string
}

type Common struct {
Backend *Backend `yaml:"backend,omitempty"`
ExtraVars map[string]string `yaml:"extra_vars,omitempty"`
Owner *string `yaml:"owner,omitempty"`
Project *string `yaml:"project,omitempty"`
Providers *Providers `yaml:"providers,omitempty"`
DependsOn *DependsOn `yaml:"depends_on,omitempty"`
TerraformVersion *string `yaml:"terraform_version,omitempty"`
Tools *Tools `yaml:"tools,omitempty"`
NeedsAWSAccountsVariable *bool `yaml:"needs_aws_accounts_variable,omitempty"`
Backend *Backend `yaml:"backend,omitempty"`
ExtraVars map[string]string `yaml:"extra_vars,omitempty"`
Owner *string `yaml:"owner,omitempty"`
Project *string `yaml:"project,omitempty"`
Providers *Providers `yaml:"providers,omitempty"`
DependsOn *DependsOn `yaml:"depends_on,omitempty"`
TerraformVersion *string `yaml:"terraform_version,omitempty"`
Tools *Tools `yaml:"tools,omitempty"`
NeedsAWSAccountsVariable *bool `yaml:"needs_aws_accounts_variable,omitempty"`
ExtraTemplates *map[string]ExtraTemplate `yaml:"extra_templates,omitempty"`
}

type Defaults struct {
Expand Down
108 changes: 53 additions & 55 deletions config/v2/resolvers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,12 @@ import (
"github.com/chanzuckerberg/fogg/util"
)

// lastNonNilBool, despite its name can return nil if all results are nil
func lastNonNilBool(getter func(Common) *bool, commons ...Common) *bool {
var s *bool
for _, c := range commons {
t := getter(c)
if t != nil {
s = t
}
}
return s
type Nillable interface {
*bool | *float64 | *int64 | *string | []string | *[]ExtraTemplate
}

// lastNonNil, despite its name can return nil if all results are nil
func lastNonNil(getter func(Common) *string, commons ...Common) *string {
var s *string
for _, c := range commons {
t := getter(c)
if t != nil {
s = t
}
}
return s
}

// lastNonNilInt64, despite its name can return nil if all results are nil
func lastNonNilInt64(getter func(Common) *int64, commons ...Common) *int64 {
var s *int64
for _, c := range commons {
t := getter(c)
if t != nil {
s = t
}
}
return s
}

// lastNonNilStringSlice, despite its name can return nil if all results are nil
func lastNonNilStringSlice(getter func(Common) []string, commons ...Common) []string {
var s []string
func lastNonNil[T Nillable](getter func(Common) T, commons ...Common) T {
var s T
for _, c := range commons {
t := getter(c)
if t != nil {
Expand All @@ -61,15 +28,15 @@ func ResolveRequiredString(getter func(Common) *string, commons ...Common) strin

// ResolveRequiredInt64 will resolve the value and panic if it is nil. Only to be used after validations are run.
func ResolveRequiredInt64(getter func(Common) *int64, commons ...Common) int64 {
return *lastNonNilInt64(getter, commons...)
return *lastNonNil(getter, commons...)
}

func ResolveOptionalString(getter func(Common) *string, commons ...Common) *string {
return lastNonNil(getter, commons...)
}

func ResolveOptionalStringSlice(getter func(Common) []string, commons ...Common) []string {
return lastNonNilStringSlice(getter, commons...)
return lastNonNil(getter, commons...)
}

func ResolveStringArray(def []string, override []string) []string {
Expand Down Expand Up @@ -302,7 +269,7 @@ func ResolveGithubProvider(commons ...Common) *GithubProvider {
BaseURL: lastNonNil(GithubProviderBaseURLGetter, commons...),
CommonProvider: CommonProvider{
Enabled: enabled,
CustomProvider: lastNonNilBool(GithubProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(GithubProviderCustomProviderGetter, commons...),
Version: lastNonNil(GithubProviderVersionGetter, commons...),
},
}
Expand All @@ -316,13 +283,44 @@ func AWSAccountsNeededGetter(comm Common) *bool {
}

func ResolveAWSAccountsNeeded(commons ...Common) bool {
accountsNeeded := lastNonNilBool(AWSAccountsNeededGetter, commons...)
accountsNeeded := lastNonNil(AWSAccountsNeededGetter, commons...)
if accountsNeeded == nil {
return true
}
return *accountsNeeded
}

func ResolveExtraTemplates(commons ...Common) map[string]ExtraTemplate {
templates := map[string]ExtraTemplate{}
for _, common := range commons {
if common.ExtraTemplates == nil {
continue
}

for filename, cfg := range *common.ExtraTemplates {
if _, exists := templates[filename]; !exists {
templates[filename] = cfg
continue
}

prevTempl := ExtraTemplate{
Overwrite: templates[filename].Overwrite,
Content: templates[filename].Content,
}

if cfg.Overwrite != nil {
prevTempl.Overwrite = cfg.Overwrite
}
if cfg.Content != nil {
prevTempl.Content = cfg.Content
}
templates[filename] = prevTempl
}
}

return templates
}

func ResolveSnowflakeProvider(commons ...Common) *SnowflakeProvider {
account := lastNonNil(SnowflakeProviderAccountGetter, commons...)
role := lastNonNil(SnowflakeProviderRoleGetter, commons...)
Expand All @@ -335,7 +333,7 @@ func ResolveSnowflakeProvider(commons ...Common) *SnowflakeProvider {
Role: role,
Region: region,
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(SnowflakeProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(SnowflakeProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand All @@ -359,7 +357,7 @@ func ResolveOktaProvider(commons ...Common) *OktaProvider {
BaseURL: baseURL,
RegistryNamespace: registryNamespace,
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(OktaProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(OktaProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: lastNonNil(OktaProviderVersionGetter, commons...),
},
Expand All @@ -381,7 +379,7 @@ func ResolveBlessProvider(commons ...Common) *BlessProvider {
AWSRegion: region,
RoleArn: roleArn,
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(BlessProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(BlessProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: lastNonNil(BlessProviderVersionGetter, commons...),
},
Expand All @@ -406,7 +404,7 @@ func ResolveHerokuProvider(commons ...Common) *HerokuProvider {
if version != nil {
return &HerokuProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(HerokuProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(HerokuProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand All @@ -432,7 +430,7 @@ func ResolveDatadogProvider(commons ...Common) *DatadogProvider {
if version != nil {
return &DatadogProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(DatadogProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(DatadogProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand All @@ -458,7 +456,7 @@ func ResolvePagerdutyProvider(commons ...Common) *PagerdutyProvider {
if version != nil {
return &PagerdutyProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(PagerDutyProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(PagerDutyProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand All @@ -484,7 +482,7 @@ func ResolveOpsGenieProvider(commons ...Common) *OpsGenieProvider {
if version != nil {
return &OpsGenieProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(OpsGenieProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(OpsGenieProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand All @@ -507,7 +505,7 @@ func ResolveDatabricksProvider(commons ...Common) *DatabricksProvider {
if version != nil {
return &DatabricksProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(DatabricksProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(DatabricksProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand All @@ -534,7 +532,7 @@ func ResolveSentryProvider(commons ...Common) *SentryProvider {
if version != nil {
return &SentryProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(SentryProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(SentryProviderCustomProviderGetter, commons...),
Enabled: defaultEnabled(true),
Version: version,
},
Expand Down Expand Up @@ -580,7 +578,7 @@ func ResolveTfeProvider(commons ...Common) *TfeProvider {
if version != nil {
return &TfeProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(TFEProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(TFEProviderCustomProviderGetter, commons...),
Enabled: enabled,
Version: version,
},
Expand Down Expand Up @@ -623,7 +621,7 @@ func ResolveKubernetesProvider(commons ...Common) *KubernetesProvider {
return &KubernetesProvider{
ClusterComponentName: clusterComponentName,
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(KubernetesProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(KubernetesProviderCustomProviderGetter, commons...),
Enabled: enabled,
Version: version,
},
Expand Down Expand Up @@ -662,7 +660,7 @@ func ResolveHelmProvider(commons ...Common) *HelmProvider {
if version != nil {
return &HelmProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(HelmProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(HelmProviderCustomProviderGetter, commons...),
Enabled: enabled,
Version: version,
},
Expand Down Expand Up @@ -702,7 +700,7 @@ func ResolveKubectlProvider(commons ...Common) *KubectlProvider {
if version != nil {
return &KubectlProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(KubectlProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(KubectlProviderCustomProviderGetter, commons...),
Enabled: enabled,
Version: version,
},
Expand Down Expand Up @@ -742,7 +740,7 @@ func ResolveGrafanaProvider(commons ...Common) *GrafanaProvider {
if version != nil {
return &GrafanaProvider{
CommonProvider: CommonProvider{
CustomProvider: lastNonNilBool(GrafanaProviderCustomProviderGetter, commons...),
CustomProvider: lastNonNil(GrafanaProviderCustomProviderGetter, commons...),
Enabled: enabled,
Version: version,
},
Expand Down
Loading

0 comments on commit f13687a

Please sign in to comment.