diff --git a/cmd/multus-daemon/main.go b/cmd/multus-daemon/main.go index 5d4e6cfa1..9aedd4e2f 100644 --- a/cmd/multus-daemon/main.go +++ b/cmd/multus-daemon/main.go @@ -30,7 +30,6 @@ import ( "syscall" "time" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" utilwait "k8s.io/apimachinery/pkg/util/wait" "gopkg.in/k8snetworkplumbingwg/multus-cni.v4/pkg/logging" @@ -207,15 +206,8 @@ func startMultusDaemon(ctx context.Context, daemonConfig *srv.ControllerNetConf) return fmt.Errorf("failed to start the CNI server using socket %s. Reason: %+v", api.SocketPath(daemonConfig.SocketDir), err) } - server.SetKeepAlivesEnabled(false) - go func() { - utilwait.UntilWithContext(ctx, func(ctx context.Context) { - logging.Debugf("open for business") - if err := server.Serve(l); err != nil { - utilruntime.HandleError(fmt.Errorf("CNI server Serve() failed: %v", err)) - } - }, 0) - }() + server.Start(ctx, l) + go func() { <-ctx.Done() server.Shutdown(context.Background()) diff --git a/pkg/server/server.go b/pkg/server/server.go index 4389a4e08..f12382569 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -16,6 +16,7 @@ package server import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -38,6 +39,9 @@ import ( "gopkg.in/k8snetworkplumbingwg/multus-cni.v4/pkg/server/api" "gopkg.in/k8snetworkplumbingwg/multus-cni.v4/pkg/server/config" "gopkg.in/k8snetworkplumbingwg/multus-cni.v4/pkg/types" + + utilruntime "k8s.io/apimachinery/pkg/util/runtime" + utilwait "k8s.io/apimachinery/pkg/util/wait" ) const ( @@ -180,6 +184,8 @@ func newCNIServer(rundir string, kubeClient *k8s.ClientInfo, exec invoke.Exec, s ), }, } + s.SetKeepAlivesEnabled(false) + // register metrics prometheus.MustRegister(s.metrics.requestCounter) @@ -249,6 +255,18 @@ func newCNIServer(rundir string, kubeClient *k8s.ClientInfo, exec invoke.Exec, s return s, nil } +// Start starts the server and begins serving on the given listener +func (s *Server) Start(ctx context.Context, l net.Listener) { + go func() { + utilwait.UntilWithContext(ctx, func(ctx context.Context) { + logging.Debugf("open for business") + if err := s.Serve(l); err != nil { + utilruntime.HandleError(fmt.Errorf("CNI server Serve() failed: %v", err)) + } + }, 0) + }() +} + func (s *Server) handleCNIRequest(r *http.Request) ([]byte, error) { var cr api.Request b, err := io.ReadAll(r.Body) diff --git a/pkg/server/thick_cni_test.go b/pkg/server/thick_cni_test.go index cc5fbc74f..4b0591bc7 100644 --- a/pkg/server/thick_cni_test.go +++ b/pkg/server/thick_cni_test.go @@ -30,8 +30,6 @@ import ( "github.com/prometheus/client_golang/prometheus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - utilwait "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/tools/record" @@ -102,6 +100,8 @@ var _ = Describe(suiteName, func() { cniServer *Server K8sClient *k8s.ClientInfo netns ns.NetNS + ctx context.Context + cancel context.CancelFunc ) BeforeEach(func() { @@ -109,7 +109,9 @@ var _ = Describe(suiteName, func() { K8sClient = fakeK8sClient() Expect(FilesystemPreRequirements(thickPluginRunDir)).To(Succeed()) - cniServer, err = startCNIServer(thickPluginRunDir, K8sClient, nil) + + ctx, cancel = context.WithCancel(context.TODO()) + cniServer, err = startCNIServer(ctx, thickPluginRunDir, K8sClient, nil) Expect(err).NotTo(HaveOccurred()) netns, err = testutils.NewNS() @@ -121,6 +123,7 @@ var _ = Describe(suiteName, func() { }) AfterEach(func() { + cancel() unregisterMetrics(cniServer) Expect(cniServer.Close()).To(Succeed()) Expect(teardownCNIEnv()).To(Succeed()) @@ -151,6 +154,8 @@ var _ = Describe(suiteName, func() { cniServer *Server K8sClient *k8s.ClientInfo netns ns.NetNS + ctx context.Context + cancel context.CancelFunc ) BeforeEach(func() { @@ -163,7 +168,9 @@ var _ = Describe(suiteName, func() { }` Expect(FilesystemPreRequirements(thickPluginRunDir)).To(Succeed()) - cniServer, err = startCNIServer(thickPluginRunDir, K8sClient, []byte(dummyServerConfig)) + + ctx, cancel = context.WithCancel(context.TODO()) + cniServer, err = startCNIServer(ctx, thickPluginRunDir, K8sClient, []byte(dummyServerConfig)) Expect(err).NotTo(HaveOccurred()) netns, err = testutils.NewNS() @@ -175,6 +182,7 @@ var _ = Describe(suiteName, func() { }) AfterEach(func() { + cancel() unregisterMetrics(cniServer) Expect(cniServer.Close()).To(Succeed()) Expect(teardownCNIEnv()).To(Succeed()) @@ -245,7 +253,7 @@ func createFakePod(k8sClient *k8s.ClientInfo, podName string) error { return err } -func startCNIServer(runDir string, k8sClient *k8s.ClientInfo, servConfig []byte) (*Server, error) { +func startCNIServer(ctx context.Context, runDir string, k8sClient *k8s.ClientInfo, servConfig []byte) (*Server, error) { const period = 0 cniServer, err := newCNIServer(runDir, k8sClient, &fakeExec{}, servConfig) @@ -258,12 +266,8 @@ func startCNIServer(runDir string, k8sClient *k8s.ClientInfo, servConfig []byte) return nil, fmt.Errorf("failed to start the CNI server using socket %s. Reason: %+v", api.SocketPath(runDir), err) } - cniServer.SetKeepAlivesEnabled(false) - go utilwait.Forever(func() { - if err := cniServer.Serve(l); err != nil { - utilruntime.HandleError(fmt.Errorf("CNI server Serve() failed: %v", err)) - } - }, period) + cniServer.Start(ctx, l) + return cniServer, nil }