Skip to content

Commit

Permalink
test: Refactor mock CSI manager (#18554)
Browse files Browse the repository at this point in the history
and MockCSIManager to support the call counting
that csi_hook_test expects

instead of implementing csimanager
interfaces in two separate places:
* client/allocrunner/csi_hook_test
* client/csi_endpoint_test

they can both use the same mocks defined in
client/pluginmanager/csimanager/
alongside the actual implementations of them.

also refactor TestCSINode_DetachVolume
to use use it like Node_ExpandVolume
so we can also test the happy path there
  • Loading branch information
gulducat committed Sep 22, 2023
1 parent 3cda227 commit a612cfb
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 123 deletions.
124 changes: 31 additions & 93 deletions client/allocrunner/csi_hook_test.go
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
package allocrunner

import (
"context"
"errors"
"fmt"
"path/filepath"
"sync"
"testing"
"time"

"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/allocrunner/state"
"github.com/hashicorp/nomad/client/pluginmanager"
"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)

var _ interfaces.RunnerPrerunHook = (*csiHook)(nil)
Expand Down Expand Up @@ -67,7 +63,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},

{
Expand All @@ -88,7 +84,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},

{
Expand Down Expand Up @@ -133,7 +129,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 2, "mount": 1, "unmount": 1, "unpublish": 1},
"claim": 2, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},
{
name: "already mounted",
Expand All @@ -159,7 +155,7 @@ func TestCSIHook(t *testing.T) {
expectedMounts: map[string]*csimanager.MountInfo{
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1},
expectedCalls: map[string]int{"HasMount": 1, "UnmountVolume": 1, "unpublish": 1},
},
{
name: "existing but invalid mounts",
Expand All @@ -186,7 +182,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
"HasMount": 1, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},

{
Expand All @@ -208,7 +204,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 2, "unpublish": 2},
"claim": 1, "MountVolume": 1, "UnmountVolume": 2, "unpublish": 2},
},

{
Expand All @@ -223,12 +219,11 @@ func TestCSIHook(t *testing.T) {

alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests

callCounts := &callCounter{counts: map[string]int{}}
mgr := mockPluginManager{mounter: mockVolumeManager{
hasMounts: tc.startsWithValidMounts,
callCounts: callCounts,
failsFirstUnmount: pointer.Of(tc.failsFirstUnmount),
}}
callCounts := testutil.NewCallCounter()
vm := &csimanager.MockVolumeManager{
CallCounter: callCounts,
}
mgr := &csimanager.MockCSIManager{VM: vm}
rpcer := mockRPCer{
alloc: alloc,
callCounts: callCounts,
Expand All @@ -251,6 +246,17 @@ func TestCSIHook(t *testing.T) {

must.NotNil(t, hook)

if tc.startsWithValidMounts {
// TODO: this works, but it requires knowledge of how the mock works. would rather vm.MountVolume()
vm.Mounts = map[string]bool{
tc.expectedMounts["vol0"].Source: true,
}
}

if tc.failsFirstUnmount {
vm.NextUnmountVolumeErr = errors.New("bad first attempt")
}

if tc.expectedClaimErr != nil {
must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error())
mounts := ar.res.GetCSIMounts()
Expand All @@ -270,7 +276,7 @@ func TestCSIHook(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}

counts := callCounts.get()
counts := callCounts.Get()
must.MapEq(t, tc.expectedCalls, counts,
must.Sprintf("got calls: %v", counts))

Expand Down Expand Up @@ -338,14 +344,12 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
alloc.Job.TaskGroups[0].Volumes = volumeRequests

callCounts := &callCounter{counts: map[string]int{}}
mgr := mockPluginManager{mounter: mockVolumeManager{
callCounts: callCounts,
failsFirstUnmount: pointer.Of(false),
}}
mgr := &csimanager.MockCSIManager{
VM: &csimanager.MockVolumeManager{},
}
rpcer := mockRPCer{
alloc: alloc,
callCounts: callCounts,
callCounts: testutil.NewCallCounter(),
hasExistingClaim: pointer.Of(false),
schedulable: pointer.Of(true),
}
Expand Down Expand Up @@ -375,26 +379,9 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {

// HELPERS AND MOCKS

type callCounter struct {
lock sync.Mutex
counts map[string]int
}

func (c *callCounter) inc(name string) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[name]++
}

func (c *callCounter) get() map[string]int {
c.lock.Lock()
defer c.lock.Unlock()
return maps.Clone(c.counts)
}

type mockRPCer struct {
alloc *structs.Allocation
callCounts *callCounter
callCounts *testutil.CallCounter
hasExistingClaim *bool
schedulable *bool
}
Expand All @@ -403,7 +390,7 @@ type mockRPCer struct {
func (r mockRPCer) RPC(method string, args any, reply any) error {
switch method {
case "CSIVolume.Claim":
r.callCounts.inc("claim")
r.callCounts.Inc("claim")
req := args.(*structs.CSIVolumeClaimRequest)
vol := r.testVolume(req.VolumeID)
err := vol.Claim(req.ToClaim(), r.alloc)
Expand All @@ -423,7 +410,7 @@ func (r mockRPCer) RPC(method string, args any, reply any) error {
resp.QueryMeta = structs.QueryMeta{}

case "CSIVolume.Unpublish":
r.callCounts.inc("unpublish")
r.callCounts.Inc("unpublish")
resp := reply.(*structs.CSIVolumeUnpublishResponse)
resp.QueryMeta = structs.QueryMeta{}

Expand Down Expand Up @@ -466,55 +453,6 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume {
return vol
}

type mockVolumeManager struct {
hasMounts bool
failsFirstUnmount *bool
callCounts *callCounter
}

func (vm mockVolumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) {
vm.callCounts.inc("mount")
return &csimanager.MountInfo{
Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()),
}, nil
}

func (vm mockVolumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error {
vm.callCounts.inc("unmount")

if *vm.failsFirstUnmount {
*vm.failsFirstUnmount = false
return fmt.Errorf("could not unmount")
}

return nil
}

func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) {
vm.callCounts.inc("hasMount")
return mountInfo != nil && vm.hasMounts, nil
}

func (vm mockVolumeManager) ExternalID() string {
return "i-example"
}

type mockPluginManager struct {
mounter mockVolumeManager
}

func (mgr mockPluginManager) WaitForPlugin(ctx context.Context, pluginType, pluginID string) error {
return nil
}

func (mgr mockPluginManager) ManagerForPlugin(ctx context.Context, pluginID string) (csimanager.VolumeManager, error) {
return mgr.mounter, nil
}

// no-op methods to fulfill the interface
func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { return nil }
func (mgr mockPluginManager) Shutdown() {}

type mockAllocRunner struct {
res *cstructs.AllocHookResources
caps *drivers.Capabilities
Expand Down
68 changes: 38 additions & 30 deletions client/csi_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (

"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/dynamicplugins"
"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
"github.com/hashicorp/nomad/client/structs"
nstructs "github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/csi"
"github.com/hashicorp/nomad/plugins/csi/fake"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -897,24 +899,22 @@ func TestCSINode_DetachVolume(t *testing.T) {
ci.Parallel(t)

cases := []struct {
Name string
ClientSetupFunc func(*fake.Client)
Request *structs.ClientCSINodeDetachVolumeRequest
ExpectedErr error
ExpectedResponse *structs.ClientCSINodeDetachVolumeResponse
Name string
ModManager func(m *csimanager.MockCSIManager)
Request *structs.ClientCSINodeDetachVolumeRequest
ExpectedErr error
}{
{
Name: "returns plugin not found errors",
Name: "success",
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: "some-garbage",
VolumeID: "-",
AllocID: "-",
NodeID: "-",
PluginID: "fake-plugin",
VolumeID: "fake-vol",
AllocID: "fake-alloc",
NodeID: "fake-node",
AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem,
AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader,
ReadOnly: true,
},
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin some-garbage for type csi-node not found"),
},
{
Name: "validates volumeid is not empty",
Expand All @@ -932,43 +932,51 @@ func TestCSINode_DetachVolume(t *testing.T) {
ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"),
},
{
Name: "returns transitive errors",
ClientSetupFunc: func(fc *fake.Client) {
fc.NextNodeUnpublishVolumeErr = errors.New("wont-see-this")
Name: "returns csi manager errors",
ModManager: func(m *csimanager.MockCSIManager) {
m.NextManagerForPluginErr = errors.New("no plugin")
},
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: fakeNodePlugin.Name,
VolumeID: "1234-4321-1234-4321",
AllocID: "4321-1234-4321-1234",
},
// we don't have a csimanager in this context
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin test-plugin for type csi-node not found"),
ExpectedErr: errors.New("CSI.NodeDetachVolume: no plugin"),
},
{
Name: "returns volume manager errors",
ModManager: func(m *csimanager.MockCSIManager) {
m.VM.NextUnmountVolumeErr = errors.New("error unmounting")
},
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: fakeNodePlugin.Name,
VolumeID: "1234-4321-1234-4321",
AllocID: "4321-1234-4321-1234",
},
ExpectedErr: errors.New("CSI.NodeDetachVolume: error unmounting"),
},
}

for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
client, cleanup := TestClient(t, nil)
defer cleanup()

fakeClient := &fake.Client{}
if tc.ClientSetupFunc != nil {
tc.ClientSetupFunc(fakeClient)
mockManager := &csimanager.MockCSIManager{
VM: &csimanager.MockVolumeManager{},
}

dispenserFunc := func(*dynamicplugins.PluginInfo) (interface{}, error) {
return fakeClient, nil
if tc.ModManager != nil {
tc.ModManager(mockManager)
}
client.dynamicRegistry.StubDispenserForType(dynamicplugins.PluginTypeCSINode, dispenserFunc)
err := client.dynamicRegistry.RegisterPlugin(fakeNodePlugin)
require.Nil(err)
client.csimanager = mockManager

var resp structs.ClientCSINodeDetachVolumeResponse
err = client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
require.Equal(tc.ExpectedErr, err)
if tc.ExpectedResponse != nil {
require.Equal(tc.ExpectedResponse, &resp)
err := client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
if tc.ExpectedErr != nil {
must.Error(t, err)
must.EqError(t, tc.ExpectedErr, err.Error())
} else {
must.NoError(t, err)
}
})
}
Expand Down
Loading

0 comments on commit a612cfb

Please sign in to comment.