diff --git a/device.go b/device.go index 01ee559602..fdcac39116 100644 --- a/device.go +++ b/device.go @@ -71,7 +71,15 @@ var ( scsiHostPath = filepath.Join(sysClassPrefix, "scsi_host") ) -type deviceHandler func(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error +// Stores a mapping of device names (in host / outer container naming) +// to the device and resources slots in a container spec +type devIndexEntry struct { + idx int + resourceIdx []int +} +type devIndex map[string]devIndexEntry + +type deviceHandler func(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error var deviceHandlerList = map[string]deviceHandler{ driverMmioBlkType: virtioMmioBlkDeviceHandler, @@ -192,15 +200,15 @@ func getPCIDeviceNameImpl(s *sandbox, pciID string) (string, error) { // device.Id should be the predicted device name (vda, vdb, ...) // device.VmPath already provides a way to send it in -func virtioMmioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { +func virtioMmioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { if device.VmPath == "" { return fmt.Errorf("Invalid path for virtioMmioBlkDevice") } - return updateSpecDeviceList(device, spec) + return updateSpecDeviceList(device, spec, devIdx) } -func virtioBlkCCWDeviceHandler(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { +func virtioBlkCCWDeviceHandler(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { devPath, err := getBlkCCWDevPath(s, device.Id) if err != nil { return err @@ -212,13 +220,13 @@ func virtioBlkCCWDeviceHandler(ctx context.Context, device pb.Device, spec *pb.S } device.VmPath = devPath - return updateSpecDeviceList(device, spec) + return updateSpecDeviceList(device, spec, devIdx) } // device.Id should be the PCI address in the format "bridgeAddr/deviceAddr". // Here, bridgeAddr is the address at which the brige is attached on the root bus, // while deviceAddr is the address at which the device is attached on the bridge. -func virtioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { +func virtioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { // When "Id (PCIAddr)" is not set, we allow to use the predicted "VmPath" passed from kata-runtime if device.Id != "" { // Get the device node path based on the PCI device address @@ -229,11 +237,11 @@ func virtioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, device.VmPath = devPath } - return updateSpecDeviceList(device, spec) + return updateSpecDeviceList(device, spec, devIdx) } // device.Id should be the SCSI address of the disk in the format "scsiID:lunID" -func virtioSCSIDeviceHandler(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { +func virtioSCSIDeviceHandler(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { // Retrieve the device path from SCSI address. devPath, err := getSCSIDevPath(s, device.Id) if err != nil { @@ -241,11 +249,11 @@ func virtioSCSIDeviceHandler(ctx context.Context, device pb.Device, spec *pb.Spe } device.VmPath = devPath - return updateSpecDeviceList(device, spec) + return updateSpecDeviceList(device, spec, devIdx) } -func nvdimmDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { - return updateSpecDeviceList(device, spec) +func nvdimmDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { + return updateSpecDeviceList(device, spec, devIdx) } // updateSpecDeviceList takes a device description provided by the caller, @@ -254,7 +262,7 @@ func nvdimmDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s * // the same device in the list of devices provided through the OCI spec. // This is needed to update information about minor/major numbers that cannot // be predicted from the caller. -func updateSpecDeviceList(device pb.Device, spec *pb.Spec) error { +func updateSpecDeviceList(device pb.Device, spec *pb.Spec, devIdx devIndex) error { // If no ContainerPath is provided, we won't be able to match and // update the device in the OCI spec device list. This is an error. if device.ContainerPath == "" { @@ -284,42 +292,32 @@ func updateSpecDeviceList(device pb.Device, spec *pb.Spec) error { }).Info("handling block device") // Update the spec - for idx, d := range spec.Linux.Devices { - if d.Path == device.ContainerPath { - hostMajor := spec.Linux.Devices[idx].Major - hostMinor := spec.Linux.Devices[idx].Minor - agentLog.WithFields(logrus.Fields{ - "device-path": device.VmPath, - "host-device-major": hostMajor, - "host-device-minor": hostMinor, - "guest-device-major": major, - "guest-device-minor": minor, - }).Info("updating block device major/minor into the spec") - - spec.Linux.Devices[idx].Major = major - spec.Linux.Devices[idx].Minor = minor - - // there is no resource to update - if spec.Linux == nil || spec.Linux.Resources == nil { - return nil - } + idxData, ok := devIdx[device.ContainerPath] + if !ok { + return grpcStatus.Errorf(codes.Internal, + "Should have found a matching device %s in the spec", + device.ContainerPath) + } - // Resources must be updated since they are used to identify the - // device in the devices cgroup. - for idxRsrc, dRsrc := range spec.Linux.Resources.Devices { - if dRsrc.Major == hostMajor && dRsrc.Minor == hostMinor { - spec.Linux.Resources.Devices[idxRsrc].Major = major - spec.Linux.Resources.Devices[idxRsrc].Minor = minor - } - } + agentLog.WithFields(logrus.Fields{ + "device-path": device.VmPath, + "host-device-major": spec.Linux.Devices[idxData.idx].Major, + "host-device-minor": spec.Linux.Devices[idxData.idx].Minor, + "guest-device-major": major, + "guest-device-minor": minor, + }).Info("updating block device major/minor into the spec") - return nil - } + spec.Linux.Devices[idxData.idx].Major = major + spec.Linux.Devices[idxData.idx].Minor = minor + + // Resources must be updated since they are used to identify the + // device in the devices cgroup. + for _, idxRsrc := range idxData.resourceIdx { + spec.Linux.Resources.Devices[idxRsrc].Major = major + spec.Linux.Resources.Devices[idxRsrc].Minor = minor } - return grpcStatus.Errorf(codes.Internal, - "Should have found a matching device %s in the spec", - device.VmPath) + return nil } // scanSCSIBus scans SCSI bus for the given SCSI address(SCSI-Id and LUN) @@ -419,12 +417,14 @@ func getBlkCCWDevPath(s *sandbox, bus string) (string, error) { } func addDevices(ctx context.Context, devices []*pb.Device, spec *pb.Spec, s *sandbox) error { + devIdx := makeDevIndex(spec) + for _, device := range devices { if device == nil { continue } - err := addDevice(ctx, device, spec, s) + err := addDevice(ctx, device, spec, s, devIdx) if err != nil { return err } @@ -434,7 +434,34 @@ func addDevices(ctx context.Context, devices []*pb.Device, spec *pb.Spec, s *san return nil } -func addDevice(ctx context.Context, device *pb.Device, spec *pb.Spec, s *sandbox) error { +func makeDevIndex(spec *pb.Spec) devIndex { + devIdx := make(devIndex) + + if spec == nil || spec.Linux == nil || spec.Linux.Devices == nil { + return devIdx + } + + for i, d := range spec.Linux.Devices { + rIdx := make([]int, 0) + + if spec.Linux.Resources != nil && spec.Linux.Resources.Devices != nil { + for j, r := range spec.Linux.Resources.Devices { + if r.Major == d.Major && r.Minor == d.Minor { + rIdx = append(rIdx, j) + } + } + } + + devIdx[d.Path] = devIndexEntry{ + idx: i, + resourceIdx: rIdx, + } + } + + return devIdx +} + +func addDevice(ctx context.Context, device *pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { if device == nil { return grpcStatus.Error(codes.InvalidArgument, "invalid device") } @@ -474,7 +501,7 @@ func addDevice(ctx context.Context, device *pb.Device, spec *pb.Spec, s *sandbox "Unknown device type %q", device.Type) } - return devHandler(ctx, *device, spec, s) + return devHandler(ctx, *device, spec, s, devIdx) } // updateDeviceCgroupForGuestRootfs updates the device cgroup for container diff --git a/device_test.go b/device_test.go index 72a9dac1d9..f6db672f6e 100644 --- a/device_test.go +++ b/device_test.go @@ -46,7 +46,8 @@ func testVirtioBlkDeviceHandlerFailure(t *testing.T, device pb.Device, spec *pb. ctx := context.Background() - err = virtioBlkDeviceHandler(ctx, device, spec, &sandbox{}) + devIdx := makeDevIndex(spec) + err = virtioBlkDeviceHandler(ctx, device, spec, &sandbox{}, devIdx) assert.NotNil(t, err, "blockDeviceHandler() should have failed") savedFunc := getPCIDeviceName @@ -58,7 +59,7 @@ func testVirtioBlkDeviceHandlerFailure(t *testing.T, device pb.Device, spec *pb. getPCIDeviceName = savedFunc }() - err = virtioBlkDeviceHandler(ctx, device, spec, &sandbox{}) + err = virtioBlkDeviceHandler(ctx, device, spec, &sandbox{}, devIdx) assert.Error(t, err) } @@ -212,11 +213,11 @@ func TestAddDevicesNilMountsSuccessful(t *testing.T) { testAddDevicesSuccessful(t, devices, spec) } -func noopDeviceHandlerReturnNil(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { +func noopDeviceHandlerReturnNil(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { return nil } -func noopDeviceHandlerReturnError(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox) error { +func noopDeviceHandlerReturnError(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error { return fmt.Errorf("Noop handler failure") } @@ -423,7 +424,8 @@ func TestAddDevice(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - err := addDevice(ctx, d.device, d.spec, s) + devIdx := makeDevIndex(d.spec) + err := addDevice(ctx, d.device, d.spec, s, devIdx) if d.expectError { assert.Error(err, msg) } else { @@ -439,24 +441,25 @@ func TestUpdateSpecDeviceList(t *testing.T) { var err error spec := &pb.Spec{} + devIdx := makeDevIndex(spec) device := pb.Device{} major := int64(7) minor := int64(2) //ContainerPath empty - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.Error(err) device.ContainerPath = "/dev/null" //Linux is nil - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.Error(err) spec.Linux = &pb.Linux{} /// Linux.Devices empty - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.Error(err) spec.Linux.Devices = []pb.LinuxDevice{ @@ -466,21 +469,23 @@ func TestUpdateSpecDeviceList(t *testing.T) { Minor: minor, }, } + devIdx = makeDevIndex(spec) // VmPath empty - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.Error(err) device.VmPath = "/dev/null" // guest and host path are not the same - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.Error(err) spec.Linux.Devices[0].Path = device.ContainerPath + devIdx = makeDevIndex(spec) // spec.Linux.Resources is nil - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.NoError(err) // update both devices and cgroup lists @@ -499,11 +504,117 @@ func TestUpdateSpecDeviceList(t *testing.T) { }, }, } + devIdx = makeDevIndex(spec) - err = updateSpecDeviceList(device, spec) + err = updateSpecDeviceList(device, spec, devIdx) assert.NoError(err) } +// Test handling in the case that one device has the same guest +// major:minor as a different device's host major:minor +func TestUpdateSpecDeviceListGuestHostConflict(t *testing.T) { + assert := assert.New(t) + + var nullStat, zeroStat, fullStat unix.Stat_t + + err := unix.Stat("/dev/null", &nullStat) + assert.NoError(err) + err = unix.Stat("/dev/zero", &zeroStat) + assert.NoError(err) + err = unix.Stat("/dev/full", &fullStat) + assert.NoError(err) + + hostMajorA := int64(unix.Major(nullStat.Rdev)) + hostMinorA := int64(unix.Minor(nullStat.Rdev)) + hostMajorB := int64(unix.Major(zeroStat.Rdev)) + hostMinorB := int64(unix.Minor(zeroStat.Rdev)) + + spec := &pb.Spec{ + Linux: &pb.Linux{ + Devices: []pb.LinuxDevice{ + { + Path: "/dev/a", + Type: "c", + Major: hostMajorA, + Minor: hostMinorA, + }, + { + Path: "/dev/b", + Type: "c", + Major: hostMajorB, + Minor: hostMinorB, + }, + }, + Resources: &pb.LinuxResources{ + Devices: []pb.LinuxDeviceCgroup{ + { + Type: "c", + Major: hostMajorA, + Minor: hostMinorA, + }, + { + Type: "c", + Major: hostMajorB, + Minor: hostMinorB, + }, + }, + }, + }, + } + + devA := pb.Device{ + ContainerPath: "/dev/a", + VmPath: "/dev/zero", + } + guestMajorA := int64(unix.Major(zeroStat.Rdev)) + guestMinorA := int64(unix.Minor(zeroStat.Rdev)) + + devB := pb.Device{ + ContainerPath: "/dev/b", + VmPath: "/dev/full", + } + guestMajorB := int64(unix.Major(fullStat.Rdev)) + guestMinorB := int64(unix.Minor(fullStat.Rdev)) + + devIdx := makeDevIndex(spec) + + assert.Equal(hostMajorA, spec.Linux.Devices[0].Major) + assert.Equal(hostMinorA, spec.Linux.Devices[0].Minor) + assert.Equal(hostMajorB, spec.Linux.Devices[1].Major) + assert.Equal(hostMinorB, spec.Linux.Devices[1].Minor) + + assert.Equal(hostMajorA, spec.Linux.Resources.Devices[0].Major) + assert.Equal(hostMinorA, spec.Linux.Resources.Devices[0].Minor) + assert.Equal(hostMajorB, spec.Linux.Resources.Devices[1].Major) + assert.Equal(hostMinorB, spec.Linux.Resources.Devices[1].Minor) + + err = updateSpecDeviceList(devA, spec, devIdx) + assert.NoError(err) + + assert.Equal(guestMajorA, spec.Linux.Devices[0].Major) + assert.Equal(guestMinorA, spec.Linux.Devices[0].Minor) + assert.Equal(hostMajorB, spec.Linux.Devices[1].Major) + assert.Equal(hostMinorB, spec.Linux.Devices[1].Minor) + + assert.Equal(guestMajorA, spec.Linux.Resources.Devices[0].Major) + assert.Equal(guestMinorA, spec.Linux.Resources.Devices[0].Minor) + assert.Equal(hostMajorB, spec.Linux.Resources.Devices[1].Major) + assert.Equal(hostMinorB, spec.Linux.Resources.Devices[1].Minor) + + err = updateSpecDeviceList(devB, spec, devIdx) + assert.NoError(err) + + assert.Equal(guestMajorA, spec.Linux.Devices[0].Major) + assert.Equal(guestMinorA, spec.Linux.Devices[0].Minor) + assert.Equal(guestMajorB, spec.Linux.Devices[1].Major) + assert.Equal(guestMinorB, spec.Linux.Devices[1].Minor) + + assert.Equal(guestMajorA, spec.Linux.Resources.Devices[0].Major) + assert.Equal(guestMinorA, spec.Linux.Resources.Devices[0].Minor) + assert.Equal(guestMajorB, spec.Linux.Resources.Devices[1].Major) + assert.Equal(guestMinorB, spec.Linux.Resources.Devices[1].Minor) +} + func TestRescanPciBus(t *testing.T) { skipUnlessRoot(t) @@ -548,17 +659,18 @@ func TestVirtioMmioBlkDeviceHandler(t *testing.T) { device := pb.Device{} spec := &pb.Spec{} + devIdx := makeDevIndex(spec) sb := &sandbox{} ctx := context.Background() - err := virtioMmioBlkDeviceHandler(ctx, device, spec, sb) + err := virtioMmioBlkDeviceHandler(ctx, device, spec, sb, devIdx) assert.Error(err) device.VmPath = "foo" device.ContainerPath = "" - err = virtioMmioBlkDeviceHandler(ctx, device, spec, sb) + err = virtioMmioBlkDeviceHandler(ctx, device, spec, sb, devIdx) assert.Error(err) } @@ -567,11 +679,12 @@ func TestVirtioSCSIDeviceHandler(t *testing.T) { device := pb.Device{} spec := &pb.Spec{} + devIdx := makeDevIndex(spec) sb := &sandbox{} ctx, cancel := context.WithCancel(context.Background()) - err := virtioSCSIDeviceHandler(ctx, device, spec, sb) + err := virtioSCSIDeviceHandler(ctx, device, spec, sb, devIdx) assert.Error(err) cancel() @@ -586,7 +699,7 @@ func TestVirtioSCSIDeviceHandler(t *testing.T) { ctx, cancel = context.WithCancel(context.Background()) - err = virtioSCSIDeviceHandler(ctx, device, spec, sb) + err = virtioSCSIDeviceHandler(ctx, device, spec, sb, devIdx) assert.Error(err) cancel() } @@ -596,11 +709,12 @@ func TestNvdimmDeviceHandler(t *testing.T) { device := pb.Device{} spec := &pb.Spec{} + devIdx := makeDevIndex(spec) sb := &sandbox{} ctx := context.Background() - err := nvdimmDeviceHandler(ctx, device, spec, sb) + err := nvdimmDeviceHandler(ctx, device, spec, sb, devIdx) assert.Error(err) }