Skip to content

Commit

Permalink
Table: grouping (#1392)
Browse files Browse the repository at this point in the history
  • Loading branch information
kustosz authored Jan 11, 2021
1 parent 7fd1184 commit b751dfb
Show file tree
Hide file tree
Showing 26 changed files with 921 additions and 85 deletions.
2 changes: 1 addition & 1 deletion distribution/std-lib/Base/src/Data/Vector.enso
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ type Vector
not want to sort in place on the original vector, as `sort` is not
intended to be mutable.
new_vec_arr = Array.new this.length
this.to_array.copy 0 new_vec_arr 0 this.length
Array.copy this.to_array 0 new_vec_arr 0 this.length

## As we want to account for both custom projections and custom
comparisons we need to construct a comparator for internal use that
Expand Down
110 changes: 110 additions & 0 deletions distribution/std-lib/Table/src/Data/Column.enso
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,120 @@ type Column
fields = Map.singleton "name" (Json.String name) . insert "data" storage_json
Json.Object fields

## Efficiently joins two tables based on either the index or the specified
key column.

The resulting table contains rows of `this` extended with rows of
`other` with matching indexes. If the index values in `other` are not
unique, the corresponding rows of `this` will be duplicated in the
result.

Arguments:
- other: the table being the right operand of this join operation.
- on: the column of `this` that should be used as the join key. If
this argument is not provided, the index of `this` will be used.
- drop_unmatched: whether the rows of `this` without corresponding
matches in `other` should be dropped from the result.
- left_suffix: a suffix that should be added to the columns of `this`
when there's a name conflict with a column of `other`.
- right_suffix: a suffix that should be added to the columns of `other`
when there's a name conflict with a column of `this`.
join : Table.Table | Column -> Text | Nothing -> Boolean -> Text -> Text -> Table
join other on=Nothing drop_unmatched=False left_suffix='_left' right_suffix='_right' =
this.to_table.join other on drop_unmatched left_suffix right_suffix

## Converts this column into a single-column table.
to_table : Table.Table
to_table = Table.Table (this.java_column.toTable [])

## Creates a new column given a name and a vector of elements.
from_vector : Text -> Vector -> Column
from_vector name items = Column (Java_Column.fromItems [name, items.to_array])

## Wraps a column grouped by its index. Allows performing aggregation operations
on the contained values.
type Aggregate_Column
type Aggregate_Column java_column

## Converts this aggregate column into a column, aggregating groups
with the provided `function`.

Arguments:
- function: the function used for value aggregation. Values belonging
to each group are passed to this function in a vector.
- skip_missing: controls whether missing values should be included
in groups.
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
reduce : (Vector.Vector -> Any) -> Boolean -> Text -> Column
reduce function skip_missing=True name_suffix="_result" =
f arr = function (Vector.Vector arr)
r = this.java_column.aggregate [Nothing, name_suffix, f, skip_missing]
Column r

## Sums the values in each group.

Arguments:
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
sum : Text -> Column
sum name_suffix='_sum' =
r = this.java_column.aggregate ['sum', name_suffix, (x-> Vector.Vector x . reduce (+)), True]
Column r

## Computes the maximum element of each group.

Arguments:
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
max : Text -> Column
max name_suffix='_max' =
r = this.java_column.aggregate ['max', name_suffix, (x-> Vector.Vector x . reduce max), True]
Column r

## Computes the minimum element of each group.

Arguments:
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
min : Text -> Column
min name_suffix='_min' =
r = this.java_column.aggregate ['min', name_suffix, (x-> Vector.Vector x . reduce min), True]
Column r

## Computes the number of non-missing elements in each group.

Arguments:
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
count : Text -> Column
count name_suffix='_count' =
r = this.java_column.aggregate ['count', name_suffix, (x-> x.length), True]
Column r

## Computes the mean of non-missing elements in each group.

Arguments:
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
mean : Text -> Column
mean name_suffix='_mean' =
vec_mean v = if v.length == 0 then Nothing else
(Vector.Vector v).reduce (+) / v.length
r = this.java_column.aggregate ['mean', name_suffix, vec_mean, True]
Column r

## Gathers all elements in a group into a vector and returns a column of
such vectors.

Arguments:
- name_suffix: a suffix that will be appended to the original column
name to generate the resulting column name.
values : Text -> Column
values name_suffix='_values' =
r = this.java_column.aggregate [Nothing, name_suffix, Vector.Vector, False]
Column r

## PRIVATE
run_vectorized_binary_op column name fallback_fn operand = case operand of
Column col2 ->
Expand Down
75 changes: 70 additions & 5 deletions distribution/std-lib/Table/src/Data/Table.enso
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ type Table
select : Vector -> Table
select columns = Table (this.java_table.selectColumns [columns.to_array])

## Efficiently joins two tables based on either the index or a key column.
## Efficiently joins two tables based on either the index or the specified
key column.

The resulting table contains rows of `this` extended with rows of
`other` with matching indexes. If the index in `other` is not unique,
the corresponding rows of `this` will be duplicated in the result.
`other` with matching indexes. If the index values in `other` are not
unique, the corresponding rows of `this` will be duplicated in the
result.

Arguments:
- other: the table being the right operand of this join operation.
Expand All @@ -100,9 +102,12 @@ type Table
when there's a name conflict with a column of `other`.
- right_suffix: a suffix that should be added to the columns of `other`
when there's a name conflict with a column of `this`.
join : Table -> Text | Nothing -> Boolean -> Text -> Text -> Table
join : Table | Column.Column -> Text | Nothing -> Boolean -> Text -> Text -> Table
join other on=Nothing drop_unmatched=False left_suffix='_left' right_suffix='_right' =
Table (this.java_table.join [other.java_table, drop_unmatched, on, left_suffix, right_suffix])
case other of
Column.Column _ -> this.join other.to_table on drop_unmatched left_suffix right_suffix
Table t ->
Table (this.java_table.join [t, drop_unmatched, on, left_suffix, right_suffix])

## Returns a new Table without rows that contained missing values in any of
the columns.
Expand Down Expand Up @@ -136,6 +141,59 @@ type Table
cols = this.columns
here.new [["Column", cols.map name], ["Items Count", cols.map count], ["Storage Type", cols.map storage_type]]

## Returns an aggregate table resulting from grouping the elements by the
value of the specified column.

If the `by` argument is not set, the index is used for grouping instead.

> Example
Creates a simple table and computes aggregation statistics:
name = ['name', ["foo", "bar", "foo", "baz", "foo", "bar", "quux"]]
price = ['price', [0.4, 3.5, Nothing, 6.7, Nothing, 97, Nothing]]
quantity = ['quantity', [10, 20, 30, 40, 50, 60, 70]]
t = Table.new [name, price, quantity]

agg = t.group by='name'

records_num = agg.count
total_quantity = agg.at 'quantity' . sum
mean_price = agg.at 'price' . mean

Table.join [records_num, total_quantity, mean_price]
group : Text | Nothing -> Aggregate_Table
group by=Nothing =
Aggregate_Table (this.java_table.group [by])

## Represents a table with grouped rows.
type Aggregate_Table
type Aggregate_Table java_table

## Returns a vector of aggregate columns in this table.
columns : Vector.Vector
columns = Vector.Vector (this.java_table.getColumns []) . map Column.Aggregate_Column

## Returns a table containing columns resulting from calling `values` on
each column in `this`.
values : Table
values = this.columns . map (_.values name_suffix='') . reduce join

## Returns a column containing the number of elements in each group.
count : Column
count = Column.Column (this.java_table.count [])

## Returns an aggregate column with the given name, contained in this table.
at : Text -> Column | Nothing
at name = case this.java_table.getColumnByName [name] of
Nothing -> Nothing
c -> Column.Aggregate_Column c

## Prints an ASCII-art table with this data to the standard output.

Arguments:
- show_rows: the number of initial rows that should be displayed.
print : Integer -> Nothing
print show_rows=10 = this.values.print show_rows

## PRIVATE
from_columns cols = Table (Java_Table.new [cols.to_array].to_array)

Expand All @@ -149,6 +207,13 @@ new columns =
Column.from_vector (c.at 0) (c.at 1) . java_column
here.from_columns cols

## Joins a vector of tables (or columns) into a single table, using each table's
index as the join key. Particularly useful for joining multiple columns
derived from one original table into a new table.
join : Vector -> Table
join tables =
tables.reduce join

## PRIVATE
pad txt len =
true_len = txt.characters.length
Expand Down
2 changes: 1 addition & 1 deletion distribution/std-lib/Table/src/Main.enso
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Table.Data.Column

from Table.Io.Csv export all hiding Parser
export Table.Data.Column
from Table.Data.Table export new
from Table.Data.Table export new, join

## Converts a JSON array into a dataframe, by looking up the requested keys
from each item.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@

import com.oracle.truffle.api.TruffleLanguage.ContextReference;
import com.oracle.truffle.api.dsl.CachedContext;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.Node;
import org.enso.interpreter.Language;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.runtime.Context;
import org.enso.interpreter.runtime.builtin.Builtins;
import org.enso.interpreter.runtime.data.Array;
import org.enso.interpreter.runtime.error.PanicException;

@BuiltinMethod(type = "Array", name = "copy", description = "Copies one array to another.")
public abstract class CopyNode extends Node {
Expand All @@ -16,18 +23,51 @@ static CopyNode build() {
return CopyNodeGen.create();
}

abstract Object execute(Array _this, long source_index, Array that, long dest_index, long count);
abstract Object execute(
Object _this, Object src, long source_index, Array that, long dest_index, long count);

@Specialization
Object doArray(
Array _this,
Object _this,
Array src,
long source_index,
Array that,
long dest_index,
long count,
@CachedContext(Language.class) ContextReference<Context> ctxRef) {
@CachedContext(Language.class) Context ctx) {
System.arraycopy(
_this.getItems(), (int) source_index, that.getItems(), (int) dest_index, (int) count);
return ctxRef.get().getBuiltins().nothing().newInstance();
src.getItems(), (int) source_index, that.getItems(), (int) dest_index, (int) count);
return ctx.getBuiltins().nothing().newInstance();
}

@Specialization(guards = "arrays.hasArrayElements(src)")
Object doPolyglotArray(
Object _this,
Object src,
long source_index,
Array that,
long dest_index,
long count,
@CachedLibrary(limit = "3") InteropLibrary arrays,
@CachedContext(Language.class) Context ctx) {
try {
for (int i = 0; i < count; i++) {
that.getItems()[(int) dest_index + i] = arrays.readArrayElement(src, source_index + i);
}
} catch (UnsupportedMessageException e) {
throw new IllegalStateException("Unreachable");
} catch (InvalidArrayIndexException e) {
throw new PanicException(
ctx.getBuiltins().error().makeInvalidArrayIndexError(src, e.getInvalidIndex()), this);
}
return ctx.getBuiltins().nothing().newInstance();
}

@Fallback
Object doOther(
Object _this, Object src, long source_index, Array that, long dest_index, long count) {
Builtins builtins = lookupContextReference(Language.class).get().getBuiltins();
throw new PanicException(
builtins.error().makeTypeError(builtins.mutable().array().newInstance(), src), this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public class Error {
private final AtomConstructor polyglotError;
private final AtomConstructor moduleNotInPackageError;
private final AtomConstructor arithmeticError;
private final AtomConstructor invalidArrayIndexError;

private final Atom arithmeticErrorShiftTooBig;

Expand Down Expand Up @@ -68,6 +69,11 @@ public Error(Language language, ModuleScope scope) {
.initializeFields(
new ArgumentDefinition(0, "message", ArgumentDefinition.ExecutionMode.EXECUTE));
arithmeticErrorShiftTooBig = arithmeticError.newInstance(shiftTooBigMessage);
invalidArrayIndexError =
new AtomConstructor("Invalid_Array_Index_Error", scope)
.initializeFields(
new ArgumentDefinition(0, "array", ArgumentDefinition.ExecutionMode.EXECUTE),
new ArgumentDefinition(1, "index", ArgumentDefinition.ExecutionMode.EXECUTE));

scope.registerConstructor(syntaxError);
scope.registerConstructor(compileError);
Expand All @@ -77,6 +83,7 @@ public Error(Language language, ModuleScope scope) {
scope.registerConstructor(polyglotError);
scope.registerConstructor(moduleNotInPackageError);
scope.registerConstructor(arithmeticError);
scope.registerConstructor(invalidArrayIndexError);
}

/** @return the builtin {@code Syntax_Error} atom constructor. */
Expand All @@ -85,7 +92,9 @@ public AtomConstructor syntaxError() {
}

/** @return the builtin {@code Type_Error} atom constructor. */
public AtomConstructor typeError() { return typeError; }
public AtomConstructor typeError() {
return typeError;
}

/** @return the builtin {@code Compile_Error} atom constructor. */
public AtomConstructor compileError() {
Expand Down Expand Up @@ -153,4 +162,8 @@ public Atom makeArithmeticError(Text reason) {
public Atom getShiftAmountTooLargeError() {
return arithmeticErrorShiftTooBig;
}

public Atom makeInvalidArrayIndexError(Object array, Object index) {
return invalidArrayIndexError.newInstance(array, index);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.enso.table.data.column.operation.aggregate;

import org.enso.table.data.column.storage.Storage;

import java.util.List;

/**
* Represents a fold-like operation on a storage. An aggregator is usually created for a given
* storage, then {@link #nextGroup(List)} is repeatedly called and the aggregator is responsible for
* collecting the results of such calls. After that, {@link #seal()} is called to obtain a storage
* containing all the results.
*/
public abstract class Aggregator {
/**
* Requests the aggregator to append the result of aggregating the values at the specified
* positions.
*
* @param positions the positions to aggregate in this round.
*/
public abstract void nextGroup(List<Integer> positions);

/**
* Returns the results of all previous {@link #nextGroup(List)} calls.
*
* @return the storage containing all aggregation results.
*/
public abstract Storage seal();
}
Loading

0 comments on commit b751dfb

Please sign in to comment.