diff --git a/loader/full-struct_test.go b/loader/full-struct_test.go index 6ba0327f..cc11967a 100644 --- a/loader/full-struct_test.go +++ b/loader/full-struct_test.go @@ -738,8 +738,8 @@ services: - project_db_1:mysql - project_db_1:postgresql extra_hosts: - - otherhost:50.31.209.229 - - somehost:162.242.195.82 + - otherhost=50.31.209.229 + - somehost=162.242.195.82 hostname: foo healthcheck: test: @@ -1336,8 +1336,8 @@ func fullExampleJSON(workingDir, homeDir string) string { "project_db_1:postgresql" ], "extra_hosts": [ - "otherhost:50.31.209.229", - "somehost:162.242.195.82" + "otherhost=50.31.209.229", + "somehost=162.242.195.82" ], "hostname": "foo", "healthcheck": { diff --git a/types/hostList.go b/types/hostList.go index 007f9ea1..68692b7c 100644 --- a/types/hostList.go +++ b/types/hostList.go @@ -20,28 +20,33 @@ import ( "encoding/json" "fmt" "sort" + "strings" ) // HostsList is a list of colon-separated host-ip mappings type HostsList map[string]string -// AsList return host-ip mappings as a list of colon-separated strings -func (h HostsList) AsList() []string { +// AsList returns host-ip mappings as a list of strings, using the given +// separator. The Docker Engine API expects ':' separators, the original format +// for '--add-hosts'. But an '=' separator is used in YAML/JSON renderings to +// make IPv6 addresses more readable (for example "my-host=::1" instead of +// "my-host:::1"). +func (h HostsList) AsList(sep string) []string { l := make([]string, 0, len(h)) for k, v := range h { - l = append(l, fmt.Sprintf("%s:%s", k, v)) + l = append(l, fmt.Sprintf("%s%s%s", k, sep, v)) } return l } func (h HostsList) MarshalYAML() (interface{}, error) { - list := h.AsList() + list := h.AsList("=") sort.Strings(list) return list, nil } func (h HostsList) MarshalJSON() ([]byte, error) { - list := h.AsList() + list := h.AsList("=") sort.Strings(list) return json.Marshal(list) } @@ -58,9 +63,21 @@ func (h *HostsList) DecodeMapstructure(value interface{}) error { } *h = list case []interface{}: - *h = decodeMapping(v, ":") + *h = decodeMapping(v, "=", ":") default: return fmt.Errorf("unexpected value type %T for mapping", value) } + for host, ip := range *h { + // Check that there is a hostname and that it doesn't contain either + // of the allowed separators, to generate a clearer error than the + // engine would do if it splits the string differently. + if host == "" || strings.ContainsAny(host, ":=") { + return fmt.Errorf("bad host name '%s'", host) + } + // Remove brackets from IP addresses (for example "[::1]" -> "::1"). + if len(ip) > 2 && ip[0] == '[' && ip[len(ip)-1] == ']' { + (*h)[host] = ip[1 : len(ip)-1] + } + } return nil } diff --git a/types/hostList_test.go b/types/hostList_test.go new file mode 100644 index 00000000..0295bdac --- /dev/null +++ b/types/hostList_test.go @@ -0,0 +1,177 @@ +/* + Copyright 2020 The Compose Specification Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package types + +import ( + "sort" + "strings" + "testing" + + "gotest.tools/v3/assert" + is "gotest.tools/v3/assert/cmp" +) + +func TestHostsList(t *testing.T) { + testCases := []struct { + doc string + input map[string]any + expectedError string + expectedOut string + }{ + { + doc: "IPv4", + input: map[string]any{"myhost": "192.168.0.1"}, + expectedOut: "myhost:192.168.0.1", + }, + { + doc: "Weird but permitted, IPv4 with brackets", + input: map[string]any{"myhost": "[192.168.0.1]"}, + expectedOut: "myhost:192.168.0.1", + }, + { + doc: "Host and domain", + input: map[string]any{"host.invalid": "10.0.2.1"}, + expectedOut: "host.invalid:10.0.2.1", + }, + { + doc: "IPv6", + input: map[string]any{"anipv6host": "2003:ab34:e::1"}, + expectedOut: "anipv6host:2003:ab34:e::1", + }, + { + doc: "IPv6, brackets", + input: map[string]any{"anipv6host": "[2003:ab34:e::1]"}, + expectedOut: "anipv6host:2003:ab34:e::1", + }, + { + doc: "IPv6 localhost", + input: map[string]any{"ipv6local": "::1"}, + expectedOut: "ipv6local:::1", + }, + { + doc: "IPv6 localhost, brackets", + input: map[string]any{"ipv6local": "[::1]"}, + expectedOut: "ipv6local:::1", + }, + { + doc: "host-gateway special case", + input: map[string]any{"host.docker.internal": "host-gateway"}, + expectedOut: "host.docker.internal:host-gateway", + }, + { + doc: "multiple inputs", + input: map[string]any{ + "myhost": "192.168.0.1", + "anipv6host": "[2003:ab34:e::1]", + "host.docker.internal": "host-gateway", + }, + expectedOut: "anipv6host:2003:ab34:e::1 host.docker.internal:host-gateway myhost:192.168.0.1", + }, + { + // This won't work, but address validation is left to the engine. + doc: "no ip", + input: map[string]any{"myhost": nil}, + expectedOut: "myhost:", + }, + { + doc: "bad host, colon", + input: map[string]any{":": "::1"}, + expectedError: "bad host name", + }, + { + doc: "bad host, eq", + input: map[string]any{"=": "::1"}, + expectedError: "bad host name", + }, + } + + inputAsList := func(input map[string]any, sep string) []any { + result := make([]any, 0, len(input)) + for host, ip := range input { + if ip == nil { + result = append(result, host+sep) + } else { + result = append(result, host+sep+ip.(string)) + } + } + return result + } + + for _, tc := range testCases { + // Decode the input map, check the output is as-expected. + var hlFromMap HostsList + t.Run(tc.doc+"_map", func(t *testing.T) { + err := hlFromMap.DecodeMapstructure(tc.input) + if tc.expectedError == "" { + assert.NilError(t, err) + actualOut := hlFromMap.AsList(":") + sort.Strings(actualOut) + sortedActualStr := strings.Join(actualOut, " ") + assert.Check(t, is.Equal(sortedActualStr, tc.expectedOut)) + + // The YAML rendering of HostsList should be the same as the AsList() output, but + // with '=' separators. + yamlOut, err := hlFromMap.MarshalYAML() + assert.NilError(t, err) + expYAMLOut := make([]string, len(actualOut)) + for i, s := range actualOut { + expYAMLOut[i] = strings.Replace(s, ":", "=", 1) + } + assert.DeepEqual(t, yamlOut.([]string), expYAMLOut) + + // The JSON rendering of HostsList should also have '=' separators. Same as the + // YAML output, but as a JSON list of strings. + jsonOut, err := hlFromMap.MarshalJSON() + assert.NilError(t, err) + expJSONStrings := make([]string, len(expYAMLOut)) + for i, s := range expYAMLOut { + expJSONStrings[i] = `"` + s + `"` + } + expJSONString := "[" + strings.Join(expJSONStrings, ",") + "]" + assert.Check(t, is.Equal(string(jsonOut), expJSONString)) + } else { + assert.ErrorContains(t, err, tc.expectedError) + } + }) + + // Convert the input into a ':' separated list, check that the result is the same + // as for the map-input. + t.Run(tc.doc+"_colon_sep", func(t *testing.T) { + var hl HostsList + err := hl.DecodeMapstructure(inputAsList(tc.input, ":")) + if tc.expectedError == "" { + assert.NilError(t, err) + assert.DeepEqual(t, hl, hlFromMap) + } else { + assert.ErrorContains(t, err, tc.expectedError) + } + }) + + // Convert the input into a ':' separated list, check that the result is the same + // as for the map-input. + t.Run(tc.doc+"_eq_sep", func(t *testing.T) { + var hl HostsList + err := hl.DecodeMapstructure(inputAsList(tc.input, "=")) + if tc.expectedError == "" { + assert.NilError(t, err) + assert.DeepEqual(t, hl, hlFromMap) + } else { + assert.ErrorContains(t, err, tc.expectedError) + } + }) + } +} diff --git a/types/mapping.go b/types/mapping.go index ea8e137f..de6fb123 100644 --- a/types/mapping.go +++ b/types/mapping.go @@ -195,14 +195,23 @@ func (m *Mapping) DecodeMapstructure(value interface{}) error { return nil } -func decodeMapping(v []interface{}, sep string) map[string]string { +// Generate a mapping by splitting strings at any of seps, which will be tried +// in-order for each input string. (For example, to allow the preferred 'host=ip' +// in 'extra_hosts', as well as 'host:ip' for backwards compatibility.) +func decodeMapping(v []interface{}, seps ...string) map[string]string { mapping := make(Mapping, len(v)) for _, s := range v { - k, e, ok := strings.Cut(fmt.Sprint(s), sep) - if !ok { - e = "" + for i, sep := range seps { + k, e, ok := strings.Cut(fmt.Sprint(s), sep) + if ok { + // Mapping found with this separator, stop here. + mapping[k] = e + break + } else if i == len(seps)-1 { + // No more separators to try, map to empty string. + mapping[k] = "" + } } - mapping[k] = e } return mapping }