diff --git a/pkg/credentials/gitcreds/creds_test.go b/pkg/credentials/gitcreds/creds_test.go index a1d9ad56..71793970 100644 --- a/pkg/credentials/gitcreds/creds_test.go +++ b/pkg/credentials/gitcreds/creds_test.go @@ -218,6 +218,7 @@ func TestSSHFlagHandling(t *testing.T) { expectedSSHConfig := fmt.Sprintf(`Host github.com HostName github.com IdentityFile %s + Port 22 `, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo")) if string(b) != expectedSSHConfig { t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig) @@ -243,7 +244,7 @@ func TestSSHFlagHandling(t *testing.T) { } } -func TestSSHFlagHandlingTwice(t *testing.T) { +func TestSSHFlagHandlingThrice(t *testing.T) { credentials.VolumePath, _ = ioutil.TempDir("", "") fooDir := credentials.VolumeName("foo") if err := os.MkdirAll(fooDir, os.ModePerm); err != nil { @@ -265,12 +266,23 @@ func TestSSHFlagHandlingTwice(t *testing.T) { if err := ioutil.WriteFile(filepath.Join(barDir, "known_hosts"), []byte("ssh-rsa bbbb"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(known_hosts) = %v", err) } + bazDir := credentials.VolumeName("baz") + if err := os.MkdirAll(bazDir, os.ModePerm); err != nil { + t.Fatalf("os.MkdirAll(%s) = %v", bazDir, err) + } + if err := ioutil.WriteFile(filepath.Join(bazDir, corev1.SSHAuthPrivateKey), []byte("derp"), 0777); err != nil { + t.Fatalf("ioutil.WriteFile(ssh-privatekey) = %v", err) + } + if err := ioutil.WriteFile(filepath.Join(bazDir, "known_hosts"), []byte("ssh-rsa cccc"), 0777); err != nil { + t.Fatalf("ioutil.WriteFile(known_hosts) = %v", err) + } fs := flag.NewFlagSet("test", flag.ContinueOnError) flags(fs) err := fs.Parse([]string{ "-ssh-git=foo=github.com", "-ssh-git=bar=gitlab.com", + "-ssh-git=baz=gitlab.example.com:2222", }) if err != nil { t.Fatalf("flag.CommandLine.Parse() = %v", err) @@ -289,11 +301,18 @@ func TestSSHFlagHandlingTwice(t *testing.T) { expectedSSHConfig := fmt.Sprintf(`Host github.com HostName github.com IdentityFile %s + Port 22 Host gitlab.com HostName gitlab.com IdentityFile %s + Port 22 +Host gitlab.example.com + HostName gitlab.example.com + IdentityFile %s + Port 2222 `, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"), - filepath.Join(os.Getenv("HOME"), ".ssh", "id_bar")) + filepath.Join(os.Getenv("HOME"), ".ssh", "id_bar"), + filepath.Join(os.Getenv("HOME"), ".ssh", "id_baz")) if string(b) != expectedSSHConfig { t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig) } @@ -303,7 +322,8 @@ Host gitlab.com t.Fatalf("ioutil.ReadFile(.ssh/known_hosts) = %v", err) } expectedSSHKnownHosts := `ssh-rsa aaaa -ssh-rsa bbbb` +ssh-rsa bbbb +ssh-rsa cccc` if string(b) != expectedSSHKnownHosts { t.Errorf("got: %v, wanted: %v", string(b), expectedSSHKnownHosts) } @@ -327,6 +347,16 @@ ssh-rsa bbbb` if string(b) != expectedIDBar { t.Errorf("got: %v, wanted: %v", string(b), expectedIDBar) } + + b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "id_baz")) + if err != nil { + t.Fatalf("ioutil.ReadFile(.ssh/id_baz) = %v", err) + } + + expectedIDBaz := `derp` + if string(b) != expectedIDBaz { + t.Errorf("got: %v, wanted: %v", string(b), expectedIDBaz) + } } func TestSSHFlagHandlingMissingFiles(t *testing.T) { diff --git a/pkg/credentials/gitcreds/ssh.go b/pkg/credentials/gitcreds/ssh.go index 43cc0341..8107419c 100644 --- a/pkg/credentials/gitcreds/ssh.go +++ b/pkg/credentials/gitcreds/ssh.go @@ -20,6 +20,7 @@ import ( "bytes" "fmt" "io/ioutil" + "net" "os" "os/exec" "path/filepath" @@ -85,8 +86,15 @@ func (dc *sshGitConfig) Write() error { // 2. Compute its part of "~/.ssh/config" // 3. Compute its part of "~/.ssh/known_hosts" var configEntries []string + var defaultPort = "22" var knownHosts []string for _, k := range dc.order { + var host, port string + var err error + if host, port, err = net.SplitHostPort(k); err != nil { + host = k + port = defaultPort + } v := dc.entries[k] if err := v.Write(sshDir); err != nil { return err @@ -94,7 +102,8 @@ func (dc *sshGitConfig) Write() error { configEntries = append(configEntries, fmt.Sprintf(`Host %s HostName %s IdentityFile %s -`, k, k, v.path(sshDir))) + Port %s +`, host, host, v.path(sshDir), port)) knownHosts = append(knownHosts, v.knownHosts) }