Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent overwriting existing host_uuid file #48012

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ import (
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
"github.com/gravitational/teleport/lib/utils/cert"
"github.com/gravitational/teleport/lib/utils/hostid"
logutils "github.com/gravitational/teleport/lib/utils/log"
vc "github.com/gravitational/teleport/lib/versioncontrol"
"github.com/gravitational/teleport/lib/versioncontrol/endpoint"
Expand Down Expand Up @@ -2934,7 +2935,7 @@ func (process *TeleportProcess) initSSH() error {
storagePresence := local.NewPresenceService(process.storage.BackendStorage)

// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
serverID, err := hostid.ReadOrCreateFile(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -4439,7 +4440,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
}

// read the host UUID:
serverID, err := utils.ReadOrMakeHostUUID(cfg.DataDir)
serverID, err := hostid.ReadOrCreateFile(cfg.DataDir)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -6498,7 +6499,7 @@ func readOrGenerateHostID(ctx context.Context, cfg *servicecfg.Config, kubeBacke
if err := persistHostIDToStorages(ctx, cfg, kubeBackend); err != nil {
return trace.Wrap(err)
}
} else if kubeBackend != nil && utils.HostUUIDExistsLocally(cfg.DataDir) {
} else if kubeBackend != nil && hostid.ExistsLocally(cfg.DataDir) {
// This case is used when loading a Teleport pre-11 agent with storage attached.
// In this case, we have to copy the "host_uuid" from the agent to the secret
// in case storage is removed later.
Expand Down Expand Up @@ -6537,14 +6538,14 @@ func readHostIDFromStorages(ctx context.Context, dataDir string, kubeBackend kub
}
// Even if running in Kubernetes fallback to local storage if `host_uuid` was
// not found in secret.
hostID, err := utils.ReadHostUUID(dataDir)
hostID, err := hostid.ReadFile(dataDir)
return hostID, trace.Wrap(err)
}

// persistHostIDToStorages writes the cfg.HostUUID to local data and to
// Kubernetes Secret if this process is running on a Kubernetes Cluster.
func persistHostIDToStorages(ctx context.Context, cfg *servicecfg.Config, kubeBackend kubernetesBackend) error {
if err := utils.WriteHostUUID(cfg.DataDir, cfg.HostUUID); err != nil {
if err := hostid.WriteFile(cfg.DataDir, cfg.HostUUID); err != nil {
if errors.Is(err, fs.ErrPermission) {
cfg.Logger.ErrorContext(ctx, "Teleport does not have permission to write to the data directory. Ensure that you are running as a user with appropriate permissions.", "data_dir", cfg.DataDir)
}
Expand All @@ -6563,7 +6564,7 @@ func persistHostIDToStorages(ctx context.Context, cfg *servicecfg.Config, kubeBa
// loadHostIDFromKubeSecret reads the host_uuid from the Kubernetes secret with
// the expected key: `/host_uuid`.
func loadHostIDFromKubeSecret(ctx context.Context, kubeBackend kubernetesBackend) (string, error) {
item, err := kubeBackend.Get(ctx, backend.NewKey(utils.HostUUIDFile))
item, err := kubeBackend.Get(ctx, backend.NewKey(hostid.FileName))
if err != nil {
return "", trace.Wrap(err)
}
Expand All @@ -6576,7 +6577,7 @@ func writeHostIDToKubeSecret(ctx context.Context, kubeBackend kubernetesBackend,
_, err := kubeBackend.Put(
ctx,
backend.Item{
Key: backend.NewKey(utils.HostUUIDFile),
Key: backend.NewKey(hostid.FileName),
Value: []byte(id),
},
)
Expand Down
3 changes: 2 additions & 1 deletion lib/service/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import (
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -1167,7 +1168,7 @@ func Test_readOrGenerateHostID(t *testing.T) {
dataDir := t.TempDir()
// write host_uuid file to temp dir.
if len(tt.args.hostIDContent) > 0 {
err := utils.WriteHostUUID(dataDir, tt.args.hostIDContent)
err := hostid.WriteFile(dataDir, tt.args.hostIDContent)
require.NoError(t, err)
}

Expand Down
3 changes: 2 additions & 1 deletion lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
"github.com/gravitational/teleport/lib/sshutils/x11"
"github.com/gravitational/teleport/lib/teleagent"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

var log = logrus.WithFields(logrus.Fields{
Expand Down Expand Up @@ -724,7 +725,7 @@ func New(
options ...ServerOption,
) (*Server, error) {
// read the host UUID:
uuid, err := utils.ReadOrMakeHostUUID(dataDir)
uuid, err := hostid.ReadOrCreateFile(dataDir)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 4 additions & 3 deletions lib/teleterm/services/connectmycomputer/connectmycomputer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/teleterm/clusters"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

type RoleSetup struct {
Expand Down Expand Up @@ -395,7 +396,7 @@ func (n *NodeJoinWait) getNodeNameFromHostUUIDFile(ctx context.Context, cluster
// the file is empty.
//
// Here we need to be able to distinguish between both of those two cases.
out, err := utils.ReadPath(utils.GetHostUUIDPath(dataDir))
out, err := utils.ReadPath(hostid.GetPath(dataDir))
if err != nil {
if trace.IsNotFound(err) {
continue
Expand Down Expand Up @@ -536,7 +537,7 @@ type NodeDelete struct {

// Run grabs the host UUID of an agent from a disk and deletes the node with that name.
func (n *NodeDelete) Run(ctx context.Context, presence Presence, cluster *clusters.Cluster) error {
hostUUID, err := utils.ReadHostUUID(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
hostUUID, err := hostid.ReadFile(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
if trace.IsNotFound(err) {
return nil
}
Expand Down Expand Up @@ -585,7 +586,7 @@ type NodeName struct {

// Get returns the host UUID of the agent from a disk.
func (n *NodeName) Get(cluster *clusters.Cluster) (string, error) {
hostUUID, err := utils.ReadHostUUID(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
hostUUID, err := hostid.ReadFile(getAgentDataDir(n.cfg.AgentsDir, cluster.ProfileName))
return hostUUID, trace.Wrap(err)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/teleterm/api/uri"
"github.com/gravitational/teleport/lib/teleterm/clusters"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

func TestRoleSetupRun_WithNonLocalUser(t *testing.T) {
Expand Down Expand Up @@ -472,7 +472,7 @@ func mustMakeHostUUIDFile(t *testing.T, agentsDir string, profileName string) st
err = os.MkdirAll(dataDir, agentsDirStat.Mode())
require.NoError(t, err)

hostUUID, err := utils.ReadOrMakeHostUUID(dataDir)
hostUUID, err := hostid.ReadOrCreateFile(dataDir)
require.NoError(t, err)

return hostUUID
Expand Down
61 changes: 61 additions & 0 deletions lib/utils/hostid/hostid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package hostid

import (
"errors"
"io/fs"
"path/filepath"
"strings"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/utils"
)

const (
// FileName is the file name where the host UUID file is stored
FileName = "host_uuid"
)

// GetPath returns the path to the host UUID file given the data directory.
func GetPath(dataDir string) string {
return filepath.Join(dataDir, FileName)
}

// ExistsLocally checks if dataDir/host_uuid file exists in local storage.
func ExistsLocally(dataDir string) bool {
_, err := ReadFile(dataDir)
return err == nil
}

// ReadFile reads host UUID from the file in the data dir
func ReadFile(dataDir string) (string, error) {
out, err := utils.ReadPath(GetPath(dataDir))
if err != nil {
if errors.Is(err, fs.ErrPermission) {
//do not convert to system error as this loses the ability to compare that it is a permission error
return "", trace.Wrap(err)
}
return "", trace.ConvertSystemError(err)
}
id := strings.TrimSpace(string(out))
if id == "" {
return "", trace.NotFound("host uuid is empty")
}
return id, nil
}
113 changes: 113 additions & 0 deletions lib/utils/hostid/hostid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//go:build !windows

// Teleport
// Copyright (C) 2024 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package hostid_test

import (
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"

"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/hostid"
)

func TestMain(m *testing.M) {
utils.InitLoggerForTests()
os.Exit(m.Run())
}

func TestReadOrCreate(t *testing.T) {
t.Parallel()

dir := t.TempDir()

var wg errgroup.Group
concurrency := 10
ids := make([]string, concurrency)
barrier := make(chan struct{})

for i := 0; i < concurrency; i++ {
wg.Go(func() error {
<-barrier
id, err := hostid.ReadOrCreateFile(dir)
ids[i] = id
return err
})
}

close(barrier)

require.NoError(t, wg.Wait())
require.Equal(t, slices.Repeat([]string{ids[0]}, concurrency), ids)
}

func TestIdempotence(t *testing.T) {
t.Parallel()

// call twice, get same result
dir := t.TempDir()
id, err := hostid.ReadOrCreateFile(dir)
require.Len(t, id, 36)
require.NoError(t, err)
uuidCopy, err := hostid.ReadOrCreateFile(dir)
require.NoError(t, err)
require.Equal(t, id, uuidCopy)
}

func TestBadLocation(t *testing.T) {
t.Parallel()

// call with a read-only dir, make sure to get an error
id, err := hostid.ReadOrCreateFile("/bad-location")
require.Empty(t, id)
require.Error(t, err)
require.Regexp(t, "^.*no such file or directory.*$", err.Error())
}

func TestIgnoreWhitespace(t *testing.T) {
t.Parallel()

// newlines are getting ignored
dir := t.TempDir()
id := fmt.Sprintf("%s\n", uuid.NewString())
err := os.WriteFile(filepath.Join(dir, hostid.FileName), []byte(id), 0666)
require.NoError(t, err)
out, err := hostid.ReadFile(dir)
require.NoError(t, err)
require.Equal(t, strings.TrimSpace(id), out)
}

func TestRegenerateEmpty(t *testing.T) {
t.Parallel()

// empty UUID in file is regenerated
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, hostid.FileName), nil, 0666)
require.NoError(t, err)
out, err := hostid.ReadOrCreateFile(dir)
require.NoError(t, err)
require.Len(t, out, 36)
}
Loading
Loading