From 06859f7ab4f5c936ebf57da404c927275950001b Mon Sep 17 00:00:00 2001 From: Harry Pidcock Date: Fri, 11 Oct 2024 14:35:29 +1000 Subject: [PATCH] fix: close rpc conn when the pooled state is going --- apiserver/apiserver.go | 3 ++ apiserver/apiserver_test.go | 48 ++++++++++++++++++++++++++ apiserver/debuglog.go | 3 +- apiserver/debuglog_db.go | 3 ++ apiserver/debuglog_db_internal_test.go | 8 ++--- state/pool.go | 19 +++++++++- 6 files changed, 78 insertions(+), 6 deletions(-) diff --git a/apiserver/apiserver.go b/apiserver/apiserver.go index e2bc22c8b38..5d0c5aabf61 100644 --- a/apiserver/apiserver.go +++ b/apiserver/apiserver.go @@ -1132,9 +1132,11 @@ func (srv *Server) serveConn( h *apiHandler ) + var stateClosing <-chan struct{} st, err := statePool.Get(resolvedModelUUID) if err == nil { defer st.Release() + stateClosing = st.Removing() h, err = newAPIHandler(srv, st.State, conn, modelUUID, connectionID, host) } if errors.Is(err, errors.NotFound) { @@ -1158,6 +1160,7 @@ func (srv *Server) serveConn( select { case <-conn.Dead(): case <-srv.tomb.Dying(): + case <-stateClosing: } return conn.Close() } diff --git a/apiserver/apiserver_test.go b/apiserver/apiserver_test.go index 5c9308c42d8..d8bbe09178f 100644 --- a/apiserver/apiserver_test.go +++ b/apiserver/apiserver_test.go @@ -45,6 +45,7 @@ import ( "github.com/juju/juju/rpc/params" "github.com/juju/juju/state" statetesting "github.com/juju/juju/state/testing" + "github.com/juju/juju/storage" "github.com/juju/juju/testing" "github.com/juju/juju/worker/gate" "github.com/juju/juju/worker/modelcache" @@ -496,3 +497,50 @@ func (s *apiserverSuite) assertEmbeddedCommand(c *gc.C, cmdArgs params.CLIComman Error: resultErr, }) } + +// TestModelRemoveClosesRPC tests that when an RPC connection is opened +// to a model that is being removed, the connection is closed +// gracefully. +func (s *apiserverSuite) TestModelRemoveClosesRPC(c *gc.C) { + uuid, err := utils.NewUUID() + c.Assert(err, jc.ErrorIsNil) + modelConfig := testing.CustomModelConfig(c, testing.Attrs{ + "name": "testing", + "uuid": uuid.String(), + }) + + model, st, err := s.Controller.NewModel(state.ModelArgs{ + Type: state.ModelTypeIAAS, + CloudName: "dummy", + CloudRegion: "dummy-region", + Config: modelConfig, + Owner: s.Owner, + StorageProviderRegistry: storage.StaticProviderRegistry{}, + }) + c.Assert(err, jc.ErrorIsNil) + s.AddCleanup(func(c *gc.C) { + st.Close() + }) + + apiInfo := s.APIInfo(s.apiServer) + apiInfo.Tag = s.Owner + apiInfo.Password = ownerPassword + apiInfo.Nonce = "" + apiInfo.ModelTag = model.ModelTag() + + conn, err := api.Open(apiInfo, api.DialOpts{}) + c.Assert(err, jc.ErrorIsNil) + c.Assert(conn, gc.NotNil) + s.AddCleanup(func(c *gc.C) { + conn.Close() + }) + + removed, err := s.StatePool.Remove(model.UUID()) + c.Assert(err, jc.ErrorIsNil) + c.Assert(removed, jc.IsFalse) + + time.Sleep(testing.ShortWait) + + err = conn.APICall("Pinger", 1, "", "Ping", nil, nil) + c.Assert(err, gc.ErrorMatches, `connection is shut down`) +} diff --git a/apiserver/debuglog.go b/apiserver/debuglog.go index bf24d3df8fe..64a4e5017c8 100644 --- a/apiserver/debuglog.go +++ b/apiserver/debuglog.go @@ -41,6 +41,7 @@ type debugLogHandlerFunc func( debugLogParams, debugLogSocket, <-chan struct{}, + <-chan struct{}, ) error func newDebugLogHandler( @@ -119,7 +120,7 @@ func (h *debugLogHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { clock := h.ctxt.srv.clock maxDuration := h.ctxt.srv.shared.maxDebugLogDuration() - if err := h.handle(clock, maxDuration, st, params, socket, h.ctxt.stop()); err != nil { + if err := h.handle(clock, maxDuration, st, params, socket, h.ctxt.stop(), st.Removing()); err != nil { if isBrokenPipe(err) { logger.Tracef("debug-log handler stopped (client disconnected)") } else { diff --git a/apiserver/debuglog_db.go b/apiserver/debuglog_db.go index 472d11f6384..0efd515a6a0 100644 --- a/apiserver/debuglog_db.go +++ b/apiserver/debuglog_db.go @@ -31,6 +31,7 @@ func handleDebugLogDBRequest( reqParams debugLogParams, socket debugLogSocket, stop <-chan struct{}, + stateClosing <-chan struct{}, ) error { tailerParams := makeLogTailerParams(reqParams) tailer, err := newLogTailer(st, tailerParams) @@ -47,6 +48,8 @@ func handleDebugLogDBRequest( var lineCount uint for { select { + case <-stateClosing: + return nil case <-stop: return nil case <-timeout: diff --git a/apiserver/debuglog_db_internal_test.go b/apiserver/debuglog_db_internal_test.go index 5661397e39f..af42f3cf9c9 100644 --- a/apiserver/debuglog_db_internal_test.go +++ b/apiserver/debuglog_db_internal_test.go @@ -73,7 +73,7 @@ func (s *debugLogDBIntSuite) TestParamConversion(c *gc.C) { stop := make(chan struct{}) close(stop) // Stop the request immediately. - err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, stop) + err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, stop, nil) c.Assert(err, jc.ErrorIsNil) c.Assert(called, jc.IsTrue) } @@ -96,7 +96,7 @@ func (s *debugLogDBIntSuite) TestParamConversionReplay(c *gc.C) { stop := make(chan struct{}) close(stop) // Stop the request immediately. - err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, stop) + err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, nil, stop) c.Assert(err, jc.ErrorIsNil) c.Assert(called, jc.IsTrue) } @@ -187,7 +187,7 @@ func (s *debugLogDBIntSuite) TestRequestStopsWhenTailerStops(c *gc.C) { return tailer, nil }) - err := handleDebugLogDBRequest(s.clock, s.timeout, nil, debugLogParams{}, s.sock, nil) + err := handleDebugLogDBRequest(s.clock, s.timeout, nil, debugLogParams{}, s.sock, nil, nil) c.Assert(err, jc.ErrorIsNil) c.Assert(tailer.stopped, jc.IsTrue) } @@ -225,7 +225,7 @@ func (s *debugLogDBIntSuite) TestMaxLines(c *gc.C) { func (s *debugLogDBIntSuite) runRequest(params debugLogParams, stop chan struct{}) chan error { done := make(chan error) go func() { - done <- handleDebugLogDBRequest(s.clock, s.timeout, &fakeState{}, params, s.sock, stop) + done <- handleDebugLogDBRequest(s.clock, s.timeout, &fakeState{}, params, s.sock, stop, nil) }() return done } diff --git a/state/pool.go b/state/pool.go index 9fdb87e4388..a2550d5c4f4 100644 --- a/state/pool.go +++ b/state/pool.go @@ -42,6 +42,7 @@ type PooledState struct { isSystemState bool released bool itemKey uint64 + removing chan struct{} } var _ PoolHelper = (*PooledState)(nil) @@ -75,6 +76,12 @@ func (ps *PooledState) Release() bool { return removed } +// Removing returns a channel that is closed when the PooledState +// should be released by the consumer. +func (ps *PooledState) Removing() <-chan struct{} { + return ps.removing +} + // TODO: implement Close that hides the state.Close for a PooledState? // Annotate writes the supplied context information back to the pool item. @@ -93,6 +100,7 @@ type PoolItem struct { modelUUID string referenceSources map[uint64]string remove bool + removing chan struct{} } func (i *PoolItem) refCount() int { @@ -242,6 +250,7 @@ func (p *StatePool) Get(modelUUID string) (*PooledState, error) { item.referenceSources[key] = source ps := newPooledState(item.state, p, modelUUID, false) ps.itemKey = key + ps.removing = item.removing return ps, nil } @@ -256,15 +265,18 @@ func (p *StatePool) Get(modelUUID string) (*PooledState, error) { if err != nil { return nil, errors.Trace(err) } + removing := make(chan struct{}) p.pool[modelUUID] = &PoolItem{ modelUUID: modelUUID, state: st, referenceSources: map[uint64]string{ key: source, }, + removing: removing, } ps := newPooledState(st, p, modelUUID, false) ps.itemKey = key + ps.removing = removing return ps, nil } @@ -355,7 +367,12 @@ func (p *StatePool) Remove(modelUUID string) (bool, error) { // ignore unknown model uuids. return false, nil } - item.remove = true + if !item.remove { + item.remove = true + if item.removing != nil { + close(item.removing) + } + } return p.maybeRemoveItem(item) }