Skip to content

Commit

Permalink
Run jobstack bug (#20)
Browse files Browse the repository at this point in the history
Fixes #18
  • Loading branch information
DiscoRiver authored Oct 14, 2021
1 parent ef73d25 commit cce0781
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
11 changes: 9 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,16 @@ func runStream(c *Config, rs chan Result) {
// run sets up goroutines, worker pool, and returns the command results for all hosts as a slice of Result. This can cause
// excessive memory usage if returning a large amount of data for a large number of hosts.
func run(c *Config) (res []Result) {
// Channels length is always how many hosts we have multiplied by the number of jobs we're running.
var resultChanLength int
if c.JobStack != nil {
resultChanLength = len(c.Hosts) * len(*c.JobStack)
} else {
resultChanLength = len(c.Hosts)
}
// Channels length is always how many hosts we have
hosts := make(chan string, len(c.Hosts))
results := make(chan Result, len(c.Hosts))
results := make(chan Result, resultChanLength)

// Set up a worker pool that will accept hosts on the hosts channel.
for i := 0; i < c.WorkerPool; i++ {
Expand All @@ -251,7 +258,7 @@ func run(c *Config) (res []Result) {
}
close(hosts)

for r := 0; r < len(c.Hosts); r++ {
for r := 0; r < resultChanLength; r++ {
res = append(res, <-results)
}

Expand Down
10 changes: 8 additions & 2 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func TestBulkWithJobStack(t *testing.T) {
}()

// Add our stack
testConfig.JobStack = &[]Job{*testJob, *testJob2}
testConfig.JobStack = &[]Job{*testJob, *testJob2, *testJob3}

if err := testConfig.SetPrivateKeyAuth("~/.ssh/id_rsa", ""); err != nil {
t.Log(err)
Expand All @@ -226,12 +226,18 @@ func TestBulkWithJobStack(t *testing.T) {
t.FailNow()
}

expectedLength := len(*testConfig.JobStack)*len(testConfig.Hosts)
if len(res) != expectedLength {
t.Logf("Expected %d results, got %d", expectedLength, len(res))
t.FailNow()
}

for i := range res {
if !strings.Contains(string(res[i].Output), "Hello, World") {
t.Logf("Expected output from bulk test not received from host %s: \n \t Output: %s \n \t Error: %s\n", res[i].Host, res[i].Output, res[i].Error)
t.FailNow()
}
fmt.Println(res[i].Host, ": ", string(res[i].Output))
fmt.Printf("%s: %s", res[i].Host, res[i].Output)
}
}

Expand Down

0 comments on commit cce0781

Please sign in to comment.