Skip to content

Commit

Permalink
Fix byFunc calling
Browse files Browse the repository at this point in the history
  • Loading branch information
Akirathan committed Mar 28, 2023
1 parent 82ac86b commit 432a87b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,15 @@ private Object sortPrimitiveVector(Object[] elems,
"interop.hasArrayElements(self)",
})
Object sortGeneric(State state, Object self, long ascending, Object comparatorsArray,
Object compareFuncsArray, Object byFunc,
Object compareFuncsArray, Object byFuncObj,
@CachedLibrary(limit = "10") InteropLibrary interop,
@CachedLibrary(limit = "5") WarningsLibrary warningsLib,
@Cached LessThanNode lessThanNode,
@Cached EqualsNode equalsNode,
@Cached TypeOfNode typeOfNode,
@Cached AnyToTextNode toTextNode,
@Cached(value = "build()", uncached = "build()") CallOptimiserNode callNode) {
Function byFunc = checkAndConvertByParameter(byFuncObj);
// Split into groups
List<Object> elems = readInteropArray(interop, warningsLib, self);
List<Type> comparators = readInteropArray(interop, warningsLib, comparatorsArray);
Expand All @@ -178,7 +179,7 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA
try {
for (var group : groups) {
Comparator javaComparator;
if (isPrimitiveGroup(group)) {
if (isNothing(byFunc) && isPrimitiveGroup(group)) {
javaComparator = new DefaultComparator(
lessThanNode,
equalsNode,
Expand All @@ -187,7 +188,7 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA
ascending > 0
);
} else {
Function compareFunc = isNothing(byFunc) ? group.compareFunc : (Function) byFunc;
Function compareFunc = isNothing(byFunc) ? group.compareFunc : byFunc;
javaComparator = new GenericComparator(
ascending > 0,
compareFunc,
Expand Down Expand Up @@ -216,6 +217,31 @@ Object sortGeneric(State state, Object self, long ascending, Object comparatorsA
}
}

/**
* Checks value given for {@code by} parameter and converts it to {@link Function}. Throw a
* dataflow error otherwise.
*/
private Function checkAndConvertByParameter(Object byFuncObj) {
var ctx = EnsoContext.get(this);
var err = DataflowError.withoutTrace(
ctx.getBuiltins().error().makeUnsupportedArgumentsError(
new Object[]{byFuncObj},
"Unsupported argument for `by`, expected a method with two arguments"
),
this
);
if (byFuncObj instanceof Function byFunc) {
var argCount = byFunc.getSchema().getArgumentsCount();
if (argCount == 2 || argCount == 3) {
return byFunc;
} else {
throw err;
}
} else {
throw err;
}
}

private List<Group> splitByComparators(List<Object> elements, List<Type> comparators,
List<Function> compareFuncs) {
assert elements.size() == comparators.size();
Expand Down Expand Up @@ -576,6 +602,7 @@ private final class GenericComparator extends Comparator {
* extracted from the comparator for the appropriate group.
*/
private final Function compareFunc;
private final boolean compareFuncHasSelf;
private final Type comparator;
private final CallOptimiserNode callNode;
private final State state;
Expand All @@ -601,14 +628,17 @@ private GenericComparator(
this.less = less;
this.equal = equal;
this.greater = greater;
assert compareFunc.getSchema().getArgumentsCount()
>= 2 : "compareFunc should take more than 2 arguments";
this.compareFuncHasSelf = compareFunc.getSchema().getArgumentInfos()[0].getName()
.equals("self");
}

@Override
public int compare(Object x, Object y) {
// We are calling a static method here, so we need to pass the Comparator type as the
// self (first) argument.
Object res = callNode.executeDispatch(compareFunc, null, state,
new Object[]{comparator, x, y});
// If compareFunc takes self parameter, it is `comparator`.
Object[] args = compareFuncHasSelf ? new Object[]{comparator, x, y} : new Object[]{x, y};
Object res = callNode.executeDispatch(compareFunc, null, state, args);
if (res == less) {
return ascending ? -1 : 1;
} else if (res == equal) {
Expand Down
6 changes: 2 additions & 4 deletions test/Tests/src/Data/Vector_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,8 @@ spec = Test.group "Vectors" <|
[Time_Of_Day.new 12, Time_Of_Day.new 10 30].sort . should_equal [Time_Of_Day.new 10 30, Time_Of_Day.new 12]
[Date_Time.new 2000 12 30 12 30, Date_Time.new 2000 12 30 12 00].sort . should_equal [Date_Time.new 2000 12 30 12 00, Date_Time.new 2000 12 30 12 30]

["aa", 2].sort . should_fail_with Incomparable_Values
[2, Date.new 1999].sort . should_fail_with Incomparable_Values
[Date_Time.new 1999 1 1 12 30, Date.new 1999].sort . should_fail_with Incomparable_Values
[Date_Time.new 1999 1 1 12 30, Time_Of_Day.new 12 30].sort . should_fail_with Incomparable_Values
["aa", 2].sort . should_equal [2, "aa"]
[2, Date.new 1999].sort . should_equal [2, Date.new 1999]
Test.expect_panic_with ([3,2,1].to_array.sort 42) Type_Error

Test.specify "should leave the original vector unchanged" <|
Expand Down

0 comments on commit 432a87b

Please sign in to comment.