diff --git a/pkg/ssh/config.go b/pkg/ssh/config.go index 6cff9cf54..a155d7110 100644 --- a/pkg/ssh/config.go +++ b/pkg/ssh/config.go @@ -1,10 +1,13 @@ package ssh import ( + "bufio" "fmt" "io" "os" "path/filepath" + "runtime" + "slices" "strings" "sync" @@ -48,7 +51,6 @@ func addHost(path, host, user, context, workspace, workdir, command string, gpga if err != nil { return "", err } - newLines := []string{} // get path to executable execPath, err := os.Executable() @@ -56,6 +58,11 @@ func addHost(path, host, user, context, workspace, workdir, command string, gpga return "", err } + return addHostSection(newConfig, execPath, host, user, context, workspace, workdir, command, gpgagent) +} + +func addHostSection(config, execPath, host, user, context, workspace, workdir, command string, gpgagent bool) (string, error) { + newLines := []string{} // add new section startMarker := MarkerStartPrefix + host endMarker := MarkerEndPrefix + host @@ -79,13 +86,47 @@ func addHost(path, host, user, context, workspace, workdir, command string, gpga } newLines = append(newLines, " User "+user) newLines = append(newLines, endMarker) - // add a space between blocks - newLines = append(newLines, "") // now we append the original config - // keep our blocks on top of the file for priority reasons - newLines = append(newLines, newConfig) - return strings.Join(newLines, "\n"), nil + // keep our blocks on top of the hosts for priority reasons, but below any includes + lineNumber := 0 + found := false + lines := []string{} + commentLines := 0 + scanner := bufio.NewScanner(strings.NewReader(config)) + for scanner.Scan() { + line := scanner.Text() + // Check `Host` keyword + if strings.HasPrefix(strings.TrimSpace(line), "Host") && !found { + found = true + lineNumber = max(lineNumber-commentLines, 0) + } + + // Preserve comments + if strings.HasPrefix(strings.TrimSpace(line), "#") { + commentLines++ + } else { + commentLines = 0 + } + + if !found { + lineNumber++ + } + + lines = append(lines, line) + } + if err := scanner.Err(); err != nil { + return config, err + } + + lines = slices.Insert(lines, lineNumber, newLines...) + + newLineSep := "\n" + if runtime.GOOS == "windows" { + newLineSep = "\r\n" + } + + return strings.Join(lines, newLineSep), nil } func GetUser(workspaceID string, sshConfigPath string) (string, error) { diff --git a/pkg/ssh/config_test.go b/pkg/ssh/config_test.go new file mode 100644 index 000000000..ceb3b28f2 --- /dev/null +++ b/pkg/ssh/config_test.go @@ -0,0 +1,263 @@ +package ssh + +import ( + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestAddHostSection(t *testing.T) { + tests := []struct { + name string + config string + execPath string + host string + user string + context string + workspace string + workdir string + command string + gpgagent bool + expected string + }{ + { + name: "Basic host addition", + config: "", + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "", + gpgagent: false, + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost`, + }, + { + name: "Host addition with workdir", + config: "", + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "/path/to/workdir", + command: "", + gpgagent: false, + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace --workdir /path/to/workdir + User testuser +# DevPod End testhost`, + }, + { + name: "Host addition with gpg agent", + config: "", + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "", + gpgagent: true, + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --gpg-agent-forwarding --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost`, + }, + { + name: "Host addition with custom command", + config: "", + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "ssh -W %h:%p bastion", + gpgagent: false, + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "ssh -W %h:%p bastion" + User testuser +# DevPod End testhost`, + }, + { + name: "Host addition to existing config", + config: `Host existinghost + User existinguser`, + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "", + gpgagent: false, + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost +Host existinghost + User existinguser`, + }, + { + name: "Host addition to existing config with DevPod host", + config: `# DevPod Start existingtesthost +Host existingtesthost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost + +Host existinghost + User existinguser`, + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "", + gpgagent: false, + expected: `# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost +# DevPod Start existingtesthost +Host existingtesthost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost + +Host existinghost + User existinguser`, + }, + { + name: "Host addition after top level includes", + config: `Include ~/config1 + +Include ~/config2 + + + +Include ~/config3`, + execPath: "/path/to/exec", + host: "testhost", + user: "testuser", + context: "testcontext", + workspace: "testworkspace", + workdir: "", + command: "", + gpgagent: false, + expected: `Include ~/config1 + +Include ~/config2 + + + +Include ~/config3 +# DevPod Start testhost +Host testhost + ForwardAgent yes + LogLevel error + StrictHostKeyChecking no + UserKnownHostsFile /dev/null + HostKeyAlgorithms rsa-sha2-256,rsa-sha2-512,ssh-rsa + ProxyCommand "/path/to/exec" ssh --stdio --context testcontext --user testuser testworkspace + User testuser +# DevPod End testhost`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := addHostSection(tt.config, tt.execPath, tt.host, tt.user, tt.context, tt.workspace, tt.workdir, tt.command, tt.gpgagent) + if err != nil { + t.Errorf("Failed with err: %v", err) + } + + if result != tt.expected { + t.Errorf("addHostSection result does not match expected.\nGot:\n%s\nExpected:\n%s", result, tt.expected) + t.Errorf("addHostSection result does not match expected:\n%s", cmp.Diff(result, tt.expected)) + } + + if !strings.Contains(result, MarkerEndPrefix+tt.host) { + t.Errorf("Result does not contain the end marker: %s", MarkerEndPrefix+tt.host) + } + + if !strings.Contains(result, "Host "+tt.host) { + t.Errorf("Result does not contain the Host line: Host %s", tt.host) + } + + if !strings.Contains(result, "User "+tt.user) { + t.Errorf("Result does not contain the User line: User %s", tt.user) + } + + if tt.command != "" && !strings.Contains(result, fmt.Sprintf("ProxyCommand \"%s\"", tt.command)) { + t.Errorf("Result does not contain the custom ProxyCommand: %s", tt.command) + } + + if tt.workdir != "" && !strings.Contains(result, "--workdir "+tt.workdir) { + t.Errorf("Result does not contain the workdir: %s", tt.workdir) + } + + if tt.gpgagent && !strings.Contains(result, "--gpg-agent-forwarding") { + t.Errorf("Result does not contain gpg-agent-forwarding flag") + } + + if tt.config != "" && !strings.Contains(result, tt.config) { + t.Errorf("Result does not contain the original config") + } + }) + } +}