Skip to content

Commit

Permalink
Vector.sort handles incomparable types
Browse files Browse the repository at this point in the history
  • Loading branch information
Akirathan committed Mar 17, 2023
1 parent 1b237e6 commit 31398a0
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 16 deletions.
79 changes: 64 additions & 15 deletions distribution/lib/Standard/Base/0.0.0-dev/src/Data/Vector.enso
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import project.Data.Filter_Condition.Filter_Condition
import project.Data.List.List
import project.Data.Map.Map
import project.Data.Numbers.Integer
import project.Data.Ordering.Ordering
import project.Data.Ordering.Sort_Direction.Sort_Direction
import project.Data.Pair.Pair
import project.Data.Range.Range
Expand All @@ -22,7 +21,11 @@ import project.Math
import project.Nothing.Nothing
import project.Panic.Panic
import project.Random
import project.Warning.Warning

# We have to import also conversion methods, therefore, we import all from the Ordering
# module
from project.Data.Ordering import all
from project.Data.Boolean import Boolean, True, False
from project.Data.Index_Sub_Range import Index_Sub_Range, take_helper, drop_helper

Expand Down Expand Up @@ -842,11 +845,12 @@ type Vector a
- on: A projection from the element type to the value of that element
being sorted on.
- by: A function that compares the result of applying `on` to two
elements, returning an Ordering to compare them.
elements, returning an an `Ordering` if the two elements are comparable
or `Nothing` if they are not. If set to `Nothing` (the default argument),
`Ordering.compare _ _` method will be used.

By default, elements are sorted in ascending order.

By default, elements are sorted in ascending order, using the comparator
acquired from each element. A custom compare function may be passed to
the sort method.

This is a stable sort, meaning that items that compare the same will not
have their order changed by the sorting process.
Expand All @@ -864,6 +868,18 @@ type Vector a
is partially sorted. When the vector is randomly ordered, the
performance is equivalent to a standard mergesort.

? Multiple comparators
Elements with different comparators are incomparable by definition.
This case is handled by first grouping the `self` vector into groups
with the same comparator, recursively sorting these groups, and then
merging them back together. The order of the sorted groups in the
resulting vector is based on the order of elements' comparators in the
`self` vector, with the exception of the group for the default
comparator, which is always the first group.

Additionally, a warning will be attached, explaining that incomparable
values were encountered.

It takes equal advantage of ascending and descending runs in the array,
making it much simpler to merge two or more sorted arrays: simply
concatenate them and sort.
Expand All @@ -877,16 +893,49 @@ type Vector a
Sorting a vector of `Pair`s on the first element, descending.

[Pair 1 2, Pair -1 8].sort Sort_Direction.Descending (_.first)
sort : Sort_Direction -> (Any -> Any) -> (Any -> Any -> Ordering) -> Vector Any ! Incomparable_Values
sort self (order = Sort_Direction.Ascending) (on = x -> x) (by = (Ordering.compare _ _)) =
comp_ascending l r = by (on l) (on r)
comp_descending l r = by (on r) (on l)
compare = if order == Sort_Direction.Ascending then comp_ascending else
comp_descending

new_vec_arr = self.to_array.sort compare
if new_vec_arr.is_error then Error.throw new_vec_arr else
Vector.from_polyglot_array new_vec_arr

> Example
Sorting a vector with elements with different comparators. Values 1
and My_Type have different comparators. 1 will be sorted before My_Type
because it has the default comparator, and warning will be attached to
the resulting vector.

[My_Type.Value 'hello', 1].sort == [1, My_Type.Value 'hello']
sort : Sort_Direction -> (Any -> Any)|Nothing -> (Any -> Any -> (Ordering|Nothing))|Nothing -> Vector Any ! Incomparable_Values
sort self (order = Sort_Direction.Ascending) on=x->x by=Nothing =
comps = self.map (it-> Comparable.from (on it)) . distinct
optimized_case = comps == [Default_Comparator] && by == Nothing
# In optimize_case, forward to Vector.sort_builtin, otherwise split to groups
# based on different comparators, and forward to Array.sort
case optimized_case of
True ->
elems = if on == Nothing then self else self.map it-> on it
elems.sort_builtin order.to_sign
False ->
# Groups of elements with different comparators
groups = comps.distinct.map comp->
self.filter it->
Comparable.from (on it) == comp
# TODO: Runtime.assert groups.reduce 0 acc->it-> acc + it.length == self.length
case groups.length of
# self consists only of elements with the same comparator.
# Forward to Array.sort
1 ->
# The default value of `by` parameter is `Ordering.compare _ _`
by_non_null = if by == Nothing then (Ordering.compare _ _) else by
comp_ascending l r = by_non_null (on l) (on r)
comp_descending l r = by_non_null (on r) (on l)
compare = if order == Sort_Direction.Ascending then comp_ascending else
comp_descending
new_vec_arr = self.to_array.sort compare
if new_vec_arr.is_error then Error.throw new_vec_arr else
Vector.from_polyglot_array new_vec_arr
_ ->
# Recurse on each group, and attach a warning
# TODO: Sort Default_Comparator group as the first one?
sorted_groups = groups.map it-> it.sort order on by
comparators_text = comps.distinct.to_text
Warning.attach ("Different comparators: " + comparators_text) sorted_groups.flatten

## UNSTABLE
Keeps only unique elements within the Vector, removing any duplicates.
Expand Down
2 changes: 1 addition & 1 deletion distribution/lib/Standard/Test/0.0.0-dev/src/Problems.enso
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from Standard.Base import all
from project import Test
import project.Extensions

## Returns values of warnings attached to the value.Nothing
## Returns values of warnings attached to the value.
get_attached_warnings v =
Warning.get_all v . map .value

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.dsl.Fallback;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
Expand All @@ -26,6 +27,7 @@
name = "type_of",
description = "Returns the type of a value.",
autoRegister = false)
@GenerateUncached
public abstract class TypeOfNode extends Node {

public abstract Object execute(@AcceptsError Object value);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
package org.enso.interpreter.node.expression.builtin.ordering;

import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.InteropLibrary;
import com.oracle.truffle.api.interop.InvalidArrayIndexException;
import com.oracle.truffle.api.interop.UnsupportedMessageException;
import com.oracle.truffle.api.library.CachedLibrary;
import com.oracle.truffle.api.nodes.Node;
import java.util.Arrays;
import org.enso.interpreter.dsl.AcceptsError;
import org.enso.interpreter.dsl.BuiltinMethod;
import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode;
import org.enso.interpreter.node.expression.builtin.meta.EqualsNode;
import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode;
import org.enso.interpreter.node.expression.builtin.text.AnyToTextNode;
import org.enso.interpreter.runtime.EnsoContext;
import org.enso.interpreter.runtime.data.Array;
import org.enso.interpreter.runtime.data.ArrayRope;
import org.enso.interpreter.runtime.data.Vector;
import org.enso.interpreter.runtime.data.text.Text;
import org.enso.interpreter.runtime.error.PanicException;
import org.enso.interpreter.runtime.error.Warning;
import org.enso.interpreter.runtime.error.WarningsLibrary;
import org.enso.interpreter.runtime.error.WithWarnings;
import org.enso.interpreter.runtime.util.Collections.ArrayListObj;

/**
* Sorts a vector with elements that have only Default_Comparator, thus, only elements with a
* builtin type, which is the most common scenario for sorting.
*
* TODO: Max number of attached Incomparable values warnings?
* - hardcode or pass via a new parameter to Vector.sort?
*/
@BuiltinMethod(
type = "Vector",
name = "sort_builtin",
description = "Returns a sorted vector."
)
@GenerateUncached
public abstract class SortVectorNode extends Node {
public static SortVectorNode build() {
return SortVectorNodeGen.create();
}

/**
* Sorts a vector with elements that have only Default_Comparator, thus, only builtin types.
*
* @param self Vector that has elements with only Default_Comparator, that are elements with
* builtin types.
* @param ascending -1 for descending, 1 for ascending
* @return A new, sorted vector
*/
public abstract Object execute(@AcceptsError Object self, long ascending);

@Specialization(guards = {
"interop.hasArrayElements(self)"
})
Object sortCached(Object self, long ascending,
@Cached LessThanNode lessThanNode,
@Cached EqualsNode equalsNode,
@Cached HostValueToEnsoNode hostValueToEnsoNode,
@Cached TypeOfNode typeOfNode,
@Cached AnyToTextNode toTextNode,
@CachedLibrary(limit = "10") InteropLibrary interop,
@CachedLibrary(limit = "5") WarningsLibrary warningsLib) {
EnsoContext ctx = EnsoContext.get(this);
Object[] elems;
try {
long size = interop.getArraySize(self);
assert size < Integer.MAX_VALUE;
elems = new Object[(int) size];
for (int i = 0; i < size; i++) {
if (interop.isArrayElementReadable(self, i)) {
elems[i] = hostValueToEnsoNode.execute(
interop.readArrayElement(self, i)
);
} else {
throw new PanicException(
ctx.getBuiltins().error().makeUnsupportedArgumentsError(
new Object[]{self},
"Cannot read array element at index " + i + " of " + self
),
this
);
}
}
} catch (UnsupportedMessageException | InvalidArrayIndexException e) {
throw new IllegalStateException(e);
}
var comparator = new Comparator(lessThanNode, equalsNode, typeOfNode, toTextNode, ascending > 0);
Arrays.sort(elems, comparator);
var vector = Vector.fromArray(new Array(elems));

// Check for the warnings attached from the Comparator
Warning[] currWarns = null;
if (comparator.encounteredWarnings()) {
currWarns = (Warning[]) comparator.getWarnings();
}
if (currWarns != null) {
return WithWarnings.appendTo(vector, new ArrayRope<>(currWarns));
} else {
return vector;
}
}

private int typeOrder(Object object, TypeOfNode typeOfNode) {
var ctx = EnsoContext.get(this);
var builtins = ctx.getBuiltins();
if (isNothing(object, ctx)) {
return 200;
}
var type = typeOfNode.execute(object);
if (type == builtins.number().getNumber()
|| type == builtins.number().getInteger()
|| type == builtins.number().getDecimal()) {
if (object instanceof Double dbl && dbl.isNaN()) {
return 100;
} else {
return 1;
}
}
else if (type == builtins.text()) {
return 2;
}
else if (type == builtins.bool().getType()) {
return 3;
}
else if (type == builtins.date()) {
return 4;
}
else if (type == builtins.dateTime()) {
return 5;
}
else if (type == builtins.duration()) {
return 6;
} else {
throw new IllegalStateException("Unexpected type: " + type);
}
}

private boolean isTrue(Object object) {
return Boolean.TRUE.equals(object);
}

private boolean isNothing(Object object) {
return isNothing(object, EnsoContext.get(this));
}

private boolean isNothing(Object object, EnsoContext ctx) {
return object == ctx.getBuiltins().nothing();
}

private final class Comparator implements java.util.Comparator<Object> {

private final LessThanNode lessThanNode;
private final EqualsNode equalsNode;
private final TypeOfNode typeOfNode;
private final AnyToTextNode toTextNode;
private final boolean ascending;
private final ArrayListObj<Warning> warnings = new ArrayListObj<>();

private Comparator(LessThanNode lessThanNode, EqualsNode equalsNode, TypeOfNode typeOfNode,
AnyToTextNode toTextNode, boolean ascending) {
this.lessThanNode = lessThanNode;
this.equalsNode = equalsNode;
this.typeOfNode = typeOfNode;
this.toTextNode = toTextNode;
this.ascending = ascending;
}

@Override
public int compare(Object x, Object y) {
if (equalsNode.execute(x, y)) {
return 0;
} else {
// Check if x < y
Object xLessThanYRes = lessThanNode.execute(x, y);
if (isNothing(xLessThanYRes)) {
// x and y are incomparable - this can happen if x and y are different types
attachIncomparableValuesWarning(x, y);
return compareTypes(x, y);
} else if (isTrue(xLessThanYRes)) {
return ascending ? -1 : 1;
} else {
// Check if x > y
Object yLessThanXRes = lessThanNode.execute(y, x);
if (isTrue(yLessThanXRes)) {
return ascending ? 1 : -1;
} else {
// yLessThanXRes is either Nothing or False
attachIncomparableValuesWarning(y, x);
return compareTypes(y, x);
}
}
}
}

private int compareTypes(Object x, Object y) {
int res =Integer.compare(
typeOrder(x, typeOfNode),
typeOrder(y, typeOfNode)
);
return ascending ? res : -res;
}

private void attachIncomparableValuesWarning(Object x, Object y) {
var xStr = toTextNode.execute(x).toString();
var yStr = toTextNode.execute(y).toString();
var payload = Text.create("Values " + xStr + " and " + yStr + " are incomparable");
var sortNode = SortVectorNode.this;
var warn = Warning.create(EnsoContext.get(sortNode), payload, sortNode);
warnings.add(warn);
}

private boolean encounteredWarnings() {
return warnings.size() > 0;
}

private Object[] getWarnings() {
Warning[] warns = new Warning[warnings.size()];
warns = warnings.toArray(warns.getClass());
return warns;
}
}
}
Loading

0 comments on commit 31398a0

Please sign in to comment.