diff --git a/internal/server/listener.go b/internal/server/listener.go index 3d6d61c..e46c895 100644 --- a/internal/server/listener.go +++ b/internal/server/listener.go @@ -13,11 +13,12 @@ import ( ) func Serve(ctx context.Context, server *server.Server) error { - listener, ipOrPath, err := ensureListener() + listener, ipOrPath, err := ensureListener(ctx) if err != nil { return err } if listener != nil { + defer listener.Close() return serveSocket(ctx, ipOrPath, listener, server) } return server.ListenAndServe(ctx, config.Steve.HTTPSListenPort, config.Steve.HTTPListenPort, &dynamicserver.ListenOpts{ @@ -45,7 +46,6 @@ func serveSocket(ctx context.Context, socketPath string, listener net.Listener, go func() { <-ctx.Done() _ = socketServer.Shutdown(context.Background()) - _ = listener.Close() }() <-ctx.Done() return ctx.Err() diff --git a/internal/server/listener_unix.go b/internal/server/listener_unix.go index d1df629..655a4d6 100644 --- a/internal/server/listener_unix.go +++ b/internal/server/listener_unix.go @@ -4,14 +4,44 @@ package server import ( + "context" + "errors" "fmt" "net" "net/url" + "os" + "path/filepath" + "strings" + "syscall" "github.com/cnrancher/kube-explorer/internal/config" + "github.com/sirupsen/logrus" ) -func ensureListener() (net.Listener, string, error) { +var _ net.Listener = &closerListener{} + +type closerListener struct { + listener net.Listener + lockFile *os.File +} + +func (l *closerListener) Accept() (net.Conn, error) { + return l.listener.Accept() +} + +func (l *closerListener) Close() error { + return errors.Join( + l.listener.Close(), + l.lockFile.Close(), + os.RemoveAll(l.lockFile.Name()), + ) +} + +func (l *closerListener) Addr() net.Addr { + return l.listener.Addr() +} + +func ensureListener(ctx context.Context) (net.Listener, string, error) { if config.BindAddress == "" { return nil, "", nil } @@ -25,9 +55,45 @@ func ensureListener() (net.Listener, string, error) { case "tcp": return nil, u.Host, nil case "unix": - listener, err := net.Listen("unix", u.Path) + listener, err := createCloserListener(ctx, u.Path) + if err != nil { + return nil, "", err + } return listener, u.Path, err default: return nil, "", fmt.Errorf("Unsupported scheme %s, only tcp and unix are supported in UNIX OS", u.Scheme) } } + +func createCloserListener(ctx context.Context, socketPath string) (net.Listener, error) { + lockFilePath := getLockFileName(socketPath) + lockFile, err := os.OpenFile(lockFilePath, os.O_RDONLY|os.O_CREATE, 0600) + if err != nil { + return nil, err + } + + lockErr := syscall.Flock(int(lockFile.Fd()), syscall.LOCK_EX|syscall.LOCK_NB) + if lockErr != nil { + return nil, fmt.Errorf("Socket file %s is in use, exiting", socketPath) + } + + if _, err := os.Stat(socketPath); err == nil { + logrus.Infof("Removing stale socket file %s", socketPath) + _ = os.Remove(socketPath) + } + + var lc net.ListenConfig + listener, err := lc.Listen(ctx, "unix", socketPath) + if err != nil { + return nil, err + } + + return &closerListener{ + listener: listener, + lockFile: lockFile, + }, nil +} + +func getLockFileName(socketPath string) string { + return strings.TrimSuffix(socketPath, filepath.Ext(socketPath)) + ".lock" +} diff --git a/internal/server/listener_windows.go b/internal/server/listener_windows.go index 6da3637..8b34c61 100644 --- a/internal/server/listener_windows.go +++ b/internal/server/listener_windows.go @@ -4,6 +4,7 @@ package server import ( + "context" "fmt" "net" "net/url" @@ -12,7 +13,7 @@ import ( "github.com/cnrancher/kube-explorer/internal/config" ) -func ensureListener() (net.Listener, string, error) { +func ensureListener(_ context.Context) (net.Listener, string, error) { if config.BindAddress == "" { return nil, "", nil }