From b751dfb3ec4dd87df8413ccdf000bba59d68b6d5 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewa Date: Mon, 11 Jan 2021 17:05:06 +0100 Subject: [PATCH] Table: grouping (#1392) --- .../std-lib/Base/src/Data/Vector.enso | 2 +- .../std-lib/Table/src/Data/Column.enso | 110 ++++++++++++++++++ .../std-lib/Table/src/Data/Table.enso | 75 +++++++++++- distribution/std-lib/Table/src/Main.enso | 2 +- .../expression/builtin/mutable/CopyNode.java | 50 +++++++- .../interpreter/runtime/builtin/Error.java | 15 ++- .../operation/aggregate/Aggregator.java | 28 +++++ .../operation/aggregate/CountAggregator.java | 32 +++++ .../aggregate/FunctionAggregator.java | 55 +++++++++ .../numeric/LongToLongAggregator.java | 59 ++++++++++ .../aggregate/numeric/NumericAggregator.java | 78 +++++++++++++ .../data/column/storage/BoolStorage.java | 8 +- .../data/column/storage/DoubleStorage.java | 29 +++-- .../data/column/storage/LongStorage.java | 86 ++++++++++++-- .../data/column/storage/NumericStorage.java | 59 ++++++++++ .../data/column/storage/ObjectStorage.java | 3 +- .../table/data/column/storage/Storage.java | 47 +++++++- .../data/column/storage/StringStorage.java | 8 +- .../enso/table/data/index/DefaultIndex.java | 10 ++ .../org/enso/table/data/index/HashIndex.java | 66 +++++------ .../java/org/enso/table/data/index/Index.java | 10 ++ .../org/enso/table/data/table/Column.java | 9 ++ .../java/org/enso/table/data/table/Table.java | 31 +++++ .../data/table/aggregate/AggregateColumn.java | 58 +++++++++ .../data/table/aggregate/AggregateTable.java | 54 +++++++++ test/Table_Tests/src/Table_Spec.enso | 22 ++++ 26 files changed, 921 insertions(+), 85 deletions(-) create mode 100644 table/src/main/java/org/enso/table/data/column/operation/aggregate/Aggregator.java create mode 100644 table/src/main/java/org/enso/table/data/column/operation/aggregate/CountAggregator.java create mode 100644 table/src/main/java/org/enso/table/data/column/operation/aggregate/FunctionAggregator.java create mode 100644 table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/LongToLongAggregator.java create mode 100644 table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/NumericAggregator.java create mode 100644 table/src/main/java/org/enso/table/data/column/storage/NumericStorage.java create mode 100644 table/src/main/java/org/enso/table/data/table/aggregate/AggregateColumn.java create mode 100644 table/src/main/java/org/enso/table/data/table/aggregate/AggregateTable.java diff --git a/distribution/std-lib/Base/src/Data/Vector.enso b/distribution/std-lib/Base/src/Data/Vector.enso index 561f62bde498..fb24265e6b36 100644 --- a/distribution/std-lib/Base/src/Data/Vector.enso +++ b/distribution/std-lib/Base/src/Data/Vector.enso @@ -462,7 +462,7 @@ type Vector not want to sort in place on the original vector, as `sort` is not intended to be mutable. new_vec_arr = Array.new this.length - this.to_array.copy 0 new_vec_arr 0 this.length + Array.copy this.to_array 0 new_vec_arr 0 this.length ## As we want to account for both custom projections and custom comparisons we need to construct a comparator for internal use that diff --git a/distribution/std-lib/Table/src/Data/Column.enso b/distribution/std-lib/Table/src/Data/Column.enso index cda5c0da29c9..204c92a0cfd2 100644 --- a/distribution/std-lib/Table/src/Data/Column.enso +++ b/distribution/std-lib/Table/src/Data/Column.enso @@ -262,10 +262,120 @@ type Column fields = Map.singleton "name" (Json.String name) . insert "data" storage_json Json.Object fields + ## Efficiently joins two tables based on either the index or the specified + key column. + + The resulting table contains rows of `this` extended with rows of + `other` with matching indexes. If the index values in `other` are not + unique, the corresponding rows of `this` will be duplicated in the + result. + + Arguments: + - other: the table being the right operand of this join operation. + - on: the column of `this` that should be used as the join key. If + this argument is not provided, the index of `this` will be used. + - drop_unmatched: whether the rows of `this` without corresponding + matches in `other` should be dropped from the result. + - left_suffix: a suffix that should be added to the columns of `this` + when there's a name conflict with a column of `other`. + - right_suffix: a suffix that should be added to the columns of `other` + when there's a name conflict with a column of `this`. + join : Table.Table | Column -> Text | Nothing -> Boolean -> Text -> Text -> Table + join other on=Nothing drop_unmatched=False left_suffix='_left' right_suffix='_right' = + this.to_table.join other on drop_unmatched left_suffix right_suffix + + ## Converts this column into a single-column table. + to_table : Table.Table + to_table = Table.Table (this.java_column.toTable []) + ## Creates a new column given a name and a vector of elements. from_vector : Text -> Vector -> Column from_vector name items = Column (Java_Column.fromItems [name, items.to_array]) +## Wraps a column grouped by its index. Allows performing aggregation operations + on the contained values. +type Aggregate_Column + type Aggregate_Column java_column + + ## Converts this aggregate column into a column, aggregating groups + with the provided `function`. + + Arguments: + - function: the function used for value aggregation. Values belonging + to each group are passed to this function in a vector. + - skip_missing: controls whether missing values should be included + in groups. + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + reduce : (Vector.Vector -> Any) -> Boolean -> Text -> Column + reduce function skip_missing=True name_suffix="_result" = + f arr = function (Vector.Vector arr) + r = this.java_column.aggregate [Nothing, name_suffix, f, skip_missing] + Column r + + ## Sums the values in each group. + + Arguments: + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + sum : Text -> Column + sum name_suffix='_sum' = + r = this.java_column.aggregate ['sum', name_suffix, (x-> Vector.Vector x . reduce (+)), True] + Column r + + ## Computes the maximum element of each group. + + Arguments: + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + max : Text -> Column + max name_suffix='_max' = + r = this.java_column.aggregate ['max', name_suffix, (x-> Vector.Vector x . reduce max), True] + Column r + + ## Computes the minimum element of each group. + + Arguments: + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + min : Text -> Column + min name_suffix='_min' = + r = this.java_column.aggregate ['min', name_suffix, (x-> Vector.Vector x . reduce min), True] + Column r + + ## Computes the number of non-missing elements in each group. + + Arguments: + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + count : Text -> Column + count name_suffix='_count' = + r = this.java_column.aggregate ['count', name_suffix, (x-> x.length), True] + Column r + + ## Computes the mean of non-missing elements in each group. + + Arguments: + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + mean : Text -> Column + mean name_suffix='_mean' = + vec_mean v = if v.length == 0 then Nothing else + (Vector.Vector v).reduce (+) / v.length + r = this.java_column.aggregate ['mean', name_suffix, vec_mean, True] + Column r + + ## Gathers all elements in a group into a vector and returns a column of + such vectors. + + Arguments: + - name_suffix: a suffix that will be appended to the original column + name to generate the resulting column name. + values : Text -> Column + values name_suffix='_values' = + r = this.java_column.aggregate [Nothing, name_suffix, Vector.Vector, False] + Column r + ## PRIVATE run_vectorized_binary_op column name fallback_fn operand = case operand of Column col2 -> diff --git a/distribution/std-lib/Table/src/Data/Table.enso b/distribution/std-lib/Table/src/Data/Table.enso index c4c0c7631f56..2b8e2636a0e9 100644 --- a/distribution/std-lib/Table/src/Data/Table.enso +++ b/distribution/std-lib/Table/src/Data/Table.enso @@ -84,11 +84,13 @@ type Table select : Vector -> Table select columns = Table (this.java_table.selectColumns [columns.to_array]) - ## Efficiently joins two tables based on either the index or a key column. + ## Efficiently joins two tables based on either the index or the specified + key column. The resulting table contains rows of `this` extended with rows of - `other` with matching indexes. If the index in `other` is not unique, - the corresponding rows of `this` will be duplicated in the result. + `other` with matching indexes. If the index values in `other` are not + unique, the corresponding rows of `this` will be duplicated in the + result. Arguments: - other: the table being the right operand of this join operation. @@ -100,9 +102,12 @@ type Table when there's a name conflict with a column of `other`. - right_suffix: a suffix that should be added to the columns of `other` when there's a name conflict with a column of `this`. - join : Table -> Text | Nothing -> Boolean -> Text -> Text -> Table + join : Table | Column.Column -> Text | Nothing -> Boolean -> Text -> Text -> Table join other on=Nothing drop_unmatched=False left_suffix='_left' right_suffix='_right' = - Table (this.java_table.join [other.java_table, drop_unmatched, on, left_suffix, right_suffix]) + case other of + Column.Column _ -> this.join other.to_table on drop_unmatched left_suffix right_suffix + Table t -> + Table (this.java_table.join [t, drop_unmatched, on, left_suffix, right_suffix]) ## Returns a new Table without rows that contained missing values in any of the columns. @@ -136,6 +141,59 @@ type Table cols = this.columns here.new [["Column", cols.map name], ["Items Count", cols.map count], ["Storage Type", cols.map storage_type]] + ## Returns an aggregate table resulting from grouping the elements by the + value of the specified column. + + If the `by` argument is not set, the index is used for grouping instead. + + > Example + Creates a simple table and computes aggregation statistics: + name = ['name', ["foo", "bar", "foo", "baz", "foo", "bar", "quux"]] + price = ['price', [0.4, 3.5, Nothing, 6.7, Nothing, 97, Nothing]] + quantity = ['quantity', [10, 20, 30, 40, 50, 60, 70]] + t = Table.new [name, price, quantity] + + agg = t.group by='name' + + records_num = agg.count + total_quantity = agg.at 'quantity' . sum + mean_price = agg.at 'price' . mean + + Table.join [records_num, total_quantity, mean_price] + group : Text | Nothing -> Aggregate_Table + group by=Nothing = + Aggregate_Table (this.java_table.group [by]) + +## Represents a table with grouped rows. +type Aggregate_Table + type Aggregate_Table java_table + + ## Returns a vector of aggregate columns in this table. + columns : Vector.Vector + columns = Vector.Vector (this.java_table.getColumns []) . map Column.Aggregate_Column + + ## Returns a table containing columns resulting from calling `values` on + each column in `this`. + values : Table + values = this.columns . map (_.values name_suffix='') . reduce join + + ## Returns a column containing the number of elements in each group. + count : Column + count = Column.Column (this.java_table.count []) + + ## Returns an aggregate column with the given name, contained in this table. + at : Text -> Column | Nothing + at name = case this.java_table.getColumnByName [name] of + Nothing -> Nothing + c -> Column.Aggregate_Column c + + ## Prints an ASCII-art table with this data to the standard output. + + Arguments: + - show_rows: the number of initial rows that should be displayed. + print : Integer -> Nothing + print show_rows=10 = this.values.print show_rows + ## PRIVATE from_columns cols = Table (Java_Table.new [cols.to_array].to_array) @@ -149,6 +207,13 @@ new columns = Column.from_vector (c.at 0) (c.at 1) . java_column here.from_columns cols +## Joins a vector of tables (or columns) into a single table, using each table's + index as the join key. Particularly useful for joining multiple columns + derived from one original table into a new table. +join : Vector -> Table +join tables = + tables.reduce join + ## PRIVATE pad txt len = true_len = txt.characters.length diff --git a/distribution/std-lib/Table/src/Main.enso b/distribution/std-lib/Table/src/Main.enso index 330058735559..9f87c8cd3b40 100644 --- a/distribution/std-lib/Table/src/Main.enso +++ b/distribution/std-lib/Table/src/Main.enso @@ -6,7 +6,7 @@ import Table.Data.Column from Table.Io.Csv export all hiding Parser export Table.Data.Column -from Table.Data.Table export new +from Table.Data.Table export new, join ## Converts a JSON array into a dataframe, by looking up the requested keys from each item. diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/mutable/CopyNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/mutable/CopyNode.java index 67b18280c2c5..f1f87d4e5006 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/mutable/CopyNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/mutable/CopyNode.java @@ -2,12 +2,19 @@ import com.oracle.truffle.api.TruffleLanguage.ContextReference; import com.oracle.truffle.api.dsl.CachedContext; +import com.oracle.truffle.api.dsl.Fallback; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.interop.InvalidArrayIndexException; +import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.nodes.Node; import org.enso.interpreter.Language; import org.enso.interpreter.dsl.BuiltinMethod; import org.enso.interpreter.runtime.Context; +import org.enso.interpreter.runtime.builtin.Builtins; import org.enso.interpreter.runtime.data.Array; +import org.enso.interpreter.runtime.error.PanicException; @BuiltinMethod(type = "Array", name = "copy", description = "Copies one array to another.") public abstract class CopyNode extends Node { @@ -16,18 +23,51 @@ static CopyNode build() { return CopyNodeGen.create(); } - abstract Object execute(Array _this, long source_index, Array that, long dest_index, long count); + abstract Object execute( + Object _this, Object src, long source_index, Array that, long dest_index, long count); @Specialization Object doArray( - Array _this, + Object _this, + Array src, long source_index, Array that, long dest_index, long count, - @CachedContext(Language.class) ContextReference ctxRef) { + @CachedContext(Language.class) Context ctx) { System.arraycopy( - _this.getItems(), (int) source_index, that.getItems(), (int) dest_index, (int) count); - return ctxRef.get().getBuiltins().nothing().newInstance(); + src.getItems(), (int) source_index, that.getItems(), (int) dest_index, (int) count); + return ctx.getBuiltins().nothing().newInstance(); + } + + @Specialization(guards = "arrays.hasArrayElements(src)") + Object doPolyglotArray( + Object _this, + Object src, + long source_index, + Array that, + long dest_index, + long count, + @CachedLibrary(limit = "3") InteropLibrary arrays, + @CachedContext(Language.class) Context ctx) { + try { + for (int i = 0; i < count; i++) { + that.getItems()[(int) dest_index + i] = arrays.readArrayElement(src, source_index + i); + } + } catch (UnsupportedMessageException e) { + throw new IllegalStateException("Unreachable"); + } catch (InvalidArrayIndexException e) { + throw new PanicException( + ctx.getBuiltins().error().makeInvalidArrayIndexError(src, e.getInvalidIndex()), this); + } + return ctx.getBuiltins().nothing().newInstance(); + } + + @Fallback + Object doOther( + Object _this, Object src, long source_index, Array that, long dest_index, long count) { + Builtins builtins = lookupContextReference(Language.class).get().getBuiltins(); + throw new PanicException( + builtins.error().makeTypeError(builtins.mutable().array().newInstance(), src), this); } } diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java index 9646cbf8e688..bc56af39c0b2 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/builtin/Error.java @@ -19,6 +19,7 @@ public class Error { private final AtomConstructor polyglotError; private final AtomConstructor moduleNotInPackageError; private final AtomConstructor arithmeticError; + private final AtomConstructor invalidArrayIndexError; private final Atom arithmeticErrorShiftTooBig; @@ -68,6 +69,11 @@ public Error(Language language, ModuleScope scope) { .initializeFields( new ArgumentDefinition(0, "message", ArgumentDefinition.ExecutionMode.EXECUTE)); arithmeticErrorShiftTooBig = arithmeticError.newInstance(shiftTooBigMessage); + invalidArrayIndexError = + new AtomConstructor("Invalid_Array_Index_Error", scope) + .initializeFields( + new ArgumentDefinition(0, "array", ArgumentDefinition.ExecutionMode.EXECUTE), + new ArgumentDefinition(1, "index", ArgumentDefinition.ExecutionMode.EXECUTE)); scope.registerConstructor(syntaxError); scope.registerConstructor(compileError); @@ -77,6 +83,7 @@ public Error(Language language, ModuleScope scope) { scope.registerConstructor(polyglotError); scope.registerConstructor(moduleNotInPackageError); scope.registerConstructor(arithmeticError); + scope.registerConstructor(invalidArrayIndexError); } /** @return the builtin {@code Syntax_Error} atom constructor. */ @@ -85,7 +92,9 @@ public AtomConstructor syntaxError() { } /** @return the builtin {@code Type_Error} atom constructor. */ - public AtomConstructor typeError() { return typeError; } + public AtomConstructor typeError() { + return typeError; + } /** @return the builtin {@code Compile_Error} atom constructor. */ public AtomConstructor compileError() { @@ -153,4 +162,8 @@ public Atom makeArithmeticError(Text reason) { public Atom getShiftAmountTooLargeError() { return arithmeticErrorShiftTooBig; } + + public Atom makeInvalidArrayIndexError(Object array, Object index) { + return invalidArrayIndexError.newInstance(array, index); + } } diff --git a/table/src/main/java/org/enso/table/data/column/operation/aggregate/Aggregator.java b/table/src/main/java/org/enso/table/data/column/operation/aggregate/Aggregator.java new file mode 100644 index 000000000000..019c3fc5348e --- /dev/null +++ b/table/src/main/java/org/enso/table/data/column/operation/aggregate/Aggregator.java @@ -0,0 +1,28 @@ +package org.enso.table.data.column.operation.aggregate; + +import org.enso.table.data.column.storage.Storage; + +import java.util.List; + +/** + * Represents a fold-like operation on a storage. An aggregator is usually created for a given + * storage, then {@link #nextGroup(List)} is repeatedly called and the aggregator is responsible for + * collecting the results of such calls. After that, {@link #seal()} is called to obtain a storage + * containing all the results. + */ +public abstract class Aggregator { + /** + * Requests the aggregator to append the result of aggregating the values at the specified + * positions. + * + * @param positions the positions to aggregate in this round. + */ + public abstract void nextGroup(List positions); + + /** + * Returns the results of all previous {@link #nextGroup(List)} calls. + * + * @return the storage containing all aggregation results. + */ + public abstract Storage seal(); +} diff --git a/table/src/main/java/org/enso/table/data/column/operation/aggregate/CountAggregator.java b/table/src/main/java/org/enso/table/data/column/operation/aggregate/CountAggregator.java new file mode 100644 index 000000000000..5bb590f0a5e8 --- /dev/null +++ b/table/src/main/java/org/enso/table/data/column/operation/aggregate/CountAggregator.java @@ -0,0 +1,32 @@ +package org.enso.table.data.column.operation.aggregate; + +import org.enso.table.data.column.storage.LongStorage; +import org.enso.table.data.column.storage.Storage; + +import java.util.List; + +/** Aggregates a storage by counting the non-missing values in each group. */ +public class CountAggregator extends Aggregator { + private final Storage storage; + private final long[] counts; + private int position = 0; + + /** + * @param storage the storage used as data source + * @param resultSize the exact number of times {@link #nextGroup(List)} will be called. + */ + public CountAggregator(Storage storage, int resultSize) { + this.storage = storage; + this.counts = new long[resultSize]; + } + + @Override + public void nextGroup(List positions) { + counts[position++] = positions.stream().filter(i -> !storage.isNa(i)).count(); + } + + @Override + public Storage seal() { + return new LongStorage(counts); + } +} diff --git a/table/src/main/java/org/enso/table/data/column/operation/aggregate/FunctionAggregator.java b/table/src/main/java/org/enso/table/data/column/operation/aggregate/FunctionAggregator.java new file mode 100644 index 000000000000..9d2fbceafed0 --- /dev/null +++ b/table/src/main/java/org/enso/table/data/column/operation/aggregate/FunctionAggregator.java @@ -0,0 +1,55 @@ +package org.enso.table.data.column.operation.aggregate; + +import org.enso.table.data.column.builder.object.InferredBuilder; +import org.enso.table.data.column.storage.Storage; + +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** Aggregates the storage using a provided {@link Function}. */ +public class FunctionAggregator extends Aggregator { + private final Function, Object> aggregateFunction; + private final boolean skipNa; + private final Storage storage; + private final InferredBuilder builder; + + /** + * @param aggregateFunction the function used to obtain aggregation of a group + * @param storage the storage serving as data source + * @param skipNa whether missing values should be passed to the function + * @param resultSize the number of times {@link #nextGroup(List)} will be called + */ + public FunctionAggregator( + Function, Object> aggregateFunction, + Storage storage, + boolean skipNa, + int resultSize) { + this.aggregateFunction = aggregateFunction; + this.storage = storage; + this.skipNa = skipNa; + this.builder = new InferredBuilder(resultSize); + } + + @Override + public void nextGroup(List positions) { + List items = getItems(positions); + Object result = aggregateFunction.apply(items); + builder.append(result); + } + + private List getItems(List positions) { + Stream items = positions.stream().map(storage::getItemBoxed); + if (skipNa) { + items = items.filter(Objects::nonNull); + } + return items.collect(Collectors.toList()); + } + + @Override + public Storage seal() { + return builder.seal(); + } +} diff --git a/table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/LongToLongAggregator.java b/table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/LongToLongAggregator.java new file mode 100644 index 000000000000..3cbc4d11ee95 --- /dev/null +++ b/table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/LongToLongAggregator.java @@ -0,0 +1,59 @@ +package org.enso.table.data.column.operation.aggregate.numeric; + +import org.enso.table.data.column.operation.aggregate.Aggregator; +import org.enso.table.data.column.storage.LongStorage; +import org.enso.table.data.column.storage.Storage; + +import java.util.BitSet; +import java.util.List; +import java.util.stream.LongStream; + +/** An aggregator consuming a {@link LongStorage} and returning a {@link LongStorage} */ +public abstract class LongToLongAggregator extends Aggregator { + private final LongStorage storage; + private final long[] items; + private final BitSet missing; + private int position = 0; + + /** + * @param storage the data source + * @param resultSize the number of times {@link #nextGroup(List)} will be called + */ + public LongToLongAggregator(LongStorage storage, int resultSize) { + this.storage = storage; + this.items = new long[resultSize]; + this.missing = new BitSet(); + } + + /** Used by subclasses to return a missing value from a given group. */ + protected void submitMissing() { + missing.set(position++); + } + + /** + * Used by subclasses to return a value from a given group. + * + * @param value the return value of a group + */ + protected void submit(long value) { + items[position++] = value; + } + + /** + * Runs the aggregation on a particular set of values. + * + * @param items the values contained in the current group + */ + protected abstract void runGroup(LongStream items); + + @Override + public void nextGroup(List positions) { + LongStream items = positions.stream().filter(x -> !storage.isNa(x)).mapToLong(storage::getItem); + runGroup(items); + } + + @Override + public Storage seal() { + return new LongStorage(items, items.length, missing); + } +} diff --git a/table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/NumericAggregator.java b/table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/NumericAggregator.java new file mode 100644 index 000000000000..d3f76d00b569 --- /dev/null +++ b/table/src/main/java/org/enso/table/data/column/operation/aggregate/numeric/NumericAggregator.java @@ -0,0 +1,78 @@ +package org.enso.table.data.column.operation.aggregate.numeric; + +import org.enso.table.data.column.operation.aggregate.Aggregator; +import org.enso.table.data.column.storage.DoubleStorage; +import org.enso.table.data.column.storage.NumericStorage; +import org.enso.table.data.column.storage.Storage; + +import java.util.BitSet; +import java.util.List; +import java.util.OptionalDouble; +import java.util.stream.DoubleStream; + +/** + * An aggregator sourcing data from any {@link NumericStorage} and returning a {@link + * DoubleStorage}. + */ +public abstract class NumericAggregator extends Aggregator { + private final NumericStorage storage; + private final long[] data; + private final BitSet missing; + private int position = 0; + + /** + * @param storage the data source + * @param resultSize the number of times {@link #nextGroup(List)} will be called + */ + public NumericAggregator(NumericStorage storage, int resultSize) { + this.storage = storage; + this.data = new long[resultSize]; + this.missing = new BitSet(); + } + + /** + * Runs the aggregation on a particular set of values. + * + * @param elements the values contained in the current group + */ + protected abstract void runGroup(DoubleStream elements); + + /** + * Used by subclasses to return a value from a given group. + * + * @param value the return value of a group + */ + protected void submit(double value) { + data[position++] = Double.doubleToRawLongBits(value); + } + + /** + * Used by subclasses to return a value from a given group. + * + * @param value the return value of a group + */ + protected void submit(OptionalDouble value) { + if (value.isPresent()) { + submit(value.getAsDouble()); + } else { + submitMissing(); + } + } + + /** Used by subclasses to return a missing value from a given group. */ + protected void submitMissing() { + missing.set(position++); + } + + @Override + public void nextGroup(List positions) { + DoubleStream elements = + positions.stream().filter(i -> !storage.isNa(i)).mapToDouble(storage::getItemDouble); + runGroup(elements); + } + + @Override + public Storage seal() { + return new DoubleStorage(data, data.length, missing); + } +} diff --git a/table/src/main/java/org/enso/table/data/column/storage/BoolStorage.java b/table/src/main/java/org/enso/table/data/column/storage/BoolStorage.java index 0394630a1b45..8b1a7f069498 100644 --- a/table/src/main/java/org/enso/table/data/column/storage/BoolStorage.java +++ b/table/src/main/java/org/enso/table/data/column/storage/BoolStorage.java @@ -156,7 +156,7 @@ public boolean isNegated() { private static MapOpStorage buildOps() { MapOpStorage ops = new MapOpStorage<>(); ops.add( - new UnaryMapOperation<>(Ops.NOT) { + new UnaryMapOperation<>(Maps.NOT) { @Override protected Storage run(BoolStorage storage) { return new BoolStorage( @@ -164,7 +164,7 @@ protected Storage run(BoolStorage storage) { } }) .add( - new MapOperation<>(Ops.EQ) { + new MapOperation<>(Maps.EQ) { @Override public Storage runMap(BoolStorage storage, Object arg) { if (arg instanceof Boolean) { @@ -196,7 +196,7 @@ public Storage runZip(BoolStorage storage, Storage arg) { } }) .add( - new MapOperation<>(Ops.AND) { + new MapOperation<>(Maps.AND) { @Override public Storage runMap(BoolStorage storage, Object arg) { if (arg instanceof Boolean) { @@ -240,7 +240,7 @@ public Storage runZip(BoolStorage storage, Storage arg) { } }) .add( - new MapOperation<>(Ops.OR) { + new MapOperation<>(Maps.OR) { @Override public Storage runMap(BoolStorage storage, Object arg) { if (arg instanceof Boolean) { diff --git a/table/src/main/java/org/enso/table/data/column/storage/DoubleStorage.java b/table/src/main/java/org/enso/table/data/column/storage/DoubleStorage.java index 2ca977618c93..e640f4f46c8d 100644 --- a/table/src/main/java/org/enso/table/data/column/storage/DoubleStorage.java +++ b/table/src/main/java/org/enso/table/data/column/storage/DoubleStorage.java @@ -9,7 +9,7 @@ import org.enso.table.data.index.Index; /** A column containing floating point numbers. */ -public class DoubleStorage extends Storage { +public class DoubleStorage extends NumericStorage { private final long[] data; private final BitSet isMissing; private final int size; @@ -47,6 +47,11 @@ public double getItem(long idx) { return Double.longBitsToDouble(data[(int) idx]); } + @Override + public double getItemDouble(int idx) { + return getItem(idx); + } + @Override public Object getItemBoxed(int idx) { return isMissing.get(idx) ? null : Double.longBitsToDouble(data[idx]); @@ -159,56 +164,56 @@ public BitSet getIsMissing() { private static MapOpStorage buildOps() { MapOpStorage ops = new MapOpStorage<>(); ops.add( - new DoubleNumericOp(Ops.ADD) { + new DoubleNumericOp(Maps.ADD) { @Override protected double doDouble(double a, double b) { return a + b; } }) .add( - new DoubleNumericOp(Ops.SUB) { + new DoubleNumericOp(Maps.SUB) { @Override protected double doDouble(double a, double b) { return a - b; } }) .add( - new DoubleNumericOp(Ops.MUL) { + new DoubleNumericOp(Maps.MUL) { @Override protected double doDouble(double a, double b) { return a * b; } }) .add( - new DoubleNumericOp(Ops.DIV) { + new DoubleNumericOp(Maps.DIV) { @Override protected double doDouble(double a, double b) { return a / b; } }) .add( - new DoubleNumericOp(Ops.MOD) { + new DoubleNumericOp(Maps.MOD) { @Override protected double doDouble(double a, double b) { return a % b; } }) .add( - new DoubleBooleanOp(Ops.LT) { + new DoubleBooleanOp(Maps.LT) { @Override protected boolean doDouble(double a, double b) { return a < b; } }) .add( - new DoubleBooleanOp(Ops.LTE) { + new DoubleBooleanOp(Maps.LTE) { @Override protected boolean doDouble(double a, double b) { return a <= b; } }) .add( - new DoubleBooleanOp(Ops.EQ) { + new DoubleBooleanOp(Maps.EQ) { @Override protected boolean doDouble(double a, double b) { return a == b; @@ -220,21 +225,21 @@ protected boolean doObject(double a, Object o) { } }) .add( - new DoubleBooleanOp(Ops.GT) { + new DoubleBooleanOp(Maps.GT) { @Override protected boolean doDouble(double a, double b) { return a > b; } }) .add( - new DoubleBooleanOp(Ops.GTE) { + new DoubleBooleanOp(Maps.GTE) { @Override protected boolean doDouble(double a, double b) { return a >= b; } }) .add( - new UnaryMapOperation<>(Ops.IS_MISSING) { + new UnaryMapOperation<>(Maps.IS_MISSING) { @Override public Storage run(DoubleStorage storage) { return new BoolStorage(storage.isMissing, new BitSet(), storage.size, false); diff --git a/table/src/main/java/org/enso/table/data/column/storage/LongStorage.java b/table/src/main/java/org/enso/table/data/column/storage/LongStorage.java index 2da6698d6a9d..3a573bd115b0 100644 --- a/table/src/main/java/org/enso/table/data/column/storage/LongStorage.java +++ b/table/src/main/java/org/enso/table/data/column/storage/LongStorage.java @@ -1,7 +1,16 @@ package org.enso.table.data.column.storage; +import java.util.Arrays; import java.util.BitSet; +import java.util.OptionalDouble; +import java.util.OptionalLong; +import java.util.stream.DoubleStream; +import java.util.stream.LongStream; + import org.enso.table.data.column.builder.object.NumericBuilder; +import org.enso.table.data.column.operation.aggregate.Aggregator; +import org.enso.table.data.column.operation.aggregate.numeric.LongToLongAggregator; +import org.enso.table.data.column.operation.aggregate.numeric.NumericAggregator; import org.enso.table.data.column.operation.map.MapOpStorage; import org.enso.table.data.column.operation.map.UnaryMapOperation; import org.enso.table.data.column.operation.map.numeric.LongBooleanOp; @@ -9,7 +18,7 @@ import org.enso.table.data.index.Index; /** A column storing 64-bit integers. */ -public class LongStorage extends Storage { +public class LongStorage extends NumericStorage { private final long[] data; private final BitSet isMissing; private final int size; @@ -27,6 +36,10 @@ public LongStorage(long[] data, int size, BitSet isMissing) { this.size = size; } + public LongStorage(long[] data) { + this(data, data.length, new BitSet()); + } + /** @inheritDoc */ @Override public int size() { @@ -47,6 +60,11 @@ public long getItem(int idx) { return data[idx]; } + @Override + public double getItemDouble(int idx) { + return getItem(idx); + } + @Override public Object getItemBoxed(int idx) { return isMissing.get(idx) ? null : data[idx]; @@ -79,6 +97,50 @@ protected Storage runVectorizedZip(String name, Storage argument) { return ops.runZip(name, this, argument); } + @Override + protected Aggregator getVectorizedAggregator(String name, int resultSize) { + switch (name) { + case Aggregators.SUM: + return new LongToLongAggregator(this, resultSize) { + @Override + protected void runGroup(LongStream items) { + long[] elements = items.toArray(); + if (elements.length == 0) { + submitMissing(); + } else { + submit(LongStream.of(elements).sum()); + } + } + }; + case Aggregators.MAX: + return new LongToLongAggregator(this, resultSize) { + @Override + protected void runGroup(LongStream items) { + OptionalLong r = items.max(); + if (r.isPresent()) { + submit(r.getAsLong()); + } else { + submitMissing(); + } + } + }; + case Aggregators.MIN: + return new LongToLongAggregator(this, resultSize) { + @Override + protected void runGroup(LongStream items) { + OptionalLong r = items.min(); + if (r.isPresent()) { + submit(r.getAsLong()); + } else { + submitMissing(); + } + } + }; + default: + return super.getVectorizedAggregator(name, resultSize); + } + } + private Storage fillMissingDouble(double arg) { final var builder = NumericBuilder.createDoubleBuilder(size()); long rawArg = Double.doubleToRawLongBits(arg); @@ -172,7 +234,7 @@ public BitSet getIsMissing() { private static MapOpStorage buildOps() { MapOpStorage ops = new MapOpStorage<>(); ops.add( - new LongNumericOp(Ops.ADD) { + new LongNumericOp(Maps.ADD) { @Override public double doDouble(long in, double arg) { return in + arg; @@ -184,7 +246,7 @@ public long doLong(long in, long arg) { } }) .add( - new LongNumericOp(Ops.SUB) { + new LongNumericOp(Maps.SUB) { @Override public double doDouble(long in, double arg) { return in - arg; @@ -196,7 +258,7 @@ public long doLong(long in, long arg) { } }) .add( - new LongNumericOp(Ops.MUL) { + new LongNumericOp(Maps.MUL) { @Override public double doDouble(long in, double arg) { return in * arg; @@ -208,7 +270,7 @@ public long doLong(long in, long arg) { } }) .add( - new LongNumericOp(Ops.MOD) { + new LongNumericOp(Maps.MOD) { @Override public double doDouble(long in, double arg) { return in % arg; @@ -220,7 +282,7 @@ public long doLong(long in, long arg) { } }) .add( - new LongNumericOp(Ops.DIV, true) { + new LongNumericOp(Maps.DIV, true) { @Override public double doDouble(long in, double arg) { return in / arg; @@ -232,7 +294,7 @@ public long doLong(long in, long arg) { } }) .add( - new LongBooleanOp(Ops.GT) { + new LongBooleanOp(Maps.GT) { @Override protected boolean doLong(long a, long b) { return a > b; @@ -244,7 +306,7 @@ protected boolean doDouble(long a, double b) { } }) .add( - new LongBooleanOp(Ops.GTE) { + new LongBooleanOp(Maps.GTE) { @Override protected boolean doLong(long a, long b) { return a >= b; @@ -256,7 +318,7 @@ protected boolean doDouble(long a, double b) { } }) .add( - new LongBooleanOp(Ops.LT) { + new LongBooleanOp(Maps.LT) { @Override protected boolean doLong(long a, long b) { return a < b; @@ -268,7 +330,7 @@ protected boolean doDouble(long a, double b) { } }) .add( - new LongBooleanOp(Ops.LTE) { + new LongBooleanOp(Maps.LTE) { @Override protected boolean doLong(long a, long b) { return a <= b; @@ -280,7 +342,7 @@ protected boolean doDouble(long a, double b) { } }) .add( - new LongBooleanOp(Ops.EQ) { + new LongBooleanOp(Maps.EQ) { @Override protected boolean doLong(long a, long b) { return a == b; @@ -297,7 +359,7 @@ protected boolean doObject(long x, Object o) { } }) .add( - new UnaryMapOperation<>(Ops.IS_MISSING) { + new UnaryMapOperation<>(Maps.IS_MISSING) { @Override public Storage run(LongStorage storage) { return new BoolStorage(storage.isMissing, new BitSet(), storage.size, false); diff --git a/table/src/main/java/org/enso/table/data/column/storage/NumericStorage.java b/table/src/main/java/org/enso/table/data/column/storage/NumericStorage.java new file mode 100644 index 000000000000..57bef25a4e7d --- /dev/null +++ b/table/src/main/java/org/enso/table/data/column/storage/NumericStorage.java @@ -0,0 +1,59 @@ +package org.enso.table.data.column.storage; + +import org.enso.table.data.column.operation.aggregate.Aggregator; +import org.enso.table.data.column.operation.aggregate.numeric.NumericAggregator; + +import java.util.stream.DoubleStream; + +/** A storage containing items representable as a {@code double}. */ +public abstract class NumericStorage extends Storage { + /** + * Returns the value stored at the given index. The return value if the given index is missing + * ({@link #isNa(long)}) is undefined. + * + * @param idx the index to look up + * @return the value associated with {@code idx} + */ + public abstract double getItemDouble(int idx); + + @Override + protected Aggregator getVectorizedAggregator(String name, int resultSize) { + switch (name) { + case Aggregators.MAX: + return new NumericAggregator(this, resultSize) { + @Override + protected void runGroup(DoubleStream elements) { + submit(elements.max()); + } + }; + case Aggregators.MIN: + return new NumericAggregator(this, resultSize) { + @Override + protected void runGroup(DoubleStream elements) { + submit(elements.min()); + } + }; + case Aggregators.SUM: + return new NumericAggregator(this, resultSize) { + @Override + protected void runGroup(DoubleStream elements) { + double[] its = elements.toArray(); + if (its.length == 0) { + submitMissing(); + } else { + submit(DoubleStream.of(its).sum()); + } + } + }; + case Aggregators.MEAN: + return new NumericAggregator(this, resultSize) { + @Override + protected void runGroup(DoubleStream elements) { + submit(elements.average()); + } + }; + default: + return super.getVectorizedAggregator(name, resultSize); + } + } +} diff --git a/table/src/main/java/org/enso/table/data/column/storage/ObjectStorage.java b/table/src/main/java/org/enso/table/data/column/storage/ObjectStorage.java index cda37cd9d302..2048ce165be8 100644 --- a/table/src/main/java/org/enso/table/data/column/storage/ObjectStorage.java +++ b/table/src/main/java/org/enso/table/data/column/storage/ObjectStorage.java @@ -1,6 +1,7 @@ package org.enso.table.data.column.storage; import java.util.BitSet; + import org.enso.table.data.column.operation.map.MapOpStorage; import org.enso.table.data.column.operation.map.UnaryMapOperation; import org.enso.table.data.index.Index; @@ -122,7 +123,7 @@ public Object[] getData() { private static MapOpStorage buildOps() { MapOpStorage ops = new MapOpStorage<>(); ops.add( - new UnaryMapOperation<>(Ops.IS_MISSING) { + new UnaryMapOperation<>(Maps.IS_MISSING) { @Override protected Storage run(ObjectStorage storage) { BitSet r = new BitSet(); diff --git a/table/src/main/java/org/enso/table/data/column/storage/Storage.java b/table/src/main/java/org/enso/table/data/column/storage/Storage.java index 06a14d1c2e4b..152c9192c5cd 100644 --- a/table/src/main/java/org/enso/table/data/column/storage/Storage.java +++ b/table/src/main/java/org/enso/table/data/column/storage/Storage.java @@ -1,6 +1,13 @@ package org.enso.table.data.column.storage; +import org.enso.table.data.column.builder.object.Builder; +import org.enso.table.data.column.builder.object.InferredBuilder; +import org.enso.table.data.column.operation.aggregate.Aggregator; +import org.enso.table.data.column.operation.aggregate.CountAggregator; +import org.enso.table.data.column.operation.aggregate.FunctionAggregator; + import java.util.BitSet; +import java.util.List; import java.util.function.BiFunction; import java.util.function.Function; import org.enso.table.data.column.builder.object.Builder; @@ -50,7 +57,7 @@ public static final class Type { } /** A container for names of vectorizable operation. */ - public static final class Ops { + public static final class Maps { public static final String EQ = "=="; public static final String LT = "<"; public static final String LTE = "<="; @@ -70,6 +77,14 @@ public static final class Ops { public static final String CONTAINS = "contains"; } + public static final class Aggregators { + public static final String SUM = "sum"; + public static final String MEAN = "mean"; + public static final String MAX = "max"; + public static final String MIN = "min"; + public static final String COUNT = "count"; + } + protected abstract boolean isOpVectorized(String name); protected abstract Storage runVectorizedMap(String name, Object argument); @@ -102,6 +117,36 @@ public final Storage bimap( return builder.seal(); } + protected Aggregator getVectorizedAggregator(String name, int resultSize) { + if (name.equals(Aggregators.COUNT)) { + return new CountAggregator(this, resultSize); + } + return null; + } + + /** + * Returns an aggregator created based on the provided parameters. + * + * @param name name of a vectorized operation that can be used if possible. If null is passed, + * this parameter is unused. + * @param fallback the function to use if a vectorized operation is not available. + * @param skipNa whether missing values should be passed to the {@code fallback} function. + * @param resultSize the number of times the {@link Aggregator#nextGroup(List)} method will be + * called. + * @return an aggregator satisfying the above properties. + */ + public final Aggregator getAggregator( + String name, Function, Object> fallback, boolean skipNa, int resultSize) { + Aggregator result = null; + if (name != null) { + result = getVectorizedAggregator(name, resultSize); + } + if (result == null) { + result = new FunctionAggregator(fallback, this, skipNa, resultSize); + } + return result; + } + /** * Runs a function on each non-missing element in this storage and gathers the results. * diff --git a/table/src/main/java/org/enso/table/data/column/storage/StringStorage.java b/table/src/main/java/org/enso/table/data/column/storage/StringStorage.java index ae98314e839b..9436d29a2a34 100644 --- a/table/src/main/java/org/enso/table/data/column/storage/StringStorage.java +++ b/table/src/main/java/org/enso/table/data/column/storage/StringStorage.java @@ -78,7 +78,7 @@ public StringStorage countMask(int[] counts, int total) { private static MapOpStorage buildOps() { MapOpStorage t = ObjectStorage.ops.makeChild(); t.add( - new MapOperation<>(Ops.EQ) { + new MapOperation<>(Maps.EQ) { @Override public Storage runMap(StringStorage storage, Object arg) { BitSet r = new BitSet(); @@ -108,21 +108,21 @@ public Storage runZip(StringStorage storage, Storage arg) { } }); t.add( - new StringBooleanOp(Ops.STARTS_WITH) { + new StringBooleanOp(Maps.STARTS_WITH) { @Override protected boolean doString(String a, String b) { return a.startsWith(b); } }); t.add( - new StringBooleanOp(Ops.ENDS_WITH) { + new StringBooleanOp(Maps.ENDS_WITH) { @Override protected boolean doString(String a, String b) { return a.endsWith(b); } }); t.add( - new StringBooleanOp(Ops.CONTAINS) { + new StringBooleanOp(Maps.CONTAINS) { @Override protected boolean doString(String a, String b) { return a.contains(b); diff --git a/table/src/main/java/org/enso/table/data/index/DefaultIndex.java b/table/src/main/java/org/enso/table/data/index/DefaultIndex.java index 3ea35fed0cb4..e507ea4c2517 100644 --- a/table/src/main/java/org/enso/table/data/index/DefaultIndex.java +++ b/table/src/main/java/org/enso/table/data/index/DefaultIndex.java @@ -50,4 +50,14 @@ public Index mask(BitSet mask, int cardinality) { public Index countMask(int[] counts, int total) { return new DefaultIndex(total); } + + @Override + public Index unique() { + return this; + } + + @Override + public int size() { + return size; + } } diff --git a/table/src/main/java/org/enso/table/data/index/HashIndex.java b/table/src/main/java/org/enso/table/data/index/HashIndex.java index 7ccf43f87fb2..526f8d379405 100644 --- a/table/src/main/java/org/enso/table/data/index/HashIndex.java +++ b/table/src/main/java/org/enso/table/data/index/HashIndex.java @@ -1,46 +1,37 @@ package org.enso.table.data.index; import org.enso.table.data.column.storage.Storage; -import org.enso.table.data.column.storage.StringStorage; - import java.util.*; -import java.util.stream.Collectors; public class HashIndex extends Index { - private final Object[] items; + private final Storage items; private final Map> locs; private final String name; - private final int size; + private Index uniqueIndex = null; - private HashIndex(Object[] items, Map> locs, String name, int size) { + private HashIndex(Storage items, Map> locs, String name) { this.items = items; this.locs = locs; this.name = name; - this.size = size; } - private HashIndex(String name, Object[] items, int size) { + private HashIndex(String name, Storage items, int size) { Map> locations = new HashMap<>(); for (int i = 0; i < size; i++) { - List its = locations.computeIfAbsent(items[i], x -> new ArrayList<>()); + List its = locations.computeIfAbsent(items.getItemBoxed(i), x -> new ArrayList<>()); its.add(i); } this.locs = locations; this.items = items; this.name = name; - this.size = size; } public static HashIndex fromStorage(String name, Storage storage) { - Object[] data = new Object[(int) storage.size()]; - for (int i = 0; i < storage.size(); i++) { - data[i] = storage.getItemBoxed(i); - } - return new HashIndex(name, data, (int) storage.size()); + return new HashIndex(name, storage, storage.size()); } public Object iloc(int i) { - return items[i]; + return items.getItemBoxed(i); } @Override @@ -60,33 +51,32 @@ public String getName() { @Override public Index mask(BitSet mask, int cardinality) { - Map> newLocs = new HashMap<>(); - for (Map.Entry> entry : locs.entrySet()) { - List newIxes = - entry.getValue().stream().filter(mask::get).collect(Collectors.toList()); - if (!newIxes.isEmpty()) { - newLocs.put(entry.getKey(), newIxes); - } - } - Object[] newItems = new Object[cardinality]; - int j = 0; - for (int i = 0; i < size; i++) { - if (mask.get(i)) { - newItems[j++] = items[i]; - } - } - return new HashIndex(newItems, newLocs, name, cardinality); + Storage newSt = items.mask(mask, cardinality); + return HashIndex.fromStorage(name, newSt); } @Override public Index countMask(int[] counts, int total) { - Object[] newItems = new Object[total]; - int pos = 0; - for (int i = 0; i < size; i++) { - for (int j = 0; j < counts[i]; j++) { - newItems[pos++] = items[i]; + Storage newSt = items.countMask(counts, total); + return HashIndex.fromStorage(name, newSt); + } + + @Override + public Index unique() { + HashMap> newLocs = new HashMap<>(); + BitSet mask = new BitSet(); + for (int i = 0; i < items.size(); i++) { + if (!newLocs.containsKey(items.getItemBoxed(i))) { + newLocs.put(items.getItemBoxed(i), Collections.singletonList(i)); + mask.set(i); } } - return new HashIndex(name, newItems, total); + Storage newItems = items.mask(mask, locs.size()); + return new HashIndex(newItems, newLocs, name); + } + + @Override + public int size() { + return items.size(); } } diff --git a/table/src/main/java/org/enso/table/data/index/Index.java b/table/src/main/java/org/enso/table/data/index/Index.java index cff745da9189..ac44d79018a5 100644 --- a/table/src/main/java/org/enso/table/data/index/Index.java +++ b/table/src/main/java/org/enso/table/data/index/Index.java @@ -32,6 +32,13 @@ public abstract class Index { */ public abstract List loc(Object item); + /** + * Builds an index containing the same values as this one, but with only one occurrence of each. + * + * @return a unique index obtained from this one. + */ + public abstract Index unique(); + /** @return the name of this index */ public abstract String getName(); @@ -55,4 +62,7 @@ public abstract class Index { * @return the index masked according to the specified rules */ public abstract Index countMask(int[] counts, int total); + + /** @return the number of elements in this index. */ + public abstract int size(); } diff --git a/table/src/main/java/org/enso/table/data/table/Column.java b/table/src/main/java/org/enso/table/data/table/Column.java index 37a5ae5e15c9..ba077188367a 100644 --- a/table/src/main/java/org/enso/table/data/table/Column.java +++ b/table/src/main/java/org/enso/table/data/table/Column.java @@ -27,6 +27,15 @@ public Column(String name, Index index, Storage storage) { this.index = index; } + /** + * Converts this column to a single-column table. + * + * @return a table containing only this column + */ + public Table toTable() { + return new Table(new Column[] {this}, index); + } + /** @return the column name */ public String getName() { return name; diff --git a/table/src/main/java/org/enso/table/data/table/Table.java b/table/src/main/java/org/enso/table/data/table/Table.java index 13909c4614a8..50c34abe7516 100644 --- a/table/src/main/java/org/enso/table/data/table/Table.java +++ b/table/src/main/java/org/enso/table/data/table/Table.java @@ -11,6 +11,7 @@ import org.enso.table.data.index.DefaultIndex; import org.enso.table.data.index.HashIndex; import org.enso.table.data.index.Index; +import org.enso.table.data.table.aggregate.AggregateTable; import org.enso.table.error.NoSuchColumnException; import org.enso.table.error.UnexpectedColumnTypeException; @@ -186,6 +187,10 @@ public Table selectColumns(List colNames) { */ @SuppressWarnings("unchecked") public Table join(Table other, boolean dropUnmatched, String on, String lsuffix, String rsuffix) { + if (other.index == index) { + // The tables have exactly the same indexes, so they may be just be concatenated horizontally + return hconcat(other, lsuffix, rsuffix); + } int s = (int) nrows(); List[] matches = new List[s]; if (on == null) { @@ -249,4 +254,30 @@ public Table join(Table other, boolean dropUnmatched, String on, String lsuffix, private String suffixIfNecessary(Set names, String name, String suffix) { return names.contains(name) ? name + suffix : name; } + + private Table hconcat(Table other, String lsuffix, String rsuffix) { + Column[] newColumns = new Column[this.columns.length + other.columns.length]; + Set lnames = + Arrays.stream(this.columns).map(Column::getName).collect(Collectors.toSet()); + Set rnames = + Arrays.stream(other.columns).map(Column::getName).collect(Collectors.toSet()); + for (int i = 0; i < columns.length; i++) { + Column original = columns[i]; + newColumns[i] = + new Column( + suffixIfNecessary(rnames, original.getName(), lsuffix), index, original.getStorage()); + } + for (int i = 0; i < other.columns.length; i++) { + Column original = other.columns[i]; + newColumns[i + columns.length] = + new Column( + suffixIfNecessary(lnames, original.getName(), rsuffix), index, original.getStorage()); + } + return new Table(newColumns, index); + } + + public AggregateTable group(String by) { + Table t = by == null ? this : indexFromColumn(by); + return new AggregateTable(t); + } } diff --git a/table/src/main/java/org/enso/table/data/table/aggregate/AggregateColumn.java b/table/src/main/java/org/enso/table/data/table/aggregate/AggregateColumn.java new file mode 100644 index 000000000000..622f367da1b9 --- /dev/null +++ b/table/src/main/java/org/enso/table/data/table/aggregate/AggregateColumn.java @@ -0,0 +1,58 @@ +package org.enso.table.data.table.aggregate; + +import org.enso.table.data.column.builder.object.InferredBuilder; +import org.enso.table.data.column.operation.aggregate.Aggregator; +import org.enso.table.data.column.storage.Storage; +import org.enso.table.data.index.Index; +import org.enso.table.data.table.Column; + +import java.util.List; +import java.util.function.Function; + +/** A column wrapper used for aggregation operations. */ +public class AggregateColumn { + private final Index uniqueIndex; + private final Column column; + + /** + * Creates a new column + * + * @param uniqueIndex the unique index obtained from the column's index + * @param column the wrapped column + */ + public AggregateColumn(Index uniqueIndex, Column column) { + this.uniqueIndex = uniqueIndex; + this.column = column; + } + + /** + * Aggregates the groups using a given aggregation operation. + * + * @param aggName name of a vectorized operation that can be used if possible. If null is passed, + * this parameter is unused. + * @param outSuffix a string appended to the name of the resulting column. + * @param aggregatorFunction the function to use if a vectorized operation is not available. + * @param skipNa whether missing values should be passed to the {@code fallback} function. + * @return a column indexed by the unique index of this aggregate, storing results of applying the + * specified operation. + */ + public Column aggregate( + String aggName, + String outSuffix, + Function, Object> aggregatorFunction, + boolean skipNa) { + Aggregator aggregator = + column.getStorage().getAggregator(aggName, aggregatorFunction, skipNa, uniqueIndex.size()); + + for (int i = 0; i < uniqueIndex.size(); i++) { + List ixes = column.getIndex().loc(uniqueIndex.iloc(i)); + aggregator.nextGroup(ixes); + } + return new Column(column.getName() + outSuffix, uniqueIndex, aggregator.seal()); + } + + /** @return the underlying (ungroupped) column. */ + public Column getColumn() { + return column; + } +} diff --git a/table/src/main/java/org/enso/table/data/table/aggregate/AggregateTable.java b/table/src/main/java/org/enso/table/data/table/aggregate/AggregateTable.java new file mode 100644 index 000000000000..9e9b74633b45 --- /dev/null +++ b/table/src/main/java/org/enso/table/data/table/aggregate/AggregateTable.java @@ -0,0 +1,54 @@ +package org.enso.table.data.table.aggregate; + +import org.enso.table.data.column.storage.LongStorage; +import org.enso.table.data.index.Index; +import org.enso.table.data.table.Column; +import org.enso.table.data.table.Table; + +import java.util.Arrays; +import java.util.List; + +/** Represents a table grouped by a given index. */ +public class AggregateTable { + private final Table table; + private final Index uniqueIndex; + + /** @param table the underlying table */ + public AggregateTable(Table table) { + this.table = table; + this.uniqueIndex = table.getIndex().unique(); + } + + /** @return a column containing group sizes in this aggregate. */ + public Column count() { + long[] counts = new long[uniqueIndex.size()]; + for (int i = 0; i < uniqueIndex.size(); i++) { + List items = table.getIndex().loc(uniqueIndex.iloc(i)); + counts[i] = items == null ? 0 : items.size(); + } + LongStorage storage = new LongStorage(counts); + return new Column("count", uniqueIndex, storage); + } + + /** + * Returns a column with the given name. + * + * @param n the column name + * @return column with the given name or null if does not exist + */ + public AggregateColumn getColumnByName(String n) { + Column c = table.getColumnByName(n); + if (c == null) { + return null; + } else { + return new AggregateColumn(uniqueIndex, c); + } + } + + /** @return Aggregate columns contained in this table. */ + public AggregateColumn[] getColumns() { + return Arrays.stream(table.getColumns()) + .map(c -> new AggregateColumn(uniqueIndex, c)) + .toArray(AggregateColumn[]::new); + } +} diff --git a/test/Table_Tests/src/Table_Spec.enso b/test/Table_Tests/src/Table_Spec.enso index e2e1bc2fc391..d95d88279d39 100644 --- a/test/Table_Tests/src/Table_Spec.enso +++ b/test/Table_Tests/src/Table_Spec.enso @@ -300,3 +300,25 @@ spec = i.at "Column" . to_vector . should_equal ["strs", "ints", "objs"] i.at "Items Count" . to_vector . should_equal [3, 2, 4] i.at "Storage Type" . to_vector . should_equal [Storage.Text, Storage.Integer, Storage.Any] + + describe "Aggregation" <| + name = ['name', ["foo", "bar", "foo", "baz", "foo", "bar", "quux"]] + price = ['price', [0.4, 3.5, Nothing, 6.7, Nothing, 97, Nothing]] + quantity = ['quantity', [10, 20, 30, 40, 50, 60, 70]] + t = Table.new [name, price, quantity] + agg = t.group by='name' + + it "should allow counting group sizes" <| + agg.count.to_vector.should_equal [3, 2, 1, 1] + + it "should allow aggregating columns with basic arithmetic aggregators" <| + agg.at 'price' . mean . to_vector . should_equal [0.4, 50.25, 6.7, Nothing] + agg.at 'price' . min . to_vector . should_equal [0.4, 3.5, 6.7, Nothing] + + it "should allow aggregating with user-defined aggregate functions" <| + median vec = + sorted = vec.sort + if sorted.is_empty then Nothing else sorted.at (sorted.length-1 / 2).floor + agg.at 'quantity' . reduce median . to_vector . should_equal [30, 20, 40, 70] + +