diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java index 71d15c5e591d..4b130c470359 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/meta/EqualsNode.java @@ -20,6 +20,8 @@ import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.callable.function.Function; import org.enso.interpreter.runtime.data.Type; +import org.enso.interpreter.runtime.data.atom.Atom; +import org.enso.interpreter.runtime.data.atom.StructsLibrary; import org.enso.interpreter.runtime.error.PanicException; import org.enso.interpreter.runtime.library.dispatch.TypeOfNode; import org.enso.interpreter.runtime.scope.ModuleScope; @@ -140,8 +142,15 @@ private static Object convertor(EnsoContext ctx, Function convFn, Object value) InvokeFunctionNode.build( argSchema, DefaultsExecutionMode.EXECUTE, ArgumentsExecutionMode.EXECUTE); var state = State.create(ctx); - return node.execute( - convFn, null, state, new Object[] {ctx.getBuiltins().comparable(), value}); + var by = + node.execute(convFn, null, state, new Object[] {ctx.getBuiltins().comparable(), value}); + if (by instanceof Atom atom + && atom.getConstructor() == ctx.getBuiltins().comparable().getBy()) { + var structs = StructsLibrary.getUncached(); + return structs.getField(atom, 1); + } else { + return null; + } } /** @@ -175,14 +184,14 @@ private static boolean findConversionImpl( UnresolvedConversion.build(selfScope).resolveFor(ctx, comparableType, thatType); var betweenBoth = UnresolvedConversion.build(selfScope).resolveFor(ctx, selfType, thatType); - if (isDefinedIn(selfScope, fromSelfType) - && isDefinedIn(selfScope, fromThatType) - && convertor(ctx, fromSelfType, self) == convertor(ctx, fromThatType, that) - && betweenBoth != null) { - return true; - } else { - return false; + if (isDefinedIn(selfScope, fromSelfType) && isDefinedIn(selfScope, fromThatType)) { + var c1 = convertor(ctx, fromSelfType, self); + var c2 = convertor(ctx, fromThatType, that); + if (c1 == c2 && c1 != null && betweenBoth != null) { + return true; + } } + return false; } @Specialization( diff --git a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/CustomComparatorNode.java b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/CustomComparatorNode.java index 7f5aeec7b15d..69ccb62628d0 100644 --- a/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/CustomComparatorNode.java +++ b/engine/runtime/src/main/java/org/enso/interpreter/node/expression/builtin/ordering/CustomComparatorNode.java @@ -4,6 +4,7 @@ import com.oracle.truffle.api.dsl.GenerateUncached; import com.oracle.truffle.api.dsl.NeverDefault; import com.oracle.truffle.api.dsl.Specialization; +import com.oracle.truffle.api.library.CachedLibrary; import com.oracle.truffle.api.nodes.Node; import org.enso.interpreter.node.callable.InvokeCallableNode.ArgumentsExecutionMode; import org.enso.interpreter.node.callable.InvokeCallableNode.DefaultsExecutionMode; @@ -13,6 +14,7 @@ import org.enso.interpreter.runtime.callable.argument.CallArgumentInfo; import org.enso.interpreter.runtime.data.Type; import org.enso.interpreter.runtime.data.atom.Atom; +import org.enso.interpreter.runtime.data.atom.StructsLibrary; import org.enso.interpreter.runtime.state.State; /** @@ -40,18 +42,25 @@ public static CustomComparatorNode getUncached() { @Specialization Type hasCustomComparatorCached( Atom atom, + @CachedLibrary(limit = "1") StructsLibrary structs, @Cached(value = "buildConvertionNode()", allowUncached = true) InvokeConversionNode convertNode, @Cached(value = "createConversion()", allowUncached = true) UnresolvedConversion conversion) { var ctx = EnsoContext.get(this); var comparableType = ctx.getBuiltins().comparable().getType(); var state = State.create(ctx); - Object res = + Object rawRes = convertNode.execute( null, state, conversion, comparableType, atom, new Object[] {comparableType, atom}); - return res instanceof Type result && result != ctx.getBuiltins().defaultComparator().getType() - ? result - : null; + if (rawRes instanceof Atom res + && res.getConstructor() == ctx.getBuiltins().comparable().getBy()) { + if (structs.getField(res, 1) instanceof Type result) { + if (result != ctx.getBuiltins().defaultComparator().getType()) { + return result; + } + } + } + return null; } @NeverDefault