diff --git a/README.md b/README.md index 28e665d..bb14a5c 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ https://docs.ansible.com/ansible/latest/user_guide/intro_inventory.html - [X] Variables - [X] Host patterns - [X] Nested groups +- [X] Load variables from `group_vars` and `host_vars` ## Public API ```godoc @@ -24,22 +25,30 @@ type Group struct { Hosts map[string]*Host Children map[string]*Group Parents map[string]*Group + + // Has unexported fields. } Group represents ansible group func GroupMapListValues(mymap map[string]*Group) []*Group - GroupMapListValues transforms map of Groups into Group list + GroupMapListValues transforms map of Groups into Group list in lexical order + +func (group Group) String() string type Host struct { Name string Port int Vars map[string]string Groups map[string]*Group + + // Has unexported fields. } Host represents ansible host func HostMapListValues(mymap map[string]*Host) []*Host - HostMapListValues transforms map of Hosts into Host list + HostMapListValues transforms map of Hosts into Host list in lexical order + +func (host Host) String() string type InventoryData struct { Groups map[string]*Group @@ -57,6 +66,15 @@ func ParseFile(f string) (*InventoryData, error) func ParseString(input string) (*InventoryData, error) ParseString parses Inventory represented as a string +func (inventory *InventoryData) AddVars(path string) error + AddVars take a path that contains group_vars and host_vars directories and + adds these variables to the InventoryData + +func (inventory *InventoryData) AddVarsLowerCased(path string) error + AddVarsLowerCased does the same as AddVars, but converts hostnames and + groups name to lowercase Use this function if you've executed + `inventory.HostsToLower` or `inventory.GroupsToLower` + func (inventory *InventoryData) GroupsToLower() GroupsToLower transforms all group names to lowercase @@ -67,7 +85,14 @@ func (inventory *InventoryData) Match(m string) []*Host Match looks for a hosts that match the pattern func (inventory *InventoryData) Reconcile() - Reconcile ensures inventory basic rules, run after updates + Reconcile ensures inventory basic rules, run after updates After initial + inventory file processing, only direct relationships are set + + This method: + + * (re)sets Children and Parents for hosts and groups + * ensures that mandatory groups exist + * calculates variables for hosts and groups ``` diff --git a/aini.go b/aini.go index 3eda844..250ad0b 100644 --- a/aini.go +++ b/aini.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "path" + "sort" "strings" ) @@ -17,13 +18,22 @@ type InventoryData struct { } // Group represents ansible group -// Note: Hosts field lists only direct members of the group, members of children groups are not included type Group struct { Name string Vars map[string]string Hosts map[string]*Host Children map[string]*Group Parents map[string]*Group + + directParents map[string]*Group + // Vars set in inventory + inventoryVars map[string]string + // Vars set in group_vars + fileVars map[string]string + // Projection of all parent inventory variables + allInventoryVars map[string]string + // Projection of all parent group_vars variables + allFileVars map[string]string } // Host represents ansible host @@ -32,6 +42,12 @@ type Host struct { Port int Vars map[string]string Groups map[string]*Group + + directGroups map[string]*Group + // Vars set in inventory + inventoryVars map[string]string + // Vars set in host_vars + fileVars map[string]string } // ParseFile parses Inventory represented as a file @@ -72,7 +88,7 @@ func (inventory *InventoryData) Match(m string) []*Host { return matchedHosts } -// GroupMapListValues transforms map of Groups into Group list +// GroupMapListValues transforms map of Groups into Group list in lexical order func GroupMapListValues(mymap map[string]*Group) []*Group { values := make([]*Group, len(mymap)) @@ -81,10 +97,13 @@ func GroupMapListValues(mymap map[string]*Group) []*Group { values[i] = v i++ } + sort.Slice(values, func(i, j int) bool { + return values[i].Name < values[j].Name + }) return values } -// HostMapListValues transforms map of Hosts into Host list +// HostMapListValues transforms map of Hosts into Host list in lexical order func HostMapListValues(mymap map[string]*Host) []*Host { values := make([]*Host, len(mymap)) @@ -93,6 +112,9 @@ func HostMapListValues(mymap map[string]*Host) []*Host { values[i] = v i++ } + sort.Slice(values, func(i, j int) bool { + return values[i].Name < values[j].Name + }) return values } @@ -120,16 +142,26 @@ func hostMapToLower(hosts map[string]*Host, keysOnly bool) map[string]*Host { func (inventory *InventoryData) GroupsToLower() { inventory.Groups = groupMapToLower(inventory.Groups, false) for _, host := range inventory.Hosts { + host.directGroups = groupMapToLower(host.directGroups, true) host.Groups = groupMapToLower(host.Groups, true) } } +func (group Group) String() string { + return group.Name +} + +func (host Host) String() string { + return host.Name +} + func groupMapToLower(groups map[string]*Group, keysOnly bool) map[string]*Group { newGroups := make(map[string]*Group, len(groups)) for groupname, group := range groups { groupname = strings.ToLower(groupname) if !keysOnly { group.Name = groupname + group.directParents = groupMapToLower(group.directParents, true) group.Parents = groupMapToLower(group.Parents, true) group.Children = groupMapToLower(group.Children, true) } diff --git a/aini_test.go b/aini_test.go index 1d1ccd7..a7e4447 100644 --- a/aini_test.go +++ b/aini_test.go @@ -3,85 +3,16 @@ package aini import ( "fmt" "testing" + + "github.com/stretchr/testify/assert" ) func parseString(t *testing.T, input string) *InventoryData { v, err := ParseString(input) - assert(t, err == nil, fmt.Sprintf("Error occurred while parsing: %s", err)) + assert.Nil(t, err, fmt.Sprintf("Error occurred while parsing: %s", err)) return v } -func (inventory *InventoryData) assertGroupExists(t *testing.T, group string) { - if inventory.Groups[group] == nil { - t.Errorf("Cannot find group \"%s\" in %v", group, inventory.Groups) - } -} - -func (inventory *InventoryData) assertGroupNotExists(t *testing.T, group string) { - if inventory.Groups[group] != nil { - t.Errorf("Group \"%s\" should not exist in %v", group, inventory.Groups) - } -} - -func (host *Host) assertGroupExists(t *testing.T, group string) { - if host.Groups[group] == nil { - t.Errorf("Cannot find group \"%s\" in %v", group, host.Groups) - } -} - -func (host *Host) assertGroupNotExists(t *testing.T, group string) { - if host.Groups[group] != nil { - t.Errorf("Group \"%s\" should not exist in %v", group, host.Groups) - } -} - -func (host *Host) assertVar(t *testing.T, name string, value string) { - if host.Vars[name] != value { - t.Errorf("Host %s doesn't have expected variable %s. Expected value: %s, Actual value: %s", host.Name, name, value, host.Vars[name]) - } -} - -func (group *Group) assertChildGroupExists(t *testing.T, child string) { - if group.Children[child] == nil { - t.Errorf("Cannot find child group \"%s\" in %v", child, group.Parents) - } -} -func (group *Group) assertParentGroupExists(t *testing.T, parent string) { - if group.Parents[parent] == nil { - t.Errorf("Cannot find child group \"%s\" in %v", parent, group.Parents) - } -} - -func (inventory *InventoryData) assertHostExists(t *testing.T, host string) { - if inventory.Hosts[host] == nil { - t.Errorf("Cannot find host \"%s\" in %v", host, inventory.Hosts) - } -} - -func (inventory *InventoryData) assertHostNotExists(t *testing.T, host string) { - if inventory.Hosts[host] != nil { - t.Errorf("Host \"%s\" should not exist in %v", host, inventory.Hosts) - } -} - -func (group *Group) assertHostExists(t *testing.T, host string) { - if group.Hosts[host] == nil { - t.Errorf("Cannot find host \"%s\" in %v", host, group.Hosts) - } -} - -func (group *Group) assertHostNotExists(t *testing.T, host string) { - if group.Hosts[host] != nil { - t.Errorf("Host \"%s\" should not exist in %v", host, group.Hosts) - } -} - -func assert(t *testing.T, cond bool, msg string) { - if !cond { - t.Error(msg) - } -} - func TestBelongToBasicGroups(t *testing.T) { v := parseString(t, ` host1:2221 # Comments @@ -90,46 +21,46 @@ func TestBelongToBasicGroups(t *testing.T) { # ignored `) - assert(t, len(v.Hosts) == 2, "Exactly two hosts expected") - assert(t, len(v.Groups) == 3, fmt.Sprintf("Expected three groups \"web\", \"all\" and \"ungrouped\", got: %v", v.Groups)) + assert.Len(t, v.Hosts, 2, "Exactly two hosts expected") + assert.Len(t, v.Groups, 3, "Expected three groups: web, all and ungrouped") - v.assertGroupExists(t, "web") - v.assertGroupExists(t, "all") - v.assertGroupExists(t, "ungrouped") + assert.Contains(t, v.Groups, "web") + assert.Contains(t, v.Groups, "all") + assert.Contains(t, v.Groups, "ungrouped") - v.assertHostExists(t, "host1") - assert(t, len(v.Hosts["host1"].Groups) == 2, "Host1 must belong to two groups: ungrouped and all") - assert(t, v.Hosts["host1"].Groups["all"] != nil, "Host1 must belong to all group") - assert(t, v.Hosts["host1"].Groups["ungrouped"] != nil, "Host1 must belong to ungrouped group") + assert.Contains(t, v.Hosts, "host1") + assert.Len(t, v.Hosts["host1"].Groups, 2, "Host1 must belong to two groups: ungrouped and all") + assert.NotNil(t, 2, v.Hosts["host1"].Groups["all"], "Host1 must belong to two groups: ungrouped and all") + assert.NotNil(t, 2, v.Hosts["host1"].Groups["ungrouped"], "Host1 must belong to ungrouped group") - v.assertHostExists(t, "host2") - assert(t, len(v.Hosts["host2"].Groups) == 2, "Host2 must belong to two groups: web and all") - assert(t, v.Hosts["host2"].Groups["all"] != nil, "Host2 must belong to all group") - assert(t, v.Hosts["host2"].Groups["web"] != nil, "Host1 must belong to web group") + assert.Contains(t, v.Hosts, "host2") + assert.Len(t, v.Hosts["host2"].Groups, 2, "Host2 must belong to two groups: ungrouped and all") + assert.NotNil(t, 2, v.Hosts["host2"].Groups["all"], "Host2 must belong to two groups: ungrouped and all") + assert.NotNil(t, 2, v.Hosts["host2"].Groups["ungrouped"], "Host2 must belong to ungrouped group") - assert(t, len(v.Groups["all"].Hosts) == 2, "Group all must contain two hosts") - v.Groups["all"].assertHostExists(t, "host1") - v.Groups["all"].assertHostExists(t, "host2") + assert.Equal(t, 2, len(v.Groups["all"].Hosts), "Group all must contain two hosts") + assert.Contains(t, v.Groups["all"].Hosts, "host1") + assert.Contains(t, v.Groups["all"].Hosts, "host2") - assert(t, len(v.Groups["web"].Hosts) == 1, "Group web must contain one host") - v.Groups["web"].assertHostExists(t, "host2") + assert.Len(t, v.Groups["web"].Hosts, 1, "Group web must contain one host") + assert.Contains(t, v.Groups["web"].Hosts, "host2") - assert(t, len(v.Groups["ungrouped"].Hosts) == 1, "Group ungrouped must contain one host") - v.Groups["ungrouped"].assertHostExists(t, "host1") - v.Groups["ungrouped"].assertHostNotExists(t, "host2") + assert.Len(t, v.Groups["ungrouped"].Hosts, 1, "Group ungrouped must contain one host") + assert.Contains(t, v.Groups["ungrouped"].Hosts, "host1") + assert.NotContains(t, v.Groups["ungrouped"].Hosts, "host2") - assert(t, v.Hosts["host1"].Port == 2221, "Host1 ports doesn't match") - assert(t, v.Hosts["host2"].Port == 22, "Host2 ports doesn't match") + assert.Equal(t, 2221, v.Hosts["host1"].Port, "Host1 port is set") + assert.Equal(t, 22, v.Hosts["host2"].Port, "Host2 port is set") } func TestGroupStructure(t *testing.T) { v := parseString(t, ` host5 - + [web:children] nginx apache - + [web] host1 host2 @@ -144,29 +75,31 @@ func TestGroupStructure(t *testing.T) { host6 `) - v.assertGroupExists(t, "web") - v.assertGroupExists(t, "apache") - v.assertGroupExists(t, "nginx") + assert.Contains(t, v.Groups, "web") + assert.Contains(t, v.Groups, "apache") + assert.Contains(t, v.Groups, "nginx") + assert.Contains(t, v.Groups, "all") + assert.Contains(t, v.Groups, "ungrouped") - assert(t, len(v.Groups) == 5, "Five groups must present: web, apache, nginx, all, ungrouped") + assert.Len(t, v.Groups, 5, "Five groups must be present: web, apache, nginx, all, ungrouped") - v.Groups["web"].assertChildGroupExists(t, "nginx") - v.Groups["web"].assertChildGroupExists(t, "apache") - v.Groups["nginx"].assertParentGroupExists(t, "web") - v.Groups["apache"].assertParentGroupExists(t, "web") + assert.Contains(t, v.Groups["web"].Children, "nginx") + assert.Contains(t, v.Groups["web"].Children, "apache") + assert.Contains(t, v.Groups["nginx"].Parents, "web") + assert.Contains(t, v.Groups["apache"].Parents, "web") - v.Groups["web"].assertHostExists(t, "host1") - v.Groups["web"].assertHostExists(t, "host2") - v.Groups["web"].assertHostExists(t, "host3") - v.Groups["web"].assertHostExists(t, "host4") - v.Groups["web"].assertHostExists(t, "host5") + assert.Contains(t, v.Groups["web"].Hosts, "host1") + assert.Contains(t, v.Groups["web"].Hosts, "host2") + assert.Contains(t, v.Groups["web"].Hosts, "host3") + assert.Contains(t, v.Groups["web"].Hosts, "host4") + assert.Contains(t, v.Groups["web"].Hosts, "host5") - v.Groups["nginx"].assertHostExists(t, "host1") + assert.Contains(t, v.Groups["nginx"].Hosts, "host1") - v.Hosts["host1"].assertGroupExists(t, "web") - v.Hosts["host1"].assertGroupExists(t, "nginx") + assert.Contains(t, v.Hosts["host1"].Groups, "web") + assert.Contains(t, v.Hosts["host1"].Groups, "nginx") - assert(t, len(v.Groups["ungrouped"].Hosts) == 0, "Group ungrouped should be empty") + assert.Empty(t, v.Groups["ungrouped"].Hosts) } func TestGroupNotExplicitlyDefined(t *testing.T) { @@ -178,82 +111,130 @@ func TestGroupNotExplicitlyDefined(t *testing.T) { host1 `) - v.assertGroupExists(t, "web") - v.assertGroupExists(t, "nginx") + assert.Contains(t, v.Groups, "web") + assert.Contains(t, v.Groups, "nginx") + assert.Contains(t, v.Groups, "all") + assert.Contains(t, v.Groups, "ungrouped") - assert(t, len(v.Groups) == 4, "Five groups must present: web, nginx, all, ungrouped") + assert.Len(t, v.Groups, 4, "Four groups must present: web, nginx, all, ungrouped") - v.Groups["web"].assertChildGroupExists(t, "nginx") - v.Groups["nginx"].assertParentGroupExists(t, "web") + assert.Contains(t, v.Groups["web"].Children, "nginx") + assert.Contains(t, v.Groups["nginx"].Parents, "web") - v.Groups["web"].assertHostExists(t, "host1") + assert.Contains(t, v.Groups["web"].Hosts, "host1") + assert.Contains(t, v.Groups["nginx"].Hosts, "host1") - v.Groups["nginx"].assertHostExists(t, "host1") + assert.Contains(t, v.Hosts["host1"].Groups, "web") + assert.Contains(t, v.Hosts["host1"].Groups, "nginx") - v.Hosts["host1"].assertGroupExists(t, "web") - v.Hosts["host1"].assertGroupExists(t, "nginx") + assert.Empty(t, v.Groups["ungrouped"].Hosts, "Group ungrouped should be empty") +} + +func TestAllGroup(t *testing.T) { + v := parseString(t, ` + host7 + host5 + + [web:children] + nginx + apache - assert(t, len(v.Groups["ungrouped"].Hosts) == 0, "Group ungrouped should be empty") + [web] + host1 + host2 + + [nginx] + host1 + host3 + host4 + + [apache] + host5 + host6 + `) + + allGroup := v.Groups["all"] + assert.NotNil(t, allGroup) + assert.Empty(t, allGroup.Parents) + assert.NotContains(t, allGroup.Children, "all") + assert.Len(t, allGroup.Children, 4) + assert.Len(t, allGroup.Hosts, 7) + for _, group := range v.Groups { + if group.Name == "all" { + continue + } + assert.Contains(t, allGroup.Children, group.Name) + assert.Contains(t, group.Parents, allGroup.Name) + } + for _, host := range v.Hosts { + assert.Contains(t, allGroup.Hosts, host.Name) + assert.Contains(t, host.Groups, allGroup.Name) + + } } func TestHostExpansionFullNumericPattern(t *testing.T) { v := parseString(t, ` host-[001:015:3]-web:23 `) - assert(t, len(v.Hosts) == 5, fmt.Sprintf("There must be 5 hosts in the list, found: %d", len(v.Hosts))) - v.assertHostExists(t, "host-001-web") - v.assertHostExists(t, "host-004-web") - v.assertHostExists(t, "host-007-web") - v.assertHostExists(t, "host-010-web") - v.assertHostExists(t, "host-013-web") - - assert(t, v.Hosts["host-007-web"].Port == 23, "host-007-web ports doesn't match") + + assert.Contains(t, v.Hosts, "host-001-web") + assert.Contains(t, v.Hosts, "host-004-web") + assert.Contains(t, v.Hosts, "host-007-web") + assert.Contains(t, v.Hosts, "host-010-web") + assert.Contains(t, v.Hosts, "host-013-web") + assert.Len(t, v.Hosts, 5) + + for _, host := range v.Hosts { + assert.Equalf(t, 23, host.Port, "%s port is set", host.Name) + } } func TestHostExpansionFullAlphabeticPattern(t *testing.T) { v := parseString(t, ` host-[a:o:3]-web `) - v.assertHostExists(t, "host-a-web") - v.assertHostExists(t, "host-d-web") - v.assertHostExists(t, "host-g-web") - v.assertHostExists(t, "host-j-web") - v.assertHostExists(t, "host-m-web") + assert.Contains(t, v.Hosts, "host-a-web") + assert.Contains(t, v.Hosts, "host-d-web") + assert.Contains(t, v.Hosts, "host-g-web") + assert.Contains(t, v.Hosts, "host-j-web") + assert.Contains(t, v.Hosts, "host-m-web") + assert.Len(t, v.Hosts, 5) } func TestHostExpansionShortNumericPattern(t *testing.T) { v := parseString(t, ` host-[:05]-web `) - assert(t, len(v.Hosts) == 6, fmt.Sprintf("There must be 6 hosts in the list, found: %d", len(v.Hosts))) - v.assertHostExists(t, "host-00-web") - v.assertHostExists(t, "host-01-web") - v.assertHostExists(t, "host-02-web") - v.assertHostExists(t, "host-03-web") - v.assertHostExists(t, "host-04-web") - v.assertHostExists(t, "host-05-web") + assert.Contains(t, v.Hosts, "host-00-web") + assert.Contains(t, v.Hosts, "host-01-web") + assert.Contains(t, v.Hosts, "host-02-web") + assert.Contains(t, v.Hosts, "host-03-web") + assert.Contains(t, v.Hosts, "host-04-web") + assert.Contains(t, v.Hosts, "host-05-web") + assert.Len(t, v.Hosts, 6) } func TestHostExpansionShortAlphabeticPattern(t *testing.T) { v := parseString(t, ` host-[a:c]-web `) - assert(t, len(v.Hosts) == 3, fmt.Sprintf("There must be 3 hosts in the list, found: %d", len(v.Hosts))) - v.assertHostExists(t, "host-a-web") - v.assertHostExists(t, "host-b-web") - v.assertHostExists(t, "host-c-web") + assert.Contains(t, v.Hosts, "host-a-web") + assert.Contains(t, v.Hosts, "host-b-web") + assert.Contains(t, v.Hosts, "host-c-web") + assert.Len(t, v.Hosts, 3) } func TestHostExpansionMultiplePatterns(t *testing.T) { v := parseString(t, ` host-[1:2]-[a:b]-web `) - assert(t, len(v.Hosts) == 4, fmt.Sprintf("There must be 4 hosts in the list, found: %d", len(v.Hosts))) - v.assertHostExists(t, "host-1-a-web") - v.assertHostExists(t, "host-1-b-web") - v.assertHostExists(t, "host-2-a-web") - v.assertHostExists(t, "host-2-b-web") + assert.Contains(t, v.Hosts, "host-1-a-web") + assert.Contains(t, v.Hosts, "host-1-b-web") + assert.Contains(t, v.Hosts, "host-2-a-web") + assert.Contains(t, v.Hosts, "host-2-b-web") + assert.Len(t, v.Hosts, 4) } func TestVariablesPriority(t *testing.T) { @@ -281,12 +262,12 @@ func TestVariablesPriority(t *testing.T) { x=f `) - v.Hosts["host-nginx-with-x"].assertVar(t, "x", "e") - v.Hosts["host-nginx"].assertVar(t, "x", "d") - v.Hosts["host-web"].assertVar(t, "x", "b") - v.Hosts["host-ungrouped-with-x"].assertVar(t, "x", "a") - v.Hosts["host-ungrouped"].assertVar(t, "x", "f") - + assert.Equal(t, "a", v.Hosts["host-ungrouped-with-x"].Vars["x"]) + assert.Equal(t, "b", v.Hosts["host-web"].Vars["x"]) + assert.Equal(t, "c", v.Groups["web"].Vars["x"]) + assert.Equal(t, "d", v.Hosts["host-nginx"].Vars["x"]) + assert.Equal(t, "e", v.Hosts["host-nginx-with-x"].Vars["x"]) + assert.Equal(t, "f", v.Hosts["host-ungrouped"].Vars["x"]) } func TestHostsToLower(t *testing.T) { @@ -300,20 +281,24 @@ func TestHostsToLower(t *testing.T) { tomcat-1 cat `) - v.assertHostExists(t, "CatFish") - v.Groups["ungrouped"].assertHostExists(t, "CatFish") - v.assertHostExists(t, "TomCat") + assert.Contains(t, v.Hosts, "CatFish") + assert.Contains(t, v.Groups["ungrouped"].Hosts, "CatFish") + assert.Contains(t, v.Hosts, "TomCat") + v.HostsToLower() - v.assertHostNotExists(t, "CatFish") - v.assertHostExists(t, "catfish") - assert(t, v.Hosts["catfish"].Name == "catfish", "Host catfish should have matching name") - v.assertHostNotExists(t, "TomCat") - v.assertHostExists(t, "tomcat") - assert(t, v.Hosts["tomcat"].Name == "tomcat", "Host catfish should have matching name") - v.Groups["ungrouped"].assertHostNotExists(t, "CatFish") - v.Groups["ungrouped"].assertHostExists(t, "catfish") - v.Groups["web"].assertHostNotExists(t, "TomCat") - v.Groups["web"].assertHostExists(t, "tomcat") + + assert.NotContains(t, v.Hosts, "CatFish") + assert.Contains(t, v.Hosts, "catfish") + assert.Equal(t, "catfish", v.Hosts["catfish"].Name, "Host catfish should have a matching name") + + assert.NotContains(t, v.Hosts, "TomCat") + assert.Contains(t, v.Hosts, "tomcat") + assert.Equal(t, "tomcat", v.Hosts["tomcat"].Name, "Host tomcat should have a matching name") + + assert.NotContains(t, v.Groups["ungrouped"].Hosts, "CatFish") + assert.Contains(t, v.Groups["ungrouped"].Hosts, "catfish") + assert.NotContains(t, v.Groups["web"].Hosts, "TomCat") + assert.Contains(t, v.Groups["web"].Hosts, "tomcat") } func TestGroupsToLower(t *testing.T) { @@ -329,22 +314,22 @@ func TestGroupsToLower(t *testing.T) { tomcat-1 cat `) - v.assertGroupExists(t, "Web") - v.assertGroupExists(t, "TomCat") + assert.Contains(t, v.Groups, "Web") + assert.Contains(t, v.Groups, "TomCat") v.GroupsToLower() - v.assertGroupNotExists(t, "Web") - v.assertGroupNotExists(t, "TomCat") - v.assertGroupExists(t, "web") - v.assertGroupExists(t, "tomcat") - - assert(t, v.Groups["web"].Name == "web", "Group web should have matching name") - v.Groups["web"].assertChildGroupExists(t, "tomcat") - v.Groups["web"].assertHostExists(t, "TomCat") - - assert(t, v.Groups["tomcat"].Name == "tomcat", "Group tomcat should have matching name") - v.Groups["tomcat"].assertHostExists(t, "TomCat") - v.Groups["tomcat"].assertHostExists(t, "tomcat-1") - v.Groups["tomcat"].assertHostExists(t, "cat") + assert.NotContains(t, v.Groups, "Web") + assert.NotContains(t, v.Groups, "TomCat") + assert.Contains(t, v.Groups, "web") + assert.Contains(t, v.Groups, "tomcat") + + assert.Equal(t, "web", v.Groups["web"].Name, "Group web should have matching name") + assert.Contains(t, v.Groups["web"].Children, "tomcat") + assert.Contains(t, v.Groups["web"].Hosts, "TomCat") + + assert.Equal(t, "tomcat", v.Groups["tomcat"].Name, "Group tomcat should have matching name") + assert.Contains(t, v.Groups["tomcat"].Hosts, "TomCat") + assert.Contains(t, v.Groups["tomcat"].Hosts, "tomcat-1") + assert.Contains(t, v.Groups["tomcat"].Hosts, "cat") } func TestGroupsAndHostsToLower(t *testing.T) { @@ -359,40 +344,61 @@ func TestGroupsAndHostsToLower(t *testing.T) { TomCat tomcat-1 `) - v.assertGroupExists(t, "Web") - v.assertGroupExists(t, "TomCat") + assert.Contains(t, v.Groups, "Web") + assert.Contains(t, v.Groups, "TomCat") - v.assertHostExists(t, "CatFish") - v.assertHostExists(t, "TomCat") - v.assertHostExists(t, "tomcat-1") + assert.Contains(t, v.Hosts, "CatFish") + assert.Contains(t, v.Hosts, "TomCat") + assert.Contains(t, v.Hosts, "tomcat-1") v.GroupsToLower() v.HostsToLower() - v.assertGroupNotExists(t, "Web") - v.assertGroupNotExists(t, "TomCat") - v.assertGroupExists(t, "web") - v.assertGroupExists(t, "tomcat") - - v.assertHostNotExists(t, "CatFish") - v.assertHostNotExists(t, "TomCat") - v.assertHostExists(t, "catfish") - v.assertHostExists(t, "tomcat") - v.assertHostExists(t, "tomcat-1") - - v.Groups["web"].assertHostExists(t, "catfish") - v.Groups["web"].assertChildGroupExists(t, "tomcat") - v.Groups["tomcat"].assertHostExists(t, "tomcat") - v.Groups["tomcat"].assertHostExists(t, "tomcat-1") + assert.NotContains(t, v.Groups, "Web") + assert.NotContains(t, v.Groups, "TomCat") + assert.Contains(t, v.Groups, "web") + assert.Contains(t, v.Groups, "tomcat") + + assert.NotContains(t, v.Hosts, "CatFish") + assert.NotContains(t, v.Hosts, "TomCat") + assert.Contains(t, v.Hosts, "catfish") + assert.Contains(t, v.Hosts, "tomcat") + assert.Contains(t, v.Hosts, "tomcat-1") + + assert.Contains(t, v.Groups["web"].Hosts, "catfish") + assert.Contains(t, v.Groups["web"].Children, "tomcat") + assert.Contains(t, v.Groups["tomcat"].Hosts, "tomcat") + assert.Contains(t, v.Groups["tomcat"].Hosts, "tomcat-1") +} + +func TestGroupLoops(t *testing.T) { + v := parseString(t, ` + [group1] + host1 + + [group1:children] + group2 + + [group2:children] + group1 + `) + + assert.Contains(t, v.Groups, "group1") + assert.Contains(t, v.Groups, "group2") + assert.Contains(t, v.Groups["group1"].Parents, "all") + assert.Contains(t, v.Groups["group1"].Parents, "group2") + assert.NotContains(t, v.Groups["group1"].Parents, "group1") + assert.Len(t, v.Groups["group1"].Parents, 2) + assert.Contains(t, v.Groups["group2"].Parents, "group1") } func TestVariablesEscaping(t *testing.T) { v := parseString(t, ` host ansible_ssh_common_args="-o ProxyCommand='ssh -W %h:%p somehost'" other_var_same_value="-o ProxyCommand='ssh -W %h:%p somehost'" # comment `) - v.assertHostExists(t, "host") - v.Hosts["host"].assertVar(t, "ansible_ssh_common_args", "-o ProxyCommand='ssh -W %h:%p somehost'") - v.Hosts["host"].assertVar(t, "other_var_same_value", "-o ProxyCommand='ssh -W %h:%p somehost'") + assert.Contains(t, v.Hosts, "host") + assert.Equal(t, "-o ProxyCommand='ssh -W %h:%p somehost'", v.Hosts["host"].Vars["ansible_ssh_common_args"]) + assert.Equal(t, "-o ProxyCommand='ssh -W %h:%p somehost'", v.Hosts["host"].Vars["other_var_same_value"]) } func TestComments(t *testing.T) { @@ -407,19 +413,18 @@ func TestComments(t *testing.T) { tomcat-1 # Small indention comment cat # Big indention comment `) - v.assertGroupExists(t, "web") - v.assertGroupExists(t, "tomcat") - v.Groups["web"].assertChildGroupExists(t, "tomcat") - - v.assertHostExists(t, "tomcat") - v.assertHostExists(t, "tomcat-1") - v.assertHostExists(t, "cat") - v.Groups["tomcat"].assertHostExists(t, "tomcat") - v.Groups["tomcat"].assertHostExists(t, "tomcat-1") - v.Groups["tomcat"].assertHostExists(t, "cat") - v.assertHostExists(t, "catfish") - v.Groups["ungrouped"].assertHostExists(t, "catfish") - + assert.Contains(t, v.Groups, "web") + assert.Contains(t, v.Groups, "tomcat") + assert.Contains(t, v.Groups["web"].Children, "tomcat") + + assert.Contains(t, v.Hosts, "tomcat") + assert.Contains(t, v.Hosts, "tomcat-1") + assert.Contains(t, v.Hosts, "cat") + assert.Contains(t, v.Groups["tomcat"].Hosts, "tomcat") + assert.Contains(t, v.Groups["tomcat"].Hosts, "tomcat-1") + assert.Contains(t, v.Groups["tomcat"].Hosts, "cat") + assert.Contains(t, v.Hosts, "catfish") + assert.Contains(t, v.Groups["ungrouped"].Hosts, "catfish") } func TestHostMatching(t *testing.T) { @@ -434,14 +439,33 @@ func TestHostMatching(t *testing.T) { cat `) hosts := v.Match("*cat*") - assert(t, len(hosts) == 4, fmt.Sprintf("Should be 4, got: %d\n%v", len(hosts), getNames(hosts))) + assert.Len(t, hosts, 4) +} +func TestHostMapListValues(t *testing.T) { + v := parseString(t, ` + host1 + host2 + host3 + `) + + hosts := HostMapListValues(v.Hosts) + assert.Len(t, hosts, 3) + for _, v := range hosts { + assert.Contains(t, hosts, v) + } } -func getNames(hosts []*Host) []string { - var result []string - for _, host := range hosts { - result = append(result, host.Name) +func TestGroupMapListValues(t *testing.T) { + v := parseString(t, ` + [group1] + [group2] + [group3] + `) + + groups := GroupMapListValues(v.Groups) + assert.Len(t, groups, 5) + for _, v := range groups { + assert.Contains(t, groups, v) } - return result } diff --git a/go.mod b/go.mod index 51147a3..32eefbe 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,8 @@ module github.com/relex/aini go 1.13 -require github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 +require ( + github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 + github.com/stretchr/testify v1.7.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect +) diff --git a/go.sum b/go.sum index 0be9157..59db3a0 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,14 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/inventory.go b/inventory.go index 3ad6603..0be1658 100644 --- a/inventory.go +++ b/inventory.go @@ -2,33 +2,87 @@ package aini // Inventory-related helper methods -// Reconcile ensures inventory basic rules, run after updates +// Reconcile ensures inventory basic rules, run after updates. +// After initial inventory file processing, only direct relationships are set. +// +// This method: +// * (re)sets Children and Parents for hosts and groups +// * ensures that mandatory groups exist +// * calculates variables for hosts and groups func (inventory *InventoryData) Reconcile() { + // Clear all computed data + for _, host := range inventory.Hosts { + host.clearData() + } + // a group can be empty (with no hosts in it), so the previous method will not clean it + // on the other hand, a group could have been attached to a host by a user, but not added to the inventory.Groups map + // so it's safer just to clean everything + for _, group := range inventory.Groups { + group.clearData(make(map[string]struct{}, len(inventory.Groups))) + } + allGroup := inventory.getOrCreateGroup("all") - allGroup.Hosts = inventory.Hosts - allGroup.Children = inventory.Groups + ungroupedGroup := inventory.getOrCreateGroup("ungrouped") + ungroupedGroup.directParents[allGroup.Name] = allGroup + // First, ensure that inventory.Groups contains all the groups for _, host := range inventory.Hosts { - for _, group := range host.Groups { - host.setVarsIfNotExist(group.Vars) + for _, group := range host.directGroups { + inventory.Groups[group.Name] = group for _, ancestor := range group.getAncestors() { - ancestor.Hosts[host.Name] = host - ancestor.Children[group.Name] = group - host.Groups[ancestor.Name] = ancestor - for k, v := range ancestor.Vars { - if _, ok := host.Vars[k]; !ok { - host.Vars[k] = v - } - if _, ok := group.Vars[k]; !ok { - group.Vars[k] = v - } - } + inventory.Groups[ancestor.Name] = ancestor + } + } + } + + // Calculate intergroup relationships + for _, group := range inventory.Groups { + group.directParents[allGroup.Name] = allGroup + for _, ancestor := range group.getAncestors() { + group.Parents[ancestor.Name] = ancestor + ancestor.Children[group.Name] = group + } + } + + // Now set hosts for groups and groups for hosts + for _, host := range inventory.Hosts { + host.Groups[allGroup.Name] = allGroup + for _, group := range host.directGroups { + group.Hosts[host.Name] = host + host.Groups[group.Name] = group + for _, parent := range group.Parents { + group.Parents[parent.Name] = parent + parent.Children[group.Name] = group + parent.Hosts[host.Name] = host + host.Groups[parent.Name] = parent } } - host.setVarsIfNotExist(allGroup.Vars) - host.Groups["all"] = allGroup } - inventory.Groups["all"] = allGroup + inventory.reconcileVars() +} + +func (host *Host) clearData() { + host.Groups = make(map[string]*Group) + host.Vars = make(map[string]string) + for _, group := range host.directGroups { + group.clearData(make(map[string]struct{}, len(host.Groups))) + } +} + +func (group *Group) clearData(visited map[string]struct{}) { + if _, ok := visited[group.Name]; ok { + return + } + group.Hosts = make(map[string]*Host) + group.Parents = make(map[string]*Group) + group.Children = make(map[string]*Group) + group.Vars = make(map[string]string) + group.allInventoryVars = nil + group.allFileVars = nil + visited[group.Name] = struct{}{} + for _, parent := range group.directParents { + parent.clearData(visited) + } } // getOrCreateGroup return group from inventory if exists or creates empty Group with given name @@ -38,10 +92,14 @@ func (inventory *InventoryData) getOrCreateGroup(groupName string) *Group { } g := &Group{ Name: groupName, - Hosts: make(map[string]*Host, 0), - Vars: make(map[string]string, 0), - Children: make(map[string]*Group, 0), - Parents: make(map[string]*Group, 0), + Hosts: make(map[string]*Host), + Vars: make(map[string]string), + Children: make(map[string]*Group), + Parents: make(map[string]*Group), + + directParents: make(map[string]*Group), + inventoryVars: make(map[string]string), + fileVars: make(map[string]string), } inventory.Groups[groupName] = g return g @@ -55,8 +113,12 @@ func (inventory *InventoryData) getOrCreateHost(hostName string) *Host { h := &Host{ Name: hostName, Port: 22, - Groups: make(map[string]*Group, 0), - Vars: make(map[string]string, 0), + Groups: make(map[string]*Group), + Vars: make(map[string]string), + + directGroups: make(map[string]*Group), + inventoryVars: make(map[string]string), + fileVars: make(map[string]string), } inventory.Hosts[hostName] = h return h @@ -65,13 +127,24 @@ func (inventory *InventoryData) getOrCreateHost(hostName string) *Host { // getAncestors returns all Ancestors of a given group in level order func (group *Group) getAncestors() []*Group { result := make([]*Group, 0) + if len(group.directParents) == 0 { + return result + } + visited := map[string]struct{}{group.Name: {}} - for queue := []*Group{group}; ; { + for queue := GroupMapListValues(group.directParents); ; { group := queue[0] - parentList := GroupMapListValues(group.Parents) - result = append(result, parentList...) copy(queue, queue[1:]) queue = queue[:len(queue)-1] + if _, ok := visited[group.Name]; ok { + if len(queue) == 0 { + return result + } + continue + } + visited[group.Name] = struct{}{} + parentList := GroupMapListValues(group.directParents) + result = append(result, group) queue = append(queue, parentList...) if len(queue) == 0 { @@ -80,19 +153,16 @@ func (group *Group) getAncestors() []*Group { } } -// setVarsIfNotExist sets Var for host if it doesn't have it already -func (host *Host) setVarsIfNotExist(vars map[string]string) { - for k, v := range vars { - if _, ok := host.Vars[k]; !ok { - host.Vars[k] = v - } +// addValues fills `to` map with values from `from` map +func addValues(to map[string]string, from map[string]string) { + for k, v := range from { + to[k] = v } } -func addValuesFromMap(m1 map[string]string, m2 map[string]string) { - for k, v := range m2 { - if m1[k] == "" { - m1[k] = v - } - } +// copyStringMap creates a non-deep copy of the map +func copyStringMap(from map[string]string) map[string]string { + result := make(map[string]string, len(from)) + addValues(result, from) + return result } diff --git a/parser.go b/parser.go index bd3c30d..0dc46d5 100644 --- a/parser.go +++ b/parser.go @@ -57,12 +57,12 @@ func (inventory *InventoryData) parse(reader *bufio.Reader) error { activeGroup = inventory.getOrCreateGroup(matches[0][1]) var ok bool if activeState, ok = getState(matches[0][2]); !ok { - return fmt.Errorf("Section [%s] has unknown type: %s", line, matches[0][2]) + return fmt.Errorf("section [%s] has unknown type: %s", line, matches[0][2]) } continue } else if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { - return fmt.Errorf("Invalid section entry: '%s'. Please make sure that there are no spaces in the section entry, and that there are no other invalid characters", line) + return fmt.Errorf("invalid section entry: '%s'. Make sure that there are no spaces or other characters in the section entry", line) } if activeState == hostsState { @@ -70,14 +70,12 @@ func (inventory *InventoryData) parse(reader *bufio.Reader) error { if err != nil { return err } - for k, v := range hosts { - activeGroup.Hosts[k] = v - if activeGroup.Name != "ungrouped" { - delete(inventory.Groups["ungrouped"].Hosts, k) - } - } for _, host := range hosts { + host.directGroups[activeGroup.Name] = activeGroup inventory.Hosts[host.Name] = host + if activeGroup.Name != "ungrouped" { + delete(host.directGroups, "ungrouped") + } } } if activeState == childrenState { @@ -87,7 +85,7 @@ func (inventory *InventoryData) parse(reader *bufio.Reader) error { } groupName := parsed[0] newGroup := inventory.getOrCreateGroup(groupName) - newGroup.Parents[activeGroup.Name] = activeGroup + newGroup.directParents[activeGroup.Name] = activeGroup inventory.Groups[line] = newGroup } if activeState == varsState { @@ -95,7 +93,7 @@ func (inventory *InventoryData) parse(reader *bufio.Reader) error { if err != nil { return err } - activeGroup.Vars[k] = v + activeGroup.inventoryVars[k] = v } } inventory.Groups[activeGroup.Name] = activeGroup @@ -130,8 +128,8 @@ func (inventory *InventoryData) getHosts(line string, group *Group) (map[string] host := inventory.getOrCreateHost(hostname) host.Port = port - host.Groups[group.Name] = group - addValuesFromMap(host.Vars, vars) + host.directGroups[group.Name] = group + addValues(host.inventoryVars, vars) result[host.Name] = host } @@ -142,7 +140,7 @@ func (inventory *InventoryData) getHosts(line string, group *Group) (map[string] func splitKV(kv string) (string, string, error) { keyval := strings.SplitN(kv, "=", 2) if len(keyval) == 1 { - return "", "", fmt.Errorf("Bad key=value pair supplied: %s", kv) + return "", "", fmt.Errorf("bad key=value pair supplied: %s", kv) } return strings.TrimSpace(keyval[0]), strings.TrimSpace(keyval[1]), nil } @@ -174,13 +172,13 @@ func expandHostPattern(hostpattern string) ([]string, error) { return []string{hostpattern}, nil } if len(parts) != 3 { - return nil, fmt.Errorf("Wrong host pattern: %s", hostpattern) + return nil, fmt.Errorf("wrong host pattern: %s", hostpattern) } head, nrange, tail := parts[0], parts[1], parts[2] bounds := strings.Split(nrange, ":") if len(bounds) < 2 || len(bounds) > 3 { - return nil, fmt.Errorf("Wrong host pattern: %s", hostpattern) + return nil, fmt.Errorf("wrong host pattern: %s", hostpattern) } var begin, end []rune @@ -195,7 +193,7 @@ func expandHostPattern(hostpattern string) ([]string, error) { format := fmt.Sprintf("%%0%dd", len(end)) begin = []rune(fmt.Sprintf(format, 0)) } else { - return nil, fmt.Errorf("Skipping range start in not allowed with alphabetical range: %s", hostpattern) + return nil, fmt.Errorf("skipping range start in not allowed with alphabetical range: %s", hostpattern) } } else { begin = []rune(bounds[0]) @@ -220,7 +218,7 @@ func expandHostPattern(hostpattern string) ([]string, error) { } if len(chars) == 0 { - return nil, fmt.Errorf("Bad range specified: %s", nrange) + return nil, fmt.Errorf("bad range specified: %s", nrange) } var hosts []string diff --git a/test_data/group_vars/empty/.gitkeep b/test_data/group_vars/empty/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test_data/group_vars/nginx.yml b/test_data/group_vars/nginx.yml new file mode 100644 index 0000000..91cc73b --- /dev/null +++ b/test_data/group_vars/nginx.yml @@ -0,0 +1,7 @@ +--- +nginx_int_var: 1 +nginx_string_var: string +nginx_bool_var: true +nginx_object_var: + this: + is: object diff --git a/test_data/group_vars/tomcat.yml b/test_data/group_vars/tomcat.yml new file mode 100644 index 0000000..1357ed3 --- /dev/null +++ b/test_data/group_vars/tomcat.yml @@ -0,0 +1,3 @@ +--- +# File name's case doesn't match group name's case in inventory +tomcat_string_var: string diff --git a/test_data/group_vars/web/any_vars.yml b/test_data/group_vars/web/any_vars.yml new file mode 100644 index 0000000..34eb181 --- /dev/null +++ b/test_data/group_vars/web/any_vars.yml @@ -0,0 +1,7 @@ +--- +# This variable will be overwritten since the file is earlier in lexical order +web_int_var: 0 +web_string_var: string1 +web_object_var: + this: + is: object? diff --git a/test_data/group_vars/web/junk_file.txt b/test_data/group_vars/web/junk_file.txt new file mode 100644 index 0000000..0036405 --- /dev/null +++ b/test_data/group_vars/web/junk_file.txt @@ -0,0 +1 @@ +This file should not be read diff --git a/test_data/group_vars/web/some_vars.yml b/test_data/group_vars/web/some_vars.yml new file mode 100644 index 0000000..ef3c346 --- /dev/null +++ b/test_data/group_vars/web/some_vars.yml @@ -0,0 +1,6 @@ +--- +web_int_var: 1 +web_string_var: string +web_object_var: + this: + is: object diff --git a/test_data/host_vars/empty/.gitkeep b/test_data/host_vars/empty/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/test_data/host_vars/host1.yml b/test_data/host_vars/host1.yml new file mode 100644 index 0000000..a7a73b7 --- /dev/null +++ b/test_data/host_vars/host1.yml @@ -0,0 +1,6 @@ +--- +host1_int_var: 1 +host1_string_var: string +host1_object_var: + this: + is: object diff --git a/test_data/host_vars/host2/any_vars.yml b/test_data/host_vars/host2/any_vars.yml new file mode 100644 index 0000000..8e67a7b --- /dev/null +++ b/test_data/host_vars/host2/any_vars.yml @@ -0,0 +1,7 @@ +--- +# This variable will be overwritten since the file is earlier in lexical order +host2_int_var: 0 +host2_string_var: string1 +host2_object_var: + this: + is: object? diff --git a/test_data/host_vars/host2/junk_file.txt b/test_data/host_vars/host2/junk_file.txt new file mode 100644 index 0000000..0036405 --- /dev/null +++ b/test_data/host_vars/host2/junk_file.txt @@ -0,0 +1 @@ +This file should not be read diff --git a/test_data/host_vars/host2/some_file.yml b/test_data/host_vars/host2/some_file.yml new file mode 100644 index 0000000..8890e20 --- /dev/null +++ b/test_data/host_vars/host2/some_file.yml @@ -0,0 +1,6 @@ +--- +host2_int_var: 1 +host2_string_var: string +host2_object_var: + this: + is: object diff --git a/test_data/host_vars/host7.yml b/test_data/host_vars/host7.yml new file mode 100644 index 0000000..28a71c8 --- /dev/null +++ b/test_data/host_vars/host7.yml @@ -0,0 +1,3 @@ +--- +# File name's case doesn't match group name's case in inventory +host7_string_var: string diff --git a/test_data/inventory b/test_data/inventory new file mode 100644 index 0000000..78eed7a --- /dev/null +++ b/test_data/inventory @@ -0,0 +1,25 @@ +host5 + +[web:children] +nginx +apache + +[web:vars] +web_string_var=should be overwritten +web_inventory_string_var=present + +[web] +host1 +host2 + +[nginx] +host1 host1_string_var="should be overwritten" host1_inventory_string_var="present" +host3 +host4 + +[apache] +host5 +host6 + +[TomCat] +Host7 diff --git a/vars.go b/vars.go new file mode 100644 index 0000000..17ef223 --- /dev/null +++ b/vars.go @@ -0,0 +1,195 @@ +package aini + +import ( + "encoding/json" + "fmt" + "io/fs" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + + "gopkg.in/yaml.v3" +) + +// AddVars take a path that contains group_vars and host_vars directories +// and adds these variables to the InventoryData +func (inventory *InventoryData) AddVars(path string) error { + return inventory.doAddVars(path, false) +} + +// AddVarsLowerCased does the same as AddVars, but converts hostnames and groups name to lowercase. +// Use this function if you've executed `inventory.HostsToLower` or `inventory.GroupsToLower` +func (inventory *InventoryData) AddVarsLowerCased(path string) error { + return inventory.doAddVars(path, true) +} + +func (inventory *InventoryData) doAddVars(path string, lowercased bool) error { + _, err := os.Stat(path) + if err != nil { + return err + } + walk(path, "group_vars", inventory.getGroupsMap(), lowercased) + walk(path, "host_vars", inventory.getHostsMap(), lowercased) + inventory.reconcileVars() + return nil +} + +type fileVarsGetter interface { + getFileVars() map[string]string +} + +func (host *Host) getFileVars() map[string]string { + return host.fileVars +} + +func (group *Group) getFileVars() map[string]string { + return group.fileVars +} + +func (inventory InventoryData) getHostsMap() map[string]fileVarsGetter { + result := make(map[string]fileVarsGetter, len(inventory.Hosts)) + for k, v := range inventory.Hosts { + result[k] = v + } + return result +} + +func (inventory InventoryData) getGroupsMap() map[string]fileVarsGetter { + result := make(map[string]fileVarsGetter, len(inventory.Groups)) + for k, v := range inventory.Groups { + result[k] = v + } + return result +} + +func walk(root string, subdir string, m map[string]fileVarsGetter, lowercased bool) error { + path := filepath.Join(root, subdir) + _, err := os.Stat(path) + // If the dir doesn't exist we can just skip it + if err != nil { + return nil + } + f := getWalkerFn(path, m, lowercased) + return filepath.WalkDir(path, f) +} + +func getWalkerFn(root string, m map[string]fileVarsGetter, lowercased bool) fs.WalkDirFunc { + var currentVars map[string]string + return func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if filepath.Dir(path) == root { + filename := filepath.Base(path) + ext := filepath.Ext(path) + itemName := strings.TrimSuffix(filename, ext) + if lowercased { + itemName = strings.ToLower(itemName) + } + if currentItem, ok := m[itemName]; ok { + currentVars = currentItem.getFileVars() + } else { + return nil + } + } + if d.IsDir() { + return nil + } + return addVarsFromFile(currentVars, path) + } +} + +func addVarsFromFile(currentVars map[string]string, path string) error { + if currentVars == nil { + // Group or Host doesn't exist in the inventory, ignoring + return nil + } + ext := filepath.Ext(path) + if ext != ".yaml" && ext != ".yml" { + return nil + } + f, err := ioutil.ReadFile(path) + if err != nil { + return err + } + vars := make(map[string]interface{}) + err = yaml.Unmarshal(f, &vars) + if err != nil { + return err + } + for k, v := range vars { + switch v := v.(type) { + case string: + currentVars[k] = v + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + currentVars[k] = fmt.Sprint(v) + case bool: + currentVars[k] = strconv.FormatBool(v) + default: + data, err := json.Marshal(v) + if err != nil { + return err + } + currentVars[k] = string(data) + } + } + return nil +} + +func (inventory *InventoryData) reconcileVars() { + /* + Priority of variables is defined here: https://docs.ansible.com/ansible/latest/user_guide/playbooks_variables.html#understanding-variable-precedence + Distilled list looks like this: + 1. inventory file group vars + 2. group_vars/* + 3. inventory file host vars + 4. inventory host_vars/* + */ + for _, group := range inventory.Groups { + group.allInventoryVars = nil + group.allFileVars = nil + } + for _, group := range inventory.Groups { + group.Vars = make(map[string]string) + group.populateInventoryVars() + group.populateFileVars() + // At this point we already "populated" all parent's inventory and file vars + // So it's fine to build Vars right away, without needing the second pass + group.Vars = copyStringMap(group.allInventoryVars) + addValues(group.Vars, group.allFileVars) + } + for _, host := range inventory.Hosts { + host.Vars = make(map[string]string) + for _, group := range GroupMapListValues(host.directGroups) { + addValues(host.Vars, group.Vars) + } + addValues(host.Vars, host.inventoryVars) + addValues(host.Vars, host.fileVars) + } +} + +func (group *Group) populateInventoryVars() { + if group.allInventoryVars != nil { + return + } + group.allInventoryVars = make(map[string]string) + for _, parent := range GroupMapListValues(group.directParents) { + parent.populateInventoryVars() + addValues(group.allInventoryVars, parent.allInventoryVars) + } + addValues(group.allInventoryVars, group.inventoryVars) +} + +func (group *Group) populateFileVars() { + if group.allFileVars != nil { + return + } + group.allFileVars = make(map[string]string) + for _, parent := range GroupMapListValues(group.directParents) { + parent.populateFileVars() + addValues(group.allFileVars, parent.allFileVars) + } + addValues(group.allFileVars, group.fileVars) +} diff --git a/vars_test.go b/vars_test.go new file mode 100644 index 0000000..f369106 --- /dev/null +++ b/vars_test.go @@ -0,0 +1,57 @@ +package aini + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddVars(t *testing.T) { + v, err := ParseFile("test_data/inventory") + assert.Nil(t, err) + + assert.Equal(t, "present", v.Groups["web"].Vars["web_inventory_string_var"]) + assert.Equal(t, "should be overwritten", v.Groups["web"].Vars["web_string_var"]) + + assert.Equal(t, "present", v.Hosts["host1"].Vars["host1_inventory_string_var"]) + assert.Equal(t, "should be overwritten", v.Hosts["host1"].Vars["host1_string_var"]) + + err = v.AddVars("test_data") + assert.Nil(t, err) + + assert.Equal(t, "1", v.Groups["web"].Vars["web_int_var"]) + assert.Equal(t, "string", v.Groups["web"].Vars["web_string_var"]) + assert.Equal(t, "{\"this\":{\"is\":\"object\"}}", v.Groups["web"].Vars["web_object_var"]) + assert.Equal(t, "present", v.Groups["web"].Vars["web_inventory_string_var"]) + + assert.Equal(t, "1", v.Groups["nginx"].Vars["nginx_int_var"]) + assert.Equal(t, "string", v.Groups["nginx"].Vars["nginx_string_var"]) + assert.Equal(t, "true", v.Groups["nginx"].Vars["nginx_bool_var"]) + assert.Equal(t, "{\"this\":{\"is\":\"object\"}}", v.Groups["nginx"].Vars["nginx_object_var"]) + + assert.Equal(t, "1", v.Hosts["host1"].Vars["host1_int_var"]) + assert.Equal(t, "string", v.Hosts["host1"].Vars["host1_string_var"]) + assert.Equal(t, "{\"this\":{\"is\":\"object\"}}", v.Hosts["host1"].Vars["host1_object_var"]) + assert.Equal(t, "present", v.Hosts["host1"].Vars["host1_inventory_string_var"]) + + assert.Equal(t, "1", v.Hosts["host2"].Vars["host2_int_var"]) + assert.Equal(t, "string", v.Hosts["host2"].Vars["host2_string_var"]) + assert.Equal(t, "{\"this\":{\"is\":\"object\"}}", v.Hosts["host2"].Vars["host2_object_var"]) + + assert.NotContains(t, v.Groups, "tomcat") + assert.NotContains(t, v.Hosts, "host7") +} + +func TestAddVarsLowerCased(t *testing.T) { + v, err := ParseFile("test_data/inventory") + assert.Nil(t, err) + + v.HostsToLower() + v.GroupsToLower() + v.AddVarsLowerCased("test_data") + + assert.Contains(t, v.Groups, "tomcat") + assert.Contains(t, v.Hosts, "host7") + assert.Equal(t, "string", v.Groups["tomcat"].Vars["tomcat_string_var"]) + assert.Equal(t, "string", v.Hosts["host7"].Vars["host7_string_var"]) +}