From fd441e3d9d291aee790d9b6b9f33a8f17643c5ea Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 8 Dec 2023 11:33:06 +0100 Subject: [PATCH] No sequences with get_id --- src/col/vct/col/typerules/CoercionUtils.scala | 1 + .../vct/rewrite/lang/LangCPPToCol.scala | 38 +++++++++++++++---- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index 1b68c5ae92..08416d0c84 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -21,6 +21,7 @@ case object CoercionUtils { case (TResource(), TResourceVal()) => CoerceResourceResourceVal() case (TResourceVal(), TResource()) => CoerceResourceValResource() case (TBool(), TResource()) => CoerceBoolResource() + case (TBool(), TResourceVal()) => CoercionSequence(Seq(CoerceBoolResource(), CoerceResourceResourceVal())) case (_, TAnyValue()) => CoerceSomethingAnyValue(source) diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index be0bba5348..f4fe0d7870 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -3,6 +3,7 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging import hre.util.{FuncTools, ScopedStack} import vct.col.ast._ +import vct.col.ast.util.ExpressionEqualityCheck.isConstantInt import vct.col.origin._ import vct.col.ref.Ref import vct.col.resolve.NotApplicable @@ -650,7 +651,13 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends L LiteralSeq[Post](TInt(), rangeIndexFields.map(f => Deref[Post](currentThis.get, f.ref)(SYCLAccessorRangeIndexFieldInsufficientReferencePermissionBlame(inv)))) case "sycl::accessor::get_range" => throw NotApplicable(inv) case "sycl::range::get" => (classInstance, args) match { - case (Some(seq: LiteralSeq[Post]), Seq(arg)) => SeqSubscript(seq, rw.dispatch(arg))(SYCLRequestedRangeIndexOutOfBoundsBlame(seq, arg)) // Range coming from calling get_range() on an accessor + case (Some(seq: LiteralSeq[Post]), Seq(arg)) => + isConstantInt(arg) match { + case Some(i) if 0<= i && i < seq.values.size => return seq.values(i.toInt) + case _ => + } + + SeqSubscript(seq, rw.dispatch(arg))(SYCLRequestedRangeIndexOutOfBoundsBlame(seq, arg)) // Range coming from calling get_range() on an accessor case _ => throw NotApplicable(inv) } @@ -1083,17 +1090,32 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends L } private def getSimpleWorkItemId(inv: CPPInvocation[Pre], level: KernelScopeLevel) (implicit o: Origin) : Expr[Post] = { - val givenValueLambda: Ref[Post, Procedure[Post]] => Seq[(Ref[Post, Variable[Post]], Expr[Post])] = procedureRef => - Seq((procedureRef.decl.contract.givenArgs.head.ref, LiteralSeq(TInt(), currentDimensions(level).map(iterVar => iterVar.variable.get)))) - - getSYCLWorkItemIdOrRange(inv, givenValueLambda) + val dim = inv.args match { + case Seq(dim) => dim + case _ => ??? + } + isConstantInt(dim) match { + case Some(i) if 0<= i && i < currentDimensions(level).size => currentDimensions(level)(i.toInt).variable.get + case _ => + } + + // // Fallback if we cannot solve it to an int + // SeqSubscript[Post](LiteralSeq(TInt(), currentDimensionIterVars(level).map(iterVar => iterVar.variable.get).toSeq), rw.dispatch(inv.args.head))(SYCLItemMethodSeqBoundFailureBlame(inv)) } private def getSimpleWorkItemRange(inv: CPPInvocation[Pre], level: KernelScopeLevel)(implicit o: Origin): Expr[Post] = { - val givenValueLambda: Ref[Post, Procedure[Post]] => Seq[(Ref[Post, Variable[Post]], Expr[Post])] = procedureRef => - Seq((procedureRef.decl.contract.givenArgs.head.ref, LiteralSeq(TInt(), currentDimensions(level).map(iterVar => iterVar.to)))) + val dim = inv.args match { + case Seq(dim) => dim + case _ => ??? + } + isConstantInt(dim) match { + case Some(i) if 0<= i && i < currentDimensions(level).size => currentDimensions(level)(i.toInt).variable.get + case _ => + } - getSYCLWorkItemIdOrRange(inv, givenValueLambda) + // // Fallback if we cannot solve it to an int + // SeqSubscript[Post](LiteralSeq(TInt(), currentDimensionIterVars(level).map(iterVar => iterVar.to).toSeq), rw.dispatch(inv.args.head))(SYCLItemMethodSeqBoundFailureBlame(inv)) + // Seq((procedureRef.decl.contract.givenArgs.head.ref, LiteralSeq(TInt(), currentDimensions(level).map(iterVar => iterVar.to)))) } private def getSimpleWorkItemLinearId(inv: CPPInvocation[Pre], level: KernelScopeLevel)(implicit o: Origin): Expr[Post] = {