diff --git a/dense_arithmetic.go b/dense_arithmetic.go old mode 100644 new mode 100755 index 5e7545a..1d24ef2 --- a/dense_arithmetic.go +++ b/dense_arithmetic.go @@ -4,7 +4,10 @@ package matrix -import "runtime" +import ( + "runtime" + "sync/atomic" +) func (A *DenseMatrix) Plus(B MatrixRO) (Matrix, error) { C := A.Copy() @@ -167,32 +170,33 @@ func parTimes1(A, B, C *DenseMatrix) { func parTimes2(A, B, C *DenseMatrix) { const threshold = 8 - currentGoroutineCount := 1 - maxGoroutines := runtime.GOMAXPROCS(0) + 2 + var currentGoroutineCount int32 = 1 + maxGoroutines := int32(runtime.GOMAXPROCS(0) + 2) var aux func(sync chan bool, A, B, C *DenseMatrix, rs, re, cs, ce, ks, ke int) aux = func(sync chan bool, A, B, C *DenseMatrix, rs, re, cs, ce, ks, ke int) { dr := re - rs dc := ce - cs dk := ke - ks + currentCount := atomic.LoadInt32(¤tGoroutineCount) switch { - case currentGoroutineCount < maxGoroutines && dr >= dc && dr >= dk && dr >= threshold: + case currentCount < maxGoroutines && dr >= dc && dr >= dk && dr >= threshold: sync0 := make(chan bool, 1) rm := (rs + re) / 2 - currentGoroutineCount++ + atomic.AddInt32(¤tGoroutineCount, 1) go aux(sync0, A, B, C, rs, rm, cs, ce, ks, ke) aux(nil, A, B, C, rm, re, cs, ce, ks, ke) <-sync0 - currentGoroutineCount-- - case currentGoroutineCount < maxGoroutines && dc >= dk && dc >= dr && dc >= threshold: + atomic.AddInt32(¤tGoroutineCount, -1) + case currentCount < maxGoroutines && dc >= dk && dc >= dr && dc >= threshold: sync0 := make(chan bool, 1) cm := (cs + ce) / 2 - currentGoroutineCount++ + atomic.AddInt32(¤tGoroutineCount, 1) go aux(sync0, A, B, C, rs, re, cs, cm, ks, ke) aux(nil, A, B, C, rs, re, cm, ce, ks, ke) <-sync0 - currentGoroutineCount-- - case currentGoroutineCount < maxGoroutines && dk >= dc && dk >= dr && dk >= threshold: + atomic.AddInt32(¤tGoroutineCount, -1) + case currentCount < maxGoroutines && dk >= dc && dk >= dr && dk >= threshold: km := (ks + ke) / 2 aux(nil, A, B, C, rs, re, cs, ce, ks, km) aux(nil, A, B, C, rs, re, cs, ce, km, ke) @@ -245,8 +249,8 @@ func (A *DenseMatrix) TimesDenseFill(B, C *DenseMatrix) (err error) { default: for i := 0; i < A.rows; i++ { sums := C.elements[i*C.step : (i+1)*C.step] - for k, a := range A.elements[i*A.step : i*A.step + A.cols] { - for j, b := range B.elements[k*B.step : k * B.step + B.cols] { + for k, a := range A.elements[i*A.step : i*A.step+A.cols] { + for j, b := range B.elements[k*B.step : k*B.step+B.cols] { sums[j] += a * b } }