diff --git a/mount.go b/mount.go index 50088c9788..a2d962aacd 100644 --- a/mount.go +++ b/mount.go @@ -271,13 +271,28 @@ func virtioFSStorageHandler(storage pb.Storage, s *sandbox) (string, error) { // virtioBlkStorageHandler handles the storage for blk driver. func virtioBlkStorageHandler(storage pb.Storage, s *sandbox) (string, error) { - // Get the device node path based on the PCI address provided - // in Storage Source - devPath, err := getPCIDeviceName(s, storage.Source) - if err != nil { - return "", err + + // If hot-plugged, get the device node path based on the PCI address else + // use the virt path provided in Storage Source + if strings.HasPrefix(storage.Source, "/dev") { + + FileInfo, err := os.Stat(storage.Source) + if err != nil { + return "", err + } + // Make sure the virt path is valid + if FileInfo.Mode()&os.ModeDevice == 0 { + return "", err + } + + } else { + devPath, err := getPCIDeviceName(s, storage.Source) + if err != nil { + return "", err + } + + storage.Source = devPath } - storage.Source = devPath return commonStorageHandler(storage) } diff --git a/mount_test.go b/mount_test.go index 2d6894f66b..05b517062d 100644 --- a/mount_test.go +++ b/mount_test.go @@ -125,6 +125,30 @@ func TestVirtio9pStorageHandlerSuccessful(t *testing.T) { assert.Nil(t, err, "storage9pDriverHandler() failed: %v", err) } +func TestVirtioBlkStoragePathFailure(t *testing.T) { + s := &sandbox{} + + storage := pb.Storage{ + Source: "/home/developer/test", + } + + _, err := virtioBlkStorageHandler(storage, s) + agentLog.WithError(err).Error("virtioBlkStorageHandler error") + assert.NotNil(t, err, "virtioBlkStorageHandler() should have failed") +} + +func TestVirtioBlkStorageDeviceFailure(t *testing.T) { + s := &sandbox{} + + storage := pb.Storage{ + Source: "/dev/foo", + } + + _, err := virtioBlkStorageHandler(storage, s) + agentLog.WithError(err).Error("virtioBlkStorageHandler error") + assert.NotNil(t, err, "virtioBlkStorageHandler() should have failed") +} + func TestVirtioBlkStorageHandlerSuccessful(t *testing.T) { skipUnlessRoot(t)