Skip to content

Commit

Permalink
Fix shell command (distribworks#1434)
Browse files Browse the repository at this point in the history
* Fix shell command

In upcoming release we're going to drop window support.

(cherry picked from commit 9186d0c)
  • Loading branch information
Victor Castell authored and ivan-kripakov-m10 committed Dec 21, 2023
1 parent 78e4d7b commit 7a0c8ce
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 112 deletions.
112 changes: 0 additions & 112 deletions builtin/bins/dkron-executor-shell/shell.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
package main

import (
"encoding/base64"
"errors"
"fmt"
"log"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/circbuf"
dkplugin "github.com/distribworks/dkron/v3/plugin"
Expand Down Expand Up @@ -53,111 +46,6 @@ func (s *Shell) Execute(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper)
return resp, nil
}

// ExecuteImpl do execute command
func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) {
output, _ := circbuf.NewBuffer(maxBufSize)

shell, err := strconv.ParseBool(args.Config["shell"])
if err != nil {
shell = false
}
command := args.Config["command"]
env := strings.Split(args.Config["env"], ",")
cwd := args.Config["cwd"]

executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",")
env = append(env, executionInfo...)

cmd, err := buildCmd(command, shell, env, cwd)
if err != nil {
return nil, err
}
err = setCmdAttr(cmd, args.Config)
if err != nil {
return nil, err
}
// use same buffer for both channels, for the full return at the end
cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true}
cmd.Stdout = reportingWriter{buffer: output, cb: cb}

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}

defer stdin.Close()

payload, err := base64.StdEncoding.DecodeString(args.Config["payload"])
if err != nil {
return nil, err
}

stdin.Write(payload)
stdin.Close()

jobTimeout := args.Config["timeout"]
var jt time.Duration

if jobTimeout != "" {
jt, err = time.ParseDuration(jobTimeout)
if err != nil {
return nil, errors.New("shell: Error parsing job timeout")
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

log.Printf("shell: going to run %s", command)

err = cmd.Start()
if err != nil {
return nil, err
}

var jobTimeoutMessage string
var jobTimedOut bool

if jt != 0 {
slowTimer := time.AfterFunc(jt, func() {
// Kill child process to avoid cmd.Wait()
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) // note the minus sign
if err != nil {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt)
} else {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt)
}

jobTimedOut = true
})

defer slowTimer.Stop()
}

quit := make(chan int)

go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit)

err = cmd.Wait()
quit <- cmd.ProcessState.ExitCode()
close(quit) // exit metric refresh goroutine after job is finished

if jobTimedOut {
_, err := output.Write([]byte(jobTimeoutMessage))
if err != nil {
log.Printf("Error writing output on timeout event: %v", err)
}
}

// Warn if buffer is overwritten
if output.TotalWritten() > output.Size() {
log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size())
}

// Always log output
log.Printf("shell: Command output %s", output)

return output.Bytes(), err
}

// Determine the shell invocation based on OS
func buildCmd(command string, useShell bool, env []string, cwd string) (cmd *exec.Cmd, err error) {
var shell, flag string
Expand Down
115 changes: 115 additions & 0 deletions builtin/bins/dkron-executor-shell/shell_unix.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
//go:build !windows
// +build !windows

package main

import (
"encoding/base64"
"errors"
"fmt"
"log"
"os/exec"
"os/user"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/circbuf"
dkplugin "github.com/distribworks/dkron/v3/plugin"
dktypes "github.com/distribworks/dkron/v3/plugin/types"
)

func setCmdAttr(cmd *exec.Cmd, config map[string]string) error {
Expand Down Expand Up @@ -37,3 +47,108 @@ func setCmdAttr(cmd *exec.Cmd, config map[string]string) error {
}
return nil
}

// ExecuteImpl do execute command
func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) {
output, _ := circbuf.NewBuffer(maxBufSize)

shell, err := strconv.ParseBool(args.Config["shell"])
if err != nil {
shell = false
}
command := args.Config["command"]
env := strings.Split(args.Config["env"], ",")
cwd := args.Config["cwd"]

executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",")
env = append(env, executionInfo...)

cmd, err := buildCmd(command, shell, env, cwd)
if err != nil {
return nil, err
}
err = setCmdAttr(cmd, args.Config)
if err != nil {
return nil, err
}
// use same buffer for both channels, for the full return at the end
cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true}
cmd.Stdout = reportingWriter{buffer: output, cb: cb}

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}

defer stdin.Close()

payload, err := base64.StdEncoding.DecodeString(args.Config["payload"])
if err != nil {
return nil, err
}

stdin.Write(payload)
stdin.Close()

jobTimeout := args.Config["timeout"]
var jt time.Duration

if jobTimeout != "" {
jt, err = time.ParseDuration(jobTimeout)
if err != nil {
return nil, errors.New("shell: Error parsing job timeout")
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

log.Printf("shell: going to run %s", command)

err = cmd.Start()
if err != nil {
return nil, err
}

var jobTimeoutMessage string
var jobTimedOut bool

if jt != 0 {
slowTimer := time.AfterFunc(jt, func() {
// Kill child process to avoid cmd.Wait()
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) // note the minus sign
if err != nil {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt)
} else {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt)
}

jobTimedOut = true
})

defer slowTimer.Stop()
}

quit := make(chan int)

go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit)

err = cmd.Wait()
quit <- cmd.ProcessState.ExitCode()
close(quit) // exit metric refresh goroutine after job is finished

if jobTimedOut {
_, err := output.Write([]byte(jobTimeoutMessage))
if err != nil {
log.Printf("Error writing output on timeout event: %v", err)
}
}

// Warn if buffer is overwritten
if output.TotalWritten() > output.Size() {
log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size())
}

// Always log output
log.Printf("shell: Command output %s", output)

return output.Bytes(), err
}
118 changes: 118 additions & 0 deletions builtin/bins/dkron-executor-shell/shell_windows.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,129 @@
//go:build windows
// +build windows

package main

import (
"encoding/base64"
"errors"
"fmt"
"log"
"os/exec"
"strconv"
"strings"
"syscall"
"time"

"github.com/armon/circbuf"
dkplugin "github.com/distribworks/dkron/v3/plugin"
dktypes "github.com/distribworks/dkron/v3/plugin/types"
)

func setCmdAttr(cmd *exec.Cmd, config map[string]string) error {
return nil
}

// ExecuteImpl do execute command
func (s *Shell) ExecuteImpl(args *dktypes.ExecuteRequest, cb dkplugin.StatusHelper) ([]byte, error) {
output, _ := circbuf.NewBuffer(maxBufSize)

shell, err := strconv.ParseBool(args.Config["shell"])
if err != nil {
shell = false
}
command := args.Config["command"]
env := strings.Split(args.Config["env"], ",")
cwd := args.Config["cwd"]

executionInfo := strings.Split(fmt.Sprintf("ENV_JOB_NAME=%s", args.JobName), ",")
env = append(env, executionInfo...)

cmd, err := buildCmd(command, shell, env, cwd)
if err != nil {
return nil, err
}
err = setCmdAttr(cmd, args.Config)
if err != nil {
return nil, err
}
// use same buffer for both channels, for the full return at the end
cmd.Stderr = reportingWriter{buffer: output, cb: cb, isError: true}
cmd.Stdout = reportingWriter{buffer: output, cb: cb}

stdin, err := cmd.StdinPipe()
if err != nil {
return nil, err
}

defer stdin.Close()

payload, err := base64.StdEncoding.DecodeString(args.Config["payload"])
if err != nil {
return nil, err
}

stdin.Write(payload)
stdin.Close()

jobTimeout := args.Config["timeout"]
var jt time.Duration

if jobTimeout != "" {
jt, err = time.ParseDuration(jobTimeout)
if err != nil {
return nil, errors.New("shell: Error parsing job timeout")
}
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
}

log.Printf("shell: going to run %s", command)

err = cmd.Start()
if err != nil {
return nil, err
}

var jobTimeoutMessage string
var jobTimedOut bool

if jt != 0 {
slowTimer := time.AfterFunc(jt, func() {
// Kill child process to avoid cmd.Wait()
err = cmd.Process.Kill()
if err != nil {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. SIGKILL returned error. Job may not have been killed", command, jt)
} else {
jobTimeoutMessage = fmt.Sprintf("shell: Job '%s' execution time exceeding defined timeout %v. Job was killed", command, jt)
}

jobTimedOut = true
})

defer slowTimer.Stop()
}

quit := make(chan int)

go CollectProcessMetrics(args.JobName, cmd.Process.Pid, quit)

err = cmd.Wait()
quit <- cmd.ProcessState.ExitCode()
close(quit) // exit metric refresh goroutine after job is finished

if jobTimedOut {
_, err := output.Write([]byte(jobTimeoutMessage))
if err != nil {
log.Printf("Error writing output on timeout event: %v", err)
}
}

// Warn if buffer is overwritten
if output.TotalWritten() > output.Size() {
log.Printf("shell: Script '%s' generated %d bytes of output, truncated to %d", command, output.TotalWritten(), output.Size())
}

// Always log output
log.Printf("shell: Command output %s", output)

return output.Bytes(), err
}

0 comments on commit 7a0c8ce

Please sign in to comment.