From 66c8cbbe1731df021058a0fa59b791e5c542d517 Mon Sep 17 00:00:00 2001
From: Daniel Bennett <dbennett@hashicorp.com>
Date: Mon, 25 Sep 2023 10:30:53 -0500
Subject: [PATCH] backport test: Refactor mock CSI manager (#18554) (#18561)

and MockCSIManager to support the call counting
that csi_hook_test expects

instead of implementing csimanager
interfaces in two separate places:
* client/allocrunner/csi_hook_test
* client/csi_endpoint_test

they can both use the same mocks defined in
client/pluginmanager/csimanager/
alongside the actual implementations of them.

also refactor TestCSINode_DetachVolume
to use use it like Node_ExpandVolume
so we can also test the happy path there
---
 client/allocrunner/csi_hook_test.go        | 124 +++++---------------
 client/csi_endpoint_test.go                |  68 ++++++-----
 client/pluginmanager/csimanager/testing.go | 126 +++++++++++++++++++++
 testutil/mock_calls.go                     |  42 +++++++
 4 files changed, 237 insertions(+), 123 deletions(-)
 create mode 100644 client/pluginmanager/csimanager/testing.go
 create mode 100644 testutil/mock_calls.go

diff --git a/client/allocrunner/csi_hook_test.go b/client/allocrunner/csi_hook_test.go
index ea2c02441ce..484039d9883 100644
--- a/client/allocrunner/csi_hook_test.go
+++ b/client/allocrunner/csi_hook_test.go
@@ -1,18 +1,14 @@
 package allocrunner
 
 import (
-	"context"
 	"errors"
 	"fmt"
-	"path/filepath"
-	"sync"
 	"testing"
 	"time"
 
 	"github.com/hashicorp/nomad/ci"
 	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
 	"github.com/hashicorp/nomad/client/allocrunner/state"
-	"github.com/hashicorp/nomad/client/pluginmanager"
 	"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
 	cstructs "github.com/hashicorp/nomad/client/structs"
 	"github.com/hashicorp/nomad/helper/pointer"
@@ -20,9 +16,9 @@ import (
 	"github.com/hashicorp/nomad/nomad/mock"
 	"github.com/hashicorp/nomad/nomad/structs"
 	"github.com/hashicorp/nomad/plugins/drivers"
+	"github.com/hashicorp/nomad/testutil"
 	"github.com/shoenig/test/must"
 	"github.com/stretchr/testify/require"
-	"golang.org/x/exp/maps"
 )
 
 var _ interfaces.RunnerPrerunHook = (*csiHook)(nil)
@@ -67,7 +63,7 @@ func TestCSIHook(t *testing.T) {
 				"vol0": &csimanager.MountInfo{Source: testMountSrc},
 			},
 			expectedCalls: map[string]int{
-				"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
+				"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
 		},
 
 		{
@@ -88,7 +84,7 @@ func TestCSIHook(t *testing.T) {
 				"vol0": &csimanager.MountInfo{Source: testMountSrc},
 			},
 			expectedCalls: map[string]int{
-				"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
+				"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
 		},
 
 		{
@@ -133,7 +129,7 @@ func TestCSIHook(t *testing.T) {
 				"vol0": &csimanager.MountInfo{Source: testMountSrc},
 			},
 			expectedCalls: map[string]int{
-				"claim": 2, "mount": 1, "unmount": 1, "unpublish": 1},
+				"claim": 2, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
 		},
 		{
 			name: "already mounted",
@@ -159,7 +155,7 @@ func TestCSIHook(t *testing.T) {
 			expectedMounts: map[string]*csimanager.MountInfo{
 				"vol0": &csimanager.MountInfo{Source: testMountSrc},
 			},
-			expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1},
+			expectedCalls: map[string]int{"HasMount": 1, "UnmountVolume": 1, "unpublish": 1},
 		},
 		{
 			name: "existing but invalid mounts",
@@ -186,7 +182,7 @@ func TestCSIHook(t *testing.T) {
 				"vol0": &csimanager.MountInfo{Source: testMountSrc},
 			},
 			expectedCalls: map[string]int{
-				"hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
+				"HasMount": 1, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
 		},
 
 		{
@@ -208,7 +204,7 @@ func TestCSIHook(t *testing.T) {
 				"vol0": &csimanager.MountInfo{Source: testMountSrc},
 			},
 			expectedCalls: map[string]int{
-				"claim": 1, "mount": 1, "unmount": 2, "unpublish": 2},
+				"claim": 1, "MountVolume": 1, "UnmountVolume": 2, "unpublish": 2},
 		},
 
 		{
@@ -223,12 +219,11 @@ func TestCSIHook(t *testing.T) {
 
 			alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests
 
-			callCounts := &callCounter{counts: map[string]int{}}
-			mgr := mockPluginManager{mounter: mockVolumeManager{
-				hasMounts:         tc.startsWithValidMounts,
-				callCounts:        callCounts,
-				failsFirstUnmount: pointer.Of(tc.failsFirstUnmount),
-			}}
+			callCounts := testutil.NewCallCounter()
+			vm := &csimanager.MockVolumeManager{
+				CallCounter: callCounts,
+			}
+			mgr := &csimanager.MockCSIManager{VM: vm}
 			rpcer := mockRPCer{
 				alloc:            alloc,
 				callCounts:       callCounts,
@@ -251,6 +246,17 @@ func TestCSIHook(t *testing.T) {
 
 			must.NotNil(t, hook)
 
+			if tc.startsWithValidMounts {
+				// TODO: this works, but it requires knowledge of how the mock works.  would rather vm.MountVolume()
+				vm.Mounts = map[string]bool{
+					tc.expectedMounts["vol0"].Source: true,
+				}
+			}
+
+			if tc.failsFirstUnmount {
+				vm.NextUnmountVolumeErr = errors.New("bad first attempt")
+			}
+
 			if tc.expectedClaimErr != nil {
 				must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error())
 				mounts := ar.res.GetCSIMounts()
@@ -270,7 +276,7 @@ func TestCSIHook(t *testing.T) {
 				time.Sleep(100 * time.Millisecond)
 			}
 
-			counts := callCounts.get()
+			counts := callCounts.Get()
 			must.MapEq(t, tc.expectedCalls, counts,
 				must.Sprintf("got calls: %v", counts))
 
@@ -338,14 +344,12 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
 		t.Run(tc.name, func(t *testing.T) {
 			alloc.Job.TaskGroups[0].Volumes = volumeRequests
 
-			callCounts := &callCounter{counts: map[string]int{}}
-			mgr := mockPluginManager{mounter: mockVolumeManager{
-				callCounts:        callCounts,
-				failsFirstUnmount: pointer.Of(false),
-			}}
+			mgr := &csimanager.MockCSIManager{
+				VM: &csimanager.MockVolumeManager{},
+			}
 			rpcer := mockRPCer{
 				alloc:            alloc,
-				callCounts:       callCounts,
+				callCounts:       testutil.NewCallCounter(),
 				hasExistingClaim: pointer.Of(false),
 				schedulable:      pointer.Of(true),
 			}
@@ -375,26 +379,9 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
 
 // HELPERS AND MOCKS
 
-type callCounter struct {
-	lock   sync.Mutex
-	counts map[string]int
-}
-
-func (c *callCounter) inc(name string) {
-	c.lock.Lock()
-	defer c.lock.Unlock()
-	c.counts[name]++
-}
-
-func (c *callCounter) get() map[string]int {
-	c.lock.Lock()
-	defer c.lock.Unlock()
-	return maps.Clone(c.counts)
-}
-
 type mockRPCer struct {
 	alloc            *structs.Allocation
-	callCounts       *callCounter
+	callCounts       *testutil.CallCounter
 	hasExistingClaim *bool
 	schedulable      *bool
 }
@@ -403,7 +390,7 @@ type mockRPCer struct {
 func (r mockRPCer) RPC(method string, args any, reply any) error {
 	switch method {
 	case "CSIVolume.Claim":
-		r.callCounts.inc("claim")
+		r.callCounts.Inc("claim")
 		req := args.(*structs.CSIVolumeClaimRequest)
 		vol := r.testVolume(req.VolumeID)
 		err := vol.Claim(req.ToClaim(), r.alloc)
@@ -423,7 +410,7 @@ func (r mockRPCer) RPC(method string, args any, reply any) error {
 		resp.QueryMeta = structs.QueryMeta{}
 
 	case "CSIVolume.Unpublish":
-		r.callCounts.inc("unpublish")
+		r.callCounts.Inc("unpublish")
 		resp := reply.(*structs.CSIVolumeUnpublishResponse)
 		resp.QueryMeta = structs.QueryMeta{}
 
@@ -466,55 +453,6 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume {
 	return vol
 }
 
-type mockVolumeManager struct {
-	hasMounts         bool
-	failsFirstUnmount *bool
-	callCounts        *callCounter
-}
-
-func (vm mockVolumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) {
-	vm.callCounts.inc("mount")
-	return &csimanager.MountInfo{
-		Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()),
-	}, nil
-}
-
-func (vm mockVolumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error {
-	vm.callCounts.inc("unmount")
-
-	if *vm.failsFirstUnmount {
-		*vm.failsFirstUnmount = false
-		return fmt.Errorf("could not unmount")
-	}
-
-	return nil
-}
-
-func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) {
-	vm.callCounts.inc("hasMount")
-	return mountInfo != nil && vm.hasMounts, nil
-}
-
-func (vm mockVolumeManager) ExternalID() string {
-	return "i-example"
-}
-
-type mockPluginManager struct {
-	mounter mockVolumeManager
-}
-
-func (mgr mockPluginManager) WaitForPlugin(ctx context.Context, pluginType, pluginID string) error {
-	return nil
-}
-
-func (mgr mockPluginManager) ManagerForPlugin(ctx context.Context, pluginID string) (csimanager.VolumeManager, error) {
-	return mgr.mounter, nil
-}
-
-// no-op methods to fulfill the interface
-func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { return nil }
-func (mgr mockPluginManager) Shutdown()                                  {}
-
 type mockAllocRunner struct {
 	res     *cstructs.AllocHookResources
 	caps    *drivers.Capabilities
diff --git a/client/csi_endpoint_test.go b/client/csi_endpoint_test.go
index 7b6df1534e7..0c551d2fe58 100644
--- a/client/csi_endpoint_test.go
+++ b/client/csi_endpoint_test.go
@@ -6,10 +6,12 @@ import (
 
 	"github.com/hashicorp/nomad/ci"
 	"github.com/hashicorp/nomad/client/dynamicplugins"
+	"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
 	"github.com/hashicorp/nomad/client/structs"
 	nstructs "github.com/hashicorp/nomad/nomad/structs"
 	"github.com/hashicorp/nomad/plugins/csi"
 	"github.com/hashicorp/nomad/plugins/csi/fake"
+	"github.com/shoenig/test/must"
 	"github.com/stretchr/testify/require"
 )
 
@@ -897,24 +899,22 @@ func TestCSINode_DetachVolume(t *testing.T) {
 	ci.Parallel(t)
 
 	cases := []struct {
-		Name             string
-		ClientSetupFunc  func(*fake.Client)
-		Request          *structs.ClientCSINodeDetachVolumeRequest
-		ExpectedErr      error
-		ExpectedResponse *structs.ClientCSINodeDetachVolumeResponse
+		Name        string
+		ModManager  func(m *csimanager.MockCSIManager)
+		Request     *structs.ClientCSINodeDetachVolumeRequest
+		ExpectedErr error
 	}{
 		{
-			Name: "returns plugin not found errors",
+			Name: "success",
 			Request: &structs.ClientCSINodeDetachVolumeRequest{
-				PluginID:       "some-garbage",
-				VolumeID:       "-",
-				AllocID:        "-",
-				NodeID:         "-",
+				PluginID:       "fake-plugin",
+				VolumeID:       "fake-vol",
+				AllocID:        "fake-alloc",
+				NodeID:         "fake-node",
 				AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem,
 				AccessMode:     nstructs.CSIVolumeAccessModeMultiNodeReader,
 				ReadOnly:       true,
 			},
-			ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin some-garbage for type csi-node not found"),
 		},
 		{
 			Name: "validates volumeid is not empty",
@@ -932,43 +932,51 @@ func TestCSINode_DetachVolume(t *testing.T) {
 			ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"),
 		},
 		{
-			Name: "returns transitive errors",
-			ClientSetupFunc: func(fc *fake.Client) {
-				fc.NextNodeUnpublishVolumeErr = errors.New("wont-see-this")
+			Name: "returns csi manager errors",
+			ModManager: func(m *csimanager.MockCSIManager) {
+				m.NextManagerForPluginErr = errors.New("no plugin")
 			},
 			Request: &structs.ClientCSINodeDetachVolumeRequest{
 				PluginID: fakeNodePlugin.Name,
 				VolumeID: "1234-4321-1234-4321",
 				AllocID:  "4321-1234-4321-1234",
 			},
-			// we don't have a csimanager in this context
-			ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin test-plugin for type csi-node not found"),
+			ExpectedErr: errors.New("CSI.NodeDetachVolume: no plugin"),
+		},
+		{
+			Name: "returns volume manager errors",
+			ModManager: func(m *csimanager.MockCSIManager) {
+				m.VM.NextUnmountVolumeErr = errors.New("error unmounting")
+			},
+			Request: &structs.ClientCSINodeDetachVolumeRequest{
+				PluginID: fakeNodePlugin.Name,
+				VolumeID: "1234-4321-1234-4321",
+				AllocID:  "4321-1234-4321-1234",
+			},
+			ExpectedErr: errors.New("CSI.NodeDetachVolume: error unmounting"),
 		},
 	}
 
 	for _, tc := range cases {
 		t.Run(tc.Name, func(t *testing.T) {
-			require := require.New(t)
 			client, cleanup := TestClient(t, nil)
 			defer cleanup()
 
-			fakeClient := &fake.Client{}
-			if tc.ClientSetupFunc != nil {
-				tc.ClientSetupFunc(fakeClient)
+			mockManager := &csimanager.MockCSIManager{
+				VM: &csimanager.MockVolumeManager{},
 			}
-
-			dispenserFunc := func(*dynamicplugins.PluginInfo) (interface{}, error) {
-				return fakeClient, nil
+			if tc.ModManager != nil {
+				tc.ModManager(mockManager)
 			}
-			client.dynamicRegistry.StubDispenserForType(dynamicplugins.PluginTypeCSINode, dispenserFunc)
-			err := client.dynamicRegistry.RegisterPlugin(fakeNodePlugin)
-			require.Nil(err)
+			client.csimanager = mockManager
 
 			var resp structs.ClientCSINodeDetachVolumeResponse
-			err = client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
-			require.Equal(tc.ExpectedErr, err)
-			if tc.ExpectedResponse != nil {
-				require.Equal(tc.ExpectedResponse, &resp)
+			err := client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
+			if tc.ExpectedErr != nil {
+				must.Error(t, err)
+				must.EqError(t, tc.ExpectedErr, err.Error())
+			} else {
+				must.NoError(t, err)
 			}
 		})
 	}
diff --git a/client/pluginmanager/csimanager/testing.go b/client/pluginmanager/csimanager/testing.go
new file mode 100644
index 00000000000..88a5055a5e1
--- /dev/null
+++ b/client/pluginmanager/csimanager/testing.go
@@ -0,0 +1,126 @@
+// Copyright (c) HashiCorp, Inc.
+// SPDX-License-Identifier: BUSL-1.1
+
+package csimanager
+
+import (
+	"context"
+	"path/filepath"
+
+	"github.com/hashicorp/nomad/client/pluginmanager"
+	nstructs "github.com/hashicorp/nomad/nomad/structs"
+	"github.com/hashicorp/nomad/plugins/csi"
+	"github.com/hashicorp/nomad/testutil"
+)
+
+var _ Manager = &MockCSIManager{}
+
+type MockCSIManager struct {
+	VM *MockVolumeManager
+
+	NextWaitForPluginErr    error
+	NextManagerForPluginErr error
+}
+
+func (m *MockCSIManager) PluginManager() pluginmanager.PluginManager {
+	panic("implement me")
+}
+
+func (m *MockCSIManager) WaitForPlugin(_ context.Context, pluginType, pluginID string) error {
+	return m.NextWaitForPluginErr
+}
+
+func (m *MockCSIManager) ManagerForPlugin(_ context.Context, pluginID string) (VolumeManager, error) {
+	if m.VM == nil {
+		m.VM = &MockVolumeManager{}
+	}
+	return m.VM, m.NextManagerForPluginErr
+}
+
+func (m *MockCSIManager) Shutdown() {
+	panic("implement me")
+}
+
+var _ VolumeManager = &MockVolumeManager{}
+
+type MockVolumeManager struct {
+	CallCounter *testutil.CallCounter
+
+	Mounts map[string]bool // lazy set
+
+	NextMountVolumeErr   error
+	NextUnmountVolumeErr error
+
+	NextExpandVolumeErr  error
+	LastExpandVolumeCall *MockExpandVolumeCall
+}
+
+func (m *MockVolumeManager) mountName(volID, allocID string, usageOpts *UsageOptions) string {
+	return filepath.Join("test-alloc-dir", allocID, volID, usageOpts.ToFS())
+}
+
+func (m *MockVolumeManager) MountVolume(_ context.Context, vol *nstructs.CSIVolume, alloc *nstructs.Allocation, usageOpts *UsageOptions, publishContext map[string]string) (*MountInfo, error) {
+	if m.CallCounter != nil {
+		m.CallCounter.Inc("MountVolume")
+	}
+
+	if m.NextMountVolumeErr != nil {
+		err := m.NextMountVolumeErr
+		m.NextMountVolumeErr = nil // reset it
+		return nil, err
+	}
+
+	// "mount" it
+	if m.Mounts == nil {
+		m.Mounts = make(map[string]bool)
+	}
+	source := m.mountName(vol.ID, alloc.ID, usageOpts)
+	m.Mounts[source] = true
+
+	return &MountInfo{
+		Source: source,
+	}, nil
+}
+
+func (m *MockVolumeManager) UnmountVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions) error {
+	if m.CallCounter != nil {
+		m.CallCounter.Inc("UnmountVolume")
+	}
+
+	if m.NextUnmountVolumeErr != nil {
+		err := m.NextUnmountVolumeErr
+		m.NextUnmountVolumeErr = nil // reset it
+		return err
+	}
+
+	// "unmount" it
+	delete(m.Mounts, m.mountName(volID, allocID, usageOpts))
+	return nil
+}
+
+func (m *MockVolumeManager) HasMount(_ context.Context, mountInfo *MountInfo) (bool, error) {
+	if m.CallCounter != nil {
+		m.CallCounter.Inc("HasMount")
+	}
+	if m.Mounts == nil {
+		return false, nil
+	}
+	return m.Mounts[mountInfo.Source], nil
+}
+
+func (m *MockVolumeManager) ExpandVolume(_ context.Context, volID, remoteID, allocID string, usageOpts *UsageOptions, capacity *csi.CapacityRange) (int64, error) {
+	m.LastExpandVolumeCall = &MockExpandVolumeCall{
+		volID, remoteID, allocID, usageOpts, capacity,
+	}
+	return capacity.RequiredBytes, m.NextExpandVolumeErr
+}
+
+type MockExpandVolumeCall struct {
+	VolID, RemoteID, AllocID string
+	UsageOpts                *UsageOptions
+	Capacity                 *csi.CapacityRange
+}
+
+func (m *MockVolumeManager) ExternalID() string {
+	return "mock-volume-manager"
+}
diff --git a/testutil/mock_calls.go b/testutil/mock_calls.go
new file mode 100644
index 00000000000..5b37832bc69
--- /dev/null
+++ b/testutil/mock_calls.go
@@ -0,0 +1,42 @@
+// Copyright (c) HashiCorp, Inc.
+// SPDX-License-Identifier: BUSL-1.1
+
+package testutil
+
+import (
+	"maps"
+	"sync"
+
+	"github.com/mitchellh/go-testing-interface"
+)
+
+func NewCallCounter() *CallCounter {
+	return &CallCounter{
+		counts: make(map[string]int),
+	}
+}
+
+type CallCounter struct {
+	lock   sync.Mutex
+	counts map[string]int
+}
+
+func (c *CallCounter) Inc(name string) {
+	c.lock.Lock()
+	defer c.lock.Unlock()
+	c.counts[name]++
+}
+
+func (c *CallCounter) Get() map[string]int {
+	c.lock.Lock()
+	defer c.lock.Unlock()
+	return maps.Clone(c.counts)
+}
+
+func (c *CallCounter) AssertCalled(t testing.T, name string) {
+	t.Helper()
+	counts := c.Get()
+	if _, ok := counts[name]; !ok {
+		t.Errorf("'%s' not called; all counts: %v", counts)
+	}
+}