diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index ba5899343e3..2010bf3735b 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -115,6 +115,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...master[Check the HEAD d *Heartbeat* - Fixed excessive memory usage introduced in 7.5 due to over-allocating memory for HTTP checks. {pull}15639[15639] +- Fixed TCP TLS checks to properly validate hostnames, this broke in 7.x and only worked for IP SANs. {pull}17549[17549] *Journalbeat* diff --git a/heartbeat/hbtest/hbtestutil.go b/heartbeat/hbtest/hbtestutil.go index 9bec84ee0b5..548bc42a0eb 100644 --- a/heartbeat/hbtest/hbtestutil.go +++ b/heartbeat/hbtest/hbtestutil.go @@ -153,10 +153,19 @@ func SummaryChecks(up int, down int) validator.Validator { }) } +// ResolveChecks returns a lookslike matcher for the 'resolve' fields. +func ResolveChecks(ip string) validator.Validator { + return lookslike.MustCompile(map[string]interface{}{ + "resolve": map[string]interface{}{ + "ip": ip, + "rtt.us": isdef.IsDuration, + }, + }) +} + // SimpleURLChecks returns a check for a simple URL // with only a scheme, host, and port func SimpleURLChecks(t *testing.T, scheme string, host string, port uint16) validator.Validator { - hostPort := host if port != 0 { hostPort = fmt.Sprintf("%s:%d", host, port) @@ -165,6 +174,11 @@ func SimpleURLChecks(t *testing.T, scheme string, host string, port uint16) vali u, err := url.Parse(fmt.Sprintf("%s://%s", scheme, hostPort)) require.NoError(t, err) + return URLChecks(t, u) +} + +// URLChecks returns a validator for the given URL's fields +func URLChecks(t *testing.T, u *url.URL) validator.Validator { return lookslike.MustCompile(map[string]interface{}{ "url": wrappers.URLFields(u), }) diff --git a/heartbeat/monitors/active/dialchain/builder.go b/heartbeat/monitors/active/dialchain/builder.go deleted file mode 100644 index 1ba9671ad59..00000000000 --- a/heartbeat/monitors/active/dialchain/builder.go +++ /dev/null @@ -1,198 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package dialchain - -import ( - "fmt" - "net" - "net/url" - "time" - - "github.com/elastic/beats/v7/heartbeat/monitors" - "github.com/elastic/beats/v7/heartbeat/monitors/jobs" - "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" - "github.com/elastic/beats/v7/libbeat/beat" - "github.com/elastic/beats/v7/libbeat/common/transport" - "github.com/elastic/beats/v7/libbeat/common/transport/tlscommon" -) - -// Builder maintains a DialerChain for building dialers and dialer based -// monitoring jobs. -// The builder ensures a constant address is being used, for any host -// configured. This ensures the upper network layers (e.g. TLS) correctly see -// and process the original hostname. -type Builder struct { - template *DialerChain - addrIndex int - resolveViaSocks5 bool -} - -// BuilderSettings configures the layers of the dialer chain to be constructed -// by a Builder. -type BuilderSettings struct { - Timeout time.Duration - Socks5 transport.ProxyConfig - TLS *tlscommon.TLSConfig -} - -// Endpoint configures a host with all port numbers to be monitored by a dialer -// based job. -type Endpoint struct { - Host string - Ports []uint16 -} - -// NewBuilder creates a new Builder for constructing dialers. -func NewBuilder(settings BuilderSettings) (*Builder, error) { - d := &DialerChain{ - Net: netDialer(settings.Timeout), - } - resolveViaSocks5 := false - withProxy := settings.Socks5.URL != "" - if withProxy { - d.AddLayer(SOCKS5Layer(&settings.Socks5)) - resolveViaSocks5 = !settings.Socks5.LocalResolve - } - - // insert empty placeholder, so address can be replaced in dialer chain - // by replacing this placeholder dialer - idx := len(d.Layers) - d.AddLayer(IDLayer()) - - // add tls layer doing the TLS handshake based on the original address - if tls := settings.TLS; tls != nil { - d.AddLayer(TLSLayer(tls, settings.Timeout)) - } - - // validate dialerchain - if err := d.TestBuild(); err != nil { - return nil, err - } - - return &Builder{ - template: d, - addrIndex: idx, - resolveViaSocks5: resolveViaSocks5, - }, nil -} - -// AddLayer adds another custom network layer to the dialer chain. -func (b *Builder) AddLayer(l Layer) { - b.template.AddLayer(l) -} - -// Build create a new dialer, that will always use the constant address, no matter -// which address is used to connect using the dialer. -// The dialer chain will add per layer information to the given event. -func (b *Builder) Build(addr string, event *beat.Event) (transport.Dialer, error) { - // clone template, as multiple instance of a dialer can exist at the same time - dchain := b.template.Clone() - - // fix the final dialers TCP-level address - dchain.Layers[b.addrIndex] = ConstAddrLayer(addr) - - // create dialer chain with event to add per network layer information - d, err := dchain.Build(event) - return d, err -} - -// Run executes the given function with a new dialer instance. -func (b *Builder) Run( - event *beat.Event, - addr string, - fn func(*beat.Event, transport.Dialer) error, -) error { - dialer, err := b.Build(addr, event) - if err != nil { - return err - } - - return fn(event, dialer) -} - -// MakeDialerJobs creates a set of monitoring jobs. The jobs behavior depends -// on the builder, endpoint and mode configurations, normally set by user -// configuration. The task to execute the actual 'ping' receives the dialer -// and the address pair (:), required to be used, to ping the -// correctly resolved endpoint. -func MakeDialerJobs( - b *Builder, - scheme string, - endpoints []Endpoint, - mode monitors.IPSettings, - fn func(event *beat.Event, dialer transport.Dialer, addr string) error, -) ([]jobs.Job, error) { - var jobs []jobs.Job - for _, endpoint := range endpoints { - for _, port := range endpoint.Ports { - endpointURL, err := url.Parse(fmt.Sprintf("%s://%s:%d", scheme, endpoint.Host, port)) - if err != nil { - return nil, err - } - endpointJob, err := makeEndpointJob(b, endpointURL, mode, fn) - if err != nil { - return nil, err - } - jobs = append(jobs, wrappers.WithURLField(endpointURL, endpointJob)) - } - - } - - return jobs, nil -} - -func makeEndpointJob( - b *Builder, - endpointURL *url.URL, - mode monitors.IPSettings, - fn func(*beat.Event, transport.Dialer, string) error, -) (jobs.Job, error) { - - // Check if SOCKS5 is configured, with relying on the socks5 proxy - // in resolving the actual IP. - // Create one job for every port number configured. - if b.resolveViaSocks5 { - return wrappers.WithURLField(endpointURL, - jobs.MakeSimpleJob(func(event *beat.Event) error { - hostPort := net.JoinHostPort(endpointURL.Hostname(), endpointURL.Port()) - return b.Run(event, hostPort, func(event *beat.Event, dialer transport.Dialer) error { - return fn(event, dialer, hostPort) - }) - })), nil - } - - // Create job that first resolves one or multiple IP (depending on - // config.Mode) in order to create one continuation Task per IP. - settings := monitors.MakeHostJobSettings(endpointURL.Hostname(), mode) - - job, err := monitors.MakeByHostJob(settings, - monitors.MakePingIPFactory( - func(event *beat.Event, ip *net.IPAddr) error { - // use address from resolved IP - ipPort := net.JoinHostPort(ip.String(), endpointURL.Port()) - cb := func(event *beat.Event, dialer transport.Dialer) error { - return fn(event, dialer, ipPort) - } - err := b.Run(event, ipPort, cb) - return err - })) - if err != nil { - return nil, err - } - return job, nil -} diff --git a/heartbeat/monitors/active/dialchain/net.go b/heartbeat/monitors/active/dialchain/dialers.go similarity index 94% rename from heartbeat/monitors/active/dialchain/net.go rename to heartbeat/monitors/active/dialchain/dialers.go index 5fc41270fb7..ddd34870d8d 100644 --- a/heartbeat/monitors/active/dialchain/net.go +++ b/heartbeat/monitors/active/dialchain/dialers.go @@ -45,7 +45,7 @@ import ( // } // } func TCPDialer(to time.Duration) NetDialer { - return netDialer(to) + return CreateNetDialer(to) } // UDPDialer creates a new NetDialer with constant event fields and default @@ -62,10 +62,11 @@ func TCPDialer(to time.Duration) NetDialer { // } // } func UDPDialer(to time.Duration) NetDialer { - return netDialer(to) + return CreateNetDialer(to) } -func netDialer(timeout time.Duration) NetDialer { +// CreateNetDialer returns a NetDialer with the given timeout. +func CreateNetDialer(timeout time.Duration) NetDialer { return func(event *beat.Event) (transport.Dialer, error) { return makeDialer(func(network, address string) (net.Conn, error) { namespace := "" diff --git a/heartbeat/monitors/active/dialchain/util.go b/heartbeat/monitors/active/dialchain/util.go index 7329cf81be8..7b29ca80b39 100644 --- a/heartbeat/monitors/active/dialchain/util.go +++ b/heartbeat/monitors/active/dialchain/util.go @@ -29,15 +29,6 @@ type timer struct { s, e time.Time } -// IDLayer creates an empty placeholder layer. -func IDLayer() Layer { - return _idLayer -} - -var _idLayer = Layer(func(event *beat.Event, next transport.Dialer) (transport.Dialer, error) { - return next, nil -}) - // ConstAddrLayer introduces a network layer always passing a constant address // to the underlying layer. func ConstAddrLayer(address string) Layer { diff --git a/heartbeat/monitors/active/http/task.go b/heartbeat/monitors/active/http/task.go index e750faf6d12..65d2f1ae62c 100644 --- a/heartbeat/monitors/active/http/task.go +++ b/heartbeat/monitors/active/http/task.go @@ -94,10 +94,8 @@ func newHTTPMonitorIPsJob( return nil, err } - settings := monitors.MakeHostJobSettings(hostname, config.Mode) - pingFactory := createPingFactory(config, port, tls, req, body, validator) - job, err := monitors.MakeByHostJob(settings, pingFactory) + job, err := monitors.MakeByHostJob(hostname, config.Mode, monitors.NewStdResolver(), pingFactory) return job, err } diff --git a/heartbeat/monitors/active/icmp/icmp.go b/heartbeat/monitors/active/icmp/icmp.go index 66bc0b2adc6..1cb19c90798 100644 --- a/heartbeat/monitors/active/icmp/icmp.go +++ b/heartbeat/monitors/active/icmp/icmp.go @@ -71,8 +71,7 @@ func create( pingFactory := monitors.MakePingIPFactory(createPingIPFactory(&config)) for _, host := range config.Hosts { - settings := monitors.MakeHostJobSettings(host, config.Mode) - job, err := monitors.MakeByHostJob(settings, pingFactory) + job, err := monitors.MakeByHostJob(host, config.Mode, monitors.NewStdResolver(), pingFactory) if err != nil { return nil, 0, err diff --git a/heartbeat/monitors/active/tcp/config.go b/heartbeat/monitors/active/tcp/config.go index 45fa4abc6e0..692af5ba65f 100644 --- a/heartbeat/monitors/active/tcp/config.go +++ b/heartbeat/monitors/active/tcp/config.go @@ -26,7 +26,7 @@ import ( "github.com/elastic/beats/v7/libbeat/common/transport/tlscommon" ) -type Config struct { +type config struct { // check all ports if host does not contain port Hosts []string `config:"hosts" validate:"required"` Ports []uint16 `config:"ports"` @@ -45,12 +45,14 @@ type Config struct { ReceiveString string `config:"check.receive"` } -var DefaultConfig = Config{ - Timeout: 16 * time.Second, - Mode: monitors.DefaultIPSettings, +func defaultConfig() config { + return config{ + Timeout: 16 * time.Second, + Mode: monitors.DefaultIPSettings, + } } -func (c *Config) Validate() error { +func (c *config) Validate() error { if c.Socks5.URL != "" { if c.Mode.Mode != monitors.PingAny && !c.Socks5.LocalResolve { return errors.New("ping all ips only supported if proxy_use_local_resolver is enabled`") diff --git a/heartbeat/monitors/active/tcp/check.go b/heartbeat/monitors/active/tcp/datacheck.go similarity index 86% rename from heartbeat/monitors/active/tcp/check.go rename to heartbeat/monitors/active/tcp/datacheck.go index 0f8d631ce56..b465052f6f5 100644 --- a/heartbeat/monitors/active/tcp/check.go +++ b/heartbeat/monitors/active/tcp/datacheck.go @@ -23,18 +23,20 @@ import ( "net" ) -type ConnCheck func(net.Conn) error +// dataCheck executes over an open TCP connection using the send / receive +// parameters the user has defined. +type dataCheck func(net.Conn) error var ( errNoDataReceived = errors.New("no data") errRecvMismatch = errors.New("received string mismatch") ) -func (c ConnCheck) Validate(conn net.Conn) error { +func (c dataCheck) Check(conn net.Conn) error { return c(conn) } -func makeValidateConn(config *Config) ConnCheck { +func makeDataCheck(config *config) dataCheck { send := config.SendString recv := config.ReceiveString @@ -52,7 +54,7 @@ func makeValidateConn(config *Config) ConnCheck { func checkOk(_ net.Conn) error { return nil } -func checkAll(checks ...ConnCheck) ConnCheck { +func checkAll(checks ...dataCheck) dataCheck { return func(conn net.Conn) error { for _, check := range checks { if err := check(conn); err != nil { @@ -63,13 +65,13 @@ func checkAll(checks ...ConnCheck) ConnCheck { } } -func checkSend(buf []byte) ConnCheck { +func checkSend(buf []byte) dataCheck { return func(conn net.Conn) error { return sendBuffer(conn, buf) } } -func checkRecv(expected []byte) ConnCheck { +func checkRecv(expected []byte) dataCheck { return func(conn net.Conn) error { buf := make([]byte, len(expected)) if err := recvBuffer(conn, buf); err != nil { diff --git a/heartbeat/monitors/active/tcp/endpoint.go b/heartbeat/monitors/active/tcp/endpoint.go new file mode 100644 index 00000000000..e3c87531ffc --- /dev/null +++ b/heartbeat/monitors/active/tcp/endpoint.go @@ -0,0 +1,108 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "fmt" + "net" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// endpoint configures a host with all port numbers to be monitored by a dialer +// based job. +type endpoint struct { + Scheme string + Hostname string + Ports []uint16 +} + +// perPortURLs returns a list containing one URL per port +func (e endpoint) perPortURLs() (urls []*url.URL) { + for _, port := range e.Ports { + urls = append(urls, &url.URL{ + Scheme: e.Scheme, + Host: net.JoinHostPort(e.Hostname, strconv.Itoa(int(port))), + }) + } + + return urls +} + +// makeEndpoints creates a single endpoint struct for each host/port permutation. +// Set `defaultScheme` to choose which scheme is used if not explicit in the host config. +func makeEndpoints(hosts []string, ports []uint16, defaultScheme string) (endpoints []endpoint, err error) { + for _, h := range hosts { + u, err := url.Parse(h) + + // If h is just a bare hostname like 'localhost' it will be parsed as the URL path, and host will + // be blank + var ep endpoint + if err == nil && u.Host != "" { + ep, err = makeURLEndpoint(u, ports) + if err != nil { + return nil, err + } + } else { + u := &url.URL{Scheme: defaultScheme, Host: h} + ep, err = makeURLEndpoint(u, ports) + if err != nil { + return nil, err + } + } + endpoints = append(endpoints, ep) + } + return endpoints, nil +} + +func makeURLEndpoint(u *url.URL, ports []uint16) (endpoint, error) { + switch u.Scheme { + case "tcp", "plain", "tls", "ssl": + default: + err := fmt.Errorf( + "'%s' is not a supported connection scheme in '%s', supported schemes are tcp, plain, tls, and ssl", + u.Scheme, + u, + ) + return endpoint{}, err + } + + if u.Port() != "" { + pUint, err := strconv.ParseUint(u.Port(), 10, 16) + if err != nil { + return endpoint{}, errors.Wrapf(err, "no port(s) defined for TCP endpoint %s", u) + } + ports = []uint16{uint16(pUint)} + } + + if len(ports) == 0 { + return endpoint{}, fmt.Errorf("host '%s' missing port number", u) + } + + if u.Hostname() == "" || u.Hostname() == ":" { + return endpoint{}, fmt.Errorf("could not parse tcp host '%s'", u) + } + + return endpoint{ + Scheme: u.Scheme, + Hostname: u.Hostname(), + Ports: ports, + }, nil +} diff --git a/heartbeat/monitors/active/tcp/endpoint_test.go b/heartbeat/monitors/active/tcp/endpoint_test.go new file mode 100644 index 00000000000..5e871914b40 --- /dev/null +++ b/heartbeat/monitors/active/tcp/endpoint_test.go @@ -0,0 +1,109 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMakeEndpoints(t *testing.T) { + type args struct { + hosts []string + ports []uint16 + defaultScheme string + } + tests := []struct { + name string + args args + wantEndpoints []endpoint + wantErr bool + }{ + { + "hostname", + args{[]string{"localhost"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "localhost", Ports: []uint16{123}}}, + false, + }, + { + "ipv4", + args{[]string{"1.2.3.4"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "1.2.3.4", Ports: []uint16{123}}}, + false, + }, + { + "unbracketed ipv6", + args{[]string{"::1"}, []uint16{123}, "tcp"}, + []endpoint{}, + true, + }, + { + "bracketed ipv6", + args{[]string{"[::1]"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "::1", Ports: []uint16{123}}}, + false, + }, + { + "url", + args{[]string{"tls://example.net"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tls", Hostname: "example.net", Ports: []uint16{123}}}, + false, + }, + { + "url:port", + args{[]string{"example.net:999"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "example.net", Ports: []uint16{999}}}, + false, + }, + { + "scheme://url:port", + args{[]string{"tls://example.net:999"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tls", Hostname: "example.net", Ports: []uint16{999}}}, + false, + }, + { + "hybrid", + args{[]string{ + "localhost", + "192.168.0.1", + "[2607:f8b0:4004:814::200e]", + "example.net:999", + "tls://example.net:999", + }, []uint16{123}, "tcp"}, + []endpoint{ + {Scheme: "tcp", Hostname: "localhost", Ports: []uint16{123}}, + {Scheme: "tcp", Hostname: "192.168.0.1", Ports: []uint16{123}}, + {Scheme: "tcp", Hostname: "2607:f8b0:4004:814::200e", Ports: []uint16{123}}, + {Scheme: "tcp", Hostname: "example.net", Ports: []uint16{999}}, + {Scheme: "tls", Hostname: "example.net", Ports: []uint16{999}}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotEndpoints, err := makeEndpoints(tt.args.hosts, tt.args.ports, tt.args.defaultScheme) + if tt.wantErr { + require.Error(t, err) + return + } + require.Equal(t, tt.wantEndpoints, gotEndpoints) + }) + } +} diff --git a/heartbeat/monitors/active/tcp/helpers_test.go b/heartbeat/monitors/active/tcp/helpers_test.go new file mode 100644 index 00000000000..d1a8c1b5bc1 --- /dev/null +++ b/heartbeat/monitors/active/tcp/helpers_test.go @@ -0,0 +1,87 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/pkg/errors" + + "github.com/stretchr/testify/require" + + "github.com/elastic/beats/v7/heartbeat/hbtest" + "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" + "github.com/elastic/beats/v7/heartbeat/scheduler/schedule" + "github.com/elastic/beats/v7/libbeat/beat" + "github.com/elastic/beats/v7/libbeat/common" +) + +func testTCPConfigCheck(t *testing.T, configMap common.MapStr, host string, port uint16) *beat.Event { + config, err := common.NewConfigFrom(configMap) + require.NoError(t, err) + + jobs, endpoints, err := create("tcp", config) + require.NoError(t, err) + + sched, _ := schedule.Parse("@every 1s") + job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] + + event := &beat.Event{} + _, err = job(event) + require.NoError(t, err) + + require.Equal(t, 1, endpoints) + + return event +} + +func setupServer(t *testing.T, serverCreator func(http.Handler) (*httptest.Server, error)) (*httptest.Server, uint16, error) { + server, err := serverCreator(hbtest.HelloWorldHandler(200)) + if err != nil { + return nil, 0, err + } + + port, err := hbtest.ServerPort(server) + if err != nil { + return nil, 0, err + } + + return server, port, nil +} + +// newHostTestServer starts a server listening on the IP resolved from the host arg +// httptest.NewServer() binds explicitly on 127.0.0.1 (or [::1] if ipv4 is not available). +// The IP resolved from `localhost` can be a different one, like 127.0.1.1. +func newHostTestServer(handler http.Handler, host string) (*httptest.Server, error) { + listener, err := net.Listen("tcp", net.JoinHostPort(host, "0")) + if err != nil { + return nil, errors.Wrapf(err, "failed to listen on host '%s'", host) + } + + server := &httptest.Server{ + Listener: listener, + Config: &http.Server{Handler: handler}, + } + server.Start() + + return server, nil +} diff --git a/heartbeat/monitors/active/tcp/socks5_test.go b/heartbeat/monitors/active/tcp/socks5_test.go new file mode 100644 index 00000000000..47c691dbdea --- /dev/null +++ b/heartbeat/monitors/active/tcp/socks5_test.go @@ -0,0 +1,134 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "fmt" + "net" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/armon/go-socks5" + "github.com/stretchr/testify/require" + + "github.com/elastic/beats/v7/heartbeat/hbtest" + "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/go-lookslike" + "github.com/elastic/go-lookslike/isdef" + "github.com/elastic/go-lookslike/testslike" +) + +func TestSocks5Job(t *testing.T) { + scenarios := []struct { + name string + localResolver bool + }{ + { + name: "using local resolver", + localResolver: true, + }, + { + name: "not using local resolver", + localResolver: false, + }, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + host, port, ip, closeEcho, err := startEchoServer(t) + require.NoError(t, err) + defer closeEcho() + + _, proxyPort, proxyIp, closeProxy, err := startSocks5Server(t) + require.NoError(t, err) + defer closeProxy() + + proxyURL := &url.URL{Scheme: "socks5", Host: net.JoinHostPort(proxyIp, fmt.Sprint(proxyPort))} + configMap := common.MapStr{ + "hosts": host, + "ports": port, + "timeout": "1s", + "proxy_url": proxyURL.String(), + "proxy_use_local_resolver": scenario.localResolver, + "check.receive": "echo123", + "check.send": "echo123", + } + event := testTCPConfigCheck(t, configMap, host, port) + + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.BaseChecks(ip, "up", "tcp"), + hbtest.RespondingTCPChecks(), + hbtest.SimpleURLChecks(t, "tcp", host, port), + hbtest.SummaryChecks(1, 0), + hbtest.ResolveChecks(ip), + lookslike.MustCompile(map[string]interface{}{ + "tcp": map[string]interface{}{ + "rtt.validate.us": isdef.IsDuration, + }, + "socks5": map[string]interface{}{ + "rtt.connect.us": isdef.IsDuration, + }, + }), + )), + event.Fields, + ) + }) + } +} + +func startSocks5Server(t *testing.T) (host string, port uint16, ip string, close func() error, err error) { + host = "localhost" + config := &socks5.Config{} + server, err := socks5.New(config) + if err != nil { + return "", 0, "", nil, err + } + + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", 0, "", nil, err + } + ip, portStr, err := net.SplitHostPort(listener.Addr().String()) + portUint64, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + listener.Close() + return "", 0, "", nil, err + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + if err := server.Serve(listener); err != nil { + debugf("Error in SOCKS5 Test Server %v", err) + } + wg.Done() + }() + + return host, uint16(portUint64), ip, func() error { + err := listener.Close() + if err != nil { + return err + } + wg.Wait() + return nil + }, nil +} diff --git a/heartbeat/monitors/active/tcp/task.go b/heartbeat/monitors/active/tcp/task.go deleted file mode 100644 index 3d830c0c7e9..00000000000 --- a/heartbeat/monitors/active/tcp/task.go +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package tcp - -import ( - "time" - - "github.com/elastic/beats/v7/heartbeat/eventext" - "github.com/elastic/beats/v7/heartbeat/look" - "github.com/elastic/beats/v7/heartbeat/reason" - "github.com/elastic/beats/v7/libbeat/beat" - "github.com/elastic/beats/v7/libbeat/common" - "github.com/elastic/beats/v7/libbeat/common/transport" -) - -func pingHost( - event *beat.Event, - dialer transport.Dialer, - addr string, - timeout time.Duration, - validator ConnCheck, -) error { - start := time.Now() - deadline := start.Add(timeout) - - conn, err := dialer.Dial("tcp", addr) - if err != nil { - debugf("dial failed with: %v", err) - return reason.IOFailed(err) - } - defer conn.Close() - if validator == nil { - // no additional validation step => ping success - return nil - } - - if err := conn.SetDeadline(deadline); err != nil { - debugf("setting connection deadline failed with: %v", err) - return reason.IOFailed(err) - } - - validateStart := time.Now() - err = validator.Validate(conn) - if err != nil && err != errRecvMismatch { - debugf("check failed with: %v", err) - return reason.IOFailed(err) - } - - end := time.Now() - eventext.MergeEventFields(event, common.MapStr{ - "tcp": common.MapStr{ - "rtt": common.MapStr{ - "validate": look.RTT(end.Sub(validateStart)), - }, - }, - }) - if err != nil { - return reason.MakeValidateError(err) - } - - return nil -} diff --git a/heartbeat/monitors/active/tcp/tcp.go b/heartbeat/monitors/active/tcp/tcp.go index bf584095fc4..05c687dd65b 100644 --- a/heartbeat/monitors/active/tcp/tcp.go +++ b/heartbeat/monitors/active/tcp/tcp.go @@ -18,14 +18,18 @@ package tcp import ( - "fmt" + "net" "net/url" - "strconv" - "strings" + "time" + + "github.com/elastic/beats/v7/heartbeat/eventext" + "github.com/elastic/beats/v7/heartbeat/look" + "github.com/elastic/beats/v7/heartbeat/reason" "github.com/elastic/beats/v7/heartbeat/monitors" "github.com/elastic/beats/v7/heartbeat/monitors/active/dialchain" "github.com/elastic/beats/v7/heartbeat/monitors/jobs" + "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" "github.com/elastic/beats/v7/libbeat/common/transport" @@ -39,113 +43,220 @@ func init() { var debugf = logp.MakeDebug("tcp") -type connURL struct { - Scheme string - Host string - Ports []uint16 -} - func create( name string, cfg *common.Config, ) (jobs []jobs.Job, endpoints int, err error) { - config := DefaultConfig - if err := cfg.Unpack(&config); err != nil { + return createWithResolver(cfg, monitors.NewStdResolver()) +} + +// Custom resolver is useful for tests against hostnames locally where we don't want to depend on any +// hostnames existing in test environments +func createWithResolver( + cfg *common.Config, + resolver monitors.Resolver, +) (jobs []jobs.Job, endpoints int, err error) { + jc, err := newJobFactory(cfg, resolver) + if err != nil { return nil, 0, err } - tls, err := tlscommon.LoadTLSConfig(config.TLS) + jobs, err = jc.makeJobs() if err != nil { return nil, 0, err } - defaultScheme := "tcp" - if tls != nil { - defaultScheme = "ssl" + return jobs, len(jc.endpoints), nil +} + +// jobFactory is where most of the logic here lives. It provides a common context around +// the complex logic of executing a TCP check. +type jobFactory struct { + config config + tlsConfig *tlscommon.TLSConfig + defaultScheme string + endpoints []endpoint + dataCheck dataCheck + resolver monitors.Resolver +} + +func newJobFactory(commonCfg *common.Config, resolver monitors.Resolver) (*jobFactory, error) { + jf := &jobFactory{config: defaultConfig(), resolver: resolver} + err := jf.loadConfig(commonCfg) + if err != nil { + return nil, err + } + + return jf, nil +} + +// loadConfig parses the YAML config and populates the jobFactory fields. +func (jf *jobFactory) loadConfig(commonCfg *common.Config) error { + var err error + if err = commonCfg.Unpack(&jf.config); err != nil { + return err } - schemeHosts, err := collectHosts(&config, defaultScheme) + jf.tlsConfig, err = tlscommon.LoadTLSConfig(jf.config.TLS) if err != nil { - return nil, 0, err + return err } - timeout := config.Timeout - validator := makeValidateConn(&config) + jf.defaultScheme = "tcp" + if jf.tlsConfig != nil { + jf.defaultScheme = "ssl" + } - for scheme, eps := range schemeHosts { - schemeTLS := tls - if scheme == "tcp" || scheme == "plain" { - schemeTLS = nil - } + jf.endpoints, err = makeEndpoints(jf.config.Hosts, jf.config.Ports, jf.defaultScheme) + if err != nil { + return err + } - db, err := dialchain.NewBuilder(dialchain.BuilderSettings{ - Timeout: timeout, - Socks5: config.Socks5, - TLS: schemeTLS, - }) - if err != nil { - return nil, 0, err - } + jf.dataCheck = makeDataCheck(&jf.config) - epJobs, err := dialchain.MakeDialerJobs(db, scheme, eps, config.Mode, - func(event *beat.Event, dialer transport.Dialer, addr string) error { - return pingHost(event, dialer, addr, timeout, validator) - }) - if err != nil { - return nil, 0, err + return nil +} + +// makeJobs returns the actual schedulable jobs for this monitor. +func (jf *jobFactory) makeJobs() ([]jobs.Job, error) { + var jobs []jobs.Job + for _, endpoint := range jf.endpoints { + for _, url := range endpoint.perPortURLs() { + endpointJob, err := jf.makeEndpointJob(url) + if err != nil { + return nil, err + } + jobs = append(jobs, wrappers.WithURLField(url, endpointJob)) } - jobs = append(jobs, epJobs...) } - numHosts := 0 - for _, hosts := range schemeHosts { - numHosts += len(hosts) + return jobs, nil +} + +// makeEndpointJob makes a job for a single check of a single scheme/host/port combo. +func (jf *jobFactory) makeEndpointJob(endpointURL *url.URL) (jobs.Job, error) { + // Check if SOCKS5 is configured, with relying on the socks5 proxy + // in resolving the actual IP. + // Create one job for every port number configured. + if jf.config.Socks5.URL != "" && !jf.config.Socks5.LocalResolve { + jf.makeSocksLookupEndpointJob(endpointURL) } - return jobs, numHosts, nil + return jf.makeDirectEndpointJob(endpointURL) } -func collectHosts(config *Config, defaultScheme string) (map[string][]dialchain.Endpoint, error) { - endpoints := map[string][]dialchain.Endpoint{} - for _, h := range config.Hosts { - scheme := defaultScheme - host := "" - u, err := url.Parse(h) - - if err != nil || u.Host == "" { - host = h - } else { - scheme = u.Scheme - host = u.Host - } - debugf("Add tcp endpoint '%v://%v'.", scheme, host) +// makeDirectEndpointJob makes jobs that directly lookup the IP of the endpoints, as opposed to using +// a Socks5 proxy. +func (jf *jobFactory) makeDirectEndpointJob(endpointURL *url.URL) (jobs.Job, error) { + // Create job that first resolves one or multiple IPs (depending on + // config.Mode) in order to create one continuation Task per IP. + job, err := monitors.MakeByHostJob( + endpointURL.Hostname(), + jf.config.Mode, + jf.resolver, + monitors.MakePingIPFactory( + func(event *beat.Event, ip *net.IPAddr) error { + // use address from resolved IP + ipPort := net.JoinHostPort(ip.String(), endpointURL.Port()) - switch scheme { - case "tcp", "plain", "tls", "ssl": - default: - err := fmt.Errorf("'%v' is no supported connection scheme in '%v'", scheme, h) - return nil, err - } + return jf.dial(event, ipPort, endpointURL) + })) + if err != nil { + return nil, err + } + return job, nil +} - pair := strings.SplitN(host, ":", 2) - ports := config.Ports - if len(pair) == 2 { - port, err := strconv.ParseUint(pair[1], 10, 16) - if err != nil { - return nil, fmt.Errorf("'%v' is no valid port number in '%v'", pair[1], h) - } +// makeSocksLookupEndpointJob makes jobs that use a Socks5 proxy to perform DNS lookups +func (jf *jobFactory) makeSocksLookupEndpointJob(endpointURL *url.URL) (jobs.Job, error) { + return wrappers.WithURLField(endpointURL, + jobs.MakeSimpleJob(func(event *beat.Event) error { + hostPort := net.JoinHostPort(endpointURL.Hostname(), endpointURL.Port()) + return jf.dial(event, hostPort, endpointURL) + })), nil +} - ports = []uint16{uint16(port)} - host = pair[0] - } else if len(config.Ports) == 0 { - return nil, fmt.Errorf("host '%v' missing port number", h) - } +// dial builds a dialer and executes the network request. +// dialAddr is the host:port that the dialer will connect to, and where an explicit IP should go to. +// canonicalURL is the URL used to determine if TLS is used via the scheme of the URL, and +// also which hostname should be passed to the TLS implementation for validation of the server cert. +func (jf *jobFactory) dial(event *beat.Event, dialAddr string, canonicalURL *url.URL) error { + // First, create a plain dialer that can connect directly to either hostnames or IPs + dc := &dialchain.DialerChain{ + Net: dialchain.CreateNetDialer(jf.config.Timeout), + } + + // If Socks5 is configured make that the next layer, since everything needs to go through the proxy first. + if jf.config.Socks5.URL != "" { + dc.AddLayer(dialchain.SOCKS5Layer(&jf.config.Socks5)) + } + + // Now add the IP or Hostname of the server we want to connect to. + // Usually this is the IP we've resolved in a prior step. + // If we're using a proxy with host lookup enabled the dialAddr should be the + // hostname we want the server to resolve for us. + dc.AddLayer(dialchain.ConstAddrLayer(dialAddr)) + + // If we're using TLS we need to add a fake layer so that the TLS layer knows the hostname we're connecting to + // So, the canonical URL is fixed via a ConstAddrLayer to override the TLS layer's x509 logic so it doesn't + // try and directly match the IP from the prior ConstAddrLayer to the cert. + if canonicalURL.Scheme != "tcp" && canonicalURL.Scheme != "plain" { + dc.AddLayer(dialchain.TLSLayer(jf.tlsConfig, jf.config.Timeout)) + dc.AddLayer(dialchain.ConstAddrLayer(canonicalURL.Host)) + } + + dialer, err := dc.Build(event) + if err != nil { + return err + } + + return jf.execDialer(event, dialer, dialAddr) +} - endpoints[scheme] = append(endpoints[scheme], dialchain.Endpoint{ - Host: host, - Ports: ports, - }) +// exec dialer executes a network request against the given dialer. +func (jf *jobFactory) execDialer( + event *beat.Event, + dialer transport.Dialer, + addr string, +) error { + start := time.Now() + deadline := start.Add(jf.config.Timeout) + + conn, err := dialer.Dial("tcp", addr) + if err != nil { + debugf("dial failed with: %v", err) + return reason.IOFailed(err) } - return endpoints, nil + defer conn.Close() + if jf.dataCheck == nil { + // no additional validation step => ping success + return nil + } + + if err := conn.SetDeadline(deadline); err != nil { + debugf("setting connection deadline failed with: %v", err) + return reason.IOFailed(err) + } + + validateStart := time.Now() + err = jf.dataCheck.Check(conn) + if err != nil && err != errRecvMismatch { + debugf("check failed with: %v", err) + return reason.IOFailed(err) + } + + end := time.Now() + eventext.MergeEventFields(event, common.MapStr{ + "tcp": common.MapStr{ + "rtt": common.MapStr{ + "validate": look.RTT(end.Sub(validateStart)), + }, + }, + }) + if err != nil { + return reason.MakeValidateError(err) + } + + return nil } diff --git a/heartbeat/monitors/active/tcp/tcp_test.go b/heartbeat/monitors/active/tcp/tcp_test.go index 7980510f3f8..ef4746c5e3e 100644 --- a/heartbeat/monitors/active/tcp/tcp_test.go +++ b/heartbeat/monitors/active/tcp/tcp_test.go @@ -18,28 +18,24 @@ package tcp import ( - "crypto/x509" "fmt" "net" "net/http" "net/http/httptest" "net/url" - "os" "strconv" "testing" - "time" "github.com/stretchr/testify/require" "github.com/elastic/beats/v7/heartbeat/hbtest" - "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" - "github.com/elastic/beats/v7/heartbeat/scheduler/schedule" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" btesting "github.com/elastic/beats/v7/libbeat/testing" "github.com/elastic/go-lookslike" "github.com/elastic/go-lookslike/isdef" "github.com/elastic/go-lookslike/testslike" + "github.com/elastic/go-lookslike/validator" ) func testTCPCheck(t *testing.T, host string, port uint16) *beat.Event { @@ -51,139 +47,73 @@ func testTCPCheck(t *testing.T, host string, port uint16) *beat.Event { return testTCPConfigCheck(t, config, host, port) } -func testTCPConfigCheck(t *testing.T, configMap common.MapStr, host string, port uint16) *beat.Event { - config, err := common.NewConfigFrom(configMap) - require.NoError(t, err) - - jobs, endpoints, err := create("tcp", config) - require.NoError(t, err) - - sched, _ := schedule.Parse("@every 1s") - job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] - - event := &beat.Event{} - _, err = job(event) - require.NoError(t, err) - - require.Equal(t, 1, endpoints) - - return event -} - -func testTLSTCPCheck(t *testing.T, host string, port uint16, certFileName string) *beat.Event { - config, err := common.NewConfigFrom(common.MapStr{ - "hosts": host, - "ports": int64(port), - "ssl": common.MapStr{"certificate_authorities": certFileName}, - "timeout": "1s", - }) - require.NoError(t, err) - - jobs, endpoints, err := create("tcp", config) - require.NoError(t, err) - - sched, _ := schedule.Parse("@every 1s") - job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] - - event := &beat.Event{} - _, err = job(event) - require.NoError(t, err) - - require.Equal(t, 1, endpoints) - - return event -} - -func setupServer(t *testing.T, serverCreator func(http.Handler) *httptest.Server) (*httptest.Server, uint16) { - server := serverCreator(hbtest.HelloWorldHandler(200)) - - port, err := hbtest.ServerPort(server) - require.NoError(t, err) - - return server, port -} - -// newLocalhostTestServer starts a server listening on the IP resolved from `localhost` -// httptest.NewServer() binds explicitly on 127.0.0.1 (or [::1] if ipv4 is not available). -// The IP resolved from `localhost` can be a different one, like 127.0.1.1. -func newLocalhostTestServer(handler http.Handler) *httptest.Server { - listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - panic("failed to listen on localhost: " + err.Error()) +// TestUpEndpointJob tests an up endpoint configured using either direct lookups or IPs +func TestUpEndpointJob(t *testing.T) { + // Test with domain, IPv4 and IPv6 + scenarios := []struct { + name string + hostname string + isIP bool + expectedIP string + }{ + { + name: "localhost", + hostname: "localhost", + isIP: false, + expectedIP: "127.0.0.1", + }, + { + name: "ipv4", + hostname: "127.0.0.1", + isIP: true, + expectedIP: "127.0.0.1", + }, + { + name: "ipv6", + hostname: "::1", + isIP: true, + }, } - server := &httptest.Server{ - Listener: listener, - Config: &http.Server{Handler: handler}, + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + server, port, err := setupServer(t, func(handler http.Handler) (*httptest.Server, error) { + return newHostTestServer(handler, scenario.hostname) + }) + // Some machines don't have ipv6 setup correctly, so we ignore the test + // if we can't bind to the port / setup the server. + if err != nil && scenario.hostname == "::1" { + return + } + require.NoError(t, err) + + defer server.Close() + + hostURL := &url.URL{Scheme: "tcp", Host: net.JoinHostPort(scenario.hostname, strconv.Itoa(int(port)))} + + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + event := testTCPCheck(t, hostURL.String(), port) + + validators := []validator.Validator{ + hbtest.BaseChecks(serverURL.Hostname(), "up", "tcp"), + hbtest.SummaryChecks(1, 0), + hbtest.URLChecks(t, hostURL), + hbtest.RespondingTCPChecks(), + } + + if !scenario.isIP { + validators = append(validators, hbtest.ResolveChecks(scenario.expectedIP)) + } + + testslike.Test( + t, + lookslike.Strict(lookslike.Compose(validators...)), + event.Fields, + ) + }) } - server.Start() - - return server -} - -func TestUpEndpointJob(t *testing.T) { - server, port := setupServer(t, newLocalhostTestServer) - defer server.Close() - - serverURL, err := url.Parse(server.URL) - require.NoError(t, err) - - event := testTCPCheck(t, "localhost", port) - - testslike.Test( - t, - lookslike.Strict(lookslike.Compose( - hbtest.BaseChecks(serverURL.Hostname(), "up", "tcp"), - hbtest.SummaryChecks(1, 0), - hbtest.SimpleURLChecks(t, "tcp", "localhost", port), - hbtest.RespondingTCPChecks(), - lookslike.MustCompile(map[string]interface{}{ - "resolve": map[string]interface{}{ - "ip": serverURL.Hostname(), - "rtt.us": isdef.IsDuration, - }, - }), - )), - event.Fields, - ) -} - -func TestTLSConnection(t *testing.T) { - // Start up a TLS Server - server, port := setupServer(t, httptest.NewTLSServer) - defer server.Close() - - // Parse its URL - serverURL, err := url.Parse(server.URL) - require.NoError(t, err) - - // Determine the IP address the server's hostname resolves to - ips, err := net.LookupHost(serverURL.Hostname()) - require.NoError(t, err) - require.Len(t, ips, 1) - ip := ips[0] - - // Parse the cert so we can test against it - cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0]) - require.NoError(t, err) - - // Save the server's cert to a file so heartbeat can use it - certFile := hbtest.CertToTempFile(t, cert) - require.NoError(t, certFile.Close()) - defer os.Remove(certFile.Name()) - - event := testTLSTCPCheck(t, ip, port, certFile.Name()) - testslike.Test( - t, - lookslike.Strict(lookslike.Compose( - hbtest.TLSChecks(0, 0, cert), - hbtest.RespondingTCPChecks(), - hbtest.BaseChecks(ip, "up", "tcp"), - hbtest.SummaryChecks(1, 0), - hbtest.SimpleURLChecks(t, "ssl", serverURL.Hostname(), port), - )), - event.Fields, - ) } func TestConnectionRefusedEndpointJob(t *testing.T) { @@ -246,11 +176,8 @@ func TestCheckUp(t *testing.T) { hbtest.RespondingTCPChecks(), hbtest.SimpleURLChecks(t, "tcp", host, port), hbtest.SummaryChecks(1, 0), + hbtest.ResolveChecks(ip), lookslike.MustCompile(map[string]interface{}{ - "resolve": map[string]interface{}{ - "ip": ip, - "rtt.us": isdef.IsDuration, - }, "tcp": map[string]interface{}{ "rtt.validate.us": isdef.IsDuration, }, @@ -282,11 +209,8 @@ func TestCheckDown(t *testing.T) { hbtest.RespondingTCPChecks(), hbtest.SimpleURLChecks(t, "tcp", host, port), hbtest.SummaryChecks(0, 1), + hbtest.ResolveChecks(ip), lookslike.MustCompile(map[string]interface{}{ - "resolve": map[string]interface{}{ - "ip": ip, - "rtt.us": isdef.IsDuration, - }, "tcp": map[string]interface{}{ "rtt.validate.us": isdef.IsDuration, }, @@ -346,3 +270,36 @@ func startEchoServer(t *testing.T) (host string, port uint16, ip string, close f return "localhost", uint16(portUint64), ip, listener.Close, nil } + +// StaticResolver allows for a custom in-memory mapping of hosts to IPs, it ignores network names +// and zones. +type StaticResolver struct { + mapping map[string][]net.IP +} + +func NewStaticResolver(mapping map[string][]net.IP) StaticResolver { + return StaticResolver{mapping} +} + +func (s StaticResolver) ResolveIPAddr(network string, host string) (*net.IPAddr, error) { + found, err := s.LookupIP(host) + if err != nil { + return nil, err + } + return &net.IPAddr{IP: found[0]}, nil +} + +func (s StaticResolver) LookupIP(host string) ([]net.IP, error) { + if found, ok := s.mapping[host]; ok { + return found, nil + } else { + return nil, makeStaticNXDomainErr(host) + } +} + +func makeStaticNXDomainErr(host string) *net.DNSError { + return &net.DNSError{ + IsNotFound: true, + Err: fmt.Sprintf("Hostname '%s' not found in static resolver", host), + } +} diff --git a/heartbeat/monitors/active/tcp/tls_test.go b/heartbeat/monitors/active/tcp/tls_test.go new file mode 100644 index 00000000000..0628b1694b4 --- /dev/null +++ b/heartbeat/monitors/active/tcp/tls_test.go @@ -0,0 +1,169 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "crypto/x509" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" + "github.com/elastic/beats/v7/heartbeat/scheduler/schedule" + "github.com/elastic/beats/v7/libbeat/beat" + "github.com/elastic/beats/v7/libbeat/common" + + "github.com/stretchr/testify/require" + + "github.com/elastic/beats/v7/heartbeat/hbtest" + "github.com/elastic/beats/v7/heartbeat/monitors" + "github.com/elastic/go-lookslike" + "github.com/elastic/go-lookslike/testslike" +) + +// Tests that we can check a TLS connection with a cert for a SAN IP +func TestTLSSANIPConnection(t *testing.T) { + ip, port, cert, certFile, teardown := setupTLSTestServer(t) + defer teardown() + + event := testTLSTCPCheck(t, ip, port, certFile.Name(), monitors.NewStdResolver()) + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.TLSChecks(0, 0, cert), + hbtest.RespondingTCPChecks(), + hbtest.BaseChecks(ip, "up", "tcp"), + hbtest.SummaryChecks(1, 0), + hbtest.SimpleURLChecks(t, "ssl", ip, port), + )), + event.Fields, + ) +} + +func TestTLSHostname(t *testing.T) { + ip, port, cert, certFile, teardown := setupTLSTestServer(t) + defer teardown() + + hostname := cert.DNSNames[0] // Should be example.com + resolver := NewStaticResolver(map[string][]net.IP{hostname: []net.IP{net.ParseIP(ip)}}) + event := testTLSTCPCheck(t, hostname, port, certFile.Name(), resolver) + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.TLSChecks(0, 0, cert), + hbtest.RespondingTCPChecks(), + hbtest.BaseChecks(ip, "up", "tcp"), + hbtest.SummaryChecks(1, 0), + hbtest.SimpleURLChecks(t, "ssl", hostname, port), + hbtest.ResolveChecks(ip), + )), + event.Fields, + ) +} + +func TestTLSInvalidCert(t *testing.T) { + ip, port, cert, certFile, teardown := setupTLSTestServer(t) + defer teardown() + + mismatchedHostname := "notadomain.elastic.co" + resolver := NewStaticResolver( + map[string][]net.IP{ + cert.DNSNames[0]: {net.ParseIP(ip)}, + mismatchedHostname: {net.ParseIP(ip)}, + }, + ) + event := testTLSTCPCheck(t, mismatchedHostname, port, certFile.Name(), resolver) + + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.RespondingTCPChecks(), + hbtest.BaseChecks(ip, "down", "tcp"), + hbtest.SummaryChecks(0, 1), + hbtest.SimpleURLChecks(t, "ssl", mismatchedHostname, port), + hbtest.ResolveChecks(ip), + lookslike.MustCompile(map[string]interface{}{ + "error": map[string]interface{}{ + "message": x509.HostnameError{Certificate: cert, Host: mismatchedHostname}.Error(), + "type": "io", + }, + }), + )), + event.Fields, + ) +} + +func setupTLSTestServer(t *testing.T) (ip string, port uint16, cert *x509.Certificate, certFile *os.File, teardown func()) { + // Start up a TLS Server + server, port, err := setupServer(t, func(handler http.Handler) (*httptest.Server, error) { + return httptest.NewTLSServer(handler), nil + }) + require.NoError(t, err) + + // Parse its URL + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + // Determine the IP address the server's hostname resolves to + ips, err := net.LookupHost(serverURL.Hostname()) + require.NoError(t, err) + require.Len(t, ips, 1) + ip = ips[0] + + // Parse the cert so we can test against it + cert, err = x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0]) + require.NoError(t, err) + + // Save the server's cert to a file so heartbeat can use it + certFile = hbtest.CertToTempFile(t, cert) + require.NoError(t, certFile.Close()) + + return ip, port, cert, certFile, func() { + defer server.Close() + err := os.Remove(certFile.Name()) + require.NoError(t, err) + } +} + +func testTLSTCPCheck(t *testing.T, host string, port uint16, certFileName string, resolver monitors.Resolver) *beat.Event { + config, err := common.NewConfigFrom(common.MapStr{ + "hosts": host, + "ports": int64(port), + "ssl": common.MapStr{"certificate_authorities": certFileName}, + "timeout": "1s", + }) + require.NoError(t, err) + + jobs, endpoints, err := createWithResolver(config, resolver) + require.NoError(t, err) + + sched, _ := schedule.Parse("@every 1s") + job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] + + event := &beat.Event{} + _, err = job(event) + require.NoError(t, err) + + require.Equal(t, 1, endpoints) + + return event +} diff --git a/heartbeat/monitors/resolver.go b/heartbeat/monitors/resolver.go new file mode 100644 index 00000000000..23b96868680 --- /dev/null +++ b/heartbeat/monitors/resolver.go @@ -0,0 +1,46 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package monitors + +import ( + "net" +) + +// Resolver lets us define custom DNS resolvers similar to what the go stdlib provides, but +// potentially with custom functionality +type Resolver interface { + // ResolveIPAddr is an analog of net.ResolveIPAddr + ResolveIPAddr(network string, host string) (*net.IPAddr, error) + // LookupIP is an analog of net.LookupIP + LookupIP(host string) ([]net.IP, error) +} + +// StdResolver uses the go std library to perform DNS resolution. +type StdResolver struct{} + +func NewStdResolver() StdResolver { + return StdResolver{} +} + +func (s StdResolver) ResolveIPAddr(network string, host string) (*net.IPAddr, error) { + return net.ResolveIPAddr(network, host) +} + +func (s StdResolver) LookupIP(host string) ([]net.IP, error) { + return net.LookupIP(host) +} diff --git a/heartbeat/monitors/util.go b/heartbeat/monitors/util.go index 8cf2679c0d8..31191e85e63 100644 --- a/heartbeat/monitors/util.go +++ b/heartbeat/monitors/util.go @@ -39,14 +39,6 @@ type IPSettings struct { Mode PingMode `config:"mode"` } -// HostJobSettings configures a Job including Host lookups and global fields to be added -// to every event. -type HostJobSettings struct { - Host string - IP IPSettings - Fields common.MapStr -} - // PingMode enumeration for configuring `any` or `all` IPs pinging. type PingMode uint8 @@ -111,7 +103,7 @@ func MakeByIPJob( pingFactory func(ip *net.IPAddr) jobs.Job, ) (jobs.Job, error) { // use ResolveIPAddr to parse the ip into net.IPAddr adding a zone info - // if ipv6 is used. + // if ipv6 is used. We intentionally do not use a custom resolver here. addr, err := net.ResolveIPAddr("ip", ip.String()) if err != nil { return nil, err @@ -130,39 +122,40 @@ func MakeByIPJob( // A pingFactory instance is normally build with MakePingIPFactory, // MakePingAllIPFactory or MakePingAllIPPortFactory. func MakeByHostJob( - settings HostJobSettings, + host string, + ipSettings IPSettings, + resolver Resolver, pingFactory func(ip *net.IPAddr) jobs.Job, ) (jobs.Job, error) { - host := settings.Host - if ip := net.ParseIP(host); ip != nil { return MakeByIPJob(ip, pingFactory) } - network := settings.IP.Network() + network := ipSettings.Network() if network == "" { return nil, errors.New("pinging hosts requires ipv4 or ipv6 mode enabled") } - mode := settings.IP.Mode + mode := ipSettings.Mode if mode == PingAny { - return makeByHostAnyIPJob(settings, host, pingFactory), nil + return makeByHostAnyIPJob(host, ipSettings, resolver, pingFactory), nil } - return makeByHostAllIPJob(settings, host, pingFactory), nil + return makeByHostAllIPJob(host, ipSettings, resolver, pingFactory), nil } func makeByHostAnyIPJob( - settings HostJobSettings, host string, + ipSettings IPSettings, + resolver Resolver, pingFactory func(ip *net.IPAddr) jobs.Job, ) jobs.Job { - network := settings.IP.Network() + network := ipSettings.Network() return func(event *beat.Event) ([]jobs.Job, error) { resolveStart := time.Now() - ip, err := net.ResolveIPAddr(network, host) + ip, err := resolver.ResolveIPAddr(network, host) if err != nil { return nil, err } @@ -176,11 +169,12 @@ func makeByHostAnyIPJob( } func makeByHostAllIPJob( - settings HostJobSettings, host string, + ipSettings IPSettings, + resolver Resolver, pingFactory func(ip *net.IPAddr) jobs.Job, ) jobs.Job { - network := settings.IP.Network() + network := ipSettings.Network() filter := makeIPFilter(network) return func(event *beat.Event) ([]jobs.Job, error) { @@ -265,34 +259,3 @@ func filterIPs(ips []net.IP, filt func(net.IP) bool) []net.IP { } return out } - -// MakeHostJobSettings creates a new HostJobSettings structure without any global -// event fields. -func MakeHostJobSettings(host string, ip IPSettings) HostJobSettings { - return HostJobSettings{Host: host, IP: ip} -} - -// WithFields adds new event fields to a Job. Existing fields will be -// overwritten. -// The fields map will be updated (no copy). -func (s HostJobSettings) WithFields(m common.MapStr) HostJobSettings { - s.AddFields(m) - return s -} - -// AddFields adds new event fields to a Job. Existing fields will be -// overwritten. -func (s *HostJobSettings) AddFields(m common.MapStr) { addFields(&s.Fields, m) } - -func addFields(to *common.MapStr, m common.MapStr) { - if m == nil { - return - } - - fields := *to - if fields == nil { - fields = common.MapStr{} - *to = fields - } - fields.DeepUpdate(m) -}