diff --git a/extensions/fnutils/fnutils.lua b/extensions/fnutils/fnutils.lua index d2a0f5a7e..221f0968d 100644 --- a/extensions/fnutils/fnutils.lua +++ b/extensions/fnutils/fnutils.lua @@ -225,13 +225,14 @@ function fnutils.mapCat(t, fn) return nt end ---- hs.fnutils.reduce(table, fn) -> table +--- hs.fnutils.reduce(table, fn[, initial]) -> table --- Function --- Reduce a table to a single element, using a function --- --- Parameters: --- * table - A table containing some sort of data --- * fn - A function that takes two parameters, which will be elements of the supplied table. It should choose one of these elements and return it +--- * initial - If given, the first call to fn will be with this value and the first element of the table --- --- Returns: --- * The element of the supplied table that was chosen by the iterative reducer function @@ -239,13 +240,19 @@ end --- Notes: --- * table cannot be a sparse table, see [http://www.luafaq.org/gotchas.html#T6.4](http://www.luafaq.org/gotchas.html#T6.4) --- * The first iteration of the reducer will call fn with the first and second elements of the table. The second iteration will call fn with the result of the first iteration, and the third element. This repeats until there is only one element left -function fnutils.reduce(t, fn) +function fnutils.reduce(t, fn, ...) + local rest = {...} local len = #t - if len == 0 then return nil end - if len == 1 then return t[1] end - - local result = t[1] - for i = 2, #t do + local start, result + if #rest == 0 then + if len == 0 then return nil end + result = t[1] + start = 2 + else + result = rest[1] + start = 1 + end + for i = start, len do result = fn(result, t[i]) end return result diff --git a/extensions/fnutils/test_fnutils.lua b/extensions/fnutils/test_fnutils.lua new file mode 100644 index 000000000..6755ab7fe --- /dev/null +++ b/extensions/fnutils/test_fnutils.lua @@ -0,0 +1,10 @@ +function test_reduce() + local reduce = hs.fnutils.reduce + + assert(reduce({}, function(x, y) return x + y end) == nil) + assert(reduce({}, function(x, y) return x + y end, 10) == 10) + assert(reduce({1}, function(x, y) return x + y end) == 1) + assert(reduce({1}, function(x, y) return x + y end, 10) == 11) + assert(reduce({1, 2}, function(x, y) return x + y end) == 3) + assert(reduce({1, 2}, function(x, y) return x + y end, 10) == 13) +end