diff --git a/cmd/agent/subcommands/flare/command_test.go b/cmd/agent/subcommands/flare/command_test.go index f199a9a3a59b8..7af1aa5ba7c5b 100644 --- a/cmd/agent/subcommands/flare/command_test.go +++ b/cmd/agent/subcommands/flare/command_test.go @@ -29,14 +29,25 @@ import ( type commandTestSuite struct { suite.Suite sysprobeSocketPath string + tcpServer *httptest.Server + unixServer *httptest.Server } func (c *commandTestSuite) SetupSuite() { t := c.T() c.sysprobeSocketPath = path.Join(t.TempDir(), "sysprobe.sock") + c.tcpServer, c.unixServer = c.getPprofTestServer() } -func getPprofTestServer(t *testing.T, utsPath string) (tcpServer *httptest.Server, unixServer *httptest.Server) { +func (c *commandTestSuite) TearDownSuite() { + c.tcpServer.Close() + if c.unixServer != nil { + c.unixServer.Close() + } +} + +func (c *commandTestSuite) getPprofTestServer() (tcpServer *httptest.Server, unixServer *httptest.Server) { + t := c.T() handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/debug/pprof/heap": @@ -61,14 +72,12 @@ func getPprofTestServer(t *testing.T, utsPath string) (tcpServer *httptest.Serve if runtime.GOOS == "linux" { unixServer = httptest.NewUnstartedServer(handler) var err error - unixServer.Listener, err = net.Listen("unix", utsPath) - require.NoError(t, err, "could not create listener for unix socket on %s", utsPath) + unixServer.Listener, err = net.Listen("unix", c.sysprobeSocketPath) + require.NoError(t, err, "could not create listener for unix socket on %s", c.sysprobeSocketPath) unixServer.Start() - - return tcpServer, unixServer } - return tcpServer, tcpServer + return tcpServer, unixServer } func TestCommandTestSuite(t *testing.T) { @@ -77,13 +86,7 @@ func TestCommandTestSuite(t *testing.T) { func (c *commandTestSuite) TestReadProfileData() { t := c.T() - ts, uts := getPprofTestServer(t, c.sysprobeSocketPath) - t.Cleanup(func() { - ts.Close() - uts.Close() - }) - - u, err := url.Parse(ts.URL) + u, err := url.Parse(c.tcpServer.URL) require.NoError(t, err) port := u.Port() @@ -151,13 +154,7 @@ func (c *commandTestSuite) TestReadProfileData() { func (c *commandTestSuite) TestReadProfileDataNoTraceAgent() { t := c.T() - ts, uts := getPprofTestServer(t, c.sysprobeSocketPath) - t.Cleanup(func() { - ts.Close() - uts.Close() - }) - - u, err := url.Parse(ts.URL) + u, err := url.Parse(c.tcpServer.URL) require.NoError(t, err) port := u.Port()