From dfe8c4bf3a1551639d6ba58788d2cd17160eeb6e Mon Sep 17 00:00:00 2001 From: Andrew Cholakian Date: Mon, 20 Apr 2020 16:17:20 -0500 Subject: [PATCH] [Heartbeat] Refactor TCP Monitor (#17549) Refactors the TCP monitor to make the code easier to follow, more testable, and fixes #17123 where TLS server name was not correctly sent. This is important because the code had accrued a lot of cruft and become very hard to follow. There were many wrappers and intermediate variable names that often subtly changed names as they crossed various functions. When debugging #17123 I frequently found myself lost tracing the execution. This new code should be simpler to understand for a few reasons: Less code (almost a 2x reduction) and fewer, simpler, better organized functions Less variable passing/renaming due to use of struct for key config variables More consistent and descriptive variable names Creation of the dialer as late as possible, to remove the confusing partial states, and clarity as to when which dialer layers are used. Adds (frustratingly tricky) integration tests for #17123 using mismatched TLS certs, and also against a real SOCKS5 proxy Adds, for testing only, the ability to override the real network resolver for TCP checks which is invaluable in debugging TLS checks that depend on setting hostnames correctly. In the future if we decide to let users use a custom DNS resolver this will be nice. Reorganized giant TCP test file into multiple files --- CHANGELOG.next.asciidoc | 1 + heartbeat/hbtest/hbtestutil.go | 16 +- .../monitors/active/dialchain/builder.go | 198 ------------- .../active/dialchain/{net.go => dialers.go} | 7 +- heartbeat/monitors/active/dialchain/util.go | 9 - heartbeat/monitors/active/http/task.go | 4 +- heartbeat/monitors/active/icmp/icmp.go | 3 +- heartbeat/monitors/active/tcp/config.go | 12 +- .../active/tcp/{check.go => datacheck.go} | 14 +- heartbeat/monitors/active/tcp/endpoint.go | 108 +++++++ .../monitors/active/tcp/endpoint_test.go | 109 +++++++ heartbeat/monitors/active/tcp/helpers_test.go | 87 ++++++ heartbeat/monitors/active/tcp/socks5_test.go | 134 +++++++++ heartbeat/monitors/active/tcp/task.go | 77 ----- heartbeat/monitors/active/tcp/tcp.go | 271 ++++++++++++------ heartbeat/monitors/active/tcp/tcp_test.go | 243 +++++++--------- heartbeat/monitors/active/tcp/tls_test.go | 169 +++++++++++ heartbeat/monitors/resolver.go | 46 +++ heartbeat/monitors/util.go | 67 +---- 19 files changed, 996 insertions(+), 579 deletions(-) delete mode 100644 heartbeat/monitors/active/dialchain/builder.go rename heartbeat/monitors/active/dialchain/{net.go => dialers.go} (94%) rename heartbeat/monitors/active/tcp/{check.go => datacheck.go} (86%) create mode 100644 heartbeat/monitors/active/tcp/endpoint.go create mode 100644 heartbeat/monitors/active/tcp/endpoint_test.go create mode 100644 heartbeat/monitors/active/tcp/helpers_test.go create mode 100644 heartbeat/monitors/active/tcp/socks5_test.go delete mode 100644 heartbeat/monitors/active/tcp/task.go create mode 100644 heartbeat/monitors/active/tcp/tls_test.go create mode 100644 heartbeat/monitors/resolver.go diff --git a/CHANGELOG.next.asciidoc b/CHANGELOG.next.asciidoc index ba5899343e3..2010bf3735b 100644 --- a/CHANGELOG.next.asciidoc +++ b/CHANGELOG.next.asciidoc @@ -115,6 +115,7 @@ https://github.com/elastic/beats/compare/v7.0.0-alpha2...master[Check the HEAD d *Heartbeat* - Fixed excessive memory usage introduced in 7.5 due to over-allocating memory for HTTP checks. {pull}15639[15639] +- Fixed TCP TLS checks to properly validate hostnames, this broke in 7.x and only worked for IP SANs. {pull}17549[17549] *Journalbeat* diff --git a/heartbeat/hbtest/hbtestutil.go b/heartbeat/hbtest/hbtestutil.go index 9bec84ee0b5..548bc42a0eb 100644 --- a/heartbeat/hbtest/hbtestutil.go +++ b/heartbeat/hbtest/hbtestutil.go @@ -153,10 +153,19 @@ func SummaryChecks(up int, down int) validator.Validator { }) } +// ResolveChecks returns a lookslike matcher for the 'resolve' fields. +func ResolveChecks(ip string) validator.Validator { + return lookslike.MustCompile(map[string]interface{}{ + "resolve": map[string]interface{}{ + "ip": ip, + "rtt.us": isdef.IsDuration, + }, + }) +} + // SimpleURLChecks returns a check for a simple URL // with only a scheme, host, and port func SimpleURLChecks(t *testing.T, scheme string, host string, port uint16) validator.Validator { - hostPort := host if port != 0 { hostPort = fmt.Sprintf("%s:%d", host, port) @@ -165,6 +174,11 @@ func SimpleURLChecks(t *testing.T, scheme string, host string, port uint16) vali u, err := url.Parse(fmt.Sprintf("%s://%s", scheme, hostPort)) require.NoError(t, err) + return URLChecks(t, u) +} + +// URLChecks returns a validator for the given URL's fields +func URLChecks(t *testing.T, u *url.URL) validator.Validator { return lookslike.MustCompile(map[string]interface{}{ "url": wrappers.URLFields(u), }) diff --git a/heartbeat/monitors/active/dialchain/builder.go b/heartbeat/monitors/active/dialchain/builder.go deleted file mode 100644 index 1ba9671ad59..00000000000 --- a/heartbeat/monitors/active/dialchain/builder.go +++ /dev/null @@ -1,198 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package dialchain - -import ( - "fmt" - "net" - "net/url" - "time" - - "github.com/elastic/beats/v7/heartbeat/monitors" - "github.com/elastic/beats/v7/heartbeat/monitors/jobs" - "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" - "github.com/elastic/beats/v7/libbeat/beat" - "github.com/elastic/beats/v7/libbeat/common/transport" - "github.com/elastic/beats/v7/libbeat/common/transport/tlscommon" -) - -// Builder maintains a DialerChain for building dialers and dialer based -// monitoring jobs. -// The builder ensures a constant address is being used, for any host -// configured. This ensures the upper network layers (e.g. TLS) correctly see -// and process the original hostname. -type Builder struct { - template *DialerChain - addrIndex int - resolveViaSocks5 bool -} - -// BuilderSettings configures the layers of the dialer chain to be constructed -// by a Builder. -type BuilderSettings struct { - Timeout time.Duration - Socks5 transport.ProxyConfig - TLS *tlscommon.TLSConfig -} - -// Endpoint configures a host with all port numbers to be monitored by a dialer -// based job. -type Endpoint struct { - Host string - Ports []uint16 -} - -// NewBuilder creates a new Builder for constructing dialers. -func NewBuilder(settings BuilderSettings) (*Builder, error) { - d := &DialerChain{ - Net: netDialer(settings.Timeout), - } - resolveViaSocks5 := false - withProxy := settings.Socks5.URL != "" - if withProxy { - d.AddLayer(SOCKS5Layer(&settings.Socks5)) - resolveViaSocks5 = !settings.Socks5.LocalResolve - } - - // insert empty placeholder, so address can be replaced in dialer chain - // by replacing this placeholder dialer - idx := len(d.Layers) - d.AddLayer(IDLayer()) - - // add tls layer doing the TLS handshake based on the original address - if tls := settings.TLS; tls != nil { - d.AddLayer(TLSLayer(tls, settings.Timeout)) - } - - // validate dialerchain - if err := d.TestBuild(); err != nil { - return nil, err - } - - return &Builder{ - template: d, - addrIndex: idx, - resolveViaSocks5: resolveViaSocks5, - }, nil -} - -// AddLayer adds another custom network layer to the dialer chain. -func (b *Builder) AddLayer(l Layer) { - b.template.AddLayer(l) -} - -// Build create a new dialer, that will always use the constant address, no matter -// which address is used to connect using the dialer. -// The dialer chain will add per layer information to the given event. -func (b *Builder) Build(addr string, event *beat.Event) (transport.Dialer, error) { - // clone template, as multiple instance of a dialer can exist at the same time - dchain := b.template.Clone() - - // fix the final dialers TCP-level address - dchain.Layers[b.addrIndex] = ConstAddrLayer(addr) - - // create dialer chain with event to add per network layer information - d, err := dchain.Build(event) - return d, err -} - -// Run executes the given function with a new dialer instance. -func (b *Builder) Run( - event *beat.Event, - addr string, - fn func(*beat.Event, transport.Dialer) error, -) error { - dialer, err := b.Build(addr, event) - if err != nil { - return err - } - - return fn(event, dialer) -} - -// MakeDialerJobs creates a set of monitoring jobs. The jobs behavior depends -// on the builder, endpoint and mode configurations, normally set by user -// configuration. The task to execute the actual 'ping' receives the dialer -// and the address pair (:), required to be used, to ping the -// correctly resolved endpoint. -func MakeDialerJobs( - b *Builder, - scheme string, - endpoints []Endpoint, - mode monitors.IPSettings, - fn func(event *beat.Event, dialer transport.Dialer, addr string) error, -) ([]jobs.Job, error) { - var jobs []jobs.Job - for _, endpoint := range endpoints { - for _, port := range endpoint.Ports { - endpointURL, err := url.Parse(fmt.Sprintf("%s://%s:%d", scheme, endpoint.Host, port)) - if err != nil { - return nil, err - } - endpointJob, err := makeEndpointJob(b, endpointURL, mode, fn) - if err != nil { - return nil, err - } - jobs = append(jobs, wrappers.WithURLField(endpointURL, endpointJob)) - } - - } - - return jobs, nil -} - -func makeEndpointJob( - b *Builder, - endpointURL *url.URL, - mode monitors.IPSettings, - fn func(*beat.Event, transport.Dialer, string) error, -) (jobs.Job, error) { - - // Check if SOCKS5 is configured, with relying on the socks5 proxy - // in resolving the actual IP. - // Create one job for every port number configured. - if b.resolveViaSocks5 { - return wrappers.WithURLField(endpointURL, - jobs.MakeSimpleJob(func(event *beat.Event) error { - hostPort := net.JoinHostPort(endpointURL.Hostname(), endpointURL.Port()) - return b.Run(event, hostPort, func(event *beat.Event, dialer transport.Dialer) error { - return fn(event, dialer, hostPort) - }) - })), nil - } - - // Create job that first resolves one or multiple IP (depending on - // config.Mode) in order to create one continuation Task per IP. - settings := monitors.MakeHostJobSettings(endpointURL.Hostname(), mode) - - job, err := monitors.MakeByHostJob(settings, - monitors.MakePingIPFactory( - func(event *beat.Event, ip *net.IPAddr) error { - // use address from resolved IP - ipPort := net.JoinHostPort(ip.String(), endpointURL.Port()) - cb := func(event *beat.Event, dialer transport.Dialer) error { - return fn(event, dialer, ipPort) - } - err := b.Run(event, ipPort, cb) - return err - })) - if err != nil { - return nil, err - } - return job, nil -} diff --git a/heartbeat/monitors/active/dialchain/net.go b/heartbeat/monitors/active/dialchain/dialers.go similarity index 94% rename from heartbeat/monitors/active/dialchain/net.go rename to heartbeat/monitors/active/dialchain/dialers.go index 5fc41270fb7..ddd34870d8d 100644 --- a/heartbeat/monitors/active/dialchain/net.go +++ b/heartbeat/monitors/active/dialchain/dialers.go @@ -45,7 +45,7 @@ import ( // } // } func TCPDialer(to time.Duration) NetDialer { - return netDialer(to) + return CreateNetDialer(to) } // UDPDialer creates a new NetDialer with constant event fields and default @@ -62,10 +62,11 @@ func TCPDialer(to time.Duration) NetDialer { // } // } func UDPDialer(to time.Duration) NetDialer { - return netDialer(to) + return CreateNetDialer(to) } -func netDialer(timeout time.Duration) NetDialer { +// CreateNetDialer returns a NetDialer with the given timeout. +func CreateNetDialer(timeout time.Duration) NetDialer { return func(event *beat.Event) (transport.Dialer, error) { return makeDialer(func(network, address string) (net.Conn, error) { namespace := "" diff --git a/heartbeat/monitors/active/dialchain/util.go b/heartbeat/monitors/active/dialchain/util.go index 7329cf81be8..7b29ca80b39 100644 --- a/heartbeat/monitors/active/dialchain/util.go +++ b/heartbeat/monitors/active/dialchain/util.go @@ -29,15 +29,6 @@ type timer struct { s, e time.Time } -// IDLayer creates an empty placeholder layer. -func IDLayer() Layer { - return _idLayer -} - -var _idLayer = Layer(func(event *beat.Event, next transport.Dialer) (transport.Dialer, error) { - return next, nil -}) - // ConstAddrLayer introduces a network layer always passing a constant address // to the underlying layer. func ConstAddrLayer(address string) Layer { diff --git a/heartbeat/monitors/active/http/task.go b/heartbeat/monitors/active/http/task.go index e750faf6d12..65d2f1ae62c 100644 --- a/heartbeat/monitors/active/http/task.go +++ b/heartbeat/monitors/active/http/task.go @@ -94,10 +94,8 @@ func newHTTPMonitorIPsJob( return nil, err } - settings := monitors.MakeHostJobSettings(hostname, config.Mode) - pingFactory := createPingFactory(config, port, tls, req, body, validator) - job, err := monitors.MakeByHostJob(settings, pingFactory) + job, err := monitors.MakeByHostJob(hostname, config.Mode, monitors.NewStdResolver(), pingFactory) return job, err } diff --git a/heartbeat/monitors/active/icmp/icmp.go b/heartbeat/monitors/active/icmp/icmp.go index 66bc0b2adc6..1cb19c90798 100644 --- a/heartbeat/monitors/active/icmp/icmp.go +++ b/heartbeat/monitors/active/icmp/icmp.go @@ -71,8 +71,7 @@ func create( pingFactory := monitors.MakePingIPFactory(createPingIPFactory(&config)) for _, host := range config.Hosts { - settings := monitors.MakeHostJobSettings(host, config.Mode) - job, err := monitors.MakeByHostJob(settings, pingFactory) + job, err := monitors.MakeByHostJob(host, config.Mode, monitors.NewStdResolver(), pingFactory) if err != nil { return nil, 0, err diff --git a/heartbeat/monitors/active/tcp/config.go b/heartbeat/monitors/active/tcp/config.go index 45fa4abc6e0..692af5ba65f 100644 --- a/heartbeat/monitors/active/tcp/config.go +++ b/heartbeat/monitors/active/tcp/config.go @@ -26,7 +26,7 @@ import ( "github.com/elastic/beats/v7/libbeat/common/transport/tlscommon" ) -type Config struct { +type config struct { // check all ports if host does not contain port Hosts []string `config:"hosts" validate:"required"` Ports []uint16 `config:"ports"` @@ -45,12 +45,14 @@ type Config struct { ReceiveString string `config:"check.receive"` } -var DefaultConfig = Config{ - Timeout: 16 * time.Second, - Mode: monitors.DefaultIPSettings, +func defaultConfig() config { + return config{ + Timeout: 16 * time.Second, + Mode: monitors.DefaultIPSettings, + } } -func (c *Config) Validate() error { +func (c *config) Validate() error { if c.Socks5.URL != "" { if c.Mode.Mode != monitors.PingAny && !c.Socks5.LocalResolve { return errors.New("ping all ips only supported if proxy_use_local_resolver is enabled`") diff --git a/heartbeat/monitors/active/tcp/check.go b/heartbeat/monitors/active/tcp/datacheck.go similarity index 86% rename from heartbeat/monitors/active/tcp/check.go rename to heartbeat/monitors/active/tcp/datacheck.go index 0f8d631ce56..b465052f6f5 100644 --- a/heartbeat/monitors/active/tcp/check.go +++ b/heartbeat/monitors/active/tcp/datacheck.go @@ -23,18 +23,20 @@ import ( "net" ) -type ConnCheck func(net.Conn) error +// dataCheck executes over an open TCP connection using the send / receive +// parameters the user has defined. +type dataCheck func(net.Conn) error var ( errNoDataReceived = errors.New("no data") errRecvMismatch = errors.New("received string mismatch") ) -func (c ConnCheck) Validate(conn net.Conn) error { +func (c dataCheck) Check(conn net.Conn) error { return c(conn) } -func makeValidateConn(config *Config) ConnCheck { +func makeDataCheck(config *config) dataCheck { send := config.SendString recv := config.ReceiveString @@ -52,7 +54,7 @@ func makeValidateConn(config *Config) ConnCheck { func checkOk(_ net.Conn) error { return nil } -func checkAll(checks ...ConnCheck) ConnCheck { +func checkAll(checks ...dataCheck) dataCheck { return func(conn net.Conn) error { for _, check := range checks { if err := check(conn); err != nil { @@ -63,13 +65,13 @@ func checkAll(checks ...ConnCheck) ConnCheck { } } -func checkSend(buf []byte) ConnCheck { +func checkSend(buf []byte) dataCheck { return func(conn net.Conn) error { return sendBuffer(conn, buf) } } -func checkRecv(expected []byte) ConnCheck { +func checkRecv(expected []byte) dataCheck { return func(conn net.Conn) error { buf := make([]byte, len(expected)) if err := recvBuffer(conn, buf); err != nil { diff --git a/heartbeat/monitors/active/tcp/endpoint.go b/heartbeat/monitors/active/tcp/endpoint.go new file mode 100644 index 00000000000..e3c87531ffc --- /dev/null +++ b/heartbeat/monitors/active/tcp/endpoint.go @@ -0,0 +1,108 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "fmt" + "net" + "net/url" + "strconv" + + "github.com/pkg/errors" +) + +// endpoint configures a host with all port numbers to be monitored by a dialer +// based job. +type endpoint struct { + Scheme string + Hostname string + Ports []uint16 +} + +// perPortURLs returns a list containing one URL per port +func (e endpoint) perPortURLs() (urls []*url.URL) { + for _, port := range e.Ports { + urls = append(urls, &url.URL{ + Scheme: e.Scheme, + Host: net.JoinHostPort(e.Hostname, strconv.Itoa(int(port))), + }) + } + + return urls +} + +// makeEndpoints creates a single endpoint struct for each host/port permutation. +// Set `defaultScheme` to choose which scheme is used if not explicit in the host config. +func makeEndpoints(hosts []string, ports []uint16, defaultScheme string) (endpoints []endpoint, err error) { + for _, h := range hosts { + u, err := url.Parse(h) + + // If h is just a bare hostname like 'localhost' it will be parsed as the URL path, and host will + // be blank + var ep endpoint + if err == nil && u.Host != "" { + ep, err = makeURLEndpoint(u, ports) + if err != nil { + return nil, err + } + } else { + u := &url.URL{Scheme: defaultScheme, Host: h} + ep, err = makeURLEndpoint(u, ports) + if err != nil { + return nil, err + } + } + endpoints = append(endpoints, ep) + } + return endpoints, nil +} + +func makeURLEndpoint(u *url.URL, ports []uint16) (endpoint, error) { + switch u.Scheme { + case "tcp", "plain", "tls", "ssl": + default: + err := fmt.Errorf( + "'%s' is not a supported connection scheme in '%s', supported schemes are tcp, plain, tls, and ssl", + u.Scheme, + u, + ) + return endpoint{}, err + } + + if u.Port() != "" { + pUint, err := strconv.ParseUint(u.Port(), 10, 16) + if err != nil { + return endpoint{}, errors.Wrapf(err, "no port(s) defined for TCP endpoint %s", u) + } + ports = []uint16{uint16(pUint)} + } + + if len(ports) == 0 { + return endpoint{}, fmt.Errorf("host '%s' missing port number", u) + } + + if u.Hostname() == "" || u.Hostname() == ":" { + return endpoint{}, fmt.Errorf("could not parse tcp host '%s'", u) + } + + return endpoint{ + Scheme: u.Scheme, + Hostname: u.Hostname(), + Ports: ports, + }, nil +} diff --git a/heartbeat/monitors/active/tcp/endpoint_test.go b/heartbeat/monitors/active/tcp/endpoint_test.go new file mode 100644 index 00000000000..5e871914b40 --- /dev/null +++ b/heartbeat/monitors/active/tcp/endpoint_test.go @@ -0,0 +1,109 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMakeEndpoints(t *testing.T) { + type args struct { + hosts []string + ports []uint16 + defaultScheme string + } + tests := []struct { + name string + args args + wantEndpoints []endpoint + wantErr bool + }{ + { + "hostname", + args{[]string{"localhost"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "localhost", Ports: []uint16{123}}}, + false, + }, + { + "ipv4", + args{[]string{"1.2.3.4"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "1.2.3.4", Ports: []uint16{123}}}, + false, + }, + { + "unbracketed ipv6", + args{[]string{"::1"}, []uint16{123}, "tcp"}, + []endpoint{}, + true, + }, + { + "bracketed ipv6", + args{[]string{"[::1]"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "::1", Ports: []uint16{123}}}, + false, + }, + { + "url", + args{[]string{"tls://example.net"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tls", Hostname: "example.net", Ports: []uint16{123}}}, + false, + }, + { + "url:port", + args{[]string{"example.net:999"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tcp", Hostname: "example.net", Ports: []uint16{999}}}, + false, + }, + { + "scheme://url:port", + args{[]string{"tls://example.net:999"}, []uint16{123}, "tcp"}, + []endpoint{{Scheme: "tls", Hostname: "example.net", Ports: []uint16{999}}}, + false, + }, + { + "hybrid", + args{[]string{ + "localhost", + "192.168.0.1", + "[2607:f8b0:4004:814::200e]", + "example.net:999", + "tls://example.net:999", + }, []uint16{123}, "tcp"}, + []endpoint{ + {Scheme: "tcp", Hostname: "localhost", Ports: []uint16{123}}, + {Scheme: "tcp", Hostname: "192.168.0.1", Ports: []uint16{123}}, + {Scheme: "tcp", Hostname: "2607:f8b0:4004:814::200e", Ports: []uint16{123}}, + {Scheme: "tcp", Hostname: "example.net", Ports: []uint16{999}}, + {Scheme: "tls", Hostname: "example.net", Ports: []uint16{999}}, + }, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotEndpoints, err := makeEndpoints(tt.args.hosts, tt.args.ports, tt.args.defaultScheme) + if tt.wantErr { + require.Error(t, err) + return + } + require.Equal(t, tt.wantEndpoints, gotEndpoints) + }) + } +} diff --git a/heartbeat/monitors/active/tcp/helpers_test.go b/heartbeat/monitors/active/tcp/helpers_test.go new file mode 100644 index 00000000000..d1a8c1b5bc1 --- /dev/null +++ b/heartbeat/monitors/active/tcp/helpers_test.go @@ -0,0 +1,87 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/pkg/errors" + + "github.com/stretchr/testify/require" + + "github.com/elastic/beats/v7/heartbeat/hbtest" + "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" + "github.com/elastic/beats/v7/heartbeat/scheduler/schedule" + "github.com/elastic/beats/v7/libbeat/beat" + "github.com/elastic/beats/v7/libbeat/common" +) + +func testTCPConfigCheck(t *testing.T, configMap common.MapStr, host string, port uint16) *beat.Event { + config, err := common.NewConfigFrom(configMap) + require.NoError(t, err) + + jobs, endpoints, err := create("tcp", config) + require.NoError(t, err) + + sched, _ := schedule.Parse("@every 1s") + job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] + + event := &beat.Event{} + _, err = job(event) + require.NoError(t, err) + + require.Equal(t, 1, endpoints) + + return event +} + +func setupServer(t *testing.T, serverCreator func(http.Handler) (*httptest.Server, error)) (*httptest.Server, uint16, error) { + server, err := serverCreator(hbtest.HelloWorldHandler(200)) + if err != nil { + return nil, 0, err + } + + port, err := hbtest.ServerPort(server) + if err != nil { + return nil, 0, err + } + + return server, port, nil +} + +// newHostTestServer starts a server listening on the IP resolved from the host arg +// httptest.NewServer() binds explicitly on 127.0.0.1 (or [::1] if ipv4 is not available). +// The IP resolved from `localhost` can be a different one, like 127.0.1.1. +func newHostTestServer(handler http.Handler, host string) (*httptest.Server, error) { + listener, err := net.Listen("tcp", net.JoinHostPort(host, "0")) + if err != nil { + return nil, errors.Wrapf(err, "failed to listen on host '%s'", host) + } + + server := &httptest.Server{ + Listener: listener, + Config: &http.Server{Handler: handler}, + } + server.Start() + + return server, nil +} diff --git a/heartbeat/monitors/active/tcp/socks5_test.go b/heartbeat/monitors/active/tcp/socks5_test.go new file mode 100644 index 00000000000..47c691dbdea --- /dev/null +++ b/heartbeat/monitors/active/tcp/socks5_test.go @@ -0,0 +1,134 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "fmt" + "net" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/armon/go-socks5" + "github.com/stretchr/testify/require" + + "github.com/elastic/beats/v7/heartbeat/hbtest" + "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/go-lookslike" + "github.com/elastic/go-lookslike/isdef" + "github.com/elastic/go-lookslike/testslike" +) + +func TestSocks5Job(t *testing.T) { + scenarios := []struct { + name string + localResolver bool + }{ + { + name: "using local resolver", + localResolver: true, + }, + { + name: "not using local resolver", + localResolver: false, + }, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + host, port, ip, closeEcho, err := startEchoServer(t) + require.NoError(t, err) + defer closeEcho() + + _, proxyPort, proxyIp, closeProxy, err := startSocks5Server(t) + require.NoError(t, err) + defer closeProxy() + + proxyURL := &url.URL{Scheme: "socks5", Host: net.JoinHostPort(proxyIp, fmt.Sprint(proxyPort))} + configMap := common.MapStr{ + "hosts": host, + "ports": port, + "timeout": "1s", + "proxy_url": proxyURL.String(), + "proxy_use_local_resolver": scenario.localResolver, + "check.receive": "echo123", + "check.send": "echo123", + } + event := testTCPConfigCheck(t, configMap, host, port) + + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.BaseChecks(ip, "up", "tcp"), + hbtest.RespondingTCPChecks(), + hbtest.SimpleURLChecks(t, "tcp", host, port), + hbtest.SummaryChecks(1, 0), + hbtest.ResolveChecks(ip), + lookslike.MustCompile(map[string]interface{}{ + "tcp": map[string]interface{}{ + "rtt.validate.us": isdef.IsDuration, + }, + "socks5": map[string]interface{}{ + "rtt.connect.us": isdef.IsDuration, + }, + }), + )), + event.Fields, + ) + }) + } +} + +func startSocks5Server(t *testing.T) (host string, port uint16, ip string, close func() error, err error) { + host = "localhost" + config := &socks5.Config{} + server, err := socks5.New(config) + if err != nil { + return "", 0, "", nil, err + } + + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", 0, "", nil, err + } + ip, portStr, err := net.SplitHostPort(listener.Addr().String()) + portUint64, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + listener.Close() + return "", 0, "", nil, err + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + if err := server.Serve(listener); err != nil { + debugf("Error in SOCKS5 Test Server %v", err) + } + wg.Done() + }() + + return host, uint16(portUint64), ip, func() error { + err := listener.Close() + if err != nil { + return err + } + wg.Wait() + return nil + }, nil +} diff --git a/heartbeat/monitors/active/tcp/task.go b/heartbeat/monitors/active/tcp/task.go deleted file mode 100644 index 3d830c0c7e9..00000000000 --- a/heartbeat/monitors/active/tcp/task.go +++ /dev/null @@ -1,77 +0,0 @@ -// Licensed to Elasticsearch B.V. under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Elasticsearch B.V. licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package tcp - -import ( - "time" - - "github.com/elastic/beats/v7/heartbeat/eventext" - "github.com/elastic/beats/v7/heartbeat/look" - "github.com/elastic/beats/v7/heartbeat/reason" - "github.com/elastic/beats/v7/libbeat/beat" - "github.com/elastic/beats/v7/libbeat/common" - "github.com/elastic/beats/v7/libbeat/common/transport" -) - -func pingHost( - event *beat.Event, - dialer transport.Dialer, - addr string, - timeout time.Duration, - validator ConnCheck, -) error { - start := time.Now() - deadline := start.Add(timeout) - - conn, err := dialer.Dial("tcp", addr) - if err != nil { - debugf("dial failed with: %v", err) - return reason.IOFailed(err) - } - defer conn.Close() - if validator == nil { - // no additional validation step => ping success - return nil - } - - if err := conn.SetDeadline(deadline); err != nil { - debugf("setting connection deadline failed with: %v", err) - return reason.IOFailed(err) - } - - validateStart := time.Now() - err = validator.Validate(conn) - if err != nil && err != errRecvMismatch { - debugf("check failed with: %v", err) - return reason.IOFailed(err) - } - - end := time.Now() - eventext.MergeEventFields(event, common.MapStr{ - "tcp": common.MapStr{ - "rtt": common.MapStr{ - "validate": look.RTT(end.Sub(validateStart)), - }, - }, - }) - if err != nil { - return reason.MakeValidateError(err) - } - - return nil -} diff --git a/heartbeat/monitors/active/tcp/tcp.go b/heartbeat/monitors/active/tcp/tcp.go index bf584095fc4..05c687dd65b 100644 --- a/heartbeat/monitors/active/tcp/tcp.go +++ b/heartbeat/monitors/active/tcp/tcp.go @@ -18,14 +18,18 @@ package tcp import ( - "fmt" + "net" "net/url" - "strconv" - "strings" + "time" + + "github.com/elastic/beats/v7/heartbeat/eventext" + "github.com/elastic/beats/v7/heartbeat/look" + "github.com/elastic/beats/v7/heartbeat/reason" "github.com/elastic/beats/v7/heartbeat/monitors" "github.com/elastic/beats/v7/heartbeat/monitors/active/dialchain" "github.com/elastic/beats/v7/heartbeat/monitors/jobs" + "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" "github.com/elastic/beats/v7/libbeat/common/transport" @@ -39,113 +43,220 @@ func init() { var debugf = logp.MakeDebug("tcp") -type connURL struct { - Scheme string - Host string - Ports []uint16 -} - func create( name string, cfg *common.Config, ) (jobs []jobs.Job, endpoints int, err error) { - config := DefaultConfig - if err := cfg.Unpack(&config); err != nil { + return createWithResolver(cfg, monitors.NewStdResolver()) +} + +// Custom resolver is useful for tests against hostnames locally where we don't want to depend on any +// hostnames existing in test environments +func createWithResolver( + cfg *common.Config, + resolver monitors.Resolver, +) (jobs []jobs.Job, endpoints int, err error) { + jc, err := newJobFactory(cfg, resolver) + if err != nil { return nil, 0, err } - tls, err := tlscommon.LoadTLSConfig(config.TLS) + jobs, err = jc.makeJobs() if err != nil { return nil, 0, err } - defaultScheme := "tcp" - if tls != nil { - defaultScheme = "ssl" + return jobs, len(jc.endpoints), nil +} + +// jobFactory is where most of the logic here lives. It provides a common context around +// the complex logic of executing a TCP check. +type jobFactory struct { + config config + tlsConfig *tlscommon.TLSConfig + defaultScheme string + endpoints []endpoint + dataCheck dataCheck + resolver monitors.Resolver +} + +func newJobFactory(commonCfg *common.Config, resolver monitors.Resolver) (*jobFactory, error) { + jf := &jobFactory{config: defaultConfig(), resolver: resolver} + err := jf.loadConfig(commonCfg) + if err != nil { + return nil, err + } + + return jf, nil +} + +// loadConfig parses the YAML config and populates the jobFactory fields. +func (jf *jobFactory) loadConfig(commonCfg *common.Config) error { + var err error + if err = commonCfg.Unpack(&jf.config); err != nil { + return err } - schemeHosts, err := collectHosts(&config, defaultScheme) + jf.tlsConfig, err = tlscommon.LoadTLSConfig(jf.config.TLS) if err != nil { - return nil, 0, err + return err } - timeout := config.Timeout - validator := makeValidateConn(&config) + jf.defaultScheme = "tcp" + if jf.tlsConfig != nil { + jf.defaultScheme = "ssl" + } - for scheme, eps := range schemeHosts { - schemeTLS := tls - if scheme == "tcp" || scheme == "plain" { - schemeTLS = nil - } + jf.endpoints, err = makeEndpoints(jf.config.Hosts, jf.config.Ports, jf.defaultScheme) + if err != nil { + return err + } - db, err := dialchain.NewBuilder(dialchain.BuilderSettings{ - Timeout: timeout, - Socks5: config.Socks5, - TLS: schemeTLS, - }) - if err != nil { - return nil, 0, err - } + jf.dataCheck = makeDataCheck(&jf.config) - epJobs, err := dialchain.MakeDialerJobs(db, scheme, eps, config.Mode, - func(event *beat.Event, dialer transport.Dialer, addr string) error { - return pingHost(event, dialer, addr, timeout, validator) - }) - if err != nil { - return nil, 0, err + return nil +} + +// makeJobs returns the actual schedulable jobs for this monitor. +func (jf *jobFactory) makeJobs() ([]jobs.Job, error) { + var jobs []jobs.Job + for _, endpoint := range jf.endpoints { + for _, url := range endpoint.perPortURLs() { + endpointJob, err := jf.makeEndpointJob(url) + if err != nil { + return nil, err + } + jobs = append(jobs, wrappers.WithURLField(url, endpointJob)) } - jobs = append(jobs, epJobs...) } - numHosts := 0 - for _, hosts := range schemeHosts { - numHosts += len(hosts) + return jobs, nil +} + +// makeEndpointJob makes a job for a single check of a single scheme/host/port combo. +func (jf *jobFactory) makeEndpointJob(endpointURL *url.URL) (jobs.Job, error) { + // Check if SOCKS5 is configured, with relying on the socks5 proxy + // in resolving the actual IP. + // Create one job for every port number configured. + if jf.config.Socks5.URL != "" && !jf.config.Socks5.LocalResolve { + jf.makeSocksLookupEndpointJob(endpointURL) } - return jobs, numHosts, nil + return jf.makeDirectEndpointJob(endpointURL) } -func collectHosts(config *Config, defaultScheme string) (map[string][]dialchain.Endpoint, error) { - endpoints := map[string][]dialchain.Endpoint{} - for _, h := range config.Hosts { - scheme := defaultScheme - host := "" - u, err := url.Parse(h) - - if err != nil || u.Host == "" { - host = h - } else { - scheme = u.Scheme - host = u.Host - } - debugf("Add tcp endpoint '%v://%v'.", scheme, host) +// makeDirectEndpointJob makes jobs that directly lookup the IP of the endpoints, as opposed to using +// a Socks5 proxy. +func (jf *jobFactory) makeDirectEndpointJob(endpointURL *url.URL) (jobs.Job, error) { + // Create job that first resolves one or multiple IPs (depending on + // config.Mode) in order to create one continuation Task per IP. + job, err := monitors.MakeByHostJob( + endpointURL.Hostname(), + jf.config.Mode, + jf.resolver, + monitors.MakePingIPFactory( + func(event *beat.Event, ip *net.IPAddr) error { + // use address from resolved IP + ipPort := net.JoinHostPort(ip.String(), endpointURL.Port()) - switch scheme { - case "tcp", "plain", "tls", "ssl": - default: - err := fmt.Errorf("'%v' is no supported connection scheme in '%v'", scheme, h) - return nil, err - } + return jf.dial(event, ipPort, endpointURL) + })) + if err != nil { + return nil, err + } + return job, nil +} - pair := strings.SplitN(host, ":", 2) - ports := config.Ports - if len(pair) == 2 { - port, err := strconv.ParseUint(pair[1], 10, 16) - if err != nil { - return nil, fmt.Errorf("'%v' is no valid port number in '%v'", pair[1], h) - } +// makeSocksLookupEndpointJob makes jobs that use a Socks5 proxy to perform DNS lookups +func (jf *jobFactory) makeSocksLookupEndpointJob(endpointURL *url.URL) (jobs.Job, error) { + return wrappers.WithURLField(endpointURL, + jobs.MakeSimpleJob(func(event *beat.Event) error { + hostPort := net.JoinHostPort(endpointURL.Hostname(), endpointURL.Port()) + return jf.dial(event, hostPort, endpointURL) + })), nil +} - ports = []uint16{uint16(port)} - host = pair[0] - } else if len(config.Ports) == 0 { - return nil, fmt.Errorf("host '%v' missing port number", h) - } +// dial builds a dialer and executes the network request. +// dialAddr is the host:port that the dialer will connect to, and where an explicit IP should go to. +// canonicalURL is the URL used to determine if TLS is used via the scheme of the URL, and +// also which hostname should be passed to the TLS implementation for validation of the server cert. +func (jf *jobFactory) dial(event *beat.Event, dialAddr string, canonicalURL *url.URL) error { + // First, create a plain dialer that can connect directly to either hostnames or IPs + dc := &dialchain.DialerChain{ + Net: dialchain.CreateNetDialer(jf.config.Timeout), + } + + // If Socks5 is configured make that the next layer, since everything needs to go through the proxy first. + if jf.config.Socks5.URL != "" { + dc.AddLayer(dialchain.SOCKS5Layer(&jf.config.Socks5)) + } + + // Now add the IP or Hostname of the server we want to connect to. + // Usually this is the IP we've resolved in a prior step. + // If we're using a proxy with host lookup enabled the dialAddr should be the + // hostname we want the server to resolve for us. + dc.AddLayer(dialchain.ConstAddrLayer(dialAddr)) + + // If we're using TLS we need to add a fake layer so that the TLS layer knows the hostname we're connecting to + // So, the canonical URL is fixed via a ConstAddrLayer to override the TLS layer's x509 logic so it doesn't + // try and directly match the IP from the prior ConstAddrLayer to the cert. + if canonicalURL.Scheme != "tcp" && canonicalURL.Scheme != "plain" { + dc.AddLayer(dialchain.TLSLayer(jf.tlsConfig, jf.config.Timeout)) + dc.AddLayer(dialchain.ConstAddrLayer(canonicalURL.Host)) + } + + dialer, err := dc.Build(event) + if err != nil { + return err + } + + return jf.execDialer(event, dialer, dialAddr) +} - endpoints[scheme] = append(endpoints[scheme], dialchain.Endpoint{ - Host: host, - Ports: ports, - }) +// exec dialer executes a network request against the given dialer. +func (jf *jobFactory) execDialer( + event *beat.Event, + dialer transport.Dialer, + addr string, +) error { + start := time.Now() + deadline := start.Add(jf.config.Timeout) + + conn, err := dialer.Dial("tcp", addr) + if err != nil { + debugf("dial failed with: %v", err) + return reason.IOFailed(err) } - return endpoints, nil + defer conn.Close() + if jf.dataCheck == nil { + // no additional validation step => ping success + return nil + } + + if err := conn.SetDeadline(deadline); err != nil { + debugf("setting connection deadline failed with: %v", err) + return reason.IOFailed(err) + } + + validateStart := time.Now() + err = jf.dataCheck.Check(conn) + if err != nil && err != errRecvMismatch { + debugf("check failed with: %v", err) + return reason.IOFailed(err) + } + + end := time.Now() + eventext.MergeEventFields(event, common.MapStr{ + "tcp": common.MapStr{ + "rtt": common.MapStr{ + "validate": look.RTT(end.Sub(validateStart)), + }, + }, + }) + if err != nil { + return reason.MakeValidateError(err) + } + + return nil } diff --git a/heartbeat/monitors/active/tcp/tcp_test.go b/heartbeat/monitors/active/tcp/tcp_test.go index 7980510f3f8..ef4746c5e3e 100644 --- a/heartbeat/monitors/active/tcp/tcp_test.go +++ b/heartbeat/monitors/active/tcp/tcp_test.go @@ -18,28 +18,24 @@ package tcp import ( - "crypto/x509" "fmt" "net" "net/http" "net/http/httptest" "net/url" - "os" "strconv" "testing" - "time" "github.com/stretchr/testify/require" "github.com/elastic/beats/v7/heartbeat/hbtest" - "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" - "github.com/elastic/beats/v7/heartbeat/scheduler/schedule" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" btesting "github.com/elastic/beats/v7/libbeat/testing" "github.com/elastic/go-lookslike" "github.com/elastic/go-lookslike/isdef" "github.com/elastic/go-lookslike/testslike" + "github.com/elastic/go-lookslike/validator" ) func testTCPCheck(t *testing.T, host string, port uint16) *beat.Event { @@ -51,139 +47,73 @@ func testTCPCheck(t *testing.T, host string, port uint16) *beat.Event { return testTCPConfigCheck(t, config, host, port) } -func testTCPConfigCheck(t *testing.T, configMap common.MapStr, host string, port uint16) *beat.Event { - config, err := common.NewConfigFrom(configMap) - require.NoError(t, err) - - jobs, endpoints, err := create("tcp", config) - require.NoError(t, err) - - sched, _ := schedule.Parse("@every 1s") - job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] - - event := &beat.Event{} - _, err = job(event) - require.NoError(t, err) - - require.Equal(t, 1, endpoints) - - return event -} - -func testTLSTCPCheck(t *testing.T, host string, port uint16, certFileName string) *beat.Event { - config, err := common.NewConfigFrom(common.MapStr{ - "hosts": host, - "ports": int64(port), - "ssl": common.MapStr{"certificate_authorities": certFileName}, - "timeout": "1s", - }) - require.NoError(t, err) - - jobs, endpoints, err := create("tcp", config) - require.NoError(t, err) - - sched, _ := schedule.Parse("@every 1s") - job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] - - event := &beat.Event{} - _, err = job(event) - require.NoError(t, err) - - require.Equal(t, 1, endpoints) - - return event -} - -func setupServer(t *testing.T, serverCreator func(http.Handler) *httptest.Server) (*httptest.Server, uint16) { - server := serverCreator(hbtest.HelloWorldHandler(200)) - - port, err := hbtest.ServerPort(server) - require.NoError(t, err) - - return server, port -} - -// newLocalhostTestServer starts a server listening on the IP resolved from `localhost` -// httptest.NewServer() binds explicitly on 127.0.0.1 (or [::1] if ipv4 is not available). -// The IP resolved from `localhost` can be a different one, like 127.0.1.1. -func newLocalhostTestServer(handler http.Handler) *httptest.Server { - listener, err := net.Listen("tcp", "localhost:0") - if err != nil { - panic("failed to listen on localhost: " + err.Error()) +// TestUpEndpointJob tests an up endpoint configured using either direct lookups or IPs +func TestUpEndpointJob(t *testing.T) { + // Test with domain, IPv4 and IPv6 + scenarios := []struct { + name string + hostname string + isIP bool + expectedIP string + }{ + { + name: "localhost", + hostname: "localhost", + isIP: false, + expectedIP: "127.0.0.1", + }, + { + name: "ipv4", + hostname: "127.0.0.1", + isIP: true, + expectedIP: "127.0.0.1", + }, + { + name: "ipv6", + hostname: "::1", + isIP: true, + }, } - server := &httptest.Server{ - Listener: listener, - Config: &http.Server{Handler: handler}, + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + server, port, err := setupServer(t, func(handler http.Handler) (*httptest.Server, error) { + return newHostTestServer(handler, scenario.hostname) + }) + // Some machines don't have ipv6 setup correctly, so we ignore the test + // if we can't bind to the port / setup the server. + if err != nil && scenario.hostname == "::1" { + return + } + require.NoError(t, err) + + defer server.Close() + + hostURL := &url.URL{Scheme: "tcp", Host: net.JoinHostPort(scenario.hostname, strconv.Itoa(int(port)))} + + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + event := testTCPCheck(t, hostURL.String(), port) + + validators := []validator.Validator{ + hbtest.BaseChecks(serverURL.Hostname(), "up", "tcp"), + hbtest.SummaryChecks(1, 0), + hbtest.URLChecks(t, hostURL), + hbtest.RespondingTCPChecks(), + } + + if !scenario.isIP { + validators = append(validators, hbtest.ResolveChecks(scenario.expectedIP)) + } + + testslike.Test( + t, + lookslike.Strict(lookslike.Compose(validators...)), + event.Fields, + ) + }) } - server.Start() - - return server -} - -func TestUpEndpointJob(t *testing.T) { - server, port := setupServer(t, newLocalhostTestServer) - defer server.Close() - - serverURL, err := url.Parse(server.URL) - require.NoError(t, err) - - event := testTCPCheck(t, "localhost", port) - - testslike.Test( - t, - lookslike.Strict(lookslike.Compose( - hbtest.BaseChecks(serverURL.Hostname(), "up", "tcp"), - hbtest.SummaryChecks(1, 0), - hbtest.SimpleURLChecks(t, "tcp", "localhost", port), - hbtest.RespondingTCPChecks(), - lookslike.MustCompile(map[string]interface{}{ - "resolve": map[string]interface{}{ - "ip": serverURL.Hostname(), - "rtt.us": isdef.IsDuration, - }, - }), - )), - event.Fields, - ) -} - -func TestTLSConnection(t *testing.T) { - // Start up a TLS Server - server, port := setupServer(t, httptest.NewTLSServer) - defer server.Close() - - // Parse its URL - serverURL, err := url.Parse(server.URL) - require.NoError(t, err) - - // Determine the IP address the server's hostname resolves to - ips, err := net.LookupHost(serverURL.Hostname()) - require.NoError(t, err) - require.Len(t, ips, 1) - ip := ips[0] - - // Parse the cert so we can test against it - cert, err := x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0]) - require.NoError(t, err) - - // Save the server's cert to a file so heartbeat can use it - certFile := hbtest.CertToTempFile(t, cert) - require.NoError(t, certFile.Close()) - defer os.Remove(certFile.Name()) - - event := testTLSTCPCheck(t, ip, port, certFile.Name()) - testslike.Test( - t, - lookslike.Strict(lookslike.Compose( - hbtest.TLSChecks(0, 0, cert), - hbtest.RespondingTCPChecks(), - hbtest.BaseChecks(ip, "up", "tcp"), - hbtest.SummaryChecks(1, 0), - hbtest.SimpleURLChecks(t, "ssl", serverURL.Hostname(), port), - )), - event.Fields, - ) } func TestConnectionRefusedEndpointJob(t *testing.T) { @@ -246,11 +176,8 @@ func TestCheckUp(t *testing.T) { hbtest.RespondingTCPChecks(), hbtest.SimpleURLChecks(t, "tcp", host, port), hbtest.SummaryChecks(1, 0), + hbtest.ResolveChecks(ip), lookslike.MustCompile(map[string]interface{}{ - "resolve": map[string]interface{}{ - "ip": ip, - "rtt.us": isdef.IsDuration, - }, "tcp": map[string]interface{}{ "rtt.validate.us": isdef.IsDuration, }, @@ -282,11 +209,8 @@ func TestCheckDown(t *testing.T) { hbtest.RespondingTCPChecks(), hbtest.SimpleURLChecks(t, "tcp", host, port), hbtest.SummaryChecks(0, 1), + hbtest.ResolveChecks(ip), lookslike.MustCompile(map[string]interface{}{ - "resolve": map[string]interface{}{ - "ip": ip, - "rtt.us": isdef.IsDuration, - }, "tcp": map[string]interface{}{ "rtt.validate.us": isdef.IsDuration, }, @@ -346,3 +270,36 @@ func startEchoServer(t *testing.T) (host string, port uint16, ip string, close f return "localhost", uint16(portUint64), ip, listener.Close, nil } + +// StaticResolver allows for a custom in-memory mapping of hosts to IPs, it ignores network names +// and zones. +type StaticResolver struct { + mapping map[string][]net.IP +} + +func NewStaticResolver(mapping map[string][]net.IP) StaticResolver { + return StaticResolver{mapping} +} + +func (s StaticResolver) ResolveIPAddr(network string, host string) (*net.IPAddr, error) { + found, err := s.LookupIP(host) + if err != nil { + return nil, err + } + return &net.IPAddr{IP: found[0]}, nil +} + +func (s StaticResolver) LookupIP(host string) ([]net.IP, error) { + if found, ok := s.mapping[host]; ok { + return found, nil + } else { + return nil, makeStaticNXDomainErr(host) + } +} + +func makeStaticNXDomainErr(host string) *net.DNSError { + return &net.DNSError{ + IsNotFound: true, + Err: fmt.Sprintf("Hostname '%s' not found in static resolver", host), + } +} diff --git a/heartbeat/monitors/active/tcp/tls_test.go b/heartbeat/monitors/active/tcp/tls_test.go new file mode 100644 index 00000000000..0628b1694b4 --- /dev/null +++ b/heartbeat/monitors/active/tcp/tls_test.go @@ -0,0 +1,169 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package tcp + +import ( + "crypto/x509" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/elastic/beats/v7/heartbeat/monitors/wrappers" + "github.com/elastic/beats/v7/heartbeat/scheduler/schedule" + "github.com/elastic/beats/v7/libbeat/beat" + "github.com/elastic/beats/v7/libbeat/common" + + "github.com/stretchr/testify/require" + + "github.com/elastic/beats/v7/heartbeat/hbtest" + "github.com/elastic/beats/v7/heartbeat/monitors" + "github.com/elastic/go-lookslike" + "github.com/elastic/go-lookslike/testslike" +) + +// Tests that we can check a TLS connection with a cert for a SAN IP +func TestTLSSANIPConnection(t *testing.T) { + ip, port, cert, certFile, teardown := setupTLSTestServer(t) + defer teardown() + + event := testTLSTCPCheck(t, ip, port, certFile.Name(), monitors.NewStdResolver()) + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.TLSChecks(0, 0, cert), + hbtest.RespondingTCPChecks(), + hbtest.BaseChecks(ip, "up", "tcp"), + hbtest.SummaryChecks(1, 0), + hbtest.SimpleURLChecks(t, "ssl", ip, port), + )), + event.Fields, + ) +} + +func TestTLSHostname(t *testing.T) { + ip, port, cert, certFile, teardown := setupTLSTestServer(t) + defer teardown() + + hostname := cert.DNSNames[0] // Should be example.com + resolver := NewStaticResolver(map[string][]net.IP{hostname: []net.IP{net.ParseIP(ip)}}) + event := testTLSTCPCheck(t, hostname, port, certFile.Name(), resolver) + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.TLSChecks(0, 0, cert), + hbtest.RespondingTCPChecks(), + hbtest.BaseChecks(ip, "up", "tcp"), + hbtest.SummaryChecks(1, 0), + hbtest.SimpleURLChecks(t, "ssl", hostname, port), + hbtest.ResolveChecks(ip), + )), + event.Fields, + ) +} + +func TestTLSInvalidCert(t *testing.T) { + ip, port, cert, certFile, teardown := setupTLSTestServer(t) + defer teardown() + + mismatchedHostname := "notadomain.elastic.co" + resolver := NewStaticResolver( + map[string][]net.IP{ + cert.DNSNames[0]: {net.ParseIP(ip)}, + mismatchedHostname: {net.ParseIP(ip)}, + }, + ) + event := testTLSTCPCheck(t, mismatchedHostname, port, certFile.Name(), resolver) + + testslike.Test( + t, + lookslike.Strict(lookslike.Compose( + hbtest.RespondingTCPChecks(), + hbtest.BaseChecks(ip, "down", "tcp"), + hbtest.SummaryChecks(0, 1), + hbtest.SimpleURLChecks(t, "ssl", mismatchedHostname, port), + hbtest.ResolveChecks(ip), + lookslike.MustCompile(map[string]interface{}{ + "error": map[string]interface{}{ + "message": x509.HostnameError{Certificate: cert, Host: mismatchedHostname}.Error(), + "type": "io", + }, + }), + )), + event.Fields, + ) +} + +func setupTLSTestServer(t *testing.T) (ip string, port uint16, cert *x509.Certificate, certFile *os.File, teardown func()) { + // Start up a TLS Server + server, port, err := setupServer(t, func(handler http.Handler) (*httptest.Server, error) { + return httptest.NewTLSServer(handler), nil + }) + require.NoError(t, err) + + // Parse its URL + serverURL, err := url.Parse(server.URL) + require.NoError(t, err) + + // Determine the IP address the server's hostname resolves to + ips, err := net.LookupHost(serverURL.Hostname()) + require.NoError(t, err) + require.Len(t, ips, 1) + ip = ips[0] + + // Parse the cert so we can test against it + cert, err = x509.ParseCertificate(server.TLS.Certificates[0].Certificate[0]) + require.NoError(t, err) + + // Save the server's cert to a file so heartbeat can use it + certFile = hbtest.CertToTempFile(t, cert) + require.NoError(t, certFile.Close()) + + return ip, port, cert, certFile, func() { + defer server.Close() + err := os.Remove(certFile.Name()) + require.NoError(t, err) + } +} + +func testTLSTCPCheck(t *testing.T, host string, port uint16, certFileName string, resolver monitors.Resolver) *beat.Event { + config, err := common.NewConfigFrom(common.MapStr{ + "hosts": host, + "ports": int64(port), + "ssl": common.MapStr{"certificate_authorities": certFileName}, + "timeout": "1s", + }) + require.NoError(t, err) + + jobs, endpoints, err := createWithResolver(config, resolver) + require.NoError(t, err) + + sched, _ := schedule.Parse("@every 1s") + job := wrappers.WrapCommon(jobs, "test", "", "tcp", sched, time.Duration(0))[0] + + event := &beat.Event{} + _, err = job(event) + require.NoError(t, err) + + require.Equal(t, 1, endpoints) + + return event +} diff --git a/heartbeat/monitors/resolver.go b/heartbeat/monitors/resolver.go new file mode 100644 index 00000000000..23b96868680 --- /dev/null +++ b/heartbeat/monitors/resolver.go @@ -0,0 +1,46 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package monitors + +import ( + "net" +) + +// Resolver lets us define custom DNS resolvers similar to what the go stdlib provides, but +// potentially with custom functionality +type Resolver interface { + // ResolveIPAddr is an analog of net.ResolveIPAddr + ResolveIPAddr(network string, host string) (*net.IPAddr, error) + // LookupIP is an analog of net.LookupIP + LookupIP(host string) ([]net.IP, error) +} + +// StdResolver uses the go std library to perform DNS resolution. +type StdResolver struct{} + +func NewStdResolver() StdResolver { + return StdResolver{} +} + +func (s StdResolver) ResolveIPAddr(network string, host string) (*net.IPAddr, error) { + return net.ResolveIPAddr(network, host) +} + +func (s StdResolver) LookupIP(host string) ([]net.IP, error) { + return net.LookupIP(host) +} diff --git a/heartbeat/monitors/util.go b/heartbeat/monitors/util.go index 8cf2679c0d8..31191e85e63 100644 --- a/heartbeat/monitors/util.go +++ b/heartbeat/monitors/util.go @@ -39,14 +39,6 @@ type IPSettings struct { Mode PingMode `config:"mode"` } -// HostJobSettings configures a Job including Host lookups and global fields to be added -// to every event. -type HostJobSettings struct { - Host string - IP IPSettings - Fields common.MapStr -} - // PingMode enumeration for configuring `any` or `all` IPs pinging. type PingMode uint8 @@ -111,7 +103,7 @@ func MakeByIPJob( pingFactory func(ip *net.IPAddr) jobs.Job, ) (jobs.Job, error) { // use ResolveIPAddr to parse the ip into net.IPAddr adding a zone info - // if ipv6 is used. + // if ipv6 is used. We intentionally do not use a custom resolver here. addr, err := net.ResolveIPAddr("ip", ip.String()) if err != nil { return nil, err @@ -130,39 +122,40 @@ func MakeByIPJob( // A pingFactory instance is normally build with MakePingIPFactory, // MakePingAllIPFactory or MakePingAllIPPortFactory. func MakeByHostJob( - settings HostJobSettings, + host string, + ipSettings IPSettings, + resolver Resolver, pingFactory func(ip *net.IPAddr) jobs.Job, ) (jobs.Job, error) { - host := settings.Host - if ip := net.ParseIP(host); ip != nil { return MakeByIPJob(ip, pingFactory) } - network := settings.IP.Network() + network := ipSettings.Network() if network == "" { return nil, errors.New("pinging hosts requires ipv4 or ipv6 mode enabled") } - mode := settings.IP.Mode + mode := ipSettings.Mode if mode == PingAny { - return makeByHostAnyIPJob(settings, host, pingFactory), nil + return makeByHostAnyIPJob(host, ipSettings, resolver, pingFactory), nil } - return makeByHostAllIPJob(settings, host, pingFactory), nil + return makeByHostAllIPJob(host, ipSettings, resolver, pingFactory), nil } func makeByHostAnyIPJob( - settings HostJobSettings, host string, + ipSettings IPSettings, + resolver Resolver, pingFactory func(ip *net.IPAddr) jobs.Job, ) jobs.Job { - network := settings.IP.Network() + network := ipSettings.Network() return func(event *beat.Event) ([]jobs.Job, error) { resolveStart := time.Now() - ip, err := net.ResolveIPAddr(network, host) + ip, err := resolver.ResolveIPAddr(network, host) if err != nil { return nil, err } @@ -176,11 +169,12 @@ func makeByHostAnyIPJob( } func makeByHostAllIPJob( - settings HostJobSettings, host string, + ipSettings IPSettings, + resolver Resolver, pingFactory func(ip *net.IPAddr) jobs.Job, ) jobs.Job { - network := settings.IP.Network() + network := ipSettings.Network() filter := makeIPFilter(network) return func(event *beat.Event) ([]jobs.Job, error) { @@ -265,34 +259,3 @@ func filterIPs(ips []net.IP, filt func(net.IP) bool) []net.IP { } return out } - -// MakeHostJobSettings creates a new HostJobSettings structure without any global -// event fields. -func MakeHostJobSettings(host string, ip IPSettings) HostJobSettings { - return HostJobSettings{Host: host, IP: ip} -} - -// WithFields adds new event fields to a Job. Existing fields will be -// overwritten. -// The fields map will be updated (no copy). -func (s HostJobSettings) WithFields(m common.MapStr) HostJobSettings { - s.AddFields(m) - return s -} - -// AddFields adds new event fields to a Job. Existing fields will be -// overwritten. -func (s *HostJobSettings) AddFields(m common.MapStr) { addFields(&s.Fields, m) } - -func addFields(to *common.MapStr, m common.MapStr) { - if m == nil { - return - } - - fields := *to - if fields == nil { - fields = common.MapStr{} - *to = fields - } - fields.DeepUpdate(m) -}