diff --git a/virtcontainers/container.go b/virtcontainers/container.go index e14d78d37b..d870be6994 100644 --- a/virtcontainers/container.go +++ b/virtcontainers/container.go @@ -457,8 +457,7 @@ func (c *Container) mountSharedDirMounts(hostSharedDir, guestSharedDir string) ( // instead of passing this as a shared mount. if len(m.BlockDeviceID) > 0 { // Attach this block device, all other devices passed in the config have been attached at this point - if err := c.sandbox.devManager.AttachDevice(m.BlockDeviceID, c.sandbox); err != nil && - err != manager.ErrDeviceAttached { + if err := c.sandbox.devManager.AttachDevice(m.BlockDeviceID, c.sandbox); err != nil { return nil, err } @@ -1153,10 +1152,6 @@ func (c *Container) attachDevices() error { // and rollbackFailingContainerCreation could do all the rollbacks for _, dev := range c.devices { if err := c.sandbox.devManager.AttachDevice(dev.ID, c.sandbox); err != nil { - if err == manager.ErrDeviceAttached { - // skip if device is already attached before - continue - } return err } } diff --git a/virtcontainers/device/api/interface.go b/virtcontainers/device/api/interface.go index 6bdba9b1bc..da15831aef 100644 --- a/virtcontainers/device/api/interface.go +++ b/virtcontainers/device/api/interface.go @@ -49,13 +49,20 @@ type Device interface { // DeviceType indicates which kind of device it is // e.g. block, vfio or vhost user DeviceType() config.DeviceType + // GetMajorMinor returns major and minor numbers + GetMajorMinor() (int64, int64) // GetDeviceInfo returns device specific data used for hotplugging by hypervisor // Caller could cast the return value to device specific struct // e.g. Block device returns *config.BlockDrive and // vfio device returns []*config.VFIODev GetDeviceInfo() interface{} - // IsAttached checks if the device is attached - IsAttached() bool + // GetAttachCount returns how many times the device has been attached + GetAttachCount() uint + + // Reference adds one reference to device then returns final ref count + Reference() uint + // Dereference removes one reference to device then returns final ref count + Dereference() uint } // DeviceManager can be used to create a new device, this can be used as single diff --git a/virtcontainers/device/config/config.go b/virtcontainers/device/config/config.go index 3f77540d20..66de00cfc6 100644 --- a/virtcontainers/device/config/config.go +++ b/virtcontainers/device/config/config.go @@ -75,10 +75,6 @@ type DeviceInfo struct { // id of the device group. GID uint32 - // Hotplugged is used to store device state indicating if the - // device was hotplugged. - Hotplugged bool - // ID for the device that is passed to the hypervisor. ID string diff --git a/virtcontainers/device/drivers/block.go b/virtcontainers/device/drivers/block.go index 41d7087c17..cbbe7bb316 100644 --- a/virtcontainers/device/drivers/block.go +++ b/virtcontainers/device/drivers/block.go @@ -35,7 +35,11 @@ func NewBlockDevice(devInfo *config.DeviceInfo) *BlockDevice { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -47,6 +51,8 @@ func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { defer func() { if err != nil { devReceiver.DecrementSandboxBlockIndex() + } else { + device.AttachCount = 1 } }() @@ -84,15 +90,17 @@ func (device *BlockDevice) Attach(devReceiver api.DeviceReceiver) (err error) { return err } - device.DeviceInfo.Hotplugged = true - return nil } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *BlockDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } @@ -102,7 +110,7 @@ func (device *BlockDevice) Detach(devReceiver api.DeviceReceiver) error { deviceLogger().WithError(err).Error("Failed to unplug block device") return err } - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } @@ -116,5 +124,5 @@ func (device *BlockDevice) GetDeviceInfo() interface{} { return device.BlockDrive } -// It should implement IsAttached() and DeviceID() as api.Device implementation +// It should implement GetAttachCount() and DeviceID() as api.Device implementation // here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/drivers/generic.go b/virtcontainers/device/drivers/generic.go index 8b03b9c75a..10835e3f16 100644 --- a/virtcontainers/device/drivers/generic.go +++ b/virtcontainers/device/drivers/generic.go @@ -7,6 +7,8 @@ package drivers import ( + "fmt" + "github.com/kata-containers/runtime/virtcontainers/device/api" "github.com/kata-containers/runtime/virtcontainers/device/config" ) @@ -15,6 +17,9 @@ import ( type GenericDevice struct { ID string DeviceInfo *config.DeviceInfo + + RefCount uint + AttachCount uint } // NewGenericDevice creates a new GenericDevice @@ -27,19 +32,27 @@ func NewGenericDevice(devInfo *config.DeviceInfo) *GenericDevice { // Attach is standard interface of api.Device func (device *GenericDevice) Attach(devReceiver api.DeviceReceiver) error { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } - + device.AttachCount = 1 return nil } // Detach is standard interface of api.Device func (device *GenericDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } - + device.AttachCount = 0 return nil } @@ -53,12 +66,64 @@ func (device *GenericDevice) GetDeviceInfo() interface{} { return device.DeviceInfo } -// IsAttached checks if the device is attached -func (device *GenericDevice) IsAttached() bool { - return device.DeviceInfo.Hotplugged +// GetAttachCount returns how many times the device has been attached +func (device *GenericDevice) GetAttachCount() uint { + return device.AttachCount } // DeviceID returns device ID func (device *GenericDevice) DeviceID() string { return device.ID } + +// GetMajorMinor returns device major and minor numbers +func (device *GenericDevice) GetMajorMinor() (int64, int64) { + return device.DeviceInfo.Major, device.DeviceInfo.Minor +} + +// Reference adds one reference to device +func (device *GenericDevice) Reference() uint { + if device.RefCount != intMax { + device.RefCount++ + } + return device.RefCount +} + +// Dereference remove one reference from device +func (device *GenericDevice) Dereference() uint { + if device.RefCount != 0 { + device.RefCount-- + } + return device.RefCount +} + +// bumpAttachCount is used to add/minus attach count for a device +// * attach bool: true means attach, false means detach +// return values: +// * skip bool: no need to do real attach/detach, skip following actions. +// * err error: error while do attach count bump +func (device *GenericDevice) bumpAttachCount(attach bool) (skip bool, err error) { + if attach { // attach use case + switch device.AttachCount { + case 0: + // do real attach + return false, nil + case intMax: + return true, fmt.Errorf("device was attached too many times") + default: + device.AttachCount++ + return true, nil + } + } else { // detach use case + switch device.AttachCount { + case 0: + return true, fmt.Errorf("detaching a device that wasn't attached") + case 1: + // do real work + return false, nil + default: + device.AttachCount-- + return true, nil + } + } +} diff --git a/virtcontainers/device/drivers/generic_test.go b/virtcontainers/device/drivers/generic_test.go new file mode 100644 index 0000000000..5b0c28484f --- /dev/null +++ b/virtcontainers/device/drivers/generic_test.go @@ -0,0 +1,44 @@ +// Copyright (c) 2018 Huawei Corporation +// +// SPDX-License-Identifier: Apache-2.0 +// + +package drivers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBumpAttachCount(t *testing.T) { + type testData struct { + attach bool + attachCount uint + expectedAC uint + expectSkip bool + expectErr bool + } + + data := []testData{ + {true, 0, 0, false, false}, + {true, 1, 2, true, false}, + {true, intMax, intMax, true, true}, + {false, 0, 0, true, true}, + {false, 1, 1, false, false}, + {false, intMax, intMax - 1, true, false}, + } + + dev := &GenericDevice{} + for _, d := range data { + dev.AttachCount = d.attachCount + skip, err := dev.bumpAttachCount(d.attach) + assert.Equal(t, skip, d.expectSkip, "") + assert.Equal(t, dev.GetAttachCount(), d.expectedAC, "") + if d.expectErr { + assert.NotNil(t, err) + } else { + assert.Nil(t, err) + } + } +} diff --git a/virtcontainers/device/drivers/utils.go b/virtcontainers/device/drivers/utils.go index c60eef0f07..33c18ae8ea 100644 --- a/virtcontainers/device/drivers/utils.go +++ b/virtcontainers/device/drivers/utils.go @@ -12,6 +12,8 @@ import ( "github.com/kata-containers/runtime/virtcontainers/device/api" ) +const intMax uint = ^uint(0) + func deviceLogger() *logrus.Entry { return api.DeviceLogger() } diff --git a/virtcontainers/device/drivers/vfio.go b/virtcontainers/device/drivers/vfio.go index 351b867353..b29cccd675 100644 --- a/virtcontainers/device/drivers/vfio.go +++ b/virtcontainers/device/drivers/vfio.go @@ -47,7 +47,11 @@ func NewVFIODevice(devInfo *config.DeviceInfo) *VFIODevice { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -83,14 +87,18 @@ func (device *VFIODevice) Attach(devReceiver api.DeviceReceiver) error { "device-group": device.DeviceInfo.HostPath, "device-type": "vfio-passthrough", }).Info("Device group attached") - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 return nil } // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } @@ -104,7 +112,7 @@ func (device *VFIODevice) Detach(devReceiver api.DeviceReceiver) error { "device-group": device.DeviceInfo.HostPath, "device-type": "vfio-passthrough", }).Info("Device group detached") - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } @@ -118,7 +126,7 @@ func (device *VFIODevice) GetDeviceInfo() interface{} { return device.vfioDevs } -// It should implement IsAttached() and DeviceID() as api.Device implementation +// It should implement GetAttachCount() and DeviceID() as api.Device implementation // here it shares function from *GenericDevice so we don't need duplicate codes // getBDF returns the BDF of pci device diff --git a/virtcontainers/device/drivers/vhost_user_blk.go b/virtcontainers/device/drivers/vhost_user_blk.go index 80e4c598fe..cfc53d4145 100644 --- a/virtcontainers/device/drivers/vhost_user_blk.go +++ b/virtcontainers/device/drivers/vhost_user_blk.go @@ -27,7 +27,11 @@ type VhostUserBlkDevice struct { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -43,7 +47,7 @@ func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err er defer func() { if err == nil { - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 } }() return devReceiver.AppendDevice(device) @@ -52,11 +56,15 @@ func (device *VhostUserBlkDevice) Attach(devReceiver api.DeviceReceiver) (err er // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserBlkDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } @@ -71,5 +79,5 @@ func (device *VhostUserBlkDevice) GetDeviceInfo() interface{} { return &device.VhostUserDeviceAttrs } -// It should implement IsAttached() and DeviceID() as api.Device implementation +// It should implement GetAttachCount() and DeviceID() as api.Device implementation // here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/drivers/vhost_user_net.go b/virtcontainers/device/drivers/vhost_user_net.go index b5961ea8d6..2ab31dd620 100644 --- a/virtcontainers/device/drivers/vhost_user_net.go +++ b/virtcontainers/device/drivers/vhost_user_net.go @@ -27,7 +27,11 @@ type VhostUserNetDevice struct { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -43,7 +47,7 @@ func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err er defer func() { if err == nil { - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 } }() return devReceiver.AppendDevice(device) @@ -52,11 +56,15 @@ func (device *VhostUserNetDevice) Attach(devReceiver api.DeviceReceiver) (err er // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserNetDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } @@ -71,5 +79,5 @@ func (device *VhostUserNetDevice) GetDeviceInfo() interface{} { return &device.VhostUserDeviceAttrs } -// It should implement IsAttached() and DeviceID() as api.Device implementation +// It should implement GetAttachCount() and DeviceID() as api.Device implementation // here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/drivers/vhost_user_scsi.go b/virtcontainers/device/drivers/vhost_user_scsi.go index dcd79f47c3..d34c50ec04 100644 --- a/virtcontainers/device/drivers/vhost_user_scsi.go +++ b/virtcontainers/device/drivers/vhost_user_scsi.go @@ -27,7 +27,11 @@ type VhostUserSCSIDevice struct { // Attach is standard interface of api.Device, it's used to add device to some // DeviceReceiver func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err error) { - if device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(true) + if err != nil { + return err + } + if skip { return nil } @@ -43,7 +47,7 @@ func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err e defer func() { if err == nil { - device.DeviceInfo.Hotplugged = true + device.AttachCount = 1 } }() return devReceiver.AppendDevice(device) @@ -52,11 +56,14 @@ func (device *VhostUserSCSIDevice) Attach(devReceiver api.DeviceReceiver) (err e // Detach is standard interface of api.Device, it's used to remove device from some // DeviceReceiver func (device *VhostUserSCSIDevice) Detach(devReceiver api.DeviceReceiver) error { - if !device.DeviceInfo.Hotplugged { + skip, err := device.bumpAttachCount(false) + if err != nil { + return err + } + if skip { return nil } - - device.DeviceInfo.Hotplugged = false + device.AttachCount = 0 return nil } @@ -71,5 +78,5 @@ func (device *VhostUserSCSIDevice) GetDeviceInfo() interface{} { return &device.VhostUserDeviceAttrs } -// It should implement IsAttached() and DeviceID() as api.Device implementation +// It should implement GetAttachCount() and DeviceID() as api.Device implementation // here it shares function from *GenericDevice so we don't need duplicate codes diff --git a/virtcontainers/device/manager/manager.go b/virtcontainers/device/manager/manager.go index 62ce0742a6..5841b938da 100644 --- a/virtcontainers/device/manager/manager.go +++ b/virtcontainers/device/manager/manager.go @@ -32,10 +32,11 @@ var ( ErrIDExhausted = errors.New("IDs are exhausted") // ErrDeviceNotExist represents device hasn't been created before ErrDeviceNotExist = errors.New("device with specified ID hasn't been created") - // ErrDeviceAttached represents the device is already attached - ErrDeviceAttached = errors.New("device is already attached") // ErrDeviceNotAttached represents the device isn't attached ErrDeviceNotAttached = errors.New("device isn't attached") + // ErrRemoveAttachedDevice represents the device isn't detached + // so not allow to remove from list + ErrRemoveAttachedDevice = errors.New("can't remove attached device") ) type deviceManager struct { @@ -66,14 +67,34 @@ func NewDeviceManager(blockDriver string, devices []api.Device) api.DeviceManage return dm } +func (dm *deviceManager) findDeviceByMajorMinor(major, minor int64) api.Device { + for _, dev := range dm.devices { + dma, dmi := dev.GetMajorMinor() + if dma == major && dmi == minor { + return dev + } + } + return nil +} + // createDevice creates one device based on DeviceInfo -func (dm *deviceManager) createDevice(devInfo config.DeviceInfo) (api.Device, error) { +func (dm *deviceManager) createDevice(devInfo config.DeviceInfo) (dev api.Device, err error) { path, err := config.GetHostPathFunc(devInfo) if err != nil { return nil, err } devInfo.HostPath = path + defer func() { + if err == nil { + dev.Reference() + } + }() + + if existingDev := dm.findDeviceByMajorMinor(devInfo.Major, devInfo.Minor); existingDev != nil { + return existingDev, nil + } + // device ID must be generated by manager instead of device itself // in case of ID collision if devInfo.ID, err = dm.newDeviceID(); err != nil { @@ -108,10 +129,17 @@ func (dm *deviceManager) NewDevice(devInfo config.DeviceInfo) (api.Device, error func (dm *deviceManager) RemoveDevice(id string) error { dm.Lock() defer dm.Unlock() - if _, ok := dm.devices[id]; !ok { + dev, ok := dm.devices[id] + if !ok { return ErrDeviceNotExist } - delete(dm.devices, id) + + if dev.Dereference() == 0 { + if dev.GetAttachCount() > 0 { + return ErrRemoveAttachedDevice + } + delete(dm.devices, id) + } return nil } @@ -141,10 +169,6 @@ func (dm *deviceManager) AttachDevice(id string, dr api.DeviceReceiver) error { return ErrDeviceNotExist } - if d.IsAttached() { - return ErrDeviceAttached - } - if err := d.Attach(dr); err != nil { return err } @@ -159,7 +183,7 @@ func (dm *deviceManager) DetachDevice(id string, dr api.DeviceReceiver) error { if !ok { return ErrDeviceNotExist } - if !d.IsAttached() { + if d.GetAttachCount() <= 0 { return ErrDeviceNotAttached } @@ -168,6 +192,7 @@ func (dm *deviceManager) DetachDevice(id string, dr api.DeviceReceiver) error { } return nil } + func (dm *deviceManager) GetDeviceByID(id string) api.Device { dm.RLock() defer dm.RUnlock() @@ -194,5 +219,5 @@ func (dm *deviceManager) IsDeviceAttached(id string) bool { if !ok { return false } - return d.IsAttached() + return d.GetAttachCount() > 0 } diff --git a/virtcontainers/device/manager/manager_test.go b/virtcontainers/device/manager/manager_test.go index 2818820109..b99dd992b4 100644 --- a/virtcontainers/device/manager/manager_test.go +++ b/virtcontainers/device/manager/manager_test.go @@ -216,13 +216,18 @@ func TestAttachDetachDevice(t *testing.T) { device, err := dm.NewDevice(deviceInfo) assert.Nil(t, err) + // attach non-exist device + err = dm.AttachDevice("non-exist", devReceiver) + assert.NotNil(t, err) + // attach device err = dm.AttachDevice(device.DeviceID(), devReceiver) assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(1), "attach device count should be 1") // attach device again(twice) err = dm.AttachDevice(device.DeviceID(), devReceiver) - assert.NotNil(t, err) - assert.Equal(t, err, ErrDeviceAttached, "attach device twice should report error %q", ErrDeviceAttached) + assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(2), "attach device count should be 2") attached := dm.IsDeviceAttached(device.DeviceID()) assert.True(t, attached) @@ -230,12 +235,20 @@ func TestAttachDetachDevice(t *testing.T) { // detach device err = dm.DetachDevice(device.DeviceID(), devReceiver) assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(1), "attach device count should be 1") // detach device again(twice) err = dm.DetachDevice(device.DeviceID(), devReceiver) + assert.Nil(t, err) + assert.Equal(t, device.GetAttachCount(), uint(0), "attach device count should be 0") + // detach device again should report error + err = dm.DetachDevice(device.DeviceID(), devReceiver) assert.NotNil(t, err) - assert.Equal(t, err, ErrDeviceNotAttached, "attach device twice should report error %q", ErrDeviceNotAttached) + assert.Equal(t, err, ErrDeviceNotAttached, "") + assert.Equal(t, device.GetAttachCount(), uint(0), "attach device count should be 0") attached = dm.IsDeviceAttached(device.DeviceID()) assert.False(t, attached) + err = dm.RemoveDevice(device.DeviceID()) + assert.Nil(t, err) } diff --git a/virtcontainers/kata_agent.go b/virtcontainers/kata_agent.go index 761155dca3..4d70fe1a1b 100644 --- a/virtcontainers/kata_agent.go +++ b/virtcontainers/kata_agent.go @@ -969,7 +969,7 @@ func (k *kataAgent) handleBlockVolumes(c *Container) []*grpc.Storage { // Add the block device to the list of container devices, to make sure the // device is detached with detachDevices() for a container. - c.devices = append(c.devices, ContainerDevice{ID: id}) + c.devices = append(c.devices, ContainerDevice{ID: id, ContainerPath: m.Destination}) if err := c.storeDevices(); err != nil { k.Logger().WithField("device", id).WithError(err).Error("store device failed") return nil diff --git a/virtcontainers/kata_agent_test.go b/virtcontainers/kata_agent_test.go index 56532a3f34..8e381ad5af 100644 --- a/virtcontainers/kata_agent_test.go +++ b/virtcontainers/kata_agent_test.go @@ -382,7 +382,9 @@ func TestAppendDevices(t *testing.T) { id := "test-append-block" ctrDevices := []api.Device{ &drivers.BlockDevice{ - ID: id, + GenericDevice: &drivers.GenericDevice{ + ID: id, + }, BlockDrive: &config.BlockDrive{ PCIAddr: testPCIAddr, }, diff --git a/virtcontainers/qemu_arch_base_test.go b/virtcontainers/qemu_arch_base_test.go index 3465f8cdad..601341fbb4 100644 --- a/virtcontainers/qemu_arch_base_test.go +++ b/virtcontainers/qemu_arch_base_test.go @@ -391,7 +391,7 @@ func TestQemuArchBaseAppendVhostUserDevice(t *testing.T) { Type: config.VhostUserNet, MacAddress: macAddress, } - vhostUserDevice.ID = id + vhostUserDevice.DevID = id vhostUserDevice.SocketPath = socketPath testQemuArchBaseAppend(t, vhostUserDevice, expectedOut)