diff --git a/powershell.go b/powershell.go new file mode 100644 index 0000000..e390458 --- /dev/null +++ b/powershell.go @@ -0,0 +1,11 @@ +package main + +const ( + psPurge = ` +Get-ChildItem -Path cert:\CurrentUser\My -Recurse -EKU "*Client Authentication*" -ExpiringInDays 0 | + Remove-Item +` + psImport = ` +Import-PfxCertificate -FilePath %s -CertStoreLocation Cert:\CurrentUser\My +` +) diff --git a/prodaccess_nix.go b/prodaccess_nix.go index facfca7..7599ee5 100644 --- a/prodaccess_nix.go +++ b/prodaccess_nix.go @@ -10,7 +10,10 @@ import ( "log" "os" "os/exec" + "path" "strings" + + "golang.org/x/sys/unix" ) var ( @@ -84,27 +87,23 @@ func saveVmwareCertificate(c string, k string) { kf.Write([]byte(k)) cp := cf.Name() kp := kf.Name() + defer os.Remove(cp) + defer os.Remove(kp) + cf.Close() cf.Close() kf.Close() fp := os.ExpandEnv(*vmwareCertPath) os.Remove(fp) os.OpenFile(fp, os.O_CREATE, 0600) - cmd := exec.Command("/usr/bin/env", "openssl", "pkcs12", "-export", "-password", "pass:", - "-in", cp, "-inkey", kp, "-out", fp) - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr + _, _ = executeWithStdout("openssl", "pkcs12", "-export", "-password", "pass:", "-in", cp, "-inkey", kp, "-out", fp) - err := cmd.Run() - if err != nil { - log.Printf("Failed to emit VMware certificate: %v", err) - log.Printf("Standard output: %s", stdout.String()) - log.Printf("Error output: %s", stderr.String()) + if isWSL() { + if err := importCertFromWSL(fp); err != nil { + log.Printf("Failed to import certificate: %v", err) + } } - os.Remove(cp) - os.Remove(kp) } func saveBrowserCertificate(c string, k string) { @@ -114,25 +113,89 @@ func saveBrowserCertificate(c string, k string) { kf.Write([]byte(k)) cp := cf.Name() kp := kf.Name() + defer os.Remove(cp) + defer os.Remove(kp) cf.Close() kf.Close() fp := os.ExpandEnv(*browserCertPath) os.Remove(fp) os.OpenFile(fp, os.O_CREATE, 0600) - cmd := exec.Command("/usr/bin/env", "openssl", "pkcs12", "-export", "-password", "pass:", - "-in", cp, "-inkey", kp, "-out", fp) + _, _ = executeWithStdout("openssl", "pkcs12", "-export", "-password", "pass:", "-in", cp, "-inkey", kp, "-out", fp) + if isWSL() { + if err := importCertFromWSL(fp); err != nil { + log.Printf("Failed to import certificate: %v", err) + } + } +} + +func isWSL() bool { + u := unix.Utsname{} + _ = unix.Uname(&u) + return strings.Contains(string(u.Release[:]), "Microsoft") +} + +func executeWithStdout(cmd ...string) (string, error) { + return executeWithStdoutWithStdin("", cmd...) +} + +func executeWithStdoutWithStdin(stdin string, cmd ...string) (string, error) { + c := exec.Command("/usr/bin/env", cmd...) var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr + si := bytes.NewBufferString(stdin) + c.Stdout = &stdout + c.Stderr = &stderr + c.Stdin = si - err := cmd.Run() + err := c.Run() if err != nil { - log.Printf("Failed to emit Browser certificate: %v", err) + log.Printf("Failed to execute %v: %v", cmd, err) log.Printf("Standard output: %s", stdout.String()) log.Printf("Error output: %s", stderr.String()) + return "", err } - os.Remove(cp) - os.Remove(kp) + + return stdout.String(), nil +} + +func executeWithStdin(stdin string, cmd ...string) error { + _, err := executeWithStdoutWithStdin(stdin, cmd...) + return err +} + +// If running under WSL invoke PowerShell to import certificate +func importCertFromWSL(pfx string) error { + winpath, err := executeWithStdout("powershell.exe", "echo $env:TEMP") + if err != nil { + return err + } + + winpath = strings.Trim(winpath, "\r\n") + winparts := strings.Split(winpath, "\\") + drive := strings.ToLower(winparts[0][:1]) + winparts = winparts[1:] + temppath := path.Join("/mnt", drive, path.Join(winparts...), "prodaccess.pfx") + + pfxdata, err := ioutil.ReadFile(pfx) + if err != nil { + return err + } + + err = ioutil.WriteFile(temppath, pfxdata, 0644) + if err != nil { + return err + } + defer os.Remove(temppath) + + ps := fmt.Sprintf(psImport, winpath+"\\prodaccess.pfx") + fmt.Printf("%s\n", ps) + o, err := executeWithStdoutWithStdin(ps, "powershell.exe", "-Command", "-") + if err != nil { + return err + } + + log.Printf("Imported certificate: %s", o) + + return executeWithStdin(psPurge, "powershell.exe", "-Command", "-") }