Skip to content

Commit

Permalink
introduce WithRootNodesAndDown to walk the graph from specified nodes…
Browse files Browse the repository at this point in the history
… and down

Signed-off-by: Nicolas De Loof <[email protected]>
  • Loading branch information
ndeloof committed May 12, 2023
1 parent b6b537e commit a12b727
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 10 deletions.
53 changes: 49 additions & 4 deletions pkg/compose/dependencies.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ const (
)

type graphTraversal struct {
mu sync.Mutex
seen map[string]struct{}
mu sync.Mutex
seen map[string]struct{}
ignored map[string]struct{}

extremityNodesFn func(*Graph) []*Vertex // leaves or roots
adjacentNodesFn func(*Vertex) []*Vertex // getParents or getChildren
Expand Down Expand Up @@ -87,15 +88,46 @@ func InDependencyOrder(ctx context.Context, project *types.Project, fn func(cont
}

// InReverseDependencyOrder applies the function to the services of the project in reverse order of dependencies
func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error) error {
func InReverseDependencyOrder(ctx context.Context, project *types.Project, fn func(context.Context, string) error, options ...func(*graphTraversal)) error {
graph, err := NewGraph(project.Services, ServiceStarted)
if err != nil {
return err
}
t := downDirectionTraversal(fn)
for _, option := range options {
option(t)
}
return t.visit(ctx, graph)
}

func WithRootNodesAndDown(nodes []string) func(*graphTraversal) {
return func(t *graphTraversal) {
if len(nodes) == 0 {
return
}
originalFn := t.extremityNodesFn
t.extremityNodesFn = func(graph *Graph) []*Vertex {
var want []string
for _, node := range nodes {
vertex := graph.Vertices[node]
want = append(want, vertex.Service)
for _, v := range getAncestors(vertex) {
want = append(want, v.Service)
}
}

t.ignored = map[string]struct{}{}
for k, _ := range graph.Vertices {
if !utils.Contains(want, k) {
t.ignored[k] = struct{}{}
}
}

return originalFn(graph)
}
}
}

func (t *graphTraversal) visit(ctx context.Context, g *Graph) error {
expect := len(g.Vertices)
if expect == 0 {
Expand Down Expand Up @@ -142,7 +174,10 @@ func (t *graphTraversal) run(ctx context.Context, graph *Graph, eg *errgroup.Gro
}

eg.Go(func() error {
err := t.visitorFn(ctx, node.Service)
var err error
if _, ignore := t.ignored[node.Service]; !ignore {
err = t.visitorFn(ctx, node.Service)
}
if err == nil {
graph.UpdateStatus(node.Key, t.targetServiceStatus)
}
Expand Down Expand Up @@ -197,6 +232,16 @@ func getChildren(v *Vertex) []*Vertex {
return v.GetChildren()
}

// getAncestors return all descendents for a vertex, might contain duplicates
func getAncestors(v *Vertex) []*Vertex {
var descendents []*Vertex
for _, parent := range v.GetParents() {
descendents = append(descendents, parent)
descendents = append(descendents, getAncestors(parent)...)
}
return descendents
}

// GetChildren returns a slice with the child vertices of the a Vertex
func (v *Vertex) GetChildren() []*Vertex {
var res []*Vertex
Expand Down
91 changes: 91 additions & 0 deletions pkg/compose/dependencies_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ package compose
import (
"context"
"fmt"
"sort"
"sync"
"testing"

"github.com/compose-spec/compose-go/types"
"github.com/docker/compose/v2/pkg/utils"
testify "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gotest.tools/v3/assert"
Expand Down Expand Up @@ -297,3 +300,91 @@ func isVertexEqual(a, b Vertex) bool {
childrenEquality &&
parentEquality
}

func TestWith_RootNodesAndUp(t *testing.T) {
graph := &Graph{
lock: sync.RWMutex{},
Vertices: map[string]*Vertex{},
}

/** graph topology:
A B
/ \ / \
G C E
\ /
D
|
F
*/

graph.AddVertex("A", "A", 0)
graph.AddVertex("B", "B", 0)
graph.AddVertex("C", "C", 0)
graph.AddVertex("D", "D", 0)
graph.AddVertex("E", "E", 0)
graph.AddVertex("F", "F", 0)
graph.AddVertex("G", "G", 0)

_ = graph.AddEdge("C", "A")
_ = graph.AddEdge("C", "B")
_ = graph.AddEdge("E", "B")
_ = graph.AddEdge("D", "C")
_ = graph.AddEdge("D", "E")
_ = graph.AddEdge("F", "D")
_ = graph.AddEdge("G", "A")

type args struct {
nodes []string
}
tests := []struct {
name string
nodes []string
want []string
}{
{
name: "whole graph",
nodes: []string{"A", "B"},
want: []string{"A", "B", "C", "D", "E", "F", "G"},
},
{
name: "only leaves",
nodes: []string{"F", "G"},
want: []string{"F", "G"},
},
{
name: "simple dependent",
nodes: []string{"D"},
want: []string{"D", "F"},
},
{
name: "diamond dependents",
nodes: []string{"B"},
want: []string{"B", "C", "D", "E", "F"},
},
{
name: "partial graph",
nodes: []string{"A"},
want: []string{"A", "C", "D", "F", "G"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mx := sync.Mutex{}
expected := utils.Set[string]{}
expected.AddAll("C", "G", "D", "F")
var visited []string

gt := downDirectionTraversal(func(ctx context.Context, s string) error {
mx.Lock()
defer mx.Unlock()
visited = append(visited, s)
return nil
})
WithRootNodesAndDown(tt.nodes)(gt)
err := gt.visit(context.TODO(), graph)
assert.NilError(t, err)
sort.Strings(visited)
assert.DeepEqual(t, tt.want, visited)
})
}
}
20 changes: 16 additions & 4 deletions pkg/compose/down.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,30 @@ func (s *composeService) down(ctx context.Context, projectName string, options a
}
}

// Check requested services exists in model
var services []string
for _, service := range options.Services {
_, err := project.GetService(service)
if err != nil {
if options.Project != nil {
// ran with an explicit compose.yaml file, so we should not ignore
return err
}
} else {
services = append(services, service)
}
}
options.Services = services

if len(containers) > 0 {
resourceToRemove = true
}

err = InReverseDependencyOrder(ctx, project, func(c context.Context, service string) error {
if len(options.Services) > 0 && !utils.StringContains(options.Services, service) {
return nil
}
serviceContainers := containers.filter(isService(service))
err := s.removeContainers(ctx, w, serviceContainers, options.Timeout, options.Volumes)
return err
})
}, WithRootNodesAndDown(options.Services))
if err != nil {
return err
}
Expand Down
14 changes: 12 additions & 2 deletions pkg/utils/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,18 @@ func (s Set[T]) Add(v T) {
s[v] = struct{}{}
}

func (s Set[T]) Remove(v T) {
delete(s, v)
func (s Set[T]) AddAll(v ...T) {
for _, e := range v {
s[e] = struct{}{}
}
}

func (s Set[T]) Remove(v T) bool {
_, ok := s[v]
if ok {
delete(s, v)
}
return ok
}

func (s Set[T]) Clear() {
Expand Down

0 comments on commit a12b727

Please sign in to comment.