Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Refactor mock CSI manager #18554

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 31 additions & 98 deletions client/allocrunner/csi_hook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,24 @@
package allocrunner

import (
"context"
"errors"
"fmt"
"path/filepath"
"sync"
"testing"
"time"

"github.com/hashicorp/nomad/ci"
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
"github.com/hashicorp/nomad/client/allocrunner/state"
"github.com/hashicorp/nomad/client/pluginmanager"
"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper/pointer"
"github.com/hashicorp/nomad/helper/testlog"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/csi"
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)

var _ interfaces.RunnerPrerunHook = (*csiHook)(nil)
Expand Down Expand Up @@ -71,7 +66,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},

{
Expand All @@ -92,7 +87,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
"claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},

{
Expand Down Expand Up @@ -137,7 +132,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 2, "mount": 1, "unmount": 1, "unpublish": 1},
"claim": 2, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},
{
name: "already mounted",
Expand All @@ -163,7 +158,7 @@ func TestCSIHook(t *testing.T) {
expectedMounts: map[string]*csimanager.MountInfo{
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{"hasMount": 1, "unmount": 1, "unpublish": 1},
expectedCalls: map[string]int{"HasMount": 1, "UnmountVolume": 1, "unpublish": 1},
},
{
name: "existing but invalid mounts",
Expand All @@ -190,7 +185,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"hasMount": 1, "claim": 1, "mount": 1, "unmount": 1, "unpublish": 1},
"HasMount": 1, "claim": 1, "MountVolume": 1, "UnmountVolume": 1, "unpublish": 1},
},

{
Expand All @@ -212,7 +207,7 @@ func TestCSIHook(t *testing.T) {
"vol0": &csimanager.MountInfo{Source: testMountSrc},
},
expectedCalls: map[string]int{
"claim": 1, "mount": 1, "unmount": 2, "unpublish": 2},
"claim": 1, "MountVolume": 1, "UnmountVolume": 2, "unpublish": 2},
},

{
Expand All @@ -227,12 +222,11 @@ func TestCSIHook(t *testing.T) {

alloc.Job.TaskGroups[0].Volumes = tc.volumeRequests

callCounts := &callCounter{counts: map[string]int{}}
mgr := mockPluginManager{mounter: mockVolumeManager{
hasMounts: tc.startsWithValidMounts,
callCounts: callCounts,
failsFirstUnmount: pointer.Of(tc.failsFirstUnmount),
}}
callCounts := testutil.NewCallCounter()
vm := &csimanager.MockVolumeManager{
CallCounter: callCounts,
}
mgr := &csimanager.MockCSIManager{VM: vm}
rpcer := mockRPCer{
alloc: alloc,
callCounts: callCounts,
Expand All @@ -255,6 +249,17 @@ func TestCSIHook(t *testing.T) {

must.NotNil(t, hook)

if tc.startsWithValidMounts {
// TODO: this works, but it requires knowledge of how the mock works. would rather vm.MountVolume()
vm.Mounts = map[string]bool{
tc.expectedMounts["vol0"].Source: true,
}
}

if tc.failsFirstUnmount {
vm.NextUnmountVolumeErr = errors.New("bad first attempt")
}

if tc.expectedClaimErr != nil {
must.EqError(t, hook.Prerun(), tc.expectedClaimErr.Error())
mounts := ar.res.GetCSIMounts()
Expand All @@ -274,7 +279,7 @@ func TestCSIHook(t *testing.T) {
time.Sleep(100 * time.Millisecond)
}

counts := callCounts.get()
counts := callCounts.Get()
must.MapEq(t, tc.expectedCalls, counts,
must.Sprintf("got calls: %v", counts))

Expand Down Expand Up @@ -342,14 +347,12 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
alloc.Job.TaskGroups[0].Volumes = volumeRequests

callCounts := &callCounter{counts: map[string]int{}}
mgr := mockPluginManager{mounter: mockVolumeManager{
callCounts: callCounts,
failsFirstUnmount: pointer.Of(false),
}}
mgr := &csimanager.MockCSIManager{
VM: &csimanager.MockVolumeManager{},
}
rpcer := mockRPCer{
alloc: alloc,
callCounts: callCounts,
callCounts: testutil.NewCallCounter(),
hasExistingClaim: pointer.Of(false),
schedulable: pointer.Of(true),
}
Expand Down Expand Up @@ -379,26 +382,9 @@ func TestCSIHook_Prerun_Validation(t *testing.T) {

// HELPERS AND MOCKS

type callCounter struct {
lock sync.Mutex
counts map[string]int
}

func (c *callCounter) inc(name string) {
c.lock.Lock()
defer c.lock.Unlock()
c.counts[name]++
}

func (c *callCounter) get() map[string]int {
c.lock.Lock()
defer c.lock.Unlock()
return maps.Clone(c.counts)
}

type mockRPCer struct {
alloc *structs.Allocation
callCounts *callCounter
callCounts *testutil.CallCounter
hasExistingClaim *bool
schedulable *bool
}
Expand All @@ -407,7 +393,7 @@ type mockRPCer struct {
func (r mockRPCer) RPC(method string, args any, reply any) error {
switch method {
case "CSIVolume.Claim":
r.callCounts.inc("claim")
r.callCounts.Inc("claim")
req := args.(*structs.CSIVolumeClaimRequest)
vol := r.testVolume(req.VolumeID)
err := vol.Claim(req.ToClaim(), r.alloc)
Expand All @@ -427,7 +413,7 @@ func (r mockRPCer) RPC(method string, args any, reply any) error {
resp.QueryMeta = structs.QueryMeta{}

case "CSIVolume.Unpublish":
r.callCounts.inc("unpublish")
r.callCounts.Inc("unpublish")
resp := reply.(*structs.CSIVolumeUnpublishResponse)
resp.QueryMeta = structs.QueryMeta{}

Expand Down Expand Up @@ -470,59 +456,6 @@ func (r mockRPCer) testVolume(id string) *structs.CSIVolume {
return vol
}

type mockVolumeManager struct {
hasMounts bool
failsFirstUnmount *bool
callCounts *callCounter
}

func (vm mockVolumeManager) MountVolume(ctx context.Context, vol *structs.CSIVolume, alloc *structs.Allocation, usageOpts *csimanager.UsageOptions, publishContext map[string]string) (*csimanager.MountInfo, error) {
vm.callCounts.inc("mount")
return &csimanager.MountInfo{
Source: filepath.Join("test-alloc-dir", alloc.ID, vol.ID, usageOpts.ToFS()),
}, nil
}

func (vm mockVolumeManager) UnmountVolume(ctx context.Context, volID, remoteID, allocID string, usageOpts *csimanager.UsageOptions) error {
vm.callCounts.inc("unmount")

if *vm.failsFirstUnmount {
*vm.failsFirstUnmount = false
return fmt.Errorf("could not unmount")
}

return nil
}

func (vm mockVolumeManager) HasMount(_ context.Context, mountInfo *csimanager.MountInfo) (bool, error) {
vm.callCounts.inc("hasMount")
return mountInfo != nil && vm.hasMounts, nil
}

func (vm mockVolumeManager) ExpandVolume(_ context.Context, _, _, _ string, _ *csimanager.UsageOptions, _ *csi.CapacityRange) (int64, error) {
return 0, nil
}

func (vm mockVolumeManager) ExternalID() string {
return "i-example"
}

type mockPluginManager struct {
mounter mockVolumeManager
}

func (mgr mockPluginManager) WaitForPlugin(ctx context.Context, pluginType, pluginID string) error {
return nil
}

func (mgr mockPluginManager) ManagerForPlugin(ctx context.Context, pluginID string) (csimanager.VolumeManager, error) {
return mgr.mounter, nil
}

// no-op methods to fulfill the interface
func (mgr mockPluginManager) PluginManager() pluginmanager.PluginManager { return nil }
func (mgr mockPluginManager) Shutdown() {}

type mockAllocRunner struct {
res *cstructs.AllocHookResources
caps *drivers.Capabilities
Expand Down
66 changes: 36 additions & 30 deletions client/csi_endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -994,24 +994,22 @@ func TestCSINode_DetachVolume(t *testing.T) {
ci.Parallel(t)

cases := []struct {
Name string
ClientSetupFunc func(*fake.Client)
Request *structs.ClientCSINodeDetachVolumeRequest
ExpectedErr error
ExpectedResponse *structs.ClientCSINodeDetachVolumeResponse
Name string
ModManager func(m *csimanager.MockCSIManager)
Request *structs.ClientCSINodeDetachVolumeRequest
ExpectedErr error
}{
{
Name: "returns plugin not found errors",
Name: "success",
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: "some-garbage",
VolumeID: "-",
AllocID: "-",
NodeID: "-",
PluginID: "fake-plugin",
VolumeID: "fake-vol",
AllocID: "fake-alloc",
NodeID: "fake-node",
AttachmentMode: nstructs.CSIVolumeAttachmentModeFilesystem,
AccessMode: nstructs.CSIVolumeAccessModeMultiNodeReader,
ReadOnly: true,
},
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin some-garbage for type csi-node not found"),
},
{
Name: "validates volumeid is not empty",
Expand All @@ -1029,43 +1027,51 @@ func TestCSINode_DetachVolume(t *testing.T) {
ExpectedErr: errors.New("CSI.NodeDetachVolume: AllocID is required"),
},
{
Name: "returns transitive errors",
ClientSetupFunc: func(fc *fake.Client) {
fc.NextNodeUnpublishVolumeErr = errors.New("wont-see-this")
Name: "returns csi manager errors",
ModManager: func(m *csimanager.MockCSIManager) {
m.NextManagerForPluginErr = errors.New("no plugin")
},
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: fakeNodePlugin.Name,
VolumeID: "1234-4321-1234-4321",
AllocID: "4321-1234-4321-1234",
},
ExpectedErr: errors.New("CSI.NodeDetachVolume: no plugin"),
},
{
Name: "returns volume manager errors",
ModManager: func(m *csimanager.MockCSIManager) {
m.VM.NextUnmountVolumeErr = errors.New("error unmounting")
},
Request: &structs.ClientCSINodeDetachVolumeRequest{
PluginID: fakeNodePlugin.Name,
VolumeID: "1234-4321-1234-4321",
AllocID: "4321-1234-4321-1234",
},
// we don't have a csimanager in this context
ExpectedErr: errors.New("CSI.NodeDetachVolume: plugin test-plugin for type csi-node not found"),
ExpectedErr: errors.New("CSI.NodeDetachVolume: error unmounting"),
},
}

for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
require := require.New(t)
client, cleanup := TestClient(t, nil)
defer cleanup()

fakeClient := &fake.Client{}
if tc.ClientSetupFunc != nil {
tc.ClientSetupFunc(fakeClient)
mockManager := &csimanager.MockCSIManager{
VM: &csimanager.MockVolumeManager{},
}

dispenserFunc := func(*dynamicplugins.PluginInfo) (interface{}, error) {
return fakeClient, nil
if tc.ModManager != nil {
tc.ModManager(mockManager)
}
client.dynamicRegistry.StubDispenserForType(dynamicplugins.PluginTypeCSINode, dispenserFunc)
err := client.dynamicRegistry.RegisterPlugin(fakeNodePlugin)
require.Nil(err)
client.csimanager = mockManager

var resp structs.ClientCSINodeDetachVolumeResponse
err = client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
require.Equal(tc.ExpectedErr, err)
if tc.ExpectedResponse != nil {
require.Equal(tc.ExpectedResponse, &resp)
err := client.ClientRPC("CSI.NodeDetachVolume", tc.Request, &resp)
if tc.ExpectedErr != nil {
must.Error(t, err)
must.EqError(t, tc.ExpectedErr, err.Error())
} else {
must.NoError(t, err)
}
})
}
Expand Down
Loading
Loading