diff --git a/client/allocrunner/csi_hook_test.go b/client/allocrunner/csi_hook_test.go index a0fd5aea750..83d28be7cda 100644 --- a/client/allocrunner/csi_hook_test.go +++ b/client/allocrunner/csi_hook_test.go @@ -4,29 +4,24 @@ package allocrunner import ( - "context" "errors" "fmt" - "path/filepath" - "sync" "testing" "time" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/client/allocrunner/interfaces" "github.com/hashicorp/nomad/client/allocrunner/state" - "github.com/hashicorp/nomad/client/pluginmanager" "github.com/hashicorp/nomad/client/pluginmanager/csimanager" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/helper/pointer" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" - "github.com/hashicorp/nomad/plugins/csi" "github.com/hashicorp/nomad/plugins/drivers" + "github.com/hashicorp/nomad/testutil" "github.com/shoenig/test/must" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" ) var _ interfaces.RunnerPrerunHook = (*csiHook)(nil) @@ -71,7 +66,7 @@ func TestCSIHook(t *testing.T) { "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, expectedCalls: map[string]int{ - "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, + "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1}, }, { @@ -92,7 +87,7 @@ func TestCSIHook(t *testing.T) { "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, expectedCalls: map[string]int{ - "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, + "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1}, }, { @@ -137,7 +132,7 @@ func TestCSIHook(t *testing.T) { "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, expectedCalls: map[string]int{ - "claim": 2, "mount": 1, "unmount": 1, "unpublish": 1}, + "claim": 2, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1}, }, { name: "already mounted", @@ -163,7 +158,7 @@ func TestCSIHook(t *testing.T) { expectedMounts: map[string]*csimanager.MountInfo{ "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, - expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1}, + expectedCalls: map[string]int{"HasMount": 1, "UnmountVolume": 1, "unpublish": 1}, }, { name: "existing but invalid mounts", @@ -190,7 +185,7 @@ func TestCSIHook(t *testing.T) { "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, expectedCalls: map[string]int{ - "hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1}, + "HasMount": 1, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1}, }, { @@ -212,7 +207,7 @@ func TestCSIHook(t *testing.T) { "vol0": &csimanager.MountInfo{Source: testMountSrc}, }, expectedCalls: map[string]int{ - "claim": 1, "mount": 1, "unmount": 2, "unpublish": 2}, + "claim": 1, "MountVolume": 1, "UnmountVolume": 2, "unpublish": 2}, }, { @@ -227,12 +222,11 @@ func TestCSIHook(t *testing.T) { alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests - callCounts := &callCounter{counts: map[string]int{}} - mgr := mockPluginManager{mounter: mockVolumeManager{ - hasMounts: tc.startsWithValidMounts, - callCounts: callCounts, - failsFirstUnmount: pointer.Of(tc.failsFirstUnmount), - }} + callCounts := testutil.NewCallCounter() + vm := &csimanager.MockVolumeManager{ + CallCounter: callCounts, + } + mgr := &csimanager.MockCSIManager{VM: vm} rpcer := mockRPCer{ alloc: alloc, callCounts: callCounts, @@ -255,6 +249,17 @@ func TestCSIHook(t *testing.T) { must.NotNil(t, hook) + if tc.startsWithValidMounts { + // TODO: this works, but it requires knowledge of how the mock works. would rather vm.MountVolume() + vm.Mounts = map[string]bool{ + tc.expectedMounts["vol0"].Source: true, + } + } + + if tc.failsFirstUnmount { + vm.NextUnmountVolumeErr = errors.New("bad first attempt") + } + if tc.expectedClaimErr != nil { must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error()) mounts := ar.res.GetCSIMounts() @@ -274,7 +279,7 @@ func TestCSIHook(t *testing.T) { time.Sleep(100 * time.Millisecond) } - counts := callCounts.get() + counts := callCounts.Get() must.MapEq(t, tc.expectedCalls, counts, must.Sprintf("got calls: %v", counts)) @@ -342,14 +347,12 @@ func TestCSIHook_Prerun_Validation(t *testing.T) { t.Run(tc.name, func(t *testing.T) { alloc.Job.TaskGroups[0].Volumes = volumeRequests - callCounts := &callCounter{counts: map[string]int{}} - mgr := mockPluginManager{mounter: mockVolumeManager{ - callCounts: callCounts, - failsFirstUnmount: pointer.Of(false), - }} + mgr := &csimanager.MockCSIManager{ + VM: &csimanager.MockVolumeManager{}, + } rpcer := mockRPCer{ alloc: alloc, - callCounts: callCounts, + callCounts: testutil.NewCallCounter(), hasExistingClaim: pointer.Of(false), schedulable: pointer.Of(true), } @@ -379,26 +382,9 @@ func TestCSIHook_Prerun_Validation(t *testing.T) { // HELPERS AND MOCKS -type callCounter struct { - lock sync.Mutex - counts map[string]int -} - -func (c *callCounter) inc(name string) { - c.lock.Lock() - defer c.lock.Unlock() - c.counts[name]++ -} - -func (c *callCounter) get() map[string]int { - c.lock.Lock() - defer c.lock.Unlock() - return maps.Clone(c.counts) -} - type mockRPCer struct { alloc *structs.Allocation - callCounts *callCounter + callCounts *testutil.CallCounter hasExistingClaim *bool schedulable *bool } @@ -407,7 +393,7 @@ type mockRPCer struct { func (r mockRPCer) RPC(method string, args any, reply any) error { switch method { case "CSIVolume.Claim": - r.callCounts.inc("claim") + r.callCounts.Inc("claim") req := args.(*structs.CSIVolumeClaimRequest) vol := r.testVolume(req.VolumeID) err := vol.Claim(req.ToClaim(), r.alloc) @@ -427,7 +413,7 @@ func (r mockRPCer) RPC(method string, args any, reply any) error { resp.QueryMeta = structs.QueryMeta{} case "CSIVolume.Unpublish": - r.callCounts.inc("unpublish") + r.callCounts.Inc("unpublish") resp := reply.(*structs.CSIVolumeUnpublishResponse) resp.QueryMeta = structs.QueryMeta{} @@ -470,59 +456,6 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume { return vol } -type mockVolumeManager struct { - hasMounts bool - failsFirstUnmount *bool - callCounts *callCounter -} - -func (vm mockVolumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) { - vm.callCounts.inc("mount") - return &csimanager.MountInfo{ - Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()), - }, nil -} - -func (vm mockVolumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error { - vm.callCounts.inc("unmount") - - if *vm.failsFirstUnmount { - *vm.failsFirstUnmount = false - return fmt.Errorf("could not unmount") - } - - return nil -} - -func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) { - vm.callCounts.inc("hasMount") - return mountInfo != nil && vm.hasMounts, nil -} - -func (vm mockVolumeManager) ExpandVolume(_ context.Context, _, _, _ string, _ *csimanager.UsageOptions, _ *csi.CapacityRange) (int64, error) { - return 0, nil -} - -func (vm mockVolumeManager) ExternalID() string { - return "i-example" -} - -type mockPluginManager struct { - mounter mockVolumeManager -} - -func (mgr mockPluginManager) WaitForPlugin(ctx context.Context, pluginType, pluginID string) error { - return nil -} - -func (mgr mockPluginManager) ManagerForPlugin(ctx context.Context, pluginID string) (csimanager.VolumeManager, error) { - return mgr.mounter, nil -} - -// no-op methods to fulfill the interface -func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { return nil } -func (mgr mockPluginManager) Shutdown() {} - type mockAllocRunner struct { res *cstructs.AllocHookResources caps *drivers.Capabilities diff --git a/client/csi_endpoint_test.go b/client/csi_endpoint_test.go index 072a427dacc..e9b2cc267ad 100644 --- a/client/csi_endpoint_test.go +++ b/client/csi_endpoint_test.go @@ -994,24 +994,22 @@ func TestCSINode_DetachVolume(t *testing.T) { ci.Parallel(t) cases := []struct { - Name string - ClientSetupFunc func(*fake.Client) - Request *structs.ClientCSINodeDetachVolumeRequest - ExpectedErr error - ExpectedResponse *structs.ClientCSINodeDetachVolumeResponse + Name string + ModManager func(m *csimanager.MockCSIManager) + Request *structs.ClientCSINodeDetachVolumeRequest + ExpectedErr error }{ { - Name: "returns plugin not found errors", + Name: "success", Request: &structs.ClientCSINodeDetachVolumeRequest{ - PluginID: "some-garbage", - VolumeID: "-", - AllocID: "-", - NodeID: "-", + PluginID: "fake-plugin", + VolumeID: "fake-vol", + AllocID: "fake-alloc", + NodeID: "fake-node", AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem, AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader, ReadOnly: true, }, - ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin some-garbage for type csi-node not found"), }, { Name: "validates volumeid is not empty", @@ -1029,43 +1027,51 @@ func TestCSINode_DetachVolume(t *testing.T) { ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"), }, { - Name: "returns transitive errors", - ClientSetupFunc: func(fc *fake.Client) { - fc.NextNodeUnpublishVolumeErr = errors.New("wont-see-this") + Name: "returns csi manager errors", + ModManager: func(m *csimanager.MockCSIManager) { + m.NextManagerForPluginErr = errors.New("no plugin") + }, + Request: &structs.ClientCSINodeDetachVolumeRequest{ + PluginID: fakeNodePlugin.Name, + VolumeID: "1234-4321-1234-4321", + AllocID: "4321-1234-4321-1234", + }, + ExpectedErr: errors.New("CSI.NodeDetachVolume: no plugin"), + }, + { + Name: "returns volume manager errors", + ModManager: func(m *csimanager.MockCSIManager) { + m.VM.NextUnmountVolumeErr = errors.New("error unmounting") }, Request: &structs.ClientCSINodeDetachVolumeRequest{ PluginID: fakeNodePlugin.Name, VolumeID: "1234-4321-1234-4321", AllocID: "4321-1234-4321-1234", }, - // we don't have a csimanager in this context - ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin test-plugin for type csi-node not found"), + ExpectedErr: errors.New("CSI.NodeDetachVolume: error unmounting"), }, } for _, tc := range cases { t.Run(tc.Name, func(t *testing.T) { - require := require.New(t) client, cleanup := TestClient(t, nil) defer cleanup() - fakeClient := &fake.Client{} - if tc.ClientSetupFunc != nil { - tc.ClientSetupFunc(fakeClient) + mockManager := &csimanager.MockCSIManager{ + VM: &csimanager.MockVolumeManager{}, } - - dispenserFunc := func(*dynamicplugins.PluginInfo) (interface{}, error) { - return fakeClient, nil + if tc.ModManager != nil { + tc.ModManager(mockManager) } - client.dynamicRegistry.StubDispenserForType(dynamicplugins.PluginTypeCSINode, dispenserFunc) - err := client.dynamicRegistry.RegisterPlugin(fakeNodePlugin) - require.Nil(err) + client.csimanager = mockManager var resp structs.ClientCSINodeDetachVolumeResponse - err = client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp) - require.Equal(tc.ExpectedErr, err) - if tc.ExpectedResponse != nil { - require.Equal(tc.ExpectedResponse, &resp) + err := client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp) + if tc.ExpectedErr != nil { + must.Error(t, err) + must.EqError(t, tc.ExpectedErr, err.Error()) + } else { + must.NoError(t, err) } }) } diff --git a/client/pluginmanager/csimanager/testing.go b/client/pluginmanager/csimanager/testing.go index f27f742655c..88a5055a5e1 100644 --- a/client/pluginmanager/csimanager/testing.go +++ b/client/pluginmanager/csimanager/testing.go @@ -5,10 +5,12 @@ package csimanager import ( "context" + "path/filepath" "github.com/hashicorp/nomad/client/pluginmanager" nstructs "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/csi" + "github.com/hashicorp/nomad/testutil" ) var _ Manager = &MockCSIManager{} @@ -29,6 +31,9 @@ func (m *MockCSIManager) WaitForPlugin(_ context.Context, pluginType, pluginID s } func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) { + if m.VM == nil { + m.VM = &MockVolumeManager{} + } return m.VM, m.NextManagerForPluginErr } @@ -39,20 +44,68 @@ func (m *MockCSIManager) Shutdown() { var _ VolumeManager = &MockVolumeManager{} type MockVolumeManager struct { + CallCounter *testutil.CallCounter + + Mounts map[string]bool // lazy set + + NextMountVolumeErr error + NextUnmountVolumeErr error + NextExpandVolumeErr error LastExpandVolumeCall *MockExpandVolumeCall } +func (m *MockVolumeManager) mountName(volID, allocID string, usageOpts *UsageOptions) string { + return filepath.Join("test-alloc-dir", allocID, volID, usageOpts.ToFS()) +} + func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) { - panic("implement me") + if m.CallCounter != nil { + m.CallCounter.Inc("MountVolume") + } + + if m.NextMountVolumeErr != nil { + err := m.NextMountVolumeErr + m.NextMountVolumeErr = nil // reset it + return nil, err + } + + // "mount" it + if m.Mounts == nil { + m.Mounts = make(map[string]bool) + } + source := m.mountName(vol.ID, alloc.ID, usageOpts) + m.Mounts[source] = true + + return &MountInfo{ + Source: source, + }, nil } func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error { - panic("implement me") + if m.CallCounter != nil { + m.CallCounter.Inc("UnmountVolume") + } + + if m.NextUnmountVolumeErr != nil { + err := m.NextUnmountVolumeErr + m.NextUnmountVolumeErr = nil // reset it + return err + } + + // "unmount" it + delete(m.Mounts, m.mountName(volID, allocID, usageOpts)) + return nil } func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) { - panic("implement me") + if m.CallCounter != nil { + m.CallCounter.Inc("HasMount") + } + if m.Mounts == nil { + return false, nil + } + return m.Mounts[mountInfo.Source], nil } func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) { diff --git a/testutil/mock_calls.go b/testutil/mock_calls.go new file mode 100644 index 00000000000..5b37832bc69 --- /dev/null +++ b/testutil/mock_calls.go @@ -0,0 +1,42 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package testutil + +import ( + "maps" + "sync" + + "github.com/mitchellh/go-testing-interface" +) + +func NewCallCounter() *CallCounter { + return &CallCounter{ + counts: make(map[string]int), + } +} + +type CallCounter struct { + lock sync.Mutex + counts map[string]int +} + +func (c *CallCounter) Inc(name string) { + c.lock.Lock() + defer c.lock.Unlock() + c.counts[name]++ +} + +func (c *CallCounter) Get() map[string]int { + c.lock.Lock() + defer c.lock.Unlock() + return maps.Clone(c.counts) +} + +func (c *CallCounter) AssertCalled(t testing.T, name string) { + t.Helper() + counts := c.Get() + if _, ok := counts[name]; !ok { + t.Errorf("'%s' not called; all counts: %v", counts) + } +}