diff --git a/cmd/mock-driver/main.go b/cmd/mock-driver/main.go index fdf46a61..57d41302 100644 --- a/cmd/mock-driver/main.go +++ b/cmd/mock-driver/main.go @@ -35,6 +35,7 @@ func main() { flag.Int64Var(&config.AttachLimit, "attach-limit", 0, "number of attachable volumes on a node") flag.BoolVar(&config.NodeExpansionRequired, "node-expand-required", false, "Enables NodeServiceCapability_RPC_EXPAND_VOLUME capacity.") flag.BoolVar(&config.DisableControllerExpansion, "disable-controller-expansion", false, "Disables ControllerServiceCapability_RPC_EXPAND_VOLUME capability.") + flag.BoolVar(&config.DisableOnlineExpansion, "disable-online-expansion", false, "Disables online volume expansion capability.") flag.Parse() endpoint := os.Getenv("CSI_ENDPOINT") diff --git a/mock/service/controller.go b/mock/service/controller.go index 50855dbd..b374ddab 100644 --- a/mock/service/controller.go +++ b/mock/service/controller.go @@ -498,6 +498,10 @@ func (s *service) ControllerExpandVolume( return nil, status.Error(codes.NotFound, req.VolumeId) } + if s.config.DisableOnlineExpansion && MockVolumes[v.GetVolumeId()].ISPublished { + return nil, status.Error(codes.Aborted, "volume is published and online volume expansion is not supported") + } + requestBytes := req.CapacityRange.RequiredBytes if v.CapacityBytes > requestBytes { diff --git a/mock/service/identity.go b/mock/service/identity.go index 7e8735a9..41d08aaa 100644 --- a/mock/service/identity.go +++ b/mock/service/identity.go @@ -34,6 +34,12 @@ func (s *service) GetPluginCapabilities( req *csi.GetPluginCapabilitiesRequest) ( *csi.GetPluginCapabilitiesResponse, error) { + volExpType := csi.PluginCapability_VolumeExpansion_ONLINE + + if s.config.DisableOnlineExpansion { + volExpType = csi.PluginCapability_VolumeExpansion_OFFLINE + } + return &csi.GetPluginCapabilitiesResponse{ Capabilities: []*csi.PluginCapability{ { @@ -43,6 +49,13 @@ func (s *service) GetPluginCapabilities( }, }, }, + { + Type: &csi.PluginCapability_VolumeExpansion_{ + VolumeExpansion: &csi.PluginCapability_VolumeExpansion{ + Type: volExpType, + }, + }, + }, }, }, nil } diff --git a/mock/service/service.go b/mock/service/service.go index a6f85a02..6435d597 100644 --- a/mock/service/service.go +++ b/mock/service/service.go @@ -32,6 +32,7 @@ type Config struct { AttachLimit int64 NodeExpansionRequired bool DisableControllerExpansion bool + DisableOnlineExpansion bool } // Service is the CSI Mock service provider. diff --git a/pkg/sanity/identity.go b/pkg/sanity/identity.go index c1a5eb7e..0cefcd9f 100644 --- a/pkg/sanity/identity.go +++ b/pkg/sanity/identity.go @@ -49,11 +49,23 @@ var _ = DescribeSanity("Identity Service", func(sc *SanityContext) { By("checking successful response") Expect(res.GetCapabilities()).NotTo(BeNil()) for _, cap := range res.GetCapabilities() { - switch cap.GetService().GetType() { - case csi.PluginCapability_Service_CONTROLLER_SERVICE: - case csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS: + switch cap.GetType().(type) { + case *csi.PluginCapability_Service_: + switch cap.GetService().GetType() { + case csi.PluginCapability_Service_CONTROLLER_SERVICE: + case csi.PluginCapability_Service_VOLUME_ACCESSIBILITY_CONSTRAINTS: + default: + Fail(fmt.Sprintf("Unknown service: %v\n", cap.GetService().GetType())) + } + case *csi.PluginCapability_VolumeExpansion_: + switch cap.GetVolumeExpansion().GetType() { + case csi.PluginCapability_VolumeExpansion_ONLINE: + case csi.PluginCapability_VolumeExpansion_OFFLINE: + default: + Fail(fmt.Sprintf("Unknown volume expansion mode: %v\n", cap.GetVolumeExpansion().GetType())) + } default: - Fail(fmt.Sprintf("Unknown capability: %v\n", cap.GetService().GetType())) + Fail(fmt.Sprintf("Unknown capability: %v\n", cap.GetType())) } } diff --git a/pkg/sanity/node.go b/pkg/sanity/node.go index 41f05440..7cb570e3 100644 --- a/pkg/sanity/node.go +++ b/pkg/sanity/node.go @@ -60,8 +60,7 @@ func isPluginCapabilitySupported(c csi.IdentityClient, Expect(caps.GetCapabilities()).NotTo(BeNil()) for _, cap := range caps.GetCapabilities() { - Expect(cap.GetService()).NotTo(BeNil()) - if cap.GetService().GetType() == capType { + if cap.GetService() != nil && cap.GetService().GetType() == capType { return true } }