From 31398a0404a021f1350be8a3e7f6fa71fa69de65 Mon Sep 17 00:00:00 2001 From: Pavel Marek Date: Fri, 17 Mar 2023 18:45:22 +0100 Subject: [PATCH] Vector.sort handles incomparable types --- .../Base/0.0.0-dev/src/Data/Vector.enso | 79 ++++-- .../Standard/Test/0.0.0-dev/src/Problems.enso | 2 +- .../expression/builtin/meta/TypeOfNode.java | 2 + .../builtin/ordering/SortVectorNode.java | 227 ++++++++++++++++++ .../interpreter/runtime/util/Collections.java | 78 ++++++ 5 files changed, 372 insertions(+), 16 deletions(-) create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java create mode 100644 engine/runtime/src/main/java/org/enso/interpreter/runtime/util/Collections.java diff --git a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Vector.enso b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Vector.enso index f3710dfd8cd91..96e7d0ef99f02 100644 --- a/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Vector.enso +++ b/distribution/lib/Standard/Base/0.0.0-dev/src/Data/Vector.enso @@ -4,7 +4,6 @@ import project.Data.Filter_Condition.Filter_Condition import project.Data.List.List import project.Data.Map.Map import project.Data.Numbers.Integer -import project.Data.Ordering.Ordering import project.Data.Ordering.Sort_Direction.Sort_Direction import project.Data.Pair.Pair import project.Data.Range.Range @@ -22,7 +21,11 @@ import project.Math import project.Nothing.Nothing import project.Panic.Panic import project.Random +import project.Warning.Warning +# We have to import also conversion methods, therefore, we import all from the Ordering +# module +from project.Data.Ordering import all from project.Data.Boolean import Boolean, True, False from project.Data.Index_Sub_Range import Index_Sub_Range, take_helper, drop_helper @@ -842,11 +845,12 @@ type Vector a - on: A projection from the element type to the value of that element being sorted on. - by: A function that compares the result of applying `on` to two - elements, returning an Ordering to compare them. + elements, returning an an `Ordering` if the two elements are comparable + or `Nothing` if they are not. If set to `Nothing` (the default argument), + `Ordering.compare _ _` method will be used. + + By default, elements are sorted in ascending order. - By default, elements are sorted in ascending order, using the comparator - acquired from each element. A custom compare function may be passed to - the sort method. This is a stable sort, meaning that items that compare the same will not have their order changed by the sorting process. @@ -864,6 +868,18 @@ type Vector a is partially sorted. When the vector is randomly ordered, the performance is equivalent to a standard mergesort. + ? Multiple comparators + Elements with different comparators are incomparable by definition. + This case is handled by first grouping the `self` vector into groups + with the same comparator, recursively sorting these groups, and then + merging them back together. The order of the sorted groups in the + resulting vector is based on the order of elements' comparators in the + `self` vector, with the exception of the group for the default + comparator, which is always the first group. + + Additionally, a warning will be attached, explaining that incomparable + values were encountered. + It takes equal advantage of ascending and descending runs in the array, making it much simpler to merge two or more sorted arrays: simply concatenate them and sort. @@ -877,16 +893,49 @@ type Vector a Sorting a vector of `Pair`s on the first element, descending. [Pair 1 2, Pair -1 8].sort Sort_Direction.Descending (_.first) - sort : Sort_Direction -> (Any -> Any) -> (Any -> Any -> Ordering) -> Vector Any ! Incomparable_Values - sort self (order = Sort_Direction.Ascending) (on = x -> x) (by = (Ordering.compare _ _)) = - comp_ascending l r = by (on l) (on r) - comp_descending l r = by (on r) (on l) - compare = if order == Sort_Direction.Ascending then comp_ascending else - comp_descending - - new_vec_arr = self.to_array.sort compare - if new_vec_arr.is_error then Error.throw new_vec_arr else - Vector.from_polyglot_array new_vec_arr + + > Example + Sorting a vector with elements with different comparators. Values 1 + and My_Type have different comparators. 1 will be sorted before My_Type + because it has the default comparator, and warning will be attached to + the resulting vector. + + [My_Type.Value 'hello', 1].sort == [1, My_Type.Value 'hello'] + sort : Sort_Direction -> (Any -> Any)|Nothing -> (Any -> Any -> (Ordering|Nothing))|Nothing -> Vector Any ! Incomparable_Values + sort self (order = Sort_Direction.Ascending) on=x->x by=Nothing = + comps = self.map (it-> Comparable.from (on it)) . distinct + optimized_case = comps == [Default_Comparator] && by == Nothing + # In optimize_case, forward to Vector.sort_builtin, otherwise split to groups + # based on different comparators, and forward to Array.sort + case optimized_case of + True -> + elems = if on == Nothing then self else self.map it-> on it + elems.sort_builtin order.to_sign + False -> + # Groups of elements with different comparators + groups = comps.distinct.map comp-> + self.filter it-> + Comparable.from (on it) == comp + # TODO: Runtime.assert groups.reduce 0 acc->it-> acc + it.length == self.length + case groups.length of + # self consists only of elements with the same comparator. + # Forward to Array.sort + 1 -> + # The default value of `by` parameter is `Ordering.compare _ _` + by_non_null = if by == Nothing then (Ordering.compare _ _) else by + comp_ascending l r = by_non_null (on l) (on r) + comp_descending l r = by_non_null (on r) (on l) + compare = if order == Sort_Direction.Ascending then comp_ascending else + comp_descending + new_vec_arr = self.to_array.sort compare + if new_vec_arr.is_error then Error.throw new_vec_arr else + Vector.from_polyglot_array new_vec_arr + _ -> + # Recurse on each group, and attach a warning + # TODO: Sort Default_Comparator group as the first one? + sorted_groups = groups.map it-> it.sort order on by + comparators_text = comps.distinct.to_text + Warning.attach ("Different comparators: " + comparators_text) sorted_groups.flatten ## UNSTABLE Keeps only unique elements within the Vector, removing any duplicates. diff --git a/distribution/lib/Standard/Test/0.0.0-dev/src/Problems.enso b/distribution/lib/Standard/Test/0.0.0-dev/src/Problems.enso index a5aff8e873d8b..375a2b99c2fe4 100644 --- a/distribution/lib/Standard/Test/0.0.0-dev/src/Problems.enso +++ b/distribution/lib/Standard/Test/0.0.0-dev/src/Problems.enso @@ -3,7 +3,7 @@ from Standard.Base import all from project import Test import project.Extensions -## Returns values of warnings attached to the value.Nothing +## Returns values of warnings attached to the value. get_attached_warnings v = Warning.get_all v . map .value diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/TypeOfNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/TypeOfNode.java index f8c2c8077b7ae..3371fa1fefaab 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/TypeOfNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/TypeOfNode.java @@ -2,6 +2,7 @@ import com.oracle.truffle.api.CompilerDirectives; import com.oracle.truffle.api.dsl.Fallback; +import com.oracle.truffle.api.dsl.GenerateUncached; import com.oracle.truffle.api.dsl.Specialization; import com.oracle.truffle.api.interop.InteropLibrary; import com.oracle.truffle.api.interop.UnsupportedMessageException; @@ -26,6 +27,7 @@ name = "type_of", description = "Returns the type of a value.", autoRegister = false) +@GenerateUncached public abstract class TypeOfNode extends Node { public abstract Object execute(@AcceptsError Object value); diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java new file mode 100644 index 0000000000000..525616482a13c --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java @@ -0,0 +1,227 @@ +package org.enso.interpreter.node.expression.builtin.ordering; + +import com.oracle.truffle.api.dsl.Cached; +import com.oracle.truffle.api.dsl.GenerateUncached; +import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.interop.InteropLibrary; +import com.oracle.truffle.api.interop.InvalidArrayIndexException; +import com.oracle.truffle.api.interop.UnsupportedMessageException; +import com.oracle.truffle.api.library.CachedLibrary; +import com.oracle.truffle.api.nodes.Node; +import java.util.Arrays; +import org.enso.interpreter.dsl.AcceptsError; +import org.enso.interpreter.dsl.BuiltinMethod; +import org.enso.interpreter.node.expression.builtin.interop.syntax.HostValueToEnsoNode; +import org.enso.interpreter.node.expression.builtin.meta.EqualsNode; +import org.enso.interpreter.node.expression.builtin.meta.TypeOfNode; +import org.enso.interpreter.node.expression.builtin.text.AnyToTextNode; +import org.enso.interpreter.runtime.EnsoContext; +import org.enso.interpreter.runtime.data.Array; +import org.enso.interpreter.runtime.data.ArrayRope; +import org.enso.interpreter.runtime.data.Vector; +import org.enso.interpreter.runtime.data.text.Text; +import org.enso.interpreter.runtime.error.PanicException; +import org.enso.interpreter.runtime.error.Warning; +import org.enso.interpreter.runtime.error.WarningsLibrary; +import org.enso.interpreter.runtime.error.WithWarnings; +import org.enso.interpreter.runtime.util.Collections.ArrayListObj; + +/** + * Sorts a vector with elements that have only Default_Comparator, thus, only elements with a + * builtin type, which is the most common scenario for sorting. + * + * TODO: Max number of attached Incomparable values warnings? + * - hardcode or pass via a new parameter to Vector.sort? + */ +@BuiltinMethod( + type = "Vector", + name = "sort_builtin", + description = "Returns a sorted vector." +) +@GenerateUncached +public abstract class SortVectorNode extends Node { + public static SortVectorNode build() { + return SortVectorNodeGen.create(); + } + + /** + * Sorts a vector with elements that have only Default_Comparator, thus, only builtin types. + * + * @param self Vector that has elements with only Default_Comparator, that are elements with + * builtin types. + * @param ascending -1 for descending, 1 for ascending + * @return A new, sorted vector + */ + public abstract Object execute(@AcceptsError Object self, long ascending); + + @Specialization(guards = { + "interop.hasArrayElements(self)" + }) + Object sortCached(Object self, long ascending, + @Cached LessThanNode lessThanNode, + @Cached EqualsNode equalsNode, + @Cached HostValueToEnsoNode hostValueToEnsoNode, + @Cached TypeOfNode typeOfNode, + @Cached AnyToTextNode toTextNode, + @CachedLibrary(limit = "10") InteropLibrary interop, + @CachedLibrary(limit = "5") WarningsLibrary warningsLib) { + EnsoContext ctx = EnsoContext.get(this); + Object[] elems; + try { + long size = interop.getArraySize(self); + assert size < Integer.MAX_VALUE; + elems = new Object[(int) size]; + for (int i = 0; i < size; i++) { + if (interop.isArrayElementReadable(self, i)) { + elems[i] = hostValueToEnsoNode.execute( + interop.readArrayElement(self, i) + ); + } else { + throw new PanicException( + ctx.getBuiltins().error().makeUnsupportedArgumentsError( + new Object[]{self}, + "Cannot read array element at index " + i + " of " + self + ), + this + ); + } + } + } catch (UnsupportedMessageException | InvalidArrayIndexException e) { + throw new IllegalStateException(e); + } + var comparator = new Comparator(lessThanNode, equalsNode, typeOfNode, toTextNode, ascending > 0); + Arrays.sort(elems, comparator); + var vector = Vector.fromArray(new Array(elems)); + + // Check for the warnings attached from the Comparator + Warning[] currWarns = null; + if (comparator.encounteredWarnings()) { + currWarns = (Warning[]) comparator.getWarnings(); + } + if (currWarns != null) { + return WithWarnings.appendTo(vector, new ArrayRope<>(currWarns)); + } else { + return vector; + } + } + + private int typeOrder(Object object, TypeOfNode typeOfNode) { + var ctx = EnsoContext.get(this); + var builtins = ctx.getBuiltins(); + if (isNothing(object, ctx)) { + return 200; + } + var type = typeOfNode.execute(object); + if (type == builtins.number().getNumber() + || type == builtins.number().getInteger() + || type == builtins.number().getDecimal()) { + if (object instanceof Double dbl && dbl.isNaN()) { + return 100; + } else { + return 1; + } + } + else if (type == builtins.text()) { + return 2; + } + else if (type == builtins.bool().getType()) { + return 3; + } + else if (type == builtins.date()) { + return 4; + } + else if (type == builtins.dateTime()) { + return 5; + } + else if (type == builtins.duration()) { + return 6; + } else { + throw new IllegalStateException("Unexpected type: " + type); + } + } + + private boolean isTrue(Object object) { + return Boolean.TRUE.equals(object); + } + + private boolean isNothing(Object object) { + return isNothing(object, EnsoContext.get(this)); + } + + private boolean isNothing(Object object, EnsoContext ctx) { + return object == ctx.getBuiltins().nothing(); + } + + private final class Comparator implements java.util.Comparator { + + private final LessThanNode lessThanNode; + private final EqualsNode equalsNode; + private final TypeOfNode typeOfNode; + private final AnyToTextNode toTextNode; + private final boolean ascending; + private final ArrayListObj warnings = new ArrayListObj<>(); + + private Comparator(LessThanNode lessThanNode, EqualsNode equalsNode, TypeOfNode typeOfNode, + AnyToTextNode toTextNode, boolean ascending) { + this.lessThanNode = lessThanNode; + this.equalsNode = equalsNode; + this.typeOfNode = typeOfNode; + this.toTextNode = toTextNode; + this.ascending = ascending; + } + + @Override + public int compare(Object x, Object y) { + if (equalsNode.execute(x, y)) { + return 0; + } else { + // Check if x < y + Object xLessThanYRes = lessThanNode.execute(x, y); + if (isNothing(xLessThanYRes)) { + // x and y are incomparable - this can happen if x and y are different types + attachIncomparableValuesWarning(x, y); + return compareTypes(x, y); + } else if (isTrue(xLessThanYRes)) { + return ascending ? -1 : 1; + } else { + // Check if x > y + Object yLessThanXRes = lessThanNode.execute(y, x); + if (isTrue(yLessThanXRes)) { + return ascending ? 1 : -1; + } else { + // yLessThanXRes is either Nothing or False + attachIncomparableValuesWarning(y, x); + return compareTypes(y, x); + } + } + } + } + + private int compareTypes(Object x, Object y) { + int res =Integer.compare( + typeOrder(x, typeOfNode), + typeOrder(y, typeOfNode) + ); + return ascending ? res : -res; + } + + private void attachIncomparableValuesWarning(Object x, Object y) { + var xStr = toTextNode.execute(x).toString(); + var yStr = toTextNode.execute(y).toString(); + var payload = Text.create("Values " + xStr + " and " + yStr + " are incomparable"); + var sortNode = SortVectorNode.this; + var warn = Warning.create(EnsoContext.get(sortNode), payload, sortNode); + warnings.add(warn); + } + + private boolean encounteredWarnings() { + return warnings.size() > 0; + } + + private Object[] getWarnings() { + Warning[] warns = new Warning[warnings.size()]; + warns = warnings.toArray(warns.getClass()); + return warns; + } + } +} diff --git a/engine/runtime/src/main/java/org/enso/interpreter/runtime/util/Collections.java b/engine/runtime/src/main/java/org/enso/interpreter/runtime/util/Collections.java new file mode 100644 index 0000000000000..0917b46c635e9 --- /dev/null +++ b/engine/runtime/src/main/java/org/enso/interpreter/runtime/util/Collections.java @@ -0,0 +1,78 @@ +package org.enso.interpreter.runtime.util; + +import com.oracle.truffle.api.CompilerDirectives; +import com.oracle.truffle.api.profiles.BranchProfile; +import java.util.Arrays; + +public class Collections { + + /** PE-friendly implementation of ArrayList. */ + public static final class ArrayListObj { + private Object[] data; + private int size; + + public ArrayListObj(int capacity) { + this.data = new Object[capacity]; + } + + public ArrayListObj() { + this(8); + } + + public void add(T value) { + add(value, BranchProfile.getUncached()); + } + + public void add(T value, BranchProfile capacityExceededProfile) { + if (size == data.length) { + capacityExceededProfile.enter(); + data = Arrays.copyOf(data, size * 2); + } + data[size++] = value; + } + + public void set(int index, T value) { + checkIndex(index); + data[index] = value; + } + + @SuppressWarnings("unchecked") + public T get(int index) { + checkIndex(index); + return (T) data[index]; + } + + @SuppressWarnings("unchecked") + public T remove(int index) { + checkIndex(index); + T result = (T) data[index]; + int lastIdx = size - 1; + int toMoveLen = lastIdx - index; + if (toMoveLen > 0) { + System.arraycopy(data, index + 1, data, index, toMoveLen); + } + data[lastIdx] = null; + size--; + return result; + } + + public int size() { + return size; + } + + public Object[] toArray() { + return Arrays.copyOf(data, size); + } + + public T[] toArray(Class newType) { + return Arrays.copyOf(data, size, newType); + } + + private void checkIndex(int index) { + if (!(0 <= index && index < size)) { + CompilerDirectives.transferToInterpreter(); + throw new IndexOutOfBoundsException(index); + } + } + } +}