Skip to content

Commit

Permalink
fix: close rpc conn when the pooled state is going
Browse files Browse the repository at this point in the history
  • Loading branch information
hpidcock committed Oct 14, 2024
1 parent 7322356 commit 06859f7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 6 deletions.
3 changes: 3 additions & 0 deletions apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1132,9 +1132,11 @@ func (srv *Server) serveConn(
h *apiHandler
)

var stateClosing <-chan struct{}
st, err := statePool.Get(resolvedModelUUID)
if err == nil {
defer st.Release()
stateClosing = st.Removing()
h, err = newAPIHandler(srv, st.State, conn, modelUUID, connectionID, host)
}
if errors.Is(err, errors.NotFound) {
Expand All @@ -1158,6 +1160,7 @@ func (srv *Server) serveConn(
select {
case <-conn.Dead():
case <-srv.tomb.Dying():
case <-stateClosing:
}
return conn.Close()
}
Expand Down
48 changes: 48 additions & 0 deletions apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/juju/juju/rpc/params"
"github.com/juju/juju/state"
statetesting "github.com/juju/juju/state/testing"
"github.com/juju/juju/storage"
"github.com/juju/juju/testing"
"github.com/juju/juju/worker/gate"
"github.com/juju/juju/worker/modelcache"
Expand Down Expand Up @@ -496,3 +497,50 @@ func (s *apiserverSuite) assertEmbeddedCommand(c *gc.C, cmdArgs params.CLIComman
Error: resultErr,
})
}

// TestModelRemoveClosesRPC tests that when an RPC connection is opened
// to a model that is being removed, the connection is closed
// gracefully.
func (s *apiserverSuite) TestModelRemoveClosesRPC(c *gc.C) {
uuid, err := utils.NewUUID()
c.Assert(err, jc.ErrorIsNil)
modelConfig := testing.CustomModelConfig(c, testing.Attrs{
"name": "testing",
"uuid": uuid.String(),
})

model, st, err := s.Controller.NewModel(state.ModelArgs{
Type: state.ModelTypeIAAS,
CloudName: "dummy",
CloudRegion: "dummy-region",
Config: modelConfig,
Owner: s.Owner,
StorageProviderRegistry: storage.StaticProviderRegistry{},
})
c.Assert(err, jc.ErrorIsNil)
s.AddCleanup(func(c *gc.C) {
st.Close()
})

apiInfo := s.APIInfo(s.apiServer)
apiInfo.Tag = s.Owner
apiInfo.Password = ownerPassword
apiInfo.Nonce = ""
apiInfo.ModelTag = model.ModelTag()

conn, err := api.Open(apiInfo, api.DialOpts{})
c.Assert(err, jc.ErrorIsNil)
c.Assert(conn, gc.NotNil)
s.AddCleanup(func(c *gc.C) {
conn.Close()
})

removed, err := s.StatePool.Remove(model.UUID())
c.Assert(err, jc.ErrorIsNil)
c.Assert(removed, jc.IsFalse)

time.Sleep(testing.ShortWait)

err = conn.APICall("Pinger", 1, "", "Ping", nil, nil)
c.Assert(err, gc.ErrorMatches, `connection is shut down`)
}
3 changes: 2 additions & 1 deletion apiserver/debuglog.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type debugLogHandlerFunc func(
debugLogParams,
debugLogSocket,
<-chan struct{},
<-chan struct{},
) error

func newDebugLogHandler(
Expand Down Expand Up @@ -119,7 +120,7 @@ func (h *debugLogHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
clock := h.ctxt.srv.clock
maxDuration := h.ctxt.srv.shared.maxDebugLogDuration()

if err := h.handle(clock, maxDuration, st, params, socket, h.ctxt.stop()); err != nil {
if err := h.handle(clock, maxDuration, st, params, socket, h.ctxt.stop(), st.Removing()); err != nil {
if isBrokenPipe(err) {
logger.Tracef("debug-log handler stopped (client disconnected)")
} else {
Expand Down
3 changes: 3 additions & 0 deletions apiserver/debuglog_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ func handleDebugLogDBRequest(
reqParams debugLogParams,
socket debugLogSocket,
stop <-chan struct{},
stateClosing <-chan struct{},
) error {
tailerParams := makeLogTailerParams(reqParams)
tailer, err := newLogTailer(st, tailerParams)
Expand All @@ -47,6 +48,8 @@ func handleDebugLogDBRequest(
var lineCount uint
for {
select {
case <-stateClosing:
return nil
case <-stop:
return nil
case <-timeout:
Expand Down
8 changes: 4 additions & 4 deletions apiserver/debuglog_db_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *debugLogDBIntSuite) TestParamConversion(c *gc.C) {

stop := make(chan struct{})
close(stop) // Stop the request immediately.
err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, stop)
err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, stop, nil)
c.Assert(err, jc.ErrorIsNil)
c.Assert(called, jc.IsTrue)
}
Expand All @@ -96,7 +96,7 @@ func (s *debugLogDBIntSuite) TestParamConversionReplay(c *gc.C) {

stop := make(chan struct{})
close(stop) // Stop the request immediately.
err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, stop)
err := handleDebugLogDBRequest(s.clock, s.timeout, nil, reqParams, s.sock, nil, stop)
c.Assert(err, jc.ErrorIsNil)
c.Assert(called, jc.IsTrue)
}
Expand Down Expand Up @@ -187,7 +187,7 @@ func (s *debugLogDBIntSuite) TestRequestStopsWhenTailerStops(c *gc.C) {
return tailer, nil
})

err := handleDebugLogDBRequest(s.clock, s.timeout, nil, debugLogParams{}, s.sock, nil)
err := handleDebugLogDBRequest(s.clock, s.timeout, nil, debugLogParams{}, s.sock, nil, nil)
c.Assert(err, jc.ErrorIsNil)
c.Assert(tailer.stopped, jc.IsTrue)
}
Expand Down Expand Up @@ -225,7 +225,7 @@ func (s *debugLogDBIntSuite) TestMaxLines(c *gc.C) {
func (s *debugLogDBIntSuite) runRequest(params debugLogParams, stop chan struct{}) chan error {
done := make(chan error)
go func() {
done <- handleDebugLogDBRequest(s.clock, s.timeout, &fakeState{}, params, s.sock, stop)
done <- handleDebugLogDBRequest(s.clock, s.timeout, &fakeState{}, params, s.sock, stop, nil)
}()
return done
}
Expand Down
19 changes: 18 additions & 1 deletion state/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type PooledState struct {
isSystemState bool
released bool
itemKey uint64
removing chan struct{}
}

var _ PoolHelper = (*PooledState)(nil)
Expand Down Expand Up @@ -75,6 +76,12 @@ func (ps *PooledState) Release() bool {
return removed
}

// Removing returns a channel that is closed when the PooledState
// should be released by the consumer.
func (ps *PooledState) Removing() <-chan struct{} {
return ps.removing
}

// TODO: implement Close that hides the state.Close for a PooledState?

// Annotate writes the supplied context information back to the pool item.
Expand All @@ -93,6 +100,7 @@ type PoolItem struct {
modelUUID string
referenceSources map[uint64]string
remove bool
removing chan struct{}
}

func (i *PoolItem) refCount() int {
Expand Down Expand Up @@ -242,6 +250,7 @@ func (p *StatePool) Get(modelUUID string) (*PooledState, error) {
item.referenceSources[key] = source
ps := newPooledState(item.state, p, modelUUID, false)
ps.itemKey = key
ps.removing = item.removing
return ps, nil
}

Expand All @@ -256,15 +265,18 @@ func (p *StatePool) Get(modelUUID string) (*PooledState, error) {
if err != nil {
return nil, errors.Trace(err)
}
removing := make(chan struct{})
p.pool[modelUUID] = &PoolItem{
modelUUID: modelUUID,
state: st,
referenceSources: map[uint64]string{
key: source,
},
removing: removing,
}
ps := newPooledState(st, p, modelUUID, false)
ps.itemKey = key
ps.removing = removing
return ps, nil
}

Expand Down Expand Up @@ -355,7 +367,12 @@ func (p *StatePool) Remove(modelUUID string) (bool, error) {
// ignore unknown model uuids.
return false, nil
}
item.remove = true
if !item.remove {
item.remove = true
if item.removing != nil {
close(item.removing)
}
}
return p.maybeRemoveItem(item)
}

Expand Down

0 comments on commit 06859f7

Please sign in to comment.