diff --git a/internal/runners/manifest/manifest.go b/internal/runners/manifest/manifest.go index 19765d7257..627308b21e 100644 --- a/internal/runners/manifest/manifest.go +++ b/internal/runners/manifest/manifest.go @@ -117,7 +117,7 @@ func (m *Manifest) fetchRequirements() ([]buildscript.Requirement, error) { } } - reqs, err := script.Requirements() + reqs, err := script.Requirements("") if err != nil { return nil, errs.Wrap(err, "Could not get requirements") } diff --git a/internal/runners/platforms/remove.go b/internal/runners/platforms/remove.go index 85a56d33db..9fe28df540 100644 --- a/internal/runners/platforms/remove.go +++ b/internal/runners/platforms/remove.go @@ -73,7 +73,7 @@ func (a *Remove) Run(params RemoveRunParams) (rerr error) { // Prepare updated buildscript script := oldCommit.BuildScript() - platforms, err := script.Platforms() + platforms, err := script.Platforms("") if err != nil { return errs.Wrap(err, "Failed to get platforms") } diff --git a/internal/runners/uninstall/uninstall.go b/internal/runners/uninstall/uninstall.go index be85debf5e..1b779d2a6b 100644 --- a/internal/runners/uninstall/uninstall.go +++ b/internal/runners/uninstall/uninstall.go @@ -165,7 +165,7 @@ func (u *Uninstall) renderUserFacing(reqs requirements) { func (u *Uninstall) resolveRequirements(script *buildscript.BuildScript, pkgs captain.PackagesValue) (requirements, error) { result := requirements{} - reqs, err := script.DependencyRequirements() + reqs, err := script.DependencyRequirements("") if err != nil { return nil, errs.Wrap(err, "Unable to get requirements") } diff --git a/pkg/buildscript/mutations.go b/pkg/buildscript/mutations.go index 6c23b7a0d3..9bb269674d 100644 --- a/pkg/buildscript/mutations.go +++ b/pkg/buildscript/mutations.go @@ -61,7 +61,7 @@ func (b *BuildScript) AddRequirement(requirement types.Requirement) error { obj = append(obj, &Assignment{requirementVersionRequirementsKey, &Value{List: &values}}) } - requirementsNode, err := b.getRequirementsNode() + requirementsNode, err := b.getRequirementsNode("") if err != nil { return errs.Wrap(err, "Could not get requirements node") } @@ -81,7 +81,7 @@ type RequirementNotFoundError struct { // RemoveRequirement will remove any matching requirement. Note that it only operates on the Name and Namespace fields. // It will not verify if revision or version match. func (b *BuildScript) RemoveRequirement(requirement types.Requirement) error { - requirementsNode, err := b.getRequirementsNode() + requirementsNode, err := b.getRequirementsNode("") if err != nil { return errs.Wrap(err, "Could not get requirements node") } @@ -126,7 +126,7 @@ func (b *BuildScript) RemoveRequirement(requirement types.Requirement) error { } func (b *BuildScript) AddPlatform(platformID strfmt.UUID) error { - platformsNode, err := b.getPlatformsNode() + platformsNode, err := b.getPlatformsNode("") if err != nil { return errs.Wrap(err, "Could not get platforms node") } @@ -144,7 +144,7 @@ type PlatformNotFoundError struct { } func (b *BuildScript) RemovePlatform(platformID strfmt.UUID) error { - platformsNode, err := b.getPlatformsNode() + platformsNode, err := b.getPlatformsNode("") if err != nil { return errs.Wrap(err, "Could not get platforms node") } diff --git a/pkg/buildscript/mutations_test.go b/pkg/buildscript/mutations_test.go index f3cff0996f..69a671864e 100644 --- a/pkg/buildscript/mutations_test.go +++ b/pkg/buildscript/mutations_test.go @@ -277,7 +277,7 @@ func TestUpdateRequirements(t *testing.T) { return } - got, err := script.Requirements() + got, err := script.Requirements("") assert.NoError(t, err) gotReqs := []DependencyRequirement{} @@ -361,7 +361,7 @@ func TestUpdatePlatform(t *testing.T) { return } - got, err := script.Platforms() + got, err := script.Platforms("") assert.NoError(t, err) sort.Slice(got, func(i, j int) bool { return got[i] < got[j] }) diff --git a/pkg/buildscript/queries.go b/pkg/buildscript/queries.go index 2f4a1a6c54..9c6e60b678 100644 --- a/pkg/buildscript/queries.go +++ b/pkg/buildscript/queries.go @@ -12,6 +12,7 @@ import ( const ( solveFuncName = "solve" solveLegacyFuncName = "solve_legacy" + srcKey = "src" requirementsKey = "requirements" platformsKey = "platforms" ) @@ -43,8 +44,10 @@ type UnknownRequirement struct { func (r UnknownRequirement) IsRequirement() {} -func (b *BuildScript) Requirements() ([]Requirement, error) { - requirementsNode, err := b.getRequirementsNode() +// Returns the requirements for the given target. +// If the given target is the empty string, uses the default target (i.e. the name assigned to 'main'). +func (b *BuildScript) Requirements(target string) ([]Requirement, error) { + requirementsNode, err := b.getRequirementsNode(target) if err != nil { return nil, errs.Wrap(err, "Could not get requirements node") } @@ -95,8 +98,8 @@ func (b *BuildScript) Requirements() ([]Requirement, error) { // DependencyRequirements is identical to Requirements except that it only considers dependency type requirements, // which are the most common. // ONLY use this when you know you only need to care about dependencies. -func (b *BuildScript) DependencyRequirements() ([]types.Requirement, error) { - reqs, err := b.Requirements() +func (b *BuildScript) DependencyRequirements(target string) ([]types.Requirement, error) { + reqs, err := b.Requirements(target) if err != nil { return nil, errs.Wrap(err, "Could not get requirements") } @@ -109,8 +112,8 @@ func (b *BuildScript) DependencyRequirements() ([]types.Requirement, error) { return deps, nil } -func (b *BuildScript) getRequirementsNode() (*Value, error) { - node, err := b.getSolveNode() +func (b *BuildScript) getRequirementsNode(target string) (*Value, error) { + node, err := b.getSolveNode(target) if err != nil { return nil, errs.Wrap(err, "Could not get solve node") } @@ -147,7 +150,23 @@ func getVersionRequirements(v *Value) []types.VersionRequirement { return reqs } -func (b *BuildScript) getSolveNode() (*Value, error) { +func isSolveFuncName(name string) bool { + return name == solveFuncName || name == solveLegacyFuncName +} + +func (b *BuildScript) getTargetNode(target string) (*Value, error) { + if target == "" { + for _, assignment := range b.raw.Assignments { + if assignment.Key != mainKey { + continue + } + if assignment.Value.Ident != nil { + target = *assignment.Value.Ident + break + } + } + } + var search func([]*Assignment) *Value search = func(assignments []*Assignment) *Value { var nextLet []*Assignment @@ -157,7 +176,13 @@ func (b *BuildScript) getSolveNode() (*Value, error) { continue } - if f := a.Value.FuncCall; f != nil && (f.Name == solveFuncName || f.Name == solveLegacyFuncName) { + if a.Key == target && a.Value.FuncCall != nil { + return a.Value + } + + if f := a.Value.FuncCall; target == "" && f != nil && isSolveFuncName(f.Name) { + // This is coming from a complex build expression with no straightforward way to determine + // a default target. Fall back on a top-level solve node. return a.Value } } @@ -169,15 +194,45 @@ func (b *BuildScript) getSolveNode() (*Value, error) { return nil } + if node := search(b.raw.Assignments); node != nil { return node, nil } + return nil, errNodeNotFound +} + +func (b *BuildScript) getSolveNode(target string) (*Value, error) { + node, err := b.getTargetNode(target) + if err != nil { + return nil, errs.Wrap(err, "Could not get target node") + } + + // If the target is the solve function, we're done. + if isSolveFuncName(node.FuncCall.Name) { + return node, nil + } + + // Otherwise, the "src" key contains a reference to the solve node. Look over the build expression + // again for that referenced node. + for _, arg := range node.FuncCall.Arguments { + if arg.Assignment == nil { + continue + } + a := arg.Assignment + if a.Key == srcKey && a.Value.Ident != nil { + node, err := b.getSolveNode(*a.Value.Ident) + if err != nil { + return nil, errs.Wrap(err, "Could not get solve node from target") + } + return node, nil + } + } return nil, errNodeNotFound } -func (b *BuildScript) getSolveAtTimeValue() (*Value, error) { - node, err := b.getSolveNode() +func (b *BuildScript) getSolveAtTimeValue(target string) (*Value, error) { + node, err := b.getSolveNode(target) if err != nil { return nil, errs.Wrap(err, "Could not get solve node") } @@ -191,8 +246,8 @@ func (b *BuildScript) getSolveAtTimeValue() (*Value, error) { return nil, errValueNotFound } -func (b *BuildScript) Platforms() ([]strfmt.UUID, error) { - node, err := b.getPlatformsNode() +func (b *BuildScript) Platforms(target string) ([]strfmt.UUID, error) { + node, err := b.getPlatformsNode(target) if err != nil { return nil, errs.Wrap(err, "Could not get platform node") } @@ -204,8 +259,8 @@ func (b *BuildScript) Platforms() ([]strfmt.UUID, error) { return list, nil } -func (b *BuildScript) getPlatformsNode() (*Value, error) { - node, err := b.getSolveNode() +func (b *BuildScript) getPlatformsNode(target string) (*Value, error) { + node, err := b.getSolveNode(target) if err != nil { return nil, errs.Wrap(err, "Could not get solve node") } diff --git a/pkg/buildscript/queries_test.go b/pkg/buildscript/queries_test.go index 9a107276ef..6aec6ef388 100644 --- a/pkg/buildscript/queries_test.go +++ b/pkg/buildscript/queries_test.go @@ -113,7 +113,7 @@ func TestRequirements(t *testing.T) { script, err := UnmarshalBuildExpression(data, nil) assert.NoError(t, err) - got, err := script.Requirements() + got, err := script.Requirements("") assert.NoError(t, err) gotReqs := []types.Requirement{} @@ -167,7 +167,7 @@ func TestRevision(t *testing.T) { script, err := UnmarshalBuildExpression(data, nil) assert.NoError(t, err) - got, err := script.Requirements() + got, err := script.Requirements("") assert.NoError(t, err) gotReqs := []RevisionRequirement{} diff --git a/pkg/buildscript/unmarshal.go b/pkg/buildscript/unmarshal.go index 8e3fbe8b03..ffbab27f9e 100644 --- a/pkg/buildscript/unmarshal.go +++ b/pkg/buildscript/unmarshal.go @@ -50,5 +50,15 @@ func Unmarshal(data []byte) (*BuildScript, error) { break } + // Verify there are no duplicate key assignments. + // This is primarily to catch duplicate solve nodes for a given target. + seen := make(map[string]bool) + for _, assignment := range raw.Assignments { + if _, exists := seen[assignment.Key]; exists { + return nil, locale.NewInputError(locale.Tl("err_buildscript_duplicate_keys", "Build script has duplicate '{{.V0}}' assignments", assignment.Key)) + } + seen[assignment.Key] = true + } + return &BuildScript{raw}, nil } diff --git a/pkg/buildscript/unmarshal_buildexpression.go b/pkg/buildscript/unmarshal_buildexpression.go index b157bef1ce..e392ea8669 100644 --- a/pkg/buildscript/unmarshal_buildexpression.go +++ b/pkg/buildscript/unmarshal_buildexpression.go @@ -87,7 +87,7 @@ func UnmarshalBuildExpression(data []byte, atTime *time.Time) (*BuildScript, err // Extract the 'at_time' from the solve node, if it exists, and change its value to be a // reference to "$at_time", which is how we want to show it in AScript format. - if atTimeNode, err := script.getSolveAtTimeValue(); err == nil && atTimeNode.Str != nil && !strings.HasPrefix(strValue(atTimeNode), `$`) { + if atTimeNode, err := script.getSolveAtTimeValue(""); err == nil && atTimeNode.Str != nil && !strings.HasPrefix(strValue(atTimeNode), `$`) { atTime, err := strfmt.ParseDateTime(strValue(atTimeNode)) if err != nil { return nil, errs.Wrap(err, "Invalid timestamp: %s", strValue(atTimeNode)) @@ -107,7 +107,7 @@ func UnmarshalBuildExpression(data []byte, atTime *time.Time) (*BuildScript, err // requirements = [{"name": "", "namespace": ""}, {...}, ...] // then transform them into function call form for the AScript format, e.g. // requirements = [Req(name = "", namespace = ""), Req(...), ...] - requirements, err := script.getRequirementsNode() + requirements, err := script.getRequirementsNode("") if err != nil { return nil, errs.Wrap(err, "Could not get requirements node") } diff --git a/pkg/platform/model/checkpoints.go b/pkg/platform/model/checkpoints.go index 87f491a940..b34554deeb 100644 --- a/pkg/platform/model/checkpoints.go +++ b/pkg/platform/model/checkpoints.go @@ -76,7 +76,7 @@ func FetchLanguagesForCommit(commitID strfmt.UUID, auth *authentication.Auth) ([ // FetchLanguagesForBuildScript fetches a list of language names for the given buildscript func FetchLanguagesForBuildScript(script *buildscript.BuildScript) ([]Language, error) { languages := []Language{} - reqs, err := script.DependencyRequirements() + reqs, err := script.DependencyRequirements("") if err != nil { return nil, errs.Wrap(err, "failed to get dependency requirements") }