Skip to content

Commit

Permalink
feat(holtwinters): replaced our optimization with gonum optimize (#4816)
Browse files Browse the repository at this point in the history
* feat(holtwinters): replaced our optimization with gonum optimize

Replaced our existing Nelder-Mead optimization with gonum optimize

fixes: #4307

* fix: added benchmark for seasonality

* perf: allocate fix size mem and copy values

* perf: more mem allocation fix

* fix: incorporated review comments

* feat: added minSSE to the api, controlled by the bool param withMinSSE

* fix: default initParams to 1 for the missing/null values

* fix: return user error for NaN/Inf value in the input

added description for minSSE
  • Loading branch information
skartikey authored Jul 12, 2022
1 parent cedce32 commit 95caa7f
Show file tree
Hide file tree
Showing 10 changed files with 510 additions and 172 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ require (
golang.org/x/exp v0.0.0-20211216164055-b2b84827b756
golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
golang.org/x/tools v0.1.9
gonum.org/v1/gonum v0.9.3
gonum.org/v1/gonum v0.11.0
google.golang.org/api v0.47.0
google.golang.org/grpc v1.44.0
gopkg.in/yaml.v2 v2.3.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -976,10 +976,10 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/gonum v0.9.3 h1:DnoIG+QAMaF5NvxnGe/oKsgKcAc6PcUyl8q0VetfQ8s=
gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
gonum.org/v1/gonum v0.11.0 h1:f1IJhK4Km5tBJmaiJXtk/PkL4cdVX6J+tGiM187uT5E=
gonum.org/v1/gonum v0.11.0/go.mod h1:fSG4YDCxxUZQJ7rKsQrj0gMOg00Il0Z96/qMA4bVQhA=
gonum.org/v1/netlib v0.0.0-20181029234149-ec6d1f5cefe6/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
Expand Down
5 changes: 5 additions & 0 deletions internal/mutable/numericarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,8 @@ func (b *Float64Array) Value(i int) float64 {
func (b *Float64Array) Set(i int, v float64) {
b.rawData[i] = v
}

// Float64Values returns the underlying float64 slice.
func (b *Float64Array) Float64Values() []float64 {
return b.rawData
}
4 changes: 2 additions & 2 deletions interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -883,8 +883,8 @@ func TestStack(t *testing.T) {
FunctionName: "window",
Location: ast.SourceLocation{
File: "universe/universe.flux",
Start: ast.Position{Line: 3768, Column: 12},
End: ast.Position{Line: 3768, Column: 51},
Start: ast.Position{Line: 3775, Column: 12},
End: ast.Position{Line: 3775, Column: 51},
Source: `window(every: inf, timeColumn: timeDst)`,
},
},
Expand Down
4 changes: 2 additions & 2 deletions libflux/go/libflux/buildinfo.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ var sourceHashes = map[string]string{
"stdlib/universe/histogram_quantile_test.flux": "8f6ef6a0c738bf70d77215857f3c3a677a1c2ff7c3dc818cfd3774d0fac32478",
"stdlib/universe/histogram_test.flux": "3016cb26398869dbaa2b7c0f35e2c99b105d3df411855da2d562eccefa3ff8a7",
"stdlib/universe/holt_winters_panic_test.flux": "204eb8044d634e5350a364eac466eb51e7f549e4ac7f454de7b244ba272b248f",
"stdlib/universe/holt_winters_test.flux": "3a2a95e6a96cbf588afdd9cdb17f4a7e8c8778361fcf62891fbaed5a44321694",
"stdlib/universe/holt_winters_test.flux": "b7f017eb36fcfd52dea337a8394d26b799efa285e16f320bb569b8593c4d21fa",
"stdlib/universe/hour_selection_test.flux": "af919d0f515984e1155194a05dbe92d533c212bef1eb23e014e862ffb85810a5",
"stdlib/universe/increase_test.flux": "3e94fe06b5b71aadc5ac8b6f070278bba663a9633b4d4b3ec58903792bfe0bc8",
"stdlib/universe/integral_test.flux": "7206d881f059f0e6009fe3f0cfeb18b5fcec9ba192b607f63d375541851fcb85",
Expand Down Expand Up @@ -601,7 +601,7 @@ var sourceHashes = map[string]string{
"stdlib/universe/union_heterogeneous_test.flux": "3298ba8e24903621505c78f4c48e4148f2d7c45e507e19ab46f547756a3173f4",
"stdlib/universe/union_test.flux": "f853a7bf588fedceee217d931733eb5f3b86b1f4717c2af24d59890b3c86f71c",
"stdlib/universe/unique_test.flux": "516e9fea81513c8cbb0c7a23545c9080e56c00149cc40bfc2187c649bfa4c958",
"stdlib/universe/universe.flux": "cd67c0d1148f3a1cf38872e08fab0fb31ac4b2c819ae5d6086f8482a976c7c14",
"stdlib/universe/universe.flux": "d95dc71b65a0674bd47ab3ab447061921a1b0a352f7a2a85ceeaac7ea8fd4ae2",
"stdlib/universe/universe_truncateTimeColumn_test.flux": "8acb700c612e9eba87c0525b33fd1f0528e6139cc912ed844932caef25d37b56",
"stdlib/universe/window_aggregate_test.flux": "cd0a1a7e788a50fa04289aa6e8b557f6c960eaf6ae95f9d8c0ff3044a48b4beb",
"stdlib/universe/window_default_start_align_test.flux": "0aaf612796fbb5ac421579151ad32a8861f4494a314ea615d0ccedd18067b980",
Expand Down
59 changes: 53 additions & 6 deletions stdlib/universe/holt_winters.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package universe

import (
"fmt"
"math"

"github.com/influxdata/flux"
"github.com/influxdata/flux/array"
Expand All @@ -25,6 +26,7 @@ type HoltWintersOpSpec struct {
N int64 `json:"n"`
S int64 `json:"s"`
Interval flux.Duration `json:"interval"`
WithMinSSE bool `json:"with_minsse"`
}

func init() {
Expand Down Expand Up @@ -76,6 +78,11 @@ func createHoltWintersOpSpec(args flux.Arguments, a *flux.Administration) (flux.
} else {
spec.S = 0
}
if withMinSSE, ok, err := args.GetBool("withMinSSE"); err != nil {
return nil, err
} else if ok {
spec.WithMinSSE = withMinSSE
}
return spec, nil
}

Expand All @@ -95,6 +102,7 @@ type HoltWintersProcedureSpec struct {
N int64
S int64
Interval flux.Duration
WithMinSSE bool
}

func newHoltWintersProcedure(qs flux.OperationSpec, pa plan.Administration) (plan.ProcedureSpec, error) {
Expand All @@ -109,6 +117,7 @@ func newHoltWintersProcedure(qs flux.OperationSpec, pa plan.Administration) (pla
N: spec.N,
S: spec.S,
Interval: spec.Interval,
WithMinSSE: spec.WithMinSSE,
}, nil
}

Expand Down Expand Up @@ -149,6 +158,7 @@ type holtWintersTransformation struct {
n int64
s int64
interval values.Duration
withMinSSE bool
}

func NewHoltWintersTransformation(d execute.Dataset, cache execute.TableBuilderCache, alloc memory.Allocator, spec *HoltWintersProcedureSpec) *holtWintersTransformation {
Expand All @@ -162,6 +172,7 @@ func NewHoltWintersTransformation(d execute.Dataset, cache execute.TableBuilderC
n: spec.N,
s: spec.S,
interval: values.Duration(spec.Interval),
withMinSSE: spec.WithMinSSE,
}
}

Expand Down Expand Up @@ -205,6 +216,16 @@ func (hwt *holtWintersTransformation) Process(id execute.DatasetID, tbl flux.Tab
if err != nil {
return err
}
newMinSSEIdx := -1
if hwt.withMinSSE {
newMinSSEIdx, err = builder.AddCol(flux.ColMeta{
Label: "minSSE",
Type: flux.TFloat,
})
if err != nil {
return err
}
}

// Cleaning data for HoltWinters input.
vs, start, stop, err := hwt.getCleanData(tbl, colIdx, timeIdx)
Expand All @@ -214,7 +235,7 @@ func (hwt *holtWintersTransformation) Process(id execute.DatasetID, tbl flux.Tab

// Holt Winters.
hw := holt_winters.New(int(hwt.n), int(hwt.s), hwt.withFit, fluxarrow.NewAllocator(hwt.alloc))
newVs := hw.Do(vs)
newVs, minSSE := hw.Do(vs)
// don't need vs anymore
vs.Release()

Expand Down Expand Up @@ -242,6 +263,21 @@ func (hwt *holtWintersTransformation) Process(id execute.DatasetID, tbl flux.Tab
if err := builder.AppendFloats(newValueIdx, newVs); err != nil {
return err
}

if hwt.withMinSSE {
minSSEb := array.NewFloatBuilder(fluxarrow.NewAllocator(hwt.alloc))
for i := 0; i < newVs.Len(); i++ {
minSSEb.Append(minSSE)
}
newMinSSE := minSSEb.NewFloatArray()
defer func() {
newMinSSE.Release()
}()

if err := builder.AppendFloats(newMinSSEIdx, newMinSSE); err != nil {
return err
}
}
if err := execute.AppendKeyValuesN(tbl.Key(), builder, newVs.Len()); err != nil {
return err
}
Expand Down Expand Up @@ -270,7 +306,7 @@ func (hwt *holtWintersTransformation) getCleanData(tbl flux.Table, colIdx, timeI
bucketEnd += int64(hwt.interval.Duration())
bucketFilled = false
}
appendV := func(cr flux.ColReader, i int) {
appendV := func(cr flux.ColReader, i int) error {
switch typ := tbl.Cols()[colIdx].Type; typ {
case flux.TInt:
c := cr.Ints(colIdx)
Expand All @@ -288,15 +324,20 @@ func (hwt *holtWintersTransformation) getCleanData(tbl flux.Table, colIdx, timeI
}
case flux.TFloat:
c := cr.Floats(colIdx)
if c.IsNull(i) {
if math.IsNaN(c.Value(i)) || math.IsInf(c.Value(i), 0) {
// If there's NaN/Inf in the user input,
// gonum will panic with message caught panic: optimize: initial function value is NaN/Inf
return errors.Newf(codes.Invalid, "NaN/Inf in input")
} else if c.IsNull(i) {
vs.AppendNull()
} else {
vs.Append(float64(c.Value(i)))
}
default:
panic(fmt.Sprintf("cannot append non-numerical type %s", typ.String()))
return errors.Newf(codes.Invalid, "cannot append non-numerical type %s", typ.String())
}
bucketFilled = true
return nil
}
isNull := func(cr flux.ColReader, i int) bool {
switch typ := tbl.Cols()[colIdx].Type; typ {
Expand Down Expand Up @@ -328,7 +369,10 @@ func (hwt *holtWintersTransformation) getCleanData(tbl flux.Table, colIdx, timeI
if isFirst() {
start = trueT
bucketEnd = roundT
appendV(cr, i)
err := appendV(cr, i)
if err != nil {
return err
}
continue
}
if roundT <= bucketEnd && bucketFilled {
Expand All @@ -343,7 +387,10 @@ func (hwt *holtWintersTransformation) getCleanData(tbl flux.Table, colIdx, timeI
nextBucket()
}
// this is the first value for the bucket
appendV(cr, i)
err := appendV(cr, i)
if err != nil {
return err
}
stop = trueT
}
}
Expand Down
67 changes: 39 additions & 28 deletions stdlib/universe/holt_winters/holt_winters.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/influxdata/flux/array"
"github.com/influxdata/flux/arrow"
"github.com/influxdata/flux/internal/mutable"
"gonum.org/v1/gonum/optimize"
)

// HoltWinters forecasts a series into the future.
Expand Down Expand Up @@ -60,11 +61,11 @@ func New(n, s int, withFit bool, alloc memory.Allocator) *HoltWinters {
}

// Do returns the points generated by the HoltWinters algorithm given a dataset.
func (r *HoltWinters) Do(vs *array.Float) *array.Float {
func (r *HoltWinters) Do(vs *array.Float) (*array.Float, float64) {
r.vs = vs
l := vs.Len() // l is the length of both times and values
if l < 2 || r.seasonal && l < r.s || r.n <= 0 {
return arrow.NewFloat(nil, nil)
return arrow.NewFloat(nil, nil), 0
}
m := r.s

Expand Down Expand Up @@ -114,14 +115,30 @@ func (r *HoltWinters) Do(vs *array.Float) *array.Float {
if vs.IsValid(i) {
initParams.Set(i+6, vs.Value(i)/l0)
} else {
initParams.Set(i+6, 0)
// default to 1 for missing/null values
initParams.Set(i+6, 1)
}
}
}

// Determine best fit for the various parameters
minSSE := math.Inf(1)
var bestParams *mutable.Float64Array
var bestParams []float64

// Params for gonum optimize
f := mutable.NewFloat64Array(r.alloc)
f.Resize(size)
defer f.Release()
problem := optimize.Problem{
Func: func(par []float64) float64 {
for i := 0; i < len(par); i++ {
f.Set(i, par[i])
}
return r.sse(f)
},
}
settings := optimize.Settings{Converger: &optimize.FunctionConverge{Absolute: 1e-10, Iterations: 100}}

for alpha := hwGuessLower; alpha < hwGuessUpper; alpha += hwGuessStep {
for beta := hwGuessLower; beta < hwGuessUpper; beta += hwGuessStep {
for gamma := hwGuessLower; gamma < hwGuessUpper; gamma += hwGuessStep {
Expand All @@ -130,33 +147,32 @@ func (r *HoltWinters) Do(vs *array.Float) *array.Float {
initParams.Set(1, beta)
initParams.Set(2, gamma)
initParams.Set(3, phi)
// Optimize creates new parameters every time it is called.
sse, newParams := r.optim.Optimize(r.sse, initParams, r.epsilon, 1)
if sse < minSSE || bestParams == nil {
if bestParams != nil {
// Previous bestParams are not the best anymore. We can release them.
bestParams.Release()
}
minSSE = sse
bestParams = newParams

// Minimize creates new parameters every time it is called.
result, err := optimize.Minimize(problem, initParams.Float64Values(), &settings, &optimize.NelderMead{})
if err != nil {
panic(err)
}
if bestParams != newParams {
// NewParams are not the best. They are useless. Release them.
newParams.Release()

if result.F < minSSE || bestParams == nil {
minSSE = result.F
bestParams = result.X
}
}
}
}
}

// Final forecast
fcast := func() *mutable.Float64Array {
fcast := r.forecast(bestParams, false)
// Now that bestParams have been used to generate the final forecast, they can be released.
defer bestParams.Release()
return fcast
}()
return fcast.NewFloat64Array()
bestParamsF := mutable.NewFloat64Array(r.alloc)
bestParamsF.Resize(size)
defer bestParamsF.Release()
for i := 0; i < len(bestParams); i++ {
bestParamsF.Set(i, bestParams[i])
}
fcast := r.forecast(bestParamsF, false)

return fcast.NewFloat64Array(), minSSE
}

// Using the recursive relations compute the next values
Expand Down Expand Up @@ -259,11 +275,6 @@ func (r *HoltWinters) sse(params *mutable.Float64Array) float64 {
for i := 0; i < fcast.Len(); i++ {
// Skip missing values since we cannot use them to compute an error.
if r.vs.IsValid(i) {
// Compute error
if math.IsNaN(fcast.Value(i)) {
// Penalize fcast NaNs
return math.Inf(1)
}
diff := fcast.Value(i) - r.vs.Value(i)
sse += diff * diff
}
Expand Down
Loading

0 comments on commit 95caa7f

Please sign in to comment.