Skip to content
This repository has been archived by the owner on May 12, 2021. It is now read-only.

Handle PCI paths consistently and more generally #855

Merged
merged 4 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 53 additions & 51 deletions device.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,13 @@ const (
)

var (
sysBusPrefix = sysfsDir + "/bus/pci/devices"
pciBusRescanFile = sysfsDir + "/bus/pci/rescan"
pciBusPathFormat = "%s/%s/pci_bus/"
systemDevPath = "/dev"
getSCSIDevPath = getSCSIDevPathImpl
getPmemDevPath = getPmemDevPathImpl
getPCIDeviceName = getPCIDeviceNameImpl
getDevicePCIAddress = getDevicePCIAddressImpl
scanSCSIBus = scanSCSIBusImpl
pciBusRescanFile = sysfsDir + "/bus/pci/rescan"
systemDevPath = "/dev"
getSCSIDevPath = getSCSIDevPathImpl
getPmemDevPath = getPmemDevPathImpl
getPCIDeviceName = getPCIDeviceNameImpl
pciPathToSysfs = pciPathToSysfsImpl
scanSCSIBus = scanSCSIBusImpl
)

// CCW variables
Expand Down Expand Up @@ -79,6 +77,18 @@ type devIndexEntry struct {
}
type devIndex map[string]devIndexEntry

// Guest-side PCI path, identifies a PCI device by where it sits in
// the PCI topology.
//
// Has the format "xx/.../yy/zz" Here, zz is the slot of the device on
// its PCI bridge, yy is the slot of the bridge on its parent bridge
// and so forth until xx is the slot of the "most upstream" bridge on
// the root bus. If a device is connected directly to the root bus,
// its pciPath is just "zz"
type PciPath struct {
path string
}

type deviceHandler func(ctx context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error

var deviceHandlerList = map[string]deviceHandler{
Expand All @@ -93,47 +103,42 @@ func rescanPciBus() error {
return ioutil.WriteFile(pciBusRescanFile, []byte{'1'}, pciBusMode)
}

// getDevicePCIAddress fetches the complete PCI address in sysfs, based on the PCI
// identifier provided. This should be in the format: "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
func getDevicePCIAddressImpl(pciID string) (string, error) {
tokens := strings.Split(pciID, "/")

if len(tokens) != 2 {
return "", fmt.Errorf("PCI Identifier for device should be of format [bridgeAddr/deviceAddr], got %s", pciID)
}
// pciPathToSysfs fetches the sysfs path for a PCI path, relative to
// the syfs path for the PCI host bridge, based on the PCI path
// provided.
func pciPathToSysfsImpl(pciPath PciPath) (string, error) {
var relPath string
bus := "0000:00"

bridgeID := tokens[0]
deviceID := tokens[1]
tokens := strings.Split(pciPath.path, "/")

// Deduce the complete bridge address based on the bridge address identifier passed
// and the fact that bridges are attached on the main bus with function 0.
pciBridgeAddr := fmt.Sprintf("0000:00:%s.0", bridgeID)
for i, slot := range tokens {
// Full PCI address of this device along the path
bdf := fmt.Sprintf("%s:%s.0", bus, slot)

// Find out the bus exposed by bridge
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, pciBridgeAddr)
relPath = filepath.Join(relPath, bdf)

files, err := ioutil.ReadDir(bridgeBusPath)
if err != nil {
return "", fmt.Errorf("Error with getting bridge pci bus : %s", err)
}
if i == len(tokens)-1 {
// Final device need not be a bridge
break
}

busNum := len(files)
if busNum != 1 {
return "", fmt.Errorf("Expected an entry for bus in %s, got %d entries instead", bridgeBusPath, busNum)
}
// Find out the bus exposed by bridge
bridgeBusPath := filepath.Join(sysfsDir, rootBusPath, relPath, "pci_bus")

bus := files[0].Name()
files, err := ioutil.ReadDir(bridgeBusPath)
if err != nil {
return "", fmt.Errorf("Error reading %s : %s", bridgeBusPath, err)
}

// Device address is based on the bus of the bridge to which it is attached.
// We do not pass devices as multifunction, hence the trailing 0 in the address.
pciDeviceAddr := fmt.Sprintf("%s:%s.0", bus, deviceID)
if len(files) != 1 {
return "", fmt.Errorf("Expected exactly one PCI bus in %s, got %d instead", bridgeBusPath, len(files))
}

bridgeDevicePCIAddr := fmt.Sprintf("%s/%s", pciBridgeAddr, pciDeviceAddr)
agentLog.WithField("completePCIAddr", bridgeDevicePCIAddr).Info("Fetched PCI address for device")
bus = files[0].Name()
}

return bridgeDevicePCIAddr, nil
return relPath, nil
}

func getDeviceName(s *sandbox, devID string) (string, error) {
Expand Down Expand Up @@ -181,21 +186,21 @@ func getDeviceName(s *sandbox, devID string) (string, error) {
return filepath.Join(systemDevPath, devName), nil
}

func getPCIDeviceNameImpl(s *sandbox, pciID string) (string, error) {
pciAddr, err := getDevicePCIAddress(pciID)
func getPCIDeviceNameImpl(s *sandbox, pciPath PciPath) (string, error) {
sysfsRelPath, err := pciPathToSysfs(pciPath)
if err != nil {
return "", err
}

fieldLogger := agentLog.WithField("pciAddr", pciAddr)
fieldLogger := agentLog.WithField("sysfsRelPath", sysfsRelPath)

// Rescan pci bus if we need to wait for a new pci device
if err = rescanPciBus(); err != nil {
fieldLogger.WithError(err).Error("Failed to scan pci bus")
return "", err
}

return getDeviceName(s, pciAddr)
return getDeviceName(s, sysfsRelPath)
}

// device.Id should be the predicted device name (vda, vdb, ...)
Expand Down Expand Up @@ -223,14 +228,11 @@ func virtioBlkCCWDeviceHandler(ctx context.Context, device pb.Device, spec *pb.S
return updateSpecDeviceList(device, spec, devIdx)
}

// device.Id should be the PCI address in the format "bridgeAddr/deviceAddr".
// Here, bridgeAddr is the address at which the brige is attached on the root bus,
// while deviceAddr is the address at which the device is attached on the bridge.
// device.Id should be a PCI path (see type PciPath)
func virtioBlkDeviceHandler(_ context.Context, device pb.Device, spec *pb.Spec, s *sandbox, devIdx devIndex) error {
// When "Id (PCIAddr)" is not set, we allow to use the predicted "VmPath" passed from kata-runtime
// When "Id" (PCI path) is not set, we allow to use the predicted "VmPath" passed from kata-runtime
if device.Id != "" {
// Get the device node path based on the PCI device address
devPath, err := getPCIDeviceName(s, device.Id)
devPath, err := getPCIDeviceName(s, PciPath{device.Id})
if err != nil {
return err
}
Expand Down
98 changes: 66 additions & 32 deletions device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func testVirtioBlkDeviceHandlerFailure(t *testing.T, device pb.Device, spec *pb.
assert.NotNil(t, err, "blockDeviceHandler() should have failed")

savedFunc := getPCIDeviceName
getPCIDeviceName = func(s *sandbox, pciID string) (string, error) {
getPCIDeviceName = func(s *sandbox, pciPath PciPath) (string, error) {
return "foo", nil
}

Expand Down Expand Up @@ -92,47 +92,81 @@ func TestVirtioBlkDeviceHandlerEmptyLinuxDevicesSpecFailure(t *testing.T) {
testVirtioBlkDeviceHandlerFailure(t, device, spec)
}

func TestGetPCIAddress(t *testing.T) {
func TestPciPathToSysfs(t *testing.T) {
testDir, err := ioutil.TempDir("", "kata-agent-tmp-")
if err != nil {
t.Fatal(t, err)
}
defer os.RemoveAll(testDir)

pciID := "02"
_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)
// Set sysfsDir to test directory for unit tests.
sysfsDir = testDir
rootBus := filepath.Join(sysfsDir, rootBusPath)
err = os.MkdirAll(rootBus, mountPerm)
assert.NoError(t, err)

pciID = "02/03/04"
_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)
sysRelPath, err := pciPathToSysfs(PciPath{"02"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0")

bridgeID := "02"
deviceID := "03"
pciBus := "0000:01"
expectedPCIAddress := "0000:00:02.0/0000:01:03.0"
pciID = fmt.Sprintf("%s/%s", bridgeID, deviceID)
_, err = pciPathToSysfs(PciPath{"02/03"})
assert.Error(t, err)

// Set sysBusPrefix to test directory for unit tests.
sysBusPrefix = testDir
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, "0000:00:02.0")
_, err = pciPathToSysfs(PciPath{"02/03/04"})
assert.Error(t, err)

_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)
// Create mock sysfs files for the device at 0000:00:02.0
bridge2Path := filepath.Join(rootBus, "0000:00:02.0")

err = os.MkdirAll(bridgeBusPath, mountPerm)
assert.Nil(t, err)
err = os.MkdirAll(bridge2Path, mountPerm)
assert.NoError(t, err)

_, err = getDevicePCIAddress(pciID)
assert.NotNil(t, err)
sysRelPath, err = pciPathToSysfs(PciPath{"02"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0")

err = os.MkdirAll(filepath.Join(bridgeBusPath, pciBus), mountPerm)
assert.Nil(t, err)
_, err = pciPathToSysfs(PciPath{"02/03"})
assert.Error(t, err)

addr, err := getDevicePCIAddress(pciID)
assert.Nil(t, err)
_, err = pciPathToSysfs(PciPath{"02/03/04"})
assert.Error(t, err)

// Create mock sysfs files to indicate that 0000:00:02.0 is a bridge to bus 01
bridge2Bus := "0000:01"
err = os.MkdirAll(filepath.Join(bridge2Path, "pci_bus", bridge2Bus), mountPerm)
assert.NoError(t, err)

assert.Equal(t, addr, expectedPCIAddress)
sysRelPath, err = pciPathToSysfs(PciPath{"02"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0")

sysRelPath, err = pciPathToSysfs(PciPath{"02/03"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0/0000:01:03.0")

_, err = pciPathToSysfs(PciPath{"02/03/04"})
assert.Error(t, err)

// Create mock sysfs files for a bridge at 0000:01:03.0 to bus 02
bridge3Path := filepath.Join(bridge2Path, "0000:01:03.0")
bridge3Bus := "0000:02"
err = os.MkdirAll(filepath.Join(bridge3Path, "pci_bus", bridge3Bus), mountPerm)
assert.NoError(t, err)

err = os.MkdirAll(bridge3Path, mountPerm)
assert.NoError(t, err)

sysRelPath, err = pciPathToSysfs(PciPath{"02"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0")

sysRelPath, err = pciPathToSysfs(PciPath{"02/03"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0/0000:01:03.0")

sysRelPath, err = pciPathToSysfs(PciPath{"02/03/04"})
assert.NoError(t, err)
assert.Equal(t, sysRelPath, "0000:00:02.0/0000:01:03.0/0000:02:04.0")
}

func TestScanSCSIBus(t *testing.T) {
Expand Down Expand Up @@ -804,27 +838,27 @@ func TestGetPCIDeviceName(t *testing.T) {

sysfsDir = testSysfsDir

savedFunc := getDevicePCIAddress
savedFunc := pciPathToSysfs
defer func() {
getDevicePCIAddress = savedFunc
pciPathToSysfs = savedFunc
}()

getDevicePCIAddress = func(pciID string) (string, error) {
pciPathToSysfs = func(pciPath PciPath) (string, error) {
return "", nil
}

sb := sandbox{
deviceWatchers: make(map[string](chan string)),
}

_, err = getPCIDeviceNameImpl(&sb, "")
_, err = getPCIDeviceNameImpl(&sb, PciPath{""})
assert.Error(err)

rescanDir := filepath.Dir(pciBusRescanFile)
err = os.MkdirAll(rescanDir, testDirMode)
assert.NoError(err)

_, err = getPCIDeviceNameImpl(&sb, "")
_, err = getPCIDeviceNameImpl(&sb, PciPath{""})
assert.Error(err)
}

Expand Down
6 changes: 3 additions & 3 deletions mount.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,8 @@ func virtioFSStorageHandler(_ context.Context, storage pb.Storage, s *sandbox) (
// virtioBlkStorageHandler handles the storage for blk driver.
func virtioBlkStorageHandler(_ context.Context, storage pb.Storage, s *sandbox) (string, error) {

// If hot-plugged, get the device node path based on the PCI address else
// use the virt path provided in Storage Source
// If hot-plugged, get the device node path based on the PCI
// path else use the virt path provided in Storage Source
if strings.HasPrefix(storage.Source, "/dev") {

FileInfo, err := os.Stat(storage.Source)
Expand All @@ -312,7 +312,7 @@ func virtioBlkStorageHandler(_ context.Context, storage pb.Storage, s *sandbox)
}

} else {
devPath, err := getPCIDeviceName(s, storage.Source)
devPath, err := getPCIDeviceName(s, PciPath{storage.Source})
if err != nil {
return "", err
}
Expand Down
14 changes: 7 additions & 7 deletions mount_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,13 @@ func TestVirtioBlkStorageHandlerSuccessful(t *testing.T) {
bridgeID := "02"
deviceID := "03"
pciBus := "0000:01"
completePCIAddr := fmt.Sprintf("0000:00:%s.0/%s:%s.0", bridgeID, pciBus, deviceID)
sysRelPath := fmt.Sprintf("0000:00:%s.0/%s:%s.0", bridgeID, pciBus, deviceID)

pciID := fmt.Sprintf("%s/%s", bridgeID, deviceID)

sysBusPrefix = testDir
bridgeBusPath := fmt.Sprintf(pciBusPathFormat, sysBusPrefix, "0000:00:02.0")
pciPath := fmt.Sprintf("%s/%s", bridgeID, deviceID)

// Set sysfsDir to test directory for unit tests.
sysfsDir = testDir
bridgeBusPath := filepath.Join(sysfsDir, rootBusPath, "0000:00:02.0", "pci_bus")
err = os.MkdirAll(filepath.Join(bridgeBusPath, pciBus), mountPerm)
assert.Nil(t, err)

Expand All @@ -242,7 +242,7 @@ func TestVirtioBlkStorageHandlerSuccessful(t *testing.T) {
defer os.RemoveAll(dirPath)

storage := pb.Storage{
Source: pciID,
Source: pciPath,
MountPoint: filepath.Join(dirPath, "test-mount"),
}
defer syscall.Unmount(storage.MountPoint, 0)
Expand All @@ -252,7 +252,7 @@ func TestVirtioBlkStorageHandlerSuccessful(t *testing.T) {
}

s.Lock()
s.sysToDevMap[completePCIAddr] = devPath
s.sysToDevMap[sysRelPath] = devPath
s.Unlock()

storage.Fstype = "bind"
Expand Down
11 changes: 5 additions & 6 deletions network.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,13 @@ func (s *sandbox) updateInterface(netHandle *netlink.Handle, iface *types.Interf
fieldLogger := agentLog.WithFields(logrus.Fields{
"mac-address": iface.HwAddr,
"interface-name": iface.Device,
"pci-address": iface.PciAddr,
"pci-path": iface.PciPath,
})

// If the PCI address of the network device is provided, wait/check for the device
// to be available first
if iface.PciAddr != "" {
// iface.PciAddr is in the format bridgeAddr/deviceAddr eg. 05/06
_, err := getPCIDeviceName(s, iface.PciAddr)
// If the PCI path of the network device is provided,
// wait/check for the device to be available first
if iface.PciPath != "" {
_, err := getPCIDeviceName(s, PciPath{iface.PciPath})
if err != nil {
return nil, err
}
Expand Down
Loading