From e4e042813fdf6fd87ad6bcb431fe1ae26cb2a1d6 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Thu, 2 Feb 2023 09:52:31 +0200 Subject: [PATCH 1/4] wip --- cmd/rigtest/rigtest.go | 47 +++--- go.mod | 1 + go.sum | 1 + pkg/ssh/config/config.go | 182 ++++++++++++++++++++++ ssh.go | 321 ++++++++++++++++++++++++++------------- test/test.sh | 18 ++- 6 files changed, 440 insertions(+), 130 deletions(-) create mode 100644 pkg/ssh/config/config.go diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go index c2b17f3a..adac68cd 100644 --- a/cmd/rigtest/rigtest.go +++ b/cmd/rigtest/rigtest.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/rand" "crypto/sha256" + "encoding/json" "errors" "flag" "fmt" @@ -20,6 +21,7 @@ import ( "github.com/k0sproject/rig/os" "github.com/k0sproject/rig/os/registry" _ "github.com/k0sproject/rig/os/support" + sshconf "github.com/k0sproject/rig/pkg/ssh/config" "github.com/kevinburke/ssh_config" "github.com/stretchr/testify/require" ) @@ -111,24 +113,33 @@ func main() { println("at least host required, see -help") goos.Exit(1) } - - if configPath := goos.Getenv("SSH_CONFIG"); configPath != "" { - f, err := goos.Open(configPath) - if err != nil { - panic(err) - } - cfg, err := ssh_config.Decode(f) - if err != nil { - panic(err) - } - rig.SSHConfigGetAll = func(dst, key string) []string { - res, err := cfg.GetAll(dst, key) + fieldset := sshconf.DefaultFieldSet + opts := fieldset.GetOptions(*dh) + enc := json.NewEncoder(goos.Stdout) + enc.Encode(opts) + hn := ssh_config.Get(*dh, "Host") + p := ssh_config.Get(*dh, "Port") + println("host:", hn, "port:", p) + + /* + if configPath := goos.Getenv("SSH_CONFIG"); configPath != "" { + f, err := goos.Open(configPath) if err != nil { - return nil + panic(err) + } + cfg, err := ssh_config.Decode(f) + if err != nil { + panic(err) + } + rig.SSHConfigGetAll = func(dst, key string) []string { + res, err := cfg.GetAll(dst, key) + if err != nil { + return nil + } + return res } - return res } - } + */ var passfunc func() (string, error) if *pc { @@ -141,16 +152,16 @@ func main() { } var hosts []*Host + var port *int for _, address := range strings.Split(*dh, ",") { - port := 22 if addr, portstr, ok := strings.Cut(address, ":"); ok { address = addr p, err := strconv.Atoi(portstr) if err != nil { panic("invalid port " + portstr) } - port = p + port = &p } var h *Host @@ -172,7 +183,7 @@ func main() { Connection: rig.Connection{ WinRM: &rig.WinRM{ Address: *dh, - Port: port, + Port: *port, User: *usr, UseHTTPS: *https, Insecure: true, diff --git a/go.mod b/go.mod index 93050114..800773a8 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.4.0 golang.org/x/term v0.3.0 + gopkg.in/yaml.v2 v2.2.2 ) require ( diff --git a/go.sum b/go.sum index 329944cd..67b9e362 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/ssh/config/config.go b/pkg/ssh/config/config.go new file mode 100644 index 00000000..f6307699 --- /dev/null +++ b/pkg/ssh/config/config.go @@ -0,0 +1,182 @@ +// Package config provides tools for getting data from OpenSSH config files +package config + +import ( + "reflect" + "strconv" + + "github.com/kevinburke/ssh_config" +) + +// DefaultOptions is set to the default values for host "*" from ssh_config on init +var defaultOptions *Options +var DefaultFieldSet *FieldSet +var KnownFields []string + +// Options has fields for all the settings available from ssh config files +type Options struct { + Host string + + BatchMode bool + BindAddress string + ChallengeResponseAuthentication bool + CheckHostIP bool + Ciphers string + ClearAllForwardings bool + Compression bool + CompressionLevel int + ConnectionAttempts int + ConnectTimeout int + ControlMaster bool + ControlPath string + DynamicForward string + EnableSSHKeysign bool + EscapeChar string + ExitOnForwardFailure bool + ForwardAgent bool + ForwardX11 bool + ForwardX11Trusted bool + GatewayPorts bool + GlobalKnownHostsFile string + GSSAPIAuthentication bool + GSSAPIDelegateCredentials bool + GSSAPIRenewalForcesRekey bool + GSSAPITrustDNS bool + HashKnownHosts bool + HostbasedAuthentication bool + HostKeyAlgorithms string + HostKeyAlias string + HostName string + IdentitiesOnly bool + IdentityFile []string + KbdInteractiveAuthentication bool + LocalCommand string + LocalForward string + LogLevel string + MACs string + NoHostAuthenticationForLocalhost bool + NumberOfPasswordPrompts int + PasswordAuthentication bool + PermitLocalCommand bool + Port int + PreferredAuthentications string + Protocol int + ProxyCommand string + PublicKeyAuthentication bool + RekeyLimit string + RemoteForward string + RhostsRSAAuthentication bool + RSAAuthentication bool + SendEnv []string + ServerAliveCountMax int + ServerAliveInterval int + SmartcardDevice string + StrictHostKeyChecking bool + TCPKeepAlive bool + Tunnel bool + TunnelDevice string + UsePrivilegedPort bool + User string + UserKnownHostsFile string + VerifyHostKeyDNS bool + VisualHostKey bool + XAuthLocation string + + fieldSet *FieldSet + isSet map[string]bool +} + +type FieldSet struct { + Fields []string + defaultOptions *Options +} + +func (f *FieldSet) GetOptions(host string) *Options { + opts := &Options{Host: host, fieldSet: f} + opts.populate() + return opts +} + +func NewFieldSet(fields []string) *FieldSet { + fs := &FieldSet{Fields: fields} + fs.defaultOptions = fs.GetOptions("*") + return fs +} + +func getString(host, field string) string { + return ssh_config.Get(host, field) +} + +func getStringAll(host, field string) []string { + return ssh_config.GetAll(host, field) +} + +func getBool(host, field string) bool { + return ssh_config.Get(host, field) == "yes" +} + +func getInt(host, field string) int { + val := ssh_config.Get(host, field) + if val == "" { + return 0 + } + if i, err := strconv.Atoi(val); err == nil { + return i + } + return 0 +} + +func (o *Options) getField(name string) reflect.Value { + return reflect.Indirect(reflect.ValueOf(o)).FieldByName(name) +} + +func (o *Options) populate() { + for _, fieldName := range o.fieldSet.Fields { + field := o.getField(fieldName) + if !field.CanSet() { + continue + } + + if ssh_config.SupportsMultiple(fieldName) { + field.Set(reflect.ValueOf(getStringAll(o.Host, fieldName))) + if defaultOptions != nil { + defaultField := defaultOptions.getField(fieldName) + o.isSet[fieldName] = !reflect.DeepEqual(field.Interface(), defaultField.Interface()) + } + continue + } + switch field.Kind() { //nolint:exhaustive + case reflect.String: + field.Set(reflect.ValueOf(getString(o.Host, fieldName))) + case reflect.Bool: + field.Set(reflect.ValueOf(getBool(o.Host, fieldName))) + case reflect.Int: + field.Set(reflect.ValueOf(getInt(o.Host, fieldName))) + default: + continue + } + if defaultOptions != nil { + defaultField := defaultOptions.getField(fieldName) + o.isSet[fieldName] = !reflect.DeepEqual(field.Interface(), defaultField.Interface()) + } + } +} + +// GetOptions returns an Options struct for the given host +func GetOptions(host string) *Options { + return DefaultFieldSet.GetOptions(host) +} + +func init() { + opt := Options{} + obj := reflect.ValueOf(opt) + KnownFields = []string{} + for i := 0; i < obj.NumField(); i++ { + f := obj.Type().Field(i) + if f.Name == "Host" { + continue + } + KnownFields = append(KnownFields, f.Name) + } + DefaultFieldSet = NewFieldSet(KnownFields) +} diff --git a/ssh.go b/ssh.go index 029eef83..b4ab0f50 100644 --- a/ssh.go +++ b/ssh.go @@ -28,12 +28,13 @@ import ( type SSH struct { Address string `yaml:"address" validate:"required,hostname|ip"` User string `yaml:"user" validate:"required" default:"root"` - Port int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` + Port *int `yaml:"port" default:"22" validate:"gt=0,lte=65535"` KeyPath *string `yaml:"keyPath" validate:"omitempty"` HostKey string `yaml:"hostKey,omitempty"` Bastion *SSH `yaml:"bastion,omitempty"` PasswordCallback PasswordCallback `yaml:"-"` - name string + + connAddr string isWindows bool knowOs bool @@ -47,19 +48,23 @@ type SSH struct { // PasswordCallback is a function that is called when a passphrase is needed to decrypt a private key type PasswordCallback func() (secret string, err error) +type sshconfig interface { + GetAll(alias, key string) []string + Get(alias, key string) string +} + var ( - authMethodCache = sync.Map{} - defaultKeypaths = []string{"~/.ssh/id_rsa", "~/.ssh/identity", "~/.ssh/id_dsa"} - dummyhostKeyPaths []string - globalOnce sync.Once - knownHostsMU sync.Mutex + authMethodCache = sync.Map{} + defaultKeypaths = []string{} + globalOnce sync.Once + knownHostsMU sync.Mutex + + sshConfig sshconfig = ssh_config.DefaultUserSettings // ErrChecksumMismatch is returned when the checksum of an uploaded file does not match expectation ErrChecksumMismatch = errstring.New("checksum mismatch") ) -const hopefullyNonexistentHost = "thisH0stDoe5not3xist" - // returns the current user homedir, prefers $HOME env var func homeDir() (string, error) { if home, ok := os.LookupEnv("HOME"); ok { @@ -112,78 +117,157 @@ func expandAndValidatePath(path string) (string, error) { return path, nil } -func (c *SSH) keypathsFromConfig() []string { - log.Tracef("%s: trying to get a keyfile path from ssh config", c) - if idf := c.getConfigAll("IdentityFile"); len(idf) > 0 { - log.Tracef("%s: detected %d identity file paths from ssh config: %v", c, len(idf), idf) - return idf +func flattenPaths(paths []string) []string { + var out []string + for _, p := range paths { + pp, err := shlex.Split(p) + if err == nil { + out = append(out, pp...) + } } - log.Tracef("%s: no identity file paths found in ssh config", c) - return []string{} + return out } -func (c *SSH) initGlobalDefaults() { - log.Tracef("discovering global default keypaths") - dummyHostIdentityFiles := SSHConfigGetAll(hopefullyNonexistentHost, "IdentityFile") - for _, keyPath := range dummyHostIdentityFiles { - if expanded, err := expandAndValidatePath(keyPath); err != nil { - dummyhostKeyPaths = append(dummyhostKeyPaths, expanded) - } +func uniqStrings(elems []string) []string { + if len(elems) < 2 { + return elems + } + uniq := make(map[string]struct{}) + for _, e := range elems { + uniq[e] = struct{}{} } + out := make([]string, 0, len(uniq)) + for k := range uniq { + out = append(out, k) + } + return out } -func findUniq(a, b []string) (string, bool) { - for _, s := range a { +func initSSHDefaults() { + keyPaths := sshConfig.GetAll("*", "IdentityFile") + if len(keyPaths) > 0 { + defaultKeypaths = flattenPaths(keyPaths) + } +} + +// sliceContainsAll returns true if string slice B contains only strings that are present in slice A +func sliceContainsAll(a, b []string) bool { + for _, s := range b { found := false - for _, t := range b { + for _, t := range a { if s == t { found = true break } } if !found { - return s, true + return false } } - return "", false + return true } -// SetDefaults sets various default values -func (c *SSH) SetDefaults() { - globalOnce.Do(c.initGlobalDefaults) - c.once.Do(func() { - if c.KeyPath != nil && *c.KeyPath != "" { - if expanded, err := expandAndValidatePath(*c.KeyPath); err == nil { - c.keyPaths = append(c.keyPaths, expanded) +func (c *SSH) nonDefaultKeypaths() []string { + var keyPaths []string + for _, p := range c.keyPaths { + var found bool + for _, d := range defaultKeypaths { + if p == d { + found = true + break } - // keypath is explicitly set, accept the fact even if it's invalid and - // don't try to find it from ssh config/defaults - return } - c.KeyPath = nil - - paths := c.keypathsFromConfig() - if len(paths) == 0 { - // no paths found in ssh config either, use defaults - paths = append(paths, defaultKeypaths...) + if !found { + keyPaths = append(keyPaths, p) } + } + return keyPaths +} - for _, p := range paths { - expanded, err := expandAndValidatePath(p) - if err != nil { - log.Tracef("%s: %s: %v", c, p, err) - continue - } - log.Debugf("%s: using identity file %s", c, expanded) - c.keyPaths = append(c.keyPaths, expanded) - } +func intPtr(num int) *int { + return &num +} + +func (c *SSH) setupAddress() { + if addr := sshConfig.Get(c.Address, "HostName"); addr != "" { + log.Debugf("%s: using hostname %s from ssh config as connection address", c.Address, addr) + c.connAddr = addr + return + } + c.connAddr = c.Address +} - // check if all the paths that were found are global defaults - // errors are handled differently when a keypath is explicitly set vs when it's defaulted - if uniq, found := findUniq(c.keyPaths, dummyhostKeyPaths); found { - c.KeyPath = &uniq +func (c *SSH) setupPort() { + if c.Port != nil { + return + } + + portS := sshConfig.Get(c.Address, "Port") + port, err := strconv.Atoi(portS) + if err == nil { + c.Port = intPtr(port) + log.Tracef("%s: using port %d from ssh config", c, c.Port) + return + } + c.Port = intPtr(22) + log.Tracef("%s: using default port", c) +} + +func (c *SSH) setupKeyPaths() { + if sshConfig.Get(c.Address, "PubkeyAuthentication") == "no" { + log.Infof("%s: public key based authentication disabled in ssh config", c) + c.keyPaths = nil + return + } + + // if keypath is set, use that + if c.KeyPath != nil && *c.KeyPath != "" { + c.keyPaths = []string{*c.KeyPath} + return + } + + log.Tracef("%s: trying to get a keyfile path from ssh config", c) + if paths := sshConfig.GetAll(c.Address, "IdentityFile"); len(paths) > 0 { + c.keyPaths = uniqStrings(flattenPaths(paths)) + log.Tracef("%s: detected %d identity file paths from ssh config: %v", c, len(c.keyPaths), c.keyPaths) + return + } +} + +func (c *SSH) sanitizeKeyPaths() { + if len(c.keyPaths) == 0 { + return + } + var newPaths []string + for _, p := range c.keyPaths { + if p == "" { + continue } - }) + p, err := expandAndValidatePath(p) + if err != nil { + log.Tracef("%s: failed to validate key path %s: %v", c, p, err) + continue + } + newPaths = append(newPaths, p) + } + c.keyPaths = uniqStrings(newPaths) +} + +func (c *SSH) setup() { + for _, f := range []func(){ + c.setupAddress, + c.setupPort, + c.setupKeyPaths, + c.sanitizeKeyPaths, + } { + f() + } +} + +// SetDefaults sets various default values +func (c *SSH) SetDefaults() { + globalOnce.Do(initSSHDefaults) + c.once.Do(c.setup) } // Protocol returns the protocol name, "SSH" @@ -196,26 +280,20 @@ func (c *SSH) IPAddress() string { return c.Address } -// SSHConfigGetAll by default points to ssh_config package's GetAll() function -// you can override it with your own implementation for testing purposes -var SSHConfigGetAll = ssh_config.GetAll - // try with port, if no results, try without func (c *SSH) getConfigAll(key string) []string { - dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - if val := SSHConfigGetAll(dst, key); len(val) > 0 { - return val - } - return SSHConfigGetAll(c.Address, key) + return sshConfig.GetAll(c.Address, key) } // String returns the connection's printable name func (c *SSH) String() string { - if c.name == "" { - c.name = fmt.Sprintf("[ssh] %s", net.JoinHostPort(c.Address, strconv.Itoa(c.Port))) + var name string + if c.connAddr != c.Address { + name = c.Address + } else { + name = c.connAddr } - - return c.name + return fmt.Sprintf("[ssh] %s", net.JoinHostPort(name, strconv.Itoa(*c.Port))) } // IsConnected returns true if the client is connected @@ -319,16 +397,23 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { config.HostKeyCallback = hkc var signers []ssh.Signer - agent, err := agentClient() - if err != nil { - log.Tracef("%s: failed to get ssh agent client: %v", c, err) + + if sshConfig.Get(c.Address, "IdentitiesOnly") == "yes" { + log.Debugf("%s: IdentitiesOnly is set to 'yes', not using ssh-agent", c) } else { - signers, err = agent.Signers() + agent, err := agentClient() if err != nil { - log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) + log.Tracef("%s: failed to get ssh agent client: %v", c, err) + } else { + signers, err = agent.Signers() + if err != nil { + log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) + } } } + nonDefaultPaths := c.nonDefaultKeypaths() + for _, keyPath := range c.keyPaths { if am, ok := authMethodCache.Load(keyPath); ok { switch authM := am.(type) { @@ -344,7 +429,17 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { } privateKeyAuth, err := c.pkeySigner(signers, keyPath) if err != nil { - log.Debugf("%s: failed to obtain a signer for identity %s: %v", c, keyPath, err) + if c.KeyPath != nil { + return nil, ErrCantConnect.Wrapf("can't use explicitly set identity file %s: %w", *c.KeyPath, err) + } + + for _, p := range nonDefaultPaths { + if p == keyPath { + return nil, ErrCantConnect.Wrapf("can't use identity file at %s: %w", keyPath, err) + } + } + + log.Debugf("%s: failed to obtain a signer for identity file %s: %v", c, keyPath, err) // store the error so this key won't be loaded again authMethodCache.Store(keyPath, err) } else { @@ -364,31 +459,7 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { return config, nil } -// Connect opens the SSH connection -func (c *SSH) Connect() error { - if err := defaults.Set(c); err != nil { - return ErrValidationFailed.Wrapf("set defaults: %w", err) - } - - config, err := c.clientConfig() - if err != nil { - return ErrCantConnect.Wrapf("create config: %w", err) - } - - dst := net.JoinHostPort(c.Address, strconv.Itoa(c.Port)) - - if c.Bastion == nil { - clientDirect, err := ssh.Dial("tcp", dst, config) - if err != nil { - if errors.Is(err, hostkey.ErrHostKeyMismatch) { - return ErrCantConnect.Wrap(err) - } - return fmt.Errorf("ssh dial: %w", err) - } - c.client = clientDirect - return nil - } - +func (c *SSH) connectViaBastion(dst string, cfg *ssh.ClientConfig) error { if err := c.Bastion.Connect(); err != nil { if errors.Is(err, hostkey.ErrHostKeyMismatch) { return ErrCantConnect.Wrapf("bastion connect: %w", err) @@ -399,7 +470,7 @@ func (c *SSH) Connect() error { if err != nil { return fmt.Errorf("bastion dial: %w", err) } - client, chans, reqs, err := ssh.NewClientConn(bconn, dst, config) + client, chans, reqs, err := ssh.NewClientConn(bconn, dst, cfg) if err != nil { if errors.Is(err, hostkey.ErrHostKeyMismatch) { return ErrCantConnect.Wrapf("bastion client connect: %w", err) @@ -411,13 +482,50 @@ func (c *SSH) Connect() error { return nil } +// Connect opens the SSH connection +func (c *SSH) Connect() error { + if err := defaults.Set(c); err != nil { + return ErrValidationFailed.Wrapf("set defaults: %w", err) + } + + config, err := c.clientConfig() + if err != nil { + return ErrCantConnect.Wrapf("create config: %w", err) + } + + var port string + if c.Port == nil { + port = "22" + } else { + port = strconv.Itoa(*c.Port) + } + dst := net.JoinHostPort(c.connAddr, port) + log.Debugf("%s: connecting to %s", c, dst) + + if c.Bastion != nil { + return c.connectViaBastion(dst, config) + } + + clientDirect, err := ssh.Dial("tcp", dst, config) + if err != nil { + if errors.Is(err, hostkey.ErrHostKeyMismatch) { + return ErrCantConnect.Wrap(err) + } + return fmt.Errorf("ssh dial: %w", err) + } + c.client = clientDirect + return nil +} + +// pubkeysigner returns an ssh.AuthMethod for an ssh public key from the supplied signers func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { if len(signers) == 0 { return nil, ErrCantConnect.Wrapf("signer not found for public key") } + keyM := key.Marshal() for _, s := range signers { - if bytes.Equal(key.Marshal(), s.PublicKey().Marshal()) { + if bytes.Equal(keyM, s.PublicKey().Marshal()) { log.Debugf("%s: signer for public key available in ssh agent", c) return ssh.PublicKeys(s), nil } @@ -426,6 +534,7 @@ func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMet return nil, ErrAuthFailed.Wrapf("the provided key is a public key and is not known by agent") } +// pkeySigner returns an AuthMethod for the given keyPath. the signers are passed around to avoid querying them from agent multiple times func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { log.Tracef("%s: checking identity file %s", c, path) key, err := os.ReadFile(path) @@ -455,6 +564,10 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err } } + if sshConfig.Get(c.Address, "BatchMode") == "yes" { + return nil, ErrCantConnect.Wrapf("passphrase required for encrypted key but BatchMode is set for host in ssh config: %w", err) + } + if c.PasswordCallback != nil { log.Tracef("%s: asking for a password to decrypt %s", c, path) pass, err := c.PasswordCallback() diff --git a/test/test.sh b/test/test.sh index 9336fdb9..449db848 100755 --- a/test/test.sh +++ b/test/test.sh @@ -38,7 +38,7 @@ rig_test_agent_with_public_key() { ssh-add .ssh/identity rm -f .ssh/identity set +e - HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity.pub -connect + SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity.pub -connect local exit_code=$? set -e kill $SSH_AGENT_PID @@ -59,7 +59,7 @@ rig_test_agent_with_private_key() { ' set +e # path points to a private key, rig should try to look for the .pub for it - HOME=$(pwd) SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity -connect + SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath .ssh/identity -connect local exit_code=$? set -e kill $SSH_AGENT_PID @@ -76,7 +76,7 @@ rig_test_agent() { rm -f .ssh/identity set +e ssh-add -l - HOME=. SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath "" -connect + SSH_AUTH_SOCK=$SSH_AUTH_SOCK ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -keypath "" -connect local exit_code=$? set -e kill $SSH_AGENT_PID @@ -92,7 +92,7 @@ rig_test_ssh_config() { echo "Host 127.0.0.1:$(ssh_port node0)" > .ssh/config echo " IdentityFile .ssh/identity2" >> .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect + ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect local exit_code=$? set -e RET=$exit_code @@ -107,7 +107,7 @@ rig_test_ssh_config_strict() { echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config cat .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect local exit_code=$? set -e if [ $exit_code -ne 0 ]; then @@ -121,7 +121,7 @@ rig_test_ssh_config_strict() { echo "${addr} ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known cat .ssh/known set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect exit_code=$? set -e @@ -142,7 +142,7 @@ rig_test_ssh_config_no_strict() { echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config echo " StrictHostKeyChecking no" >> .ssh/config set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect local exit_code=$? set -e if [ $? -ne 0 ]; then @@ -152,7 +152,7 @@ rig_test_ssh_config_no_strict() { # modify the known hosts file to make it mismatch echo "${addr} ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBBgejI9UJnRY/i4HNM/os57oFcRjE77gEbVfUkuGr5NRh3N7XxUnnBKdzrAiQNPttUjKmUm92BN7nCUxbwsoSPw=" > .ssh/known set +e - HOME=. SSH_CONFIG=.ssh/config ./rigtest -host "${addr}" -user root -connect + ./rigtest -host "${addr}" -user root -connect exit_code=$? set -e RET=$exit_code @@ -202,6 +202,8 @@ if ! sanity_check; then exit 1 fi +export HOME=$( cd "$( dirname "${BASH_SOURCE[0]}" )" > /dev/null && pwd ) + for test in $(declare -F|grep rig_test_|cut -d" " -f3); do if [ "$FOCUS" != "" ] && [ "$FOCUS" != "$test" ]; then continue From 0b2bda3113c7fcc6fa49c16b117cf69a98d5c62b Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Thu, 2 Feb 2023 14:35:27 +0200 Subject: [PATCH 2/4] wip --- test/Makefile | 2 +- test/test.sh | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/Makefile b/test/Makefile index 1c0f8f2e..d717524c 100644 --- a/test/Makefile +++ b/test/Makefile @@ -73,7 +73,7 @@ sshport: .PHONY: run run: rigtest create-host ./rigtest \ - -host 127.0.0.1:$(shell $(MAKE) sshport) \ + -host 127.0.0.1 sshport) \ -keypath $(KEY_PATH) \ -user root diff --git a/test/test.sh b/test/test.sh index 449db848..9c0d69c1 100755 --- a/test/test.sh +++ b/test/test.sh @@ -89,8 +89,9 @@ rig_test_ssh_config() { color_echo "- Testing getting identity path from ssh config" make create-host mv .ssh/identity .ssh/identity2 - echo "Host 127.0.0.1:$(ssh_port node0)" > .ssh/config + echo "Host 127.0.0.1" > .ssh/config echo " IdentityFile .ssh/identity2" >> .ssh/config + echo " Port $(ssh_port node0)" >> .ssh/config set +e ./rigtest -host 127.0.0.1:$(ssh_port node0) -user root -connect local exit_code=$? @@ -102,7 +103,7 @@ rig_test_ssh_config_strict() { color_echo "- Testing StrictHostkeyChecking=yes in ssh config" make create-host local addr="127.0.0.1:$(ssh_port node0)" - echo "Host ${addr}" > .ssh/config + echo "Host 127.0.0.1" > .ssh/config echo " IdentityFile .ssh/identity" >> .ssh/config echo " UserKnownHostsFile $(pwd)/.ssh/known" >> .ssh/config cat .ssh/config From 11e1043662f37b3d19233f1d861310d4b5e4bea2 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Fri, 3 Feb 2023 15:26:42 +0200 Subject: [PATCH 3/4] Lint --- pkg/ssh/config/config.go | 182 --------------------------------------- ssh.go | 107 +++++++++++------------ 2 files changed, 52 insertions(+), 237 deletions(-) delete mode 100644 pkg/ssh/config/config.go diff --git a/pkg/ssh/config/config.go b/pkg/ssh/config/config.go deleted file mode 100644 index f6307699..00000000 --- a/pkg/ssh/config/config.go +++ /dev/null @@ -1,182 +0,0 @@ -// Package config provides tools for getting data from OpenSSH config files -package config - -import ( - "reflect" - "strconv" - - "github.com/kevinburke/ssh_config" -) - -// DefaultOptions is set to the default values for host "*" from ssh_config on init -var defaultOptions *Options -var DefaultFieldSet *FieldSet -var KnownFields []string - -// Options has fields for all the settings available from ssh config files -type Options struct { - Host string - - BatchMode bool - BindAddress string - ChallengeResponseAuthentication bool - CheckHostIP bool - Ciphers string - ClearAllForwardings bool - Compression bool - CompressionLevel int - ConnectionAttempts int - ConnectTimeout int - ControlMaster bool - ControlPath string - DynamicForward string - EnableSSHKeysign bool - EscapeChar string - ExitOnForwardFailure bool - ForwardAgent bool - ForwardX11 bool - ForwardX11Trusted bool - GatewayPorts bool - GlobalKnownHostsFile string - GSSAPIAuthentication bool - GSSAPIDelegateCredentials bool - GSSAPIRenewalForcesRekey bool - GSSAPITrustDNS bool - HashKnownHosts bool - HostbasedAuthentication bool - HostKeyAlgorithms string - HostKeyAlias string - HostName string - IdentitiesOnly bool - IdentityFile []string - KbdInteractiveAuthentication bool - LocalCommand string - LocalForward string - LogLevel string - MACs string - NoHostAuthenticationForLocalhost bool - NumberOfPasswordPrompts int - PasswordAuthentication bool - PermitLocalCommand bool - Port int - PreferredAuthentications string - Protocol int - ProxyCommand string - PublicKeyAuthentication bool - RekeyLimit string - RemoteForward string - RhostsRSAAuthentication bool - RSAAuthentication bool - SendEnv []string - ServerAliveCountMax int - ServerAliveInterval int - SmartcardDevice string - StrictHostKeyChecking bool - TCPKeepAlive bool - Tunnel bool - TunnelDevice string - UsePrivilegedPort bool - User string - UserKnownHostsFile string - VerifyHostKeyDNS bool - VisualHostKey bool - XAuthLocation string - - fieldSet *FieldSet - isSet map[string]bool -} - -type FieldSet struct { - Fields []string - defaultOptions *Options -} - -func (f *FieldSet) GetOptions(host string) *Options { - opts := &Options{Host: host, fieldSet: f} - opts.populate() - return opts -} - -func NewFieldSet(fields []string) *FieldSet { - fs := &FieldSet{Fields: fields} - fs.defaultOptions = fs.GetOptions("*") - return fs -} - -func getString(host, field string) string { - return ssh_config.Get(host, field) -} - -func getStringAll(host, field string) []string { - return ssh_config.GetAll(host, field) -} - -func getBool(host, field string) bool { - return ssh_config.Get(host, field) == "yes" -} - -func getInt(host, field string) int { - val := ssh_config.Get(host, field) - if val == "" { - return 0 - } - if i, err := strconv.Atoi(val); err == nil { - return i - } - return 0 -} - -func (o *Options) getField(name string) reflect.Value { - return reflect.Indirect(reflect.ValueOf(o)).FieldByName(name) -} - -func (o *Options) populate() { - for _, fieldName := range o.fieldSet.Fields { - field := o.getField(fieldName) - if !field.CanSet() { - continue - } - - if ssh_config.SupportsMultiple(fieldName) { - field.Set(reflect.ValueOf(getStringAll(o.Host, fieldName))) - if defaultOptions != nil { - defaultField := defaultOptions.getField(fieldName) - o.isSet[fieldName] = !reflect.DeepEqual(field.Interface(), defaultField.Interface()) - } - continue - } - switch field.Kind() { //nolint:exhaustive - case reflect.String: - field.Set(reflect.ValueOf(getString(o.Host, fieldName))) - case reflect.Bool: - field.Set(reflect.ValueOf(getBool(o.Host, fieldName))) - case reflect.Int: - field.Set(reflect.ValueOf(getInt(o.Host, fieldName))) - default: - continue - } - if defaultOptions != nil { - defaultField := defaultOptions.getField(fieldName) - o.isSet[fieldName] = !reflect.DeepEqual(field.Interface(), defaultField.Interface()) - } - } -} - -// GetOptions returns an Options struct for the given host -func GetOptions(host string) *Options { - return DefaultFieldSet.GetOptions(host) -} - -func init() { - opt := Options{} - obj := reflect.ValueOf(opt) - KnownFields = []string{} - for i := 0; i < obj.NumField(); i++ { - f := obj.Type().Field(i) - if f.Name == "Host" { - continue - } - KnownFields = append(KnownFields, f.Name) - } - DefaultFieldSet = NewFieldSet(KnownFields) -} diff --git a/ssh.go b/ssh.go index b4ab0f50..40aa5604 100644 --- a/ssh.go +++ b/ssh.go @@ -150,23 +150,6 @@ func initSSHDefaults() { } } -// sliceContainsAll returns true if string slice B contains only strings that are present in slice A -func sliceContainsAll(a, b []string) bool { - for _, s := range b { - found := false - for _, t := range a { - if s == t { - found = true - break - } - } - if !found { - return false - } - } - return true -} - func (c *SSH) nonDefaultKeypaths() []string { var keyPaths []string for _, p := range c.keyPaths { @@ -385,6 +368,38 @@ func (c *SSH) hostkeyCallback() (ssh.HostKeyCallback, error) { return knownhostsCallback(defaultPath, permissive) } +func (c *SSH) signers() []ssh.Signer { + if sshConfig.Get(c.Address, "IdentitiesOnly") == "yes" { + log.Debugf("%s: IdentitiesOnly is set to 'yes', not using ssh-agent", c) + return []ssh.Signer{} + } + agent, err := agentClient() + if err != nil { + log.Tracef("%s: failed to get ssh agent client: %v", c, err) + return []ssh.Signer{} + } + signers, err := agent.Signers() + if err != nil { + log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) + return []ssh.Signer{} + } + return signers +} + +func getCachedAuth(keyPath string) ssh.AuthMethod { + if am, ok := authMethodCache.Load(keyPath); ok { + switch authM := am.(type) { + case ssh.AuthMethod: + return authM + case error: + log.Tracef("already discarded key before %s: %v", keyPath, authM) + default: + log.Tracef("unexpected type %T for cached auth method for %s", am, keyPath) + } + } + return nil +} + func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { config := &ssh.ClientConfig{ User: c.User, @@ -396,56 +411,38 @@ func (c *SSH) clientConfig() (*ssh.ClientConfig, error) { } config.HostKeyCallback = hkc - var signers []ssh.Signer - - if sshConfig.Get(c.Address, "IdentitiesOnly") == "yes" { - log.Debugf("%s: IdentitiesOnly is set to 'yes', not using ssh-agent", c) - } else { - agent, err := agentClient() - if err != nil { - log.Tracef("%s: failed to get ssh agent client: %v", c, err) - } else { - signers, err = agent.Signers() - if err != nil { - log.Debugf("%s: failed to list signers from ssh agent: %v", c, err) - } - } - } + signers := c.signers() nonDefaultPaths := c.nonDefaultKeypaths() for _, keyPath := range c.keyPaths { - if am, ok := authMethodCache.Load(keyPath); ok { - switch authM := am.(type) { - case ssh.AuthMethod: - log.Tracef("%s: using cached auth method for %s", c, keyPath) - config.Auth = append(config.Auth, authM) - case error: - log.Tracef("%s: already discarded key %s: %v", c, keyPath, authM) - default: - log.Tracef("%s: unexpected type %T for cached auth method for %s", c, am, keyPath) - } + if am := getCachedAuth(keyPath); am != nil { + log.Tracef("%s: using a cached auth method for identity file %s", c, keyPath) + config.Auth = append(config.Auth, am) continue } - privateKeyAuth, err := c.pkeySigner(signers, keyPath) + keyAuth, err := c.keyfileAuth(signers, keyPath) if err != nil { + authMethodCache.Store(keyPath, err) + // store the error so this key won't be loaded again + if c.KeyPath != nil { - return nil, ErrCantConnect.Wrapf("can't use explicitly set identity file %s: %w", *c.KeyPath, err) + return nil, ErrCantConnect.Wrapf("can't use configured identity file %s: %w", *c.KeyPath, err) } + // if the key isn't one of the default paths, assume it was explicitly set, and + // treat this as a fatal error for _, p := range nonDefaultPaths { if p == keyPath { - return nil, ErrCantConnect.Wrapf("can't use identity file at %s: %w", keyPath, err) + return nil, ErrCantConnect.Wrapf("can't use identity file %s: %w", keyPath, err) } } log.Debugf("%s: failed to obtain a signer for identity file %s: %v", c, keyPath, err) - // store the error so this key won't be loaded again - authMethodCache.Store(keyPath, err) - } else { - authMethodCache.Store(keyPath, privateKeyAuth) - config.Auth = append(config.Auth, privateKeyAuth) + continue } + authMethodCache.Store(keyPath, keyAuth) + config.Auth = append(config.Auth, keyAuth) } if len(config.Auth) == 0 { @@ -518,7 +515,7 @@ func (c *SSH) Connect() error { } // pubkeysigner returns an ssh.AuthMethod for an ssh public key from the supplied signers -func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { +func (c *SSH) pubkeyAuth(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMethod, error) { if len(signers) == 0 { return nil, ErrCantConnect.Wrapf("signer not found for public key") } @@ -534,8 +531,8 @@ func (c *SSH) pubkeySigner(signers []ssh.Signer, key ssh.PublicKey) (ssh.AuthMet return nil, ErrAuthFailed.Wrapf("the provided key is a public key and is not known by agent") } -// pkeySigner returns an AuthMethod for the given keyPath. the signers are passed around to avoid querying them from agent multiple times -func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { +// keyfileAuth returns an AuthMethod for the given keyPath. the signers are passed around to avoid querying them from agent multiple times +func (c *SSH) keyfileAuth(signers []ssh.Signer, path string) (ssh.AuthMethod, error) { log.Tracef("%s: checking identity file %s", c, path) key, err := os.ReadFile(path) if err != nil { @@ -545,7 +542,7 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err pubKey, _, _, _, err := ssh.ParseAuthorizedKey(key) if err == nil { log.Debugf("%s: file %s is a public key", c, path) - return c.pubkeySigner(signers, pubKey) + return c.pubkeyAuth(signers, pubKey) } signer, err := ssh.ParsePrivateKey(key) @@ -559,7 +556,7 @@ func (c *SSH) pkeySigner(signers []ssh.Signer, path string) (ssh.AuthMethod, err log.Debugf("%s: key %s is encrypted", c, path) if len(signers) > 0 { - if signer, err := c.pkeySigner(signers, path+".pub"); err == nil { + if signer, err := c.keyfileAuth(signers, path+".pub"); err == nil { return signer, nil } } From 7d46117522a2cf8fb46445798c31b03ac5f79606 Mon Sep 17 00:00:00 2001 From: Kimmo Lehto Date: Fri, 3 Feb 2023 15:48:21 +0200 Subject: [PATCH 4/4] Lint lint. --- cmd/rigtest/rigtest.go | 30 ------------------------------ go.mod | 1 - go.sum | 1 - 3 files changed, 32 deletions(-) diff --git a/cmd/rigtest/rigtest.go b/cmd/rigtest/rigtest.go index adac68cd..042f3cb1 100644 --- a/cmd/rigtest/rigtest.go +++ b/cmd/rigtest/rigtest.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/rand" "crypto/sha256" - "encoding/json" "errors" "flag" "fmt" @@ -21,8 +20,6 @@ import ( "github.com/k0sproject/rig/os" "github.com/k0sproject/rig/os/registry" _ "github.com/k0sproject/rig/os/support" - sshconf "github.com/k0sproject/rig/pkg/ssh/config" - "github.com/kevinburke/ssh_config" "github.com/stretchr/testify/require" ) @@ -113,33 +110,6 @@ func main() { println("at least host required, see -help") goos.Exit(1) } - fieldset := sshconf.DefaultFieldSet - opts := fieldset.GetOptions(*dh) - enc := json.NewEncoder(goos.Stdout) - enc.Encode(opts) - hn := ssh_config.Get(*dh, "Host") - p := ssh_config.Get(*dh, "Port") - println("host:", hn, "port:", p) - - /* - if configPath := goos.Getenv("SSH_CONFIG"); configPath != "" { - f, err := goos.Open(configPath) - if err != nil { - panic(err) - } - cfg, err := ssh_config.Decode(f) - if err != nil { - panic(err) - } - rig.SSHConfigGetAll = func(dst, key string) []string { - res, err := cfg.GetAll(dst, key) - if err != nil { - return nil - } - return res - } - } - */ var passfunc func() (string, error) if *pc { diff --git a/go.mod b/go.mod index 800773a8..93050114 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.4.0 golang.org/x/term v0.3.0 - gopkg.in/yaml.v2 v2.2.2 ) require ( diff --git a/go.sum b/go.sum index 67b9e362..329944cd 100644 --- a/go.sum +++ b/go.sum @@ -111,7 +111,6 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=