diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go index c2b17f3a..042f3cb1 100644 --- a/cmd/rigtest/rigtest.go +++ b/cmd/rigtest/rigtest.go @@ -20,7 +20,6 @@ import ( "github.com/k0sproject/rig/os" "github.com/k0sproject/rig/os/registry" _ "github.com/k0sproject/rig/os/support" - "github.com/kevinburke/ssh_config" "github.com/stretchr/testify/require" ) @@ -112,24 +111,6 @@ func main() { goos.Exit(1) } - if configPath := goos.Getenv("SSH_CONFIG"); configPath != "" { - f, err := goos.Open(configPath) - if err != nil { - panic(err) - } - cfg, err := ssh_config.Decode(f) - if err != nil { - panic(err) - } - rig.SSHConfigGetAll = func(dst, key string) []string { - res, err := cfg.GetAll(dst, key) - if err != nil { - return nil - } - return res - } - } - var passfunc func() (string, error) if *pc { passfunc = func() (string, error) { @@ -141,16 +122,16 @@ func main() { } var hosts []*Host + var port *int for _, address := range strings.Split(*dh, ",") { - port := 22 if addr, portstr, ok := strings.Cut(address, ":"); ok { address = addr p, err := strconv.Atoi(portstr) if err != nil { panic("invalid port " + portstr) } - port = p + port = &p } var h *Host @@ -172,7 +153,7 @@ func main() { Connection: rig.Connection{ WinRM: &rig.WinRM{ Address: *dh, - Port: port, + Port: *port, User: *usr, UseHTTPS: *https, Insecure: true, diff --git a/ssh.go b/ssh.go index 029eef83..40aa5604 100644 --- a/ssh.go +++ b/ssh.go @@ -28,12 +28,13 @@ import ( type SSH struct { Address string `yaml:"address" validate:"required,hostname|ip"` User string `yaml:"user" validate:"required" default:"root"` - Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` + Port *int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` KeyPath *string `yaml:"keyPath" validate:"omitempty"` HostKey string `yaml:"hostKey,omitempty"` Bastion *SSH `yaml:"bastion,omitempty"` PasswordCallback PasswordCallback `yaml:"-"` - name string + + connAddr string isWindows bool knowOs bool @@ -47,19 +48,23 @@ type SSH struct { // PasswordCallback is a function that is called when a passphrase is needed to decrypt a private key type PasswordCallback func() (secret string, err error) +type sshconfig interface { + GetAll(alias, key string) []string + Get(alias, key string) string +} + var ( - authMethodCache = sync.Map{} - defaultKeypaths = []string{"~/.ssh/id_rsa", "~/.ssh/identity", "~/.ssh/id_dsa"} - dummyhostKeyPaths []string - globalOnce sync.Once - knownHostsMU sync.Mutex + authMethodCache = sync.Map{} + defaultKeypaths = []string{} + globalOnce sync.Once + knownHostsMU sync.Mutex + + sshConfig sshconfig = ssh_config.DefaultUserSettings // ErrChecksumMismatch is returned when the checksum of an uploaded file does not match expectation ErrChecksumMismatch = errstring.New("checksum mismatch") ) -const hopefullyNonexistentHost = "thisH0stDoe5not3xist" - // returns the current user homedir, prefers $HOME env var func homeDir() (string, error) { if home, ok := os.LookupEnv("HOME"); ok { @@ -112,78 +117,140 @@ func expandAndValidatePath(path string) (string, error) { return path, nil } -func (c *SSH) keypathsFromConfig() []string { - log.Tracef("%s: trying to get a keyfile path from ssh config", c) - if idf := c.getConfigAll("IdentityFile"); len(idf) > 0 { - log.Tracef("%s: detected %d identity file paths from ssh config: %v", c, len(idf), idf) - return idf +func flattenPaths(paths []string) []string { + var out []string + for _, p := range paths { + pp, err := shlex.Split(p) + if err == nil { + out = append(out, pp...) + } } - log.Tracef("%s: no identity file paths found in ssh config", c) - return []string{} + return out } -func (c *SSH) initGlobalDefaults() { - log.Tracef("discovering global default keypaths") - dummyHostIdentityFiles := SSHConfigGetAll(hopefullyNonexistentHost, "IdentityFile") - for _, keyPath := range dummyHostIdentityFiles { - if expanded, err := expandAndValidatePath(keyPath); err != nil { - dummyhostKeyPaths = append(dummyhostKeyPaths, expanded) - } +func uniqStrings(elems []string) []string { + if len(elems) < 2 { + return elems + } + uniq := make(map[string]struct{}) + for _, e := range elems { + uniq[e] = struct{}{} + } + out := make([]string, 0, len(uniq)) + for k := range uniq { + out = append(out, k) + } + return out +} + +func initSSHDefaults() { + keyPaths := sshConfig.GetAll("*", "IdentityFile") + if len(keyPaths) > 0 { + defaultKeypaths = flattenPaths(keyPaths) } } -func findUniq(a, b []string) (string, bool) { - for _, s := range a { - found := false - for _, t := range b { - if s == t { +func (c *SSH) nonDefaultKeypaths() []string { + var keyPaths []string + for _, p := range c.keyPaths { + var found bool + for _, d := range defaultKeypaths { + if p == d { found = true break } } if !found { - return s, true + keyPaths = append(keyPaths, p) } } - return "", false + return keyPaths } -// SetDefaults sets various default values -func (c *SSH) SetDefaults() { - globalOnce.Do(c.initGlobalDefaults) - c.once.Do(func() { - if c.KeyPath != nil && *c.KeyPath != "" { - if expanded, err := expandAndValidatePath(*c.KeyPath); err == nil { - c.keyPaths = append(c.keyPaths, expanded) - } - // keypath is explicitly set, accept the fact even if it's invalid and - // don't try to find it from ssh config/defaults - return - } - c.KeyPath = nil +func intPtr(num int) *int { + return &num +} - paths := c.keypathsFromConfig() - if len(paths) == 0 { - // no paths found in ssh config either, use defaults - paths = append(paths, defaultKeypaths...) - } +func (c *SSH) setupAddress() { + if addr := sshConfig.Get(c.Address, "HostName"); addr != "" { + log.Debugf("%s: using hostname %s from ssh config as connection address", c.Address, addr) + c.connAddr = addr + return + } + c.connAddr = c.Address +} - for _, p := range paths { - expanded, err := expandAndValidatePath(p) - if err != nil { - log.Tracef("%s: %s: %v", c, p, err) - continue - } - log.Debugf("%s: using identity file %s", c, expanded) - c.keyPaths = append(c.keyPaths, expanded) - } +func (c *SSH) setupPort() { + if c.Port != nil { + return + } + + portS := sshConfig.Get(c.Address, "Port") + port, err := strconv.Atoi(portS) + if err == nil { + c.Port = intPtr(port) + log.Tracef("%s: using port %d from ssh config", c, c.Port) + return + } + c.Port = intPtr(22) + log.Tracef("%s: using default port", c) +} + +func (c *SSH) setupKeyPaths() { + if sshConfig.Get(c.Address, "PubkeyAuthentication") == "no" { + log.Infof("%s: public key based authentication disabled in ssh config", c) + c.keyPaths = nil + return + } + + // if keypath is set, use that + if c.KeyPath != nil && *c.KeyPath != "" { + c.keyPaths = []string{*c.KeyPath} + return + } + + log.Tracef("%s: trying to get a keyfile path from ssh config", c) + if paths := sshConfig.GetAll(c.Address, "IdentityFile"); len(paths) > 0 { + c.keyPaths = uniqStrings(flattenPaths(paths)) + log.Tracef("%s: detected %d identity file paths from ssh config: %v", c, len(c.keyPaths), c.keyPaths) + return + } +} - // check if all the paths that were found are global defaults - // errors are handled differently when a keypath is explicitly set vs when it's defaulted - if uniq, found := findUniq(c.keyPaths, dummyhostKeyPaths); found { - c.KeyPath = &uniq +func (c *SSH) sanitizeKeyPaths() { + if len(c.keyPaths) == 0 { + return + } + var newPaths []string + for _, p := range c.keyPaths { + if p == "" { + continue + } + p, err := expandAndValidatePath(p) + if err != nil { + log.Tracef("%s: failed to validate key path %s: %v", c, p, err) + continue } - }) + newPaths = append(newPaths, p) + } + c.keyPaths = uniqStrings(newPaths) +} + +func (c *SSH) setup() { + for _, f := range []func(){ + c.setupAddress, + c.setupPort, + c.setupKeyPaths, + c.sanitizeKeyPaths, + } { + f() + } +} + +// SetDefaults sets various default values +func (c *SSH) SetDefaults() { + globalOnce.Do(initSSHDefaults) + c.once.Do(c.setup) } // Protocol returns the protocol name, "SSH" @@ -196,26 +263,20 @@ func (c *SSH) IPAddress() string { return c.Address } -// SSHConfigGetAll by default points to ssh_config package's GetAll() function -// you can override it with your own implementation for testing purposes -var SSHConfigGetAll = ssh_config.GetAll - // try with port, if no results, try without func (c *SSH) getConfigAll(key string) []string { - dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - if val := SSHConfigGetAll(dst, key); len(val) > 0 { - return val - } - return SSHConfigGetAll(c.Address, key) + return sshConfig.GetAll(c.Address, key) } // String returns the connection's printable name func (c *SSH) String() string { - if c.name == "" { - c.name = fmt.Sprintf("[ssh] %s", net.JoinHostPort(c.Address, strconv.Itoa(c.Port))) + var name string + if c.connAddr != c.Address { + name = c.Address + } else { + name = c.connAddr } - - return c.name + return fmt.Sprintf("[ssh] %s", net.JoinHostPort(name, strconv.Itoa(*c.Port))) } // IsConnected returns true if the client is connected @@ -307,6 +368,38 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { return knownhostsCallback(defaultPath, permissive) } +func (c *SSH) signers() []ssh.Signer { + if sshConfig.Get(c.Address, "IdentitiesOnly") == "yes" { + log.Debugf("%s: IdentitiesOnly is set to 'yes', not using ssh-agent", c) + return []ssh.Signer{} + } + agent, err := agentClient() + if err != nil { + log.Tracef("%s: failed to get ssh agent client: %v", c, err) + return []ssh.Signer{} + } + signers, err := agent.Signers() + if err != nil { + log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) + return []ssh.Signer{} + } + return signers +} + +func getCachedAuth(keyPath string) ssh.AuthMethod { + if am, ok := authMethodCache.Load(keyPath); ok { + switch authM := am.(type) { + case ssh.AuthMethod: + return authM + case error: + log.Tracef("already discarded key before %s: %v", keyPath, authM) + default: + log.Tracef("unexpected type %T for cached auth method for %s", am, keyPath) + } + } + return nil +} + func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { config := &ssh.ClientConfig{ User: c.User, @@ -318,39 +411,38 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { } config.HostKeyCallback = hkc - var signers []ssh.Signer - agent, err := agentClient() - if err != nil { - log.Tracef("%s: failed to get ssh agent client: %v", c, err) - } else { - signers, err = agent.Signers() - if err != nil { - log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) - } - } + signers := c.signers() + + nonDefaultPaths := c.nonDefaultKeypaths() for _, keyPath := range c.keyPaths { - if am, ok := authMethodCache.Load(keyPath); ok { - switch authM := am.(type) { - case ssh.AuthMethod: - log.Tracef("%s: using cached auth method for %s", c, keyPath) - config.Auth = append(config.Auth, authM) - case error: - log.Tracef("%s: already discarded key %s: %v", c, keyPath, authM) - default: - log.Tracef("%s: unexpected type %T for cached auth method for %s", c, am, keyPath) - } + if am := getCachedAuth(keyPath); am != nil { + log.Tracef("%s: using a cached auth method for identity file %s", c, keyPath) + config.Auth = append(config.Auth, am) continue } - privateKeyAuth, err := c.pkeySigner(signers, keyPath) + keyAuth, err := c.keyfileAuth(signers, keyPath) if err != nil { - log.Debugf("%s: failed to obtain a signer for identity %s: %v", c, keyPath, err) - // store the error so this key won't be loaded again authMethodCache.Store(keyPath, err) - } else { - authMethodCache.Store(keyPath, privateKeyAuth) - config.Auth = append(config.Auth, privateKeyAuth) + // store the error so this key won't be loaded again + + if c.KeyPath != nil { + return nil, ErrCantConnect.Wrapf("can't use configured identity file %s: %w", *c.KeyPath, err) + } + + // if the key isn't one of the default paths, assume it was explicitly set, and + // treat this as a fatal error + for _, p := range nonDefaultPaths { + if p == keyPath { + return nil, ErrCantConnect.Wrapf("can't use identity file %s: %w", keyPath, err) + } + } + + log.Debugf("%s: failed to obtain a signer for identity file %s: %v", c, keyPath, err) + continue } + authMethodCache.Store(keyPath, keyAuth) + config.Auth = append(config.Auth, keyAuth) } if len(config.Auth) == 0 { @@ -364,31 +456,7 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { return config, nil } -// Connect opens the SSH connection -func (c *SSH) Connect() error { - if err := defaults.Set(c); err != nil { - return ErrValidationFailed.Wrapf("set defaults: %w", err) - } - - config, err := c.clientConfig() - if err != nil { - return ErrCantConnect.Wrapf("create config: %w", err) - } - - dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - - if c.Bastion == nil { - clientDirect, err := ssh.Dial("tcp", dst, config) - if err != nil { - if errors.Is(err, hostkey.ErrHostKeyMismatch) { - return ErrCantConnect.Wrap(err) - } - return fmt.Errorf("ssh dial: %w", err) - } - c.client = clientDirect - return nil - } - +func (c *SSH) connectViaBastion(dst string, cfg *ssh.ClientConfig) error { if err := c.Bastion.Connect(); err != nil { if errors.Is(err, hostkey.ErrHostKeyMismatch) { return ErrCantConnect.Wrapf("bastion connect: %w", err) @@ -399,7 +467,7 @@ func (c *SSH) Connect() error { if err != nil { return fmt.Errorf("bastion dial: %w", err) } - client, chans, reqs, err := ssh.NewClientConn(bconn, dst, config) + client, chans, reqs, err := ssh.NewClientConn(bconn, dst, cfg) if err != nil { if errors.Is(err, hostkey.ErrHostKeyMismatch) { return ErrCantConnect.Wrapf("bastion client connect: %w", err) @@ -411,13 +479,50 @@ func (c *SSH) Connect() error { return nil } -func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { +// Connect opens the SSH connection +func (c *SSH) Connect() error { + if err := defaults.Set(c); err != nil { + return ErrValidationFailed.Wrapf("set defaults: %w", err) + } + + config, err := c.clientConfig() + if err != nil { + return ErrCantConnect.Wrapf("create config: %w", err) + } + + var port string + if c.Port == nil { + port = "22" + } else { + port = strconv.Itoa(*c.Port) + } + dst := net.JoinHostPort(c.connAddr, port) + log.Debugf("%s: connecting to %s", c, dst) + + if c.Bastion != nil { + return c.connectViaBastion(dst, config) + } + + clientDirect, err := ssh.Dial("tcp", dst, config) + if err != nil { + if errors.Is(err, hostkey.ErrHostKeyMismatch) { + return ErrCantConnect.Wrap(err) + } + return fmt.Errorf("ssh dial: %w", err) + } + c.client = clientDirect + return nil +} + +// pubkeysigner returns an ssh.AuthMethod for an ssh public key from the supplied signers +func (c *SSH) pubkeyAuth(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { if len(signers) == 0 { return nil, ErrCantConnect.Wrapf("signer not found for public key") } + keyM := key.Marshal() for _, s := range signers { - if bytes.Equal(key.Marshal(), s.PublicKey().Marshal()) { + if bytes.Equal(keyM, s.PublicKey().Marshal()) { log.Debugf("%s: signer for public key available in ssh agent", c) return ssh.PublicKeys(s), nil } @@ -426,7 +531,8 @@ func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMet return nil, ErrAuthFailed.Wrapf("the provided key is a public key and is not known by agent") } -func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { +// keyfileAuth returns an AuthMethod for the given keyPath. the signers are passed around to avoid querying them from agent multiple times +func (c *SSH) keyfileAuth(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { log.Tracef("%s: checking identity file %s", c, path) key, err := os.ReadFile(path) if err != nil { @@ -436,7 +542,7 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err pubKey, _, _, _, err := ssh.ParseAuthorizedKey(key) if err == nil { log.Debugf("%s: file %s is a public key", c, path) - return c.pubkeySigner(signers, pubKey) + return c.pubkeyAuth(signers, pubKey) } signer, err := ssh.ParsePrivateKey(key) @@ -450,11 +556,15 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err log.Debugf("%s: key %s is encrypted", c, path) if len(signers) > 0 { - if signer, err := c.pkeySigner(signers, path+".pub"); err == nil { + if signer, err := c.keyfileAuth(signers, path+".pub"); err == nil { return signer, nil } } + if sshConfig.Get(c.Address, "BatchMode") == "yes" { + return nil, ErrCantConnect.Wrapf("passphrase required for encrypted key but BatchMode is set for host in ssh config: %w", err) + } + if c.PasswordCallback != nil { log.Tracef("%s: asking for a password to decrypt %s", c, path) pass, err := c.PasswordCallback() diff --git a/test/Makefile b/test/Makefile index 1c0f8f2e..d717524c 100644 --- a/test/Makefile +++ b/test/Makefile @@ -73,7 +73,7 @@ sshport: .PHONY: run run: rigtest create-host ./rigtest \ - -host 127.0.0.1:$(shell $(MAKE) sshport) \ + -host 127.0.0.1 sshport) \ -keypath $(KEY_PATH) \ -user root diff --git a/test/test.sh b/test/test.sh index 9336fdb9..9c0d69c1 100755 --- a/test/test.sh +++ b/test/test.sh @@ -38,7 +38,7 @@ rig_test_agent_with_public_key() { ssh-add .ssh/identity rm -f .ssh/identity set +e - HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity.pub -connect + SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity.pub -connect local exit_code=$? set -e kill $SSH_AGENT_PID @@ -59,7 +59,7 @@ rig_test_agent_with_private_key() { ' set +e # path points to a private key, rig should try to look for the .pub for it - HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity -connect + SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity -connect local exit_code=$? set -e kill $SSH_AGENT_PID @@ -76,7 +76,7 @@ rig_test_agent() { rm -f .ssh/identity set +e ssh-add -l - HOME=. SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath "" -connect + SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath "" -connect local exit_code=$? set -e kill $SSH_AGENT_PID @@ -89,10 +89,11 @@ rig_test_ssh_config() { color_echo "- Testing getting identity path from ssh config" make create-host mv .ssh/identity .ssh/identity2 - echo "Host 127.0.0.1:$(ssh_port node0)" > .ssh/config + echo "Host 127.0.0.1" > .ssh/config echo " IdentityFile .ssh/identity2" >> .ssh/config + echo " Port $(ssh_port node0)" >> .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect + ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect local exit_code=$? set -e RET=$exit_code @@ -102,12 +103,12 @@ rig_test_ssh_config_strict() { color_echo "- Testing StrictHostkeyChecking=yes in ssh config" make create-host local addr="127.0.0.1:$(ssh_port node0)" - echo "Host ${addr}" > .ssh/config + echo "Host 127.0.0.1" > .ssh/config echo " IdentityFile .ssh/identity" >> .ssh/config echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config cat .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect local exit_code=$? set -e if [ $exit_code -ne 0 ]; then @@ -121,7 +122,7 @@ rig_test_ssh_config_strict() { echo "${addr} ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known cat .ssh/known set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect exit_code=$? set -e @@ -142,7 +143,7 @@ rig_test_ssh_config_no_strict() { echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config echo " StrictHostKeyChecking no" >> .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect local exit_code=$? set -e if [ $? -ne 0 ]; then @@ -152,7 +153,7 @@ rig_test_ssh_config_no_strict() { # modify the known hosts file to make it mismatch echo "${addr} ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect exit_code=$? set -e RET=$exit_code @@ -202,6 +203,8 @@ if ! sanity_check; then exit 1 fi +export HOME=$( cd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null && pwd ) + for test in $(declare -F|grep rig_test_|cut -d" " -f3); do if [ "$FOCUS" != "" ] && [ "$FOCUS" != "$test" ]; then continue