From 432a87b80383c53d95607a9a8fce7c0a5da91762 Mon Sep 17 00:00:00 2001 From: Pavel Marek Date: Tue, 28 Mar 2023 09:03:18 +0200 Subject: [PATCH] Fix byFunc calling --- .../builtin/ordering/SortVectorNode.java | 44 ++++++++++++++++--- test/Tests/src/Data/Vector_Spec.enso | 6 +-- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java index 564f9dd6d5ab0..4f151bc120cf1 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/SortVectorNode.java @@ -154,7 +154,7 @@ private Object sortPrimitiveVector(Object[] elems, "interop.hasArrayElements(self)", }) Object sortGeneric(State state, Object self, long ascending, Object comparatorsArray, - Object compareFuncsArray, Object byFunc, + Object compareFuncsArray, Object byFuncObj, @CachedLibrary(limit = "10") InteropLibrary interop, @CachedLibrary(limit = "5") WarningsLibrary warningsLib, @Cached LessThanNode lessThanNode, @@ -162,6 +162,7 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA @Cached TypeOfNode typeOfNode, @Cached AnyToTextNode toTextNode, @Cached(value = "build()", uncached = "build()") CallOptimiserNode callNode) { + Function byFunc = checkAndConvertByParameter(byFuncObj); // Split into groups List elems = readInteropArray(interop, warningsLib, self); List comparators = readInteropArray(interop, warningsLib, comparatorsArray); @@ -178,7 +179,7 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA try { for (var group : groups) { Comparator javaComparator; - if (isPrimitiveGroup(group)) { + if (isNothing(byFunc) && isPrimitiveGroup(group)) { javaComparator = new DefaultComparator( lessThanNode, equalsNode, @@ -187,7 +188,7 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA ascending > 0 ); } else { - Function compareFunc = isNothing(byFunc) ? group.compareFunc : (Function) byFunc; + Function compareFunc = isNothing(byFunc) ? group.compareFunc : byFunc; javaComparator = new GenericComparator( ascending > 0, compareFunc, @@ -216,6 +217,31 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA } } + /** + * Checks value given for {@code by} parameter and converts it to {@link Function}. Throw a + * dataflow error otherwise. + */ + private Function checkAndConvertByParameter(Object byFuncObj) { + var ctx = EnsoContext.get(this); + var err = DataflowError.withoutTrace( + ctx.getBuiltins().error().makeUnsupportedArgumentsError( + new Object[]{byFuncObj}, + "Unsupported argument for `by`, expected a method with two arguments" + ), + this + ); + if (byFuncObj instanceof Function byFunc) { + var argCount = byFunc.getSchema().getArgumentsCount(); + if (argCount == 2 || argCount == 3) { + return byFunc; + } else { + throw err; + } + } else { + throw err; + } + } + private List splitByComparators(List elements, List comparators, List compareFuncs) { assert elements.size() == comparators.size(); @@ -576,6 +602,7 @@ private final class GenericComparator extends Comparator { * extracted from the comparator for the appropriate group. */ private final Function compareFunc; + private final boolean compareFuncHasSelf; private final Type comparator; private final CallOptimiserNode callNode; private final State state; @@ -601,14 +628,17 @@ private GenericComparator( this.less = less; this.equal = equal; this.greater = greater; + assert compareFunc.getSchema().getArgumentsCount() + >= 2 : "compareFunc should take more than 2 arguments"; + this.compareFuncHasSelf = compareFunc.getSchema().getArgumentInfos()[0].getName() + .equals("self"); } @Override public int compare(Object x, Object y) { - // We are calling a static method here, so we need to pass the Comparator type as the - // self (first) argument. - Object res = callNode.executeDispatch(compareFunc, null, state, - new Object[]{comparator, x, y}); + // If compareFunc takes self parameter, it is `comparator`. + Object[] args = compareFuncHasSelf ? new Object[]{comparator, x, y} : new Object[]{x, y}; + Object res = callNode.executeDispatch(compareFunc, null, state, args); if (res == less) { return ascending ? -1 : 1; } else if (res == equal) { diff --git a/test/Tests/src/Data/Vector_Spec.enso b/test/Tests/src/Data/Vector_Spec.enso index 10e37ff7d4f45..dc28ea03d583e 100644 --- a/test/Tests/src/Data/Vector_Spec.enso +++ b/test/Tests/src/Data/Vector_Spec.enso @@ -558,10 +558,8 @@ spec = Test.group "Vectors" <| [Time_Of_Day.new 12, Time_Of_Day.new 10 30].sort . should_equal [Time_Of_Day.new 10 30, Time_Of_Day.new 12] [Date_Time.new 2000 12 30 12 30, Date_Time.new 2000 12 30 12 00].sort . should_equal [Date_Time.new 2000 12 30 12 00, Date_Time.new 2000 12 30 12 30] - ["aa", 2].sort . should_fail_with Incomparable_Values - [2, Date.new 1999].sort . should_fail_with Incomparable_Values - [Date_Time.new 1999 1 1 12 30, Date.new 1999].sort . should_fail_with Incomparable_Values - [Date_Time.new 1999 1 1 12 30, Time_Of_Day.new 12 30].sort . should_fail_with Incomparable_Values + ["aa", 2].sort . should_equal [2, "aa"] + [2, Date.new 1999].sort . should_equal [2, Date.new 1999] Test.expect_panic_with ([3,2,1].to_array.sort 42) Type_Error Test.specify "should leave the original vector unchanged" <|