From 79e3d974830daae5f63371a7843fcbc190296719 Mon Sep 17 00:00:00 2001 From: Lars Date: Fri, 8 Dec 2023 11:57:28 +0100 Subject: [PATCH] Fix for no sequence indexing --- src/col/vct/col/typerules/CoercionUtils.scala | 2 ++ .../vct/rewrite/lang/LangCPPToCol.scala | 30 +++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/col/vct/col/typerules/CoercionUtils.scala b/src/col/vct/col/typerules/CoercionUtils.scala index fa6a53bc67..a3bc9b89c3 100644 --- a/src/col/vct/col/typerules/CoercionUtils.scala +++ b/src/col/vct/col/typerules/CoercionUtils.scala @@ -21,6 +21,8 @@ case object CoercionUtils { case (TResource(), TResourceVal()) => CoerceResourceResourceVal() case (TResourceVal(), TResource()) => CoerceResourceValResource() case (TBool(), TResource()) => CoerceBoolResource() + case (TBool(), TResourceVal()) => CoercionSequence(Seq(CoerceBoolResource(), CoerceResourceResourceVal())) + case (_, TAnyValue()) => CoerceSomethingAnyValue(source) diff --git a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala index 6a78c9370d..6c273afc4b 100644 --- a/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangCPPToCol.scala @@ -3,6 +3,7 @@ package vct.rewrite.lang import com.typesafe.scalalogging.LazyLogging import hre.util.ScopedStack import vct.col.ast.{CPPLocalDeclaration, Expr, InstanceField, Perm, TInt, _} +import vct.col.ast.util.ExpressionEqualityCheck.isConstantInt import vct.col.origin._ import vct.col.ref.Ref import vct.col.resolve.NotApplicable @@ -767,7 +768,13 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends L case _ => throw NotApplicable(inv) } case "sycl::range::get" => (classInstance, args) match { - case (Some(seq: LiteralSeq[Post]), Seq(arg)) => SeqSubscript(seq, rw.dispatch(arg))(SYCLRequestedRangeIndexOutOfBoundsBlame(seq, arg)) // Range coming from calling get_range() on a (local)accessor + case (Some(seq: LiteralSeq[Post]), Seq(arg)) => + + isConstantInt(arg) match { + case Some(i) if 0<= i && i < seq.values.size => seq.values(i.toInt) + case _ => ??? + } + // SeqSubscript(seq, rw.dispatch(arg))(SYCLRequestedRangeIndexOutOfBoundsBlame(seq, arg)) // Range coming from calling get_range() on a (local)accessor case _ => throw NotApplicable(inv) } @@ -1265,11 +1272,28 @@ case class LangCPPToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends L } private def getSimpleWorkItemId(inv: CPPInvocation[Pre], level: KernelScopeLevel) (implicit o: Origin) : Expr[Post] = { - SeqSubscript[Post](LiteralSeq(TInt(), currentDimensionIterVars(level).map(iterVar => iterVar.variable.get).toSeq), rw.dispatch(inv.args.head))(SYCLItemMethodSeqBoundFailureBlame(inv)) + val dim = inv.args match { + case Seq(dim) => dim + case _ => ??? + } + isConstantInt(dim) match { + case Some(i) if 0<= i && i < currentDimensionIterVars(level).size => currentDimensionIterVars(level)(i.toInt).variable.get + case _ => ??? + } + + // SeqSubscript[Post](LiteralSeq(TInt(), currentDimensionIterVars(level).map(iterVar => iterVar.variable.get).toSeq), rw.dispatch(inv.args.head))(SYCLItemMethodSeqBoundFailureBlame(inv)) } private def getSimpleWorkItemRange(inv: CPPInvocation[Pre], level: KernelScopeLevel)(implicit o: Origin): Expr[Post] = { - SeqSubscript[Post](LiteralSeq(TInt(), currentDimensionIterVars(level).map(iterVar => iterVar.to).toSeq), rw.dispatch(inv.args.head))(SYCLItemMethodSeqBoundFailureBlame(inv)) + val dim = inv.args match { + case Seq(dim) => dim + case _ => ??? + } + isConstantInt(dim) match { + case Some(i) if 0<= i && i < currentDimensionIterVars(level).size => currentDimensionIterVars(level)(i.toInt).variable.get + case _ => ??? + } + // SeqSubscript[Post](LiteralSeq(TInt(), currentDimensionIterVars(level).map(iterVar => iterVar.to).toSeq), rw.dispatch(inv.args.head))(SYCLItemMethodSeqBoundFailureBlame(inv)) } private def getSimpleWorkItemLinearId(inv: CPPInvocation[Pre], level: KernelScopeLevel)(implicit o: Origin): Expr[Post] = {