Skip to content

Commit

Permalink
Build scripts can distinguish between solve nodes for different targets.
Browse files Browse the repository at this point in the history
Clients need to pass the target when querying requirements, platforms, etc.
Use "" for the default target (i.e. the name assigned to 'main').
  • Loading branch information
mitchell-as committed Oct 7, 2024
1 parent 0d07973 commit 27d4e9d
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 28 deletions.
2 changes: 1 addition & 1 deletion internal/runners/manifest/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (m *Manifest) fetchRequirements() ([]buildscript.Requirement, error) {
}
}

reqs, err := script.Requirements()
reqs, err := script.Requirements("")
if err != nil {
return nil, errs.Wrap(err, "Could not get requirements")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/runners/platforms/remove.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (a *Remove) Run(params RemoveRunParams) (rerr error) {

// Prepare updated buildscript
script := oldCommit.BuildScript()
platforms, err := script.Platforms()
platforms, err := script.Platforms("")
if err != nil {
return errs.Wrap(err, "Failed to get platforms")
}
Expand Down
2 changes: 1 addition & 1 deletion internal/runners/uninstall/uninstall.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (u *Uninstall) renderUserFacing(reqs requirements) {
func (u *Uninstall) resolveRequirements(script *buildscript.BuildScript, pkgs captain.PackagesValue) (requirements, error) {
result := requirements{}

reqs, err := script.DependencyRequirements()
reqs, err := script.DependencyRequirements("")
if err != nil {
return nil, errs.Wrap(err, "Unable to get requirements")
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/buildscript/mutations.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (b *BuildScript) AddRequirement(requirement types.Requirement) error {
obj = append(obj, &Assignment{requirementVersionRequirementsKey, &Value{List: &values}})
}

requirementsNode, err := b.getRequirementsNode()
requirementsNode, err := b.getRequirementsNode("")
if err != nil {
return errs.Wrap(err, "Could not get requirements node")
}
Expand All @@ -81,7 +81,7 @@ type RequirementNotFoundError struct {
// RemoveRequirement will remove any matching requirement. Note that it only operates on the Name and Namespace fields.
// It will not verify if revision or version match.
func (b *BuildScript) RemoveRequirement(requirement types.Requirement) error {
requirementsNode, err := b.getRequirementsNode()
requirementsNode, err := b.getRequirementsNode("")
if err != nil {
return errs.Wrap(err, "Could not get requirements node")
}
Expand Down Expand Up @@ -126,7 +126,7 @@ func (b *BuildScript) RemoveRequirement(requirement types.Requirement) error {
}

func (b *BuildScript) AddPlatform(platformID strfmt.UUID) error {
platformsNode, err := b.getPlatformsNode()
platformsNode, err := b.getPlatformsNode("")
if err != nil {
return errs.Wrap(err, "Could not get platforms node")
}
Expand All @@ -144,7 +144,7 @@ type PlatformNotFoundError struct {
}

func (b *BuildScript) RemovePlatform(platformID strfmt.UUID) error {
platformsNode, err := b.getPlatformsNode()
platformsNode, err := b.getPlatformsNode("")
if err != nil {
return errs.Wrap(err, "Could not get platforms node")
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/buildscript/mutations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func TestUpdateRequirements(t *testing.T) {
return
}

got, err := script.Requirements()
got, err := script.Requirements("")
assert.NoError(t, err)

gotReqs := []DependencyRequirement{}
Expand Down Expand Up @@ -361,7 +361,7 @@ func TestUpdatePlatform(t *testing.T) {
return
}

got, err := script.Platforms()
got, err := script.Platforms("")
assert.NoError(t, err)

sort.Slice(got, func(i, j int) bool { return got[i] < got[j] })
Expand Down
83 changes: 69 additions & 14 deletions pkg/buildscript/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
const (
solveFuncName = "solve"
solveLegacyFuncName = "solve_legacy"
srcKey = "src"
requirementsKey = "requirements"
platformsKey = "platforms"
)
Expand Down Expand Up @@ -43,8 +44,10 @@ type UnknownRequirement struct {

func (r UnknownRequirement) IsRequirement() {}

func (b *BuildScript) Requirements() ([]Requirement, error) {
requirementsNode, err := b.getRequirementsNode()
// Returns the requirements for the given target.
// If the given target is the empty string, uses the default target (i.e. the name assigned to 'main').
func (b *BuildScript) Requirements(target string) ([]Requirement, error) {
requirementsNode, err := b.getRequirementsNode(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get requirements node")
}
Expand Down Expand Up @@ -95,8 +98,8 @@ func (b *BuildScript) Requirements() ([]Requirement, error) {
// DependencyRequirements is identical to Requirements except that it only considers dependency type requirements,
// which are the most common.
// ONLY use this when you know you only need to care about dependencies.
func (b *BuildScript) DependencyRequirements() ([]types.Requirement, error) {
reqs, err := b.Requirements()
func (b *BuildScript) DependencyRequirements(target string) ([]types.Requirement, error) {
reqs, err := b.Requirements(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get requirements")
}
Expand All @@ -109,8 +112,8 @@ func (b *BuildScript) DependencyRequirements() ([]types.Requirement, error) {
return deps, nil
}

func (b *BuildScript) getRequirementsNode() (*Value, error) {
node, err := b.getSolveNode()
func (b *BuildScript) getRequirementsNode(target string) (*Value, error) {
node, err := b.getSolveNode(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get solve node")
}
Expand Down Expand Up @@ -147,7 +150,23 @@ func getVersionRequirements(v *Value) []types.VersionRequirement {
return reqs
}

func (b *BuildScript) getSolveNode() (*Value, error) {
func isSolveFuncName(name string) bool {
return name == solveFuncName || name == solveLegacyFuncName
}

func (b *BuildScript) getTargetNode(target string) (*Value, error) {
if target == "" {
for _, assignment := range b.raw.Assignments {
if assignment.Key != mainKey {
continue
}
if assignment.Value.Ident != nil {
target = *assignment.Value.Ident
break
}
}
}

var search func([]*Assignment) *Value
search = func(assignments []*Assignment) *Value {
var nextLet []*Assignment
Expand All @@ -157,7 +176,13 @@ func (b *BuildScript) getSolveNode() (*Value, error) {
continue
}

if f := a.Value.FuncCall; f != nil && (f.Name == solveFuncName || f.Name == solveLegacyFuncName) {
if a.Key == target && a.Value.FuncCall != nil {
return a.Value
}

if f := a.Value.FuncCall; target == "" && f != nil && isSolveFuncName(f.Name) {
// This is coming from a complex build expression with no straightforward way to determine
// a default target. Fall back on a top-level solve node.
return a.Value
}
}
Expand All @@ -169,15 +194,45 @@ func (b *BuildScript) getSolveNode() (*Value, error) {

return nil
}

if node := search(b.raw.Assignments); node != nil {
return node, nil
}
return nil, errNodeNotFound
}

func (b *BuildScript) getSolveNode(target string) (*Value, error) {
node, err := b.getTargetNode(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get target node")
}

// If the target is the solve function, we're done.
if isSolveFuncName(node.FuncCall.Name) {
return node, nil
}

// Otherwise, the "src" key contains a reference to the solve node. Look over the build expression
// again for that referenced node.
for _, arg := range node.FuncCall.Arguments {
if arg.Assignment == nil {
continue
}
a := arg.Assignment
if a.Key == srcKey && a.Value.Ident != nil {
node, err := b.getSolveNode(*a.Value.Ident)
if err != nil {
return nil, errs.Wrap(err, "Could not get solve node from target")
}
return node, nil
}
}

return nil, errNodeNotFound
}

func (b *BuildScript) getSolveAtTimeValue() (*Value, error) {
node, err := b.getSolveNode()
func (b *BuildScript) getSolveAtTimeValue(target string) (*Value, error) {
node, err := b.getSolveNode(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get solve node")
}
Expand All @@ -191,8 +246,8 @@ func (b *BuildScript) getSolveAtTimeValue() (*Value, error) {
return nil, errValueNotFound
}

func (b *BuildScript) Platforms() ([]strfmt.UUID, error) {
node, err := b.getPlatformsNode()
func (b *BuildScript) Platforms(target string) ([]strfmt.UUID, error) {
node, err := b.getPlatformsNode(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get platform node")
}
Expand All @@ -204,8 +259,8 @@ func (b *BuildScript) Platforms() ([]strfmt.UUID, error) {
return list, nil
}

func (b *BuildScript) getPlatformsNode() (*Value, error) {
node, err := b.getSolveNode()
func (b *BuildScript) getPlatformsNode(target string) (*Value, error) {
node, err := b.getSolveNode(target)
if err != nil {
return nil, errs.Wrap(err, "Could not get solve node")
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/buildscript/queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func TestRequirements(t *testing.T) {
script, err := UnmarshalBuildExpression(data, nil)
assert.NoError(t, err)

got, err := script.Requirements()
got, err := script.Requirements("")
assert.NoError(t, err)

gotReqs := []types.Requirement{}
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestRevision(t *testing.T) {
script, err := UnmarshalBuildExpression(data, nil)
assert.NoError(t, err)

got, err := script.Requirements()
got, err := script.Requirements("")
assert.NoError(t, err)

gotReqs := []RevisionRequirement{}
Expand Down
10 changes: 10 additions & 0 deletions pkg/buildscript/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,15 @@ func Unmarshal(data []byte) (*BuildScript, error) {
break
}

// Verify there are no duplicate key assignments.
// This is primarily to catch duplicate solve nodes for a given target.
seen := make(map[string]bool)
for _, assignment := range raw.Assignments {
if _, exists := seen[assignment.Key]; exists {
return nil, locale.NewInputError(locale.Tl("err_buildscript_duplicate_keys", "Build script has duplicate '{{.V0}}' assignments", assignment.Key))
}
seen[assignment.Key] = true
}

return &BuildScript{raw}, nil
}
4 changes: 2 additions & 2 deletions pkg/buildscript/unmarshal_buildexpression.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func UnmarshalBuildExpression(data []byte, atTime *time.Time) (*BuildScript, err

// Extract the 'at_time' from the solve node, if it exists, and change its value to be a
// reference to "$at_time", which is how we want to show it in AScript format.
if atTimeNode, err := script.getSolveAtTimeValue(); err == nil && atTimeNode.Str != nil && !strings.HasPrefix(strValue(atTimeNode), `$`) {
if atTimeNode, err := script.getSolveAtTimeValue(""); err == nil && atTimeNode.Str != nil && !strings.HasPrefix(strValue(atTimeNode), `$`) {
atTime, err := strfmt.ParseDateTime(strValue(atTimeNode))
if err != nil {
return nil, errs.Wrap(err, "Invalid timestamp: %s", strValue(atTimeNode))
Expand All @@ -107,7 +107,7 @@ func UnmarshalBuildExpression(data []byte, atTime *time.Time) (*BuildScript, err
// requirements = [{"name": "<name>", "namespace": "<name>"}, {...}, ...]
// then transform them into function call form for the AScript format, e.g.
// requirements = [Req(name = "<name>", namespace = "<name>"), Req(...), ...]
requirements, err := script.getRequirementsNode()
requirements, err := script.getRequirementsNode("")
if err != nil {
return nil, errs.Wrap(err, "Could not get requirements node")
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/platform/model/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func FetchLanguagesForCommit(commitID strfmt.UUID, auth *authentication.Auth) ([
// FetchLanguagesForBuildScript fetches a list of language names for the given buildscript
func FetchLanguagesForBuildScript(script *buildscript.BuildScript) ([]Language, error) {
languages := []Language{}
reqs, err := script.DependencyRequirements()
reqs, err := script.DependencyRequirements("")
if err != nil {
return nil, errs.Wrap(err, "failed to get dependency requirements")
}
Expand Down

0 comments on commit 27d4e9d

Please sign in to comment.