From 861364b2ec885826130592c239524bdd09af5026 Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Fri, 28 Oct 2022 18:03:12 -0400 Subject: [PATCH] GODRIVER-2620 Fix hostname parsing for SRV polling. (#1112) * GODRIVER-2620 Fix hostname parsing for SRV polling. Co-authored-by: Preston Vasquez <24281431+prestonvasquez@users.noreply.github.com> Co-authored-by: Kevin Albertson --- .../topology/polling_srv_records_test.go | 15 ++++++++++- x/mongo/driver/topology/topology.go | 27 ++++++++++++------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/x/mongo/driver/topology/polling_srv_records_test.go b/x/mongo/driver/topology/polling_srv_records_test.go index 205e701cad..2fe0bc8149 100644 --- a/x/mongo/driver/topology/polling_srv_records_test.go +++ b/x/mongo/driver/topology/polling_srv_records_test.go @@ -125,9 +125,22 @@ func compareHosts(t *testing.T, received []description.Server, expected []string } func TestPollingSRVRecordsSpec(t *testing.T) { + for _, uri := range []string{ + "mongodb+srv://test1.test.build.10gen.cc/?heartbeatFrequencyMS=100", + // Test with user:pass as a regression test for GODRIVER-2620 + "mongodb+srv://user:pass@test1.test.build.10gen.cc/?heartbeatFrequencyMS=100", + } { + t.Run(uri, func(t *testing.T) { + testPollingSRVRecordsSpec(t, uri) + }) + } +} + +func testPollingSRVRecordsSpec(t *testing.T, uri string) { + t.Helper() for _, tt := range srvPollingTests { t.Run(tt.name, func(t *testing.T) { - cs, err := connstring.ParseAndValidate("mongodb+srv://test1.test.build.10gen.cc/?heartbeatFrequencyMS=100") + cs, err := connstring.ParseAndValidate(uri) require.NoError(t, err, "Problem parsing the uri: %v", err) topo, err := New( WithConnString(func(connstring.ConnString) connstring.ConnString { return cs }), diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index f02e5dea86..7e6ec17bca 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -14,6 +14,8 @@ import ( "context" "errors" "fmt" + "net" + "net/url" "strings" "sync" "sync/atomic" @@ -230,7 +232,21 @@ func (t *Topology) Connect() error { t.serversLock.Unlock() if t.pollingRequired { - go t.pollSRVRecords() + uri, err := url.Parse(t.cfg.uri) + if err != nil { + return err + } + // sanity check before passing the hostname to resolver + if parsedHosts := strings.Split(uri.Host, ","); len(parsedHosts) != 1 { + return fmt.Errorf("URI with SRV must include one and only one hostname") + } + _, _, err = net.SplitHostPort(uri.Host) + if err == nil { + // we were able to successfully extract a port from the host, + // but should not be able to when using SRV + return fmt.Errorf("URI with srv must not include a port number") + } + go t.pollSRVRecords(uri.Host) t.pollingwg.Add(1) } @@ -552,7 +568,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, return suitable, nil } -func (t *Topology) pollSRVRecords() { +func (t *Topology) pollSRVRecords(hosts string) { defer t.pollingwg.Done() serverConfig := newServerConfig(t.cfg.serverOpts...) @@ -569,13 +585,6 @@ func (t *Topology) pollSRVRecords() { } }() - // remove the scheme - uri := t.cfg.uri[14:] - hosts := uri - if idx := strings.IndexAny(uri, "/?@"); idx != -1 { - hosts = uri[:idx] - } - for { select { case <-pollTicker.C: