Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssh_dialer_test.go edits #44

Merged
merged 1 commit into from
Nov 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 61 additions & 111 deletions internal/sshdialer/ssh_dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,28 @@ import (

"github.com/docker/docker/pkg/homedir"
"github.com/pkg/errors"
"github.com/sclevine/spec"
"github.com/sclevine/spec/report"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"

"github.com/buildpacks/pack/internal/sshdialer"
th "github.com/buildpacks/pack/testhelpers"
)

type args struct {
connStr string
credentialConfig sshdialer.Config
}
type testParams struct {
name string
args args
setUpEnv setUpEnvFn
skipOnWin bool
CreateError string
DialError string
}

func TestCreateDialer(t *testing.T) {
for _, privateKey := range []string{"id_ed25519", "id_rsa", "id_dsa"} {
path := filepath.Join("testdata", privateKey)
Expand All @@ -36,24 +51,11 @@ func TestCreateDialer(t *testing.T) {
defer withCleanHome(t)()

connConfig, cleanUp, err := prepareSSHServer(t)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

defer cleanUp()
time.Sleep(time.Second * 1)

type args struct {
connStr string
credentialConfig sshdialer.Config
}
type testParams struct {
name string
args args
setUpEnv setUpEnvFn
skipOnWin bool
CreateError string
DialError string
}
tests := []testParams{
{
name: "read password from input",
Expand Down Expand Up @@ -376,14 +378,16 @@ func TestCreateDialer(t *testing.T) {
}

for _, ttx := range tests {
tt := ttx
t.Run(tt.name, func(t *testing.T) {
// this test cannot be parallelized as they use process wide environment variable $HOME
spec.Run(t, "sshDialer/"+ttx.name, testCreateDialer(connConfig, ttx), spec.Report(report.Terminal{}))
}
}

// this test cannot be parallelized as they use process wide environment variable $HOME
func testCreateDialer(connConfig *SSHServer, tt testParams) func(t *testing.T, when spec.G, it spec.S) {
return func(t *testing.T, when spec.G, it spec.S) {
it("creates a dialer", func() {
u, err := url.Parse(tt.args.connStr)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

if net.ParseIP(u.Hostname()).To4() == nil && connConfig.hostIPv6 == "" {
t.Skip("skipping ipv6 test since test environment doesn't support ipv6 connection")
Expand Down Expand Up @@ -495,21 +499,15 @@ func withKey(t *testing.T, keyName string) setUpEnvFn {
var err error

home, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

err = os.MkdirAll(filepath.Join(home, ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

keySrc := filepath.Join("testdata", keyName)
keyDest := filepath.Join(home, ".ssh", keyName)
err = cp(keySrc, keyDest)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

return func() {
os.Remove(keyDest)
Expand All @@ -524,20 +522,14 @@ func withInaccessibleKey(keyName string) setUpEnvFn {
var err error

home, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

err = os.MkdirAll(filepath.Join(home, ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

keyDest := filepath.Join(home, ".ssh", keyName)
_, err = os.OpenFile(keyDest, os.O_CREATE|os.O_WRONLY, 0000)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

return func() {
os.Remove(keyDest)
Expand All @@ -554,9 +546,8 @@ func withCleanHome(t *testing.T) func() {
homeName = "USERPROFILE"
}
tmpDir, err := ioutil.TempDir("", "tmpHome")
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

oldHome, hadHome := os.LookupEnv(homeName)
os.Setenv(homeName, tmpDir)

Expand All @@ -578,29 +569,23 @@ func withKnowHosts(connConfig *SSHServer) setUpEnvFn {
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")

err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

_, err = os.Stat(knownHosts)
if err == nil || !errors.Is(err, os.ErrNotExist) {
t.Fatal("known_hosts already exists")
}

f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)
defer f.Close()

// generate known_hosts
serverKeysDir := filepath.Join("testdata", "etc", "ssh")
for _, k := range []string{"ecdsa"} {
keyPath := filepath.Join(serverKeysDir, fmt.Sprintf("ssh_host_%s_key.pub", k))
key, err := ioutil.ReadFile(keyPath)
if err != nil {
t.Fatal(t)
}
th.AssertNil(t, err)

fmt.Fprintf(f, "%s %s", connConfig.hostIPv4, string(key))
fmt.Fprintf(f, "[%s]:%d %s", connConfig.hostIPv4, connConfig.portIPv4, string(key))
Expand All @@ -625,19 +610,15 @@ func withBadKnownHosts(connConfig *SSHServer) setUpEnvFn {
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")

err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

_, err = os.Stat(knownHosts)
if err == nil || !errors.Is(err, os.ErrNotExist) {
t.Fatal("known_hosts already exists")
}

f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)
defer f.Close()

knownHostTemplate := `{{range $host := .}}{{$host}} ssh-dss AAAAB3NzaC1kc3MAAACBAKH4ufS3ABVb780oTgEL1eu+pI1p6YOq/1KJn5s3zm+L3cXXq76r5OM/roGEYrXWUDGRtfVpzYTAKoMWuqcVc0AZ2zOdYkoy1fSjJ3MqDGF53QEO3TXIUt3gUzmLOewwmZWle0RgMa9GHccv7XVVIZB36RR68ZEUswLaTnlVhXQ1AAAAFQCl4t/LnY7kuUI+tL2qT2XmxmiyqwAAAIB72XaO+LfyIiqBOaTkQf+5rvH1i6y6LDO1QD9pzGWUYw3y03AEveHJMjW0EjnYBKJjK39wcZNTieRyU54lhH/HWeWABn9NcQ3duEf1WSO/s7SPsFO2R6quqVSsStkqf2Yfdy4fl24mH41olwtNA6ft5nkVfkqrIa51si4jU8fBVAAAAIB8SSvyYBcyMGLUlQjzQqhhhAHer9x/1YbknVz+y5PHJLLjHjMC4ZRfLgNEojvMKQW46Te9Pwnudcwv19ho4F+kkCOfss7xjyH70gQm6Sj76DxClmnnPoSRq3qEAOMy5Oh+7vyzxm68KHqd/aOmUaiT1LgqgViS9+kNdCoVMGAMOg== mvasek@bellatrix
Expand All @@ -648,9 +629,7 @@ func withBadKnownHosts(connConfig *SSHServer) setUpEnvFn {

tmpl := template.New(knownHostTemplate)
tmpl, err = tmpl.Parse(knownHostTemplate)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

hosts := make([]string, 0, 4)
hosts = append(hosts, connConfig.hostIPv4, fmt.Sprintf("[%s]:%d", connConfig.hostIPv4, connConfig.portIPv4))
Expand All @@ -659,9 +638,7 @@ func withBadKnownHosts(connConfig *SSHServer) setUpEnvFn {
}

err = tmpl.Execute(f, hosts)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

return func() {
os.Remove(knownHosts)
Expand All @@ -676,25 +653,19 @@ func withBrokenKnownHosts(t *testing.T) func() {
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")

err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

_, err = os.Stat(knownHosts)
if err == nil || !errors.Is(err, os.ErrNotExist) {
t.Fatal("known_hosts already exists")
}

f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)
defer f.Close()

_, err = f.WriteString("somegarbage\nsome rubish\n stuff\tqwerty")
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

return func() {
os.Remove(knownHosts)
Expand All @@ -708,19 +679,15 @@ func withInaccessibleKnownHosts(t *testing.T) func() {
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")

err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

_, err = os.Stat(knownHosts)
if err == nil || !errors.Is(err, os.ErrNotExist) {
t.Fatal("known_hosts already exists")
}

f, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0000)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)
defer f.Close()

return func() {
Expand All @@ -735,19 +702,15 @@ func withEmptyKnownHosts(t *testing.T) func() {
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")

err := os.MkdirAll(filepath.Join(homedir.Get(), ".ssh"), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

_, err = os.Stat(knownHosts)
if err == nil || !errors.Is(err, os.ErrNotExist) {
t.Fatal("known_hosts already exists")
}

_, err = os.Create(knownHosts)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

return func() {
os.Remove(knownHosts)
Expand Down Expand Up @@ -791,13 +754,10 @@ func withGoodSSHAgent(t *testing.T) func() {
t.Helper()

key, err := ioutil.ReadFile(filepath.Join("testdata", "id_ed25519"))
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

signer, err := ssh.ParsePrivateKey(key)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

return withSSHAgent(t, signerAgent{signer})
}
Expand All @@ -812,14 +772,12 @@ func withBadSSHAgent(t *testing.T) func() {
func withSSHAgent(t *testing.T, ag agent.Agent) func() {
t.Helper()
tmpDirForSocket, err := ioutil.TempDir("", "forAuthSock")
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

agentSocketPath := filepath.Join(tmpDirForSocket, "agent.sock")
unixListener, err := net.Listen("unix", agentSocketPath)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

os.Setenv("SSH_AUTH_SOCK", agentSocketPath)

ctx, cancel := context.WithCancel(context.Background())
Expand Down Expand Up @@ -856,9 +814,8 @@ func withSSHAgent(t *testing.T, ag agent.Agent) func() {
os.Unsetenv("SSH_AUTH_SOCK")

err := unixListener.Close()
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

err = <-errChan

if !errors.Is(err, net.ErrClosed) {
Expand Down Expand Up @@ -958,9 +915,8 @@ func withFixedUpSSHCLI(t *testing.T) func() {
}

out, err := exec.Command(which, "ssh").CombinedOutput()
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

sshAbsPath := string(out)
sshAbsPath = strings.Trim(sshAbsPath, "\r\n")

Expand All @@ -975,15 +931,11 @@ SSH_BIN -o PasswordAuthentication=no -o ConnectTimeout=3 -o UserKnownHostsFile=%
sshScript = strings.ReplaceAll(sshScript, "SSH_BIN", sshAbsPath)

home, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

homeBin := filepath.Join(home, "bin")
err = os.MkdirAll(homeBin, 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

sshScriptName := "ssh"
if runtime.GOOS == "windows" {
Expand All @@ -992,9 +944,7 @@ SSH_BIN -o PasswordAuthentication=no -o ConnectTimeout=3 -o UserKnownHostsFile=%

sshScriptFullPath := filepath.Join(homeBin, sshScriptName)
err = ioutil.WriteFile(sshScriptFullPath, []byte(sshScript), 0700)
if err != nil {
t.Fatal(err)
}
th.AssertNil(t, err)

oldPath := os.Getenv("PATH")
os.Setenv("PATH", homeBin+string(os.PathListSeparator)+oldPath)
Expand Down