diff --git a/agent.go b/agent.go index 7e2a86df48..95aae80ccd 100644 --- a/agent.go +++ b/agent.go @@ -156,6 +156,22 @@ var collatedTrace = false // if true, coredump when an internal error occurs or a fatal signal is received var crashOnError = false +// commType is used to denote the communication channel type used. +type commType int + +const ( + // virtio-serial channel + serialCh commType = iota + + // vsock channel + vsockCh + + // channel type not passed explicitly + unknownCh +) + +var commCh = unknownCh + // This is the list of file descriptors we can properly close after the process // has been started. When the new process is exec(), those file descriptors are // duplicated and it is our responsibility to close them since we have opened diff --git a/channel.go b/channel.go index f39ef50873..47d107c6e1 100644 --- a/channel.go +++ b/channel.go @@ -50,24 +50,29 @@ func newChannel(ctx context.Context) (channel, error) { defer span.Finish() var serialErr error - var serialPath string var vsockErr error - var vSockSupported bool + var ch channel for i := 0; i < channelExistMaxTries; i++ { - // check vsock path - if _, err := os.Stat(vSockDevPath); err == nil { - if vSockSupported, vsockErr = isAFVSockSupportedFunc(); vSockSupported && vsockErr == nil { - span.SetTag("channel-type", "vsock") - return &vSockChannel{}, nil + switch commCh { + case serialCh: + if ch, serialErr = checkForSerialChannel(ctx); serialErr == nil && ch.(*serialChannel) != nil { + return ch, nil + } + case vsockCh: + if ch, vsockErr = checkForVsockChannel(ctx); vsockErr == nil && ch.(*vSockChannel) != nil { + return ch, nil } - } - // Check serial port path - if serialPath, serialErr = findVirtualSerialPath(serialChannelName); serialErr == nil { - span.SetTag("channel-type", "serial") - span.SetTag("serial-path", serialPath) - return &serialChannel{serialPath: serialPath}, nil + case unknownCh: + // If we have not been explicitly passed if vsock is used or not, maybe due to + // an older runtime, try to check for vsock support. + if ch, vsockErr = checkForVsockChannel(ctx); vsockErr == nil && ch.(*vSockChannel) != nil { + return ch, nil + } + if ch, serialErr = checkForSerialChannel(ctx); serialErr == nil && ch.(*serialChannel) != nil { + return ch, nil + } } time.Sleep(channelExistWaitTime) @@ -84,6 +89,41 @@ func newChannel(ctx context.Context) (channel, error) { return nil, fmt.Errorf("Neither vsocks nor serial ports were found") } +func checkForSerialChannel(ctx context.Context) (*serialChannel, error) { + span, _ := trace(ctx, "channel", "checkForSerialChannel") + defer span.Finish() + + // Check serial port path + serialPath, serialErr := findVirtualSerialPath(serialChannelName) + if serialErr == nil { + span.SetTag("channel-type", "serial") + span.SetTag("serial-path", serialPath) + agentLog.Debug("Serial channel type detected") + return &serialChannel{serialPath: serialPath}, nil + } + + return nil, serialErr +} + +func checkForVsockChannel(ctx context.Context) (*vSockChannel, error) { + span, _ := trace(ctx, "channel", "checkForVsockChannel") + defer span.Finish() + + // check vsock path + if _, err := os.Stat(vSockDevPath); err != nil { + return nil, err + } + + vSockSupported, vsockErr := isAFVSockSupportedFunc() + if vSockSupported && vsockErr == nil { + span.SetTag("channel-type", "vsock") + agentLog.Debug("Vsock channel type detected") + return &vSockChannel{}, nil + } + + return nil, fmt.Errorf("Vsock not found : %s", vsockErr) +} + type vSockChannel struct { } @@ -228,23 +268,51 @@ func (c *serialChannel) teardown() error { return c.serialConn.Close() } +// isAFVSockSupported checks if vsock channel is used by the runtime +// by checking for devices under the vhost-vsock driver path. +// It returns true if a device is found for the vhost-vsock driver. func isAFVSockSupported() (bool, error) { - fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0) - if err != nil { - // This case is valid. It means AF_VSOCK is not a supported - // domain on this system. - if err == unix.EAFNOSUPPORT { - return false, nil - } + // Driver path for virtio-vsock + sysVsockPath := "/sys/bus/virtio/drivers/vmw_vsock_virtio_transport/" + + files, err := ioutil.ReadDir(sysVsockPath) + // This should not happen for a hypervisor with vsock driver + if err != nil { return false, err } - if err := unix.Close(fd); err != nil { - return true, err + // standard driver files that should be ignored + driverFiles := []string{"bind", "uevent", "unbind"} + + for _, file := range files { + for _, f := range driverFiles { + if file.Name() == f { + continue + } + } + + fPath := filepath.Join(sysVsockPath, file.Name()) + fInfo, err := os.Lstat(fPath) + if err != nil { + return false, err + } + + if fInfo.Mode()&os.ModeSymlink == 0 { + continue + } + + link, err := os.Readlink(fPath) + if err != nil { + return false, err + } + + if strings.Contains(link, "devices") { + return true, nil + } } - return true, nil + return false, nil } func findVirtualSerialPath(serialName string) (string, error) { diff --git a/config.go b/config.go index 0f31e6ea1a..8d8ea9a691 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,7 @@ package main import ( "io/ioutil" + "strconv" "strings" "github.com/sirupsen/logrus" @@ -20,6 +21,7 @@ const ( logLevelFlag = optionPrefix + "log" devModeFlag = optionPrefix + "devmode" traceModeFlag = optionPrefix + "trace" + useVsockFlag = optionPrefix + "use_vsock" kernelCmdlineFile = "/proc/cmdline" traceValueIsolated = "isolated" traceValueCollated = "collated" @@ -102,6 +104,18 @@ func (c *agentConfig) parseCmdlineOption(option string) error { case traceValueCollated: enableTracing(true) } + case useVsockFlag: + flag, err := strconv.ParseBool(split[valuePosition]) + if err != nil { + return err + } + if flag { + agentLog.Debug("Param passed to use vsock channel") + commCh = vsockCh + } else { + agentLog.Debug("Param passed to NOT use vsock channel") + commCh = serialCh + } default: if strings.HasPrefix(split[optionPosition], optionPrefix) { return grpcStatus.Errorf(codes.NotFound, "Unknown option %s", split[optionPosition]) diff --git a/config_test.go b/config_test.go index ae31c19cfd..42c936fdd7 100644 --- a/config_test.go +++ b/config_test.go @@ -285,3 +285,46 @@ func TestEnableTracing(t *testing.T) { } } } + +func TestParseCmdlineOptionWrongOptionVsock(t *testing.T) { + t.Skip() + assert := assert.New(t) + + a := &agentConfig{} + + wrongOption := "use_vsockkk=true" + + err := a.parseCmdlineOption(wrongOption) + assert.Errorf(err, "Parsing should fail because wrong option %q", wrongOption) +} + +func TestParseCmdlineOptionsVsock(t *testing.T) { + assert := assert.New(t) + + a := &agentConfig{} + + type testData struct { + val string + shouldErr bool + expectedCommCh commType + } + + data := []testData{ + {"true", false, vsockCh}, + {"false", false, serialCh}, + {"blah", true, unknownCh}, + } + + for _, d := range data { + commCh = unknownCh + option := useVsockFlag + "=" + d.val + + err := a.parseCmdlineOption(option) + if d.shouldErr { + assert.Error(err) + } else { + assert.NoError(err) + } + assert.Equal(commCh, d.expectedCommCh) + } +}