From 831600165659ba722d7affe1f1a2fc1062fd6783 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=C3=A9=20van=20Oorschot?= Date: Fri, 30 Jun 2023 23:39:25 +0200 Subject: [PATCH 1/4] Add support for pure functions --- src/col/vct/col/ast/Node.scala | 21 +-- .../LLVMAmbiguousFunctionInvocationImpl.scala | 7 +- src/col/vct/col/ast/lang/LLVMGlobalImpl.scala | 7 + .../col/ast/lang/LLVMSpecFunctionImpl.scala | 31 ++++ src/col/vct/col/resolve/Resolve.scala | 141 ++++++++++-------- src/col/vct/col/resolve/ctx/Referrable.scala | 7 +- src/col/vct/col/resolve/lang/LLVM.scala | 28 ++++ .../vct/col/typerules/CoercingRewriter.scala | 5 +- src/colhelper/ColDefs.scala | 3 +- src/main/vct/main/stages/Resolution.scala | 12 +- src/parsers/vct/parsers/ColLLVMParser.scala | 17 +++ .../parsers/transform/LLVMContractToCol.scala | 39 +++++ .../vct/rewrite/lang/LangLLVMToCol.scala | 66 ++++++-- .../vct/rewrite/lang/LangSpecificToCol.scala | 1 + 14 files changed, 291 insertions(+), 94 deletions(-) create mode 100644 src/col/vct/col/ast/lang/LLVMGlobalImpl.scala create mode 100644 src/col/vct/col/ast/lang/LLVMSpecFunctionImpl.scala create mode 100644 src/col/vct/col/resolve/lang/LLVM.scala diff --git a/src/col/vct/col/ast/Node.scala b/src/col/vct/col/ast/Node.scala index 1c66c94088..8cc33b44ab 100644 --- a/src/col/vct/col/ast/Node.scala +++ b/src/col/vct/col/ast/Node.scala @@ -1081,36 +1081,35 @@ final case class BipInternal[G]()(implicit val o: Origin = DiagnosticOrigin) ext final case class BipPortSynchronization[G](ports: Seq[Ref[G, BipPort[G]]], wires: Seq[BipGlueDataWire[G]])(val blame: Blame[BipSynchronizationFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with BipPortSynchronizationImpl[G] final case class BipTransitionSynchronization[G](transitions: Seq[Ref[G, BipTransition[G]]], wires: Seq[BipGlueDataWire[G]])(val blame: Blame[BipSynchronizationFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with BipTransitionSynchronizationImpl[G] -final class LlvmFunctionContract[G](val value:String, val variableRefs:Seq[(String, Ref[G, Variable[G]])], val invokableRefs:Seq[(String, Ref[G, LlvmFunctionDefinition[G]])]) +final class LlvmFunctionContract[G](val value:String, val variableRefs:Seq[(String, Ref[G, Variable[G]])], val invokableRefs:Seq[(String, Ref[G, LlvmCallable[G]])]) (val blame: Blame[NontrivialUnsatisfiable]) (implicit val o: Origin) extends NodeFamily[G] with LLVMFunctionContractImpl[G] { var data: Option[ApplicableContract[G]] = None } - +sealed trait LlvmCallable[G] extends GlobalDeclaration[G] final class LlvmFunctionDefinition[G](val returnType: Type[G], val args: Seq[Variable[G]], val functionBody: Statement[G], val contract: LlvmFunctionContract[G], val pure: Boolean = false) (val blame: Blame[CallableFailure])(implicit val o: Origin) - extends GlobalDeclaration[G] with Applicable[G] with LLVMFunctionDefinitionImpl[G] - + extends LlvmCallable[G] with Applicable[G] with LLVMFunctionDefinitionImpl[G] +final class LlvmSpecFunction[G](val name: String, val returnType: Type[G], val args: Seq[Variable[G]], val typeArgs: Seq[Variable[G]], + val body: Option[Expr[G]], val contract: ApplicableContract[G], val inline: Boolean = false, val threadLocal: Boolean = false) + (val blame: Blame[ContractedFailure])(implicit val o: Origin) + extends LlvmCallable[G] with AbstractFunction[G] with LLVMSpecFunctionImpl[G] final case class LlvmFunctionInvocation[G](ref: Ref[G, LlvmFunctionDefinition[G]], args: Seq[Expr[G]], givenMap: Seq[(Ref[G, Variable[G]], Expr[G])], yields: Seq[(Expr[G], Ref[G, Variable[G]])]) (val blame: Blame[InvocationFailure])(implicit val o: Origin) extends Apply[G] with LLVMFunctionInvocationImpl[G] - final case class LlvmLoop[G](cond:Expr[G], contract:LlvmLoopContract[G], body:Statement[G]) (implicit val o: Origin) extends CompositeStatement[G] with LLVMLoopImpl[G] - sealed trait LlvmLoopContract[G] extends NodeFamily[G] with LLVMLoopContractImpl[G] final case class LlvmLoopInvariant[G](value:String, references:Seq[(String, Ref[G, Declaration[G]])]) (val blame: Blame[LoopInvariantFailure]) (implicit val o: Origin) extends LlvmLoopContract[G] with LLVMLoopInvariantImpl[G] - sealed trait LlvmExpr[G] extends Expr[G] with LLVMExprImpl[G] - final case class LlvmLocal[G](name: String)(val blame: Blame[DerefInsufficientPermission])(implicit val o: Origin) extends LlvmExpr[G] with LLVMLocalImpl[G] { var ref: Option[Ref[G, Variable[G]]] = None } @@ -1119,7 +1118,11 @@ final case class LlvmAmbiguousFunctionInvocation[G](name: String, givenMap: Seq[(Ref[G, Variable[G]], Expr[G])], yields: Seq[(Expr[G], Ref[G, Variable[G]])]) (val blame: Blame[InvocationFailure])(implicit val o: Origin) extends LlvmExpr[G] with LLVMAmbiguousFunctionInvocationImpl[G] { - var ref: Option[Ref[G, LlvmFunctionDefinition[G]]] = None + var ref: Option[Ref[G, LlvmCallable[G]]] = None +} + +final class LlvmGlobal[G](val value: String)(implicit val o: Origin) extends GlobalDeclaration[G] with LLVMGlobalImpl[G] { + var data: Option[GlobalDeclaration[G]] = None } sealed trait PVLType[G] extends Type[G] with PVLTypeImpl[G] final case class PVLNamedType[G](name: String, typeArgs: Seq[Type[G]])(implicit val o: Origin = DiagnosticOrigin) extends PVLType[G] with PVLNamedTypeImpl[G] { diff --git a/src/col/vct/col/ast/lang/LLVMAmbiguousFunctionInvocationImpl.scala b/src/col/vct/col/ast/lang/LLVMAmbiguousFunctionInvocationImpl.scala index 6bd9996497..d55b63c0bd 100644 --- a/src/col/vct/col/ast/lang/LLVMAmbiguousFunctionInvocationImpl.scala +++ b/src/col/vct/col/ast/lang/LLVMAmbiguousFunctionInvocationImpl.scala @@ -1,11 +1,12 @@ package vct.col.ast.lang -import vct.col.ast.{LlvmAmbiguousFunctionInvocation, Type} +import vct.col.ast.{LlvmAmbiguousFunctionInvocation, LlvmFunctionDefinition, LlvmSpecFunction, Type} import vct.col.print.{Ctx, Doc, DocUtil, Group, Precedence, Text} trait LLVMAmbiguousFunctionInvocationImpl[G] { this: LlvmAmbiguousFunctionInvocation[G] => - override lazy val t: Type[G] = ref match { - case Some(ref) => ref.decl.returnType + override lazy val t: Type[G] = ref.get.decl match { + case func: LlvmFunctionDefinition[G] => func.returnType + case func: LlvmSpecFunction[G] => func.returnType } override def precedence: Int = Precedence.POSTFIX diff --git a/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala b/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala new file mode 100644 index 0000000000..e8dd7407e6 --- /dev/null +++ b/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala @@ -0,0 +1,7 @@ +package vct.col.ast.lang + +import vct.col.ast.LlvmGlobal + +trait LLVMGlobalImpl[G] { this: LlvmGlobal[G] => + +} diff --git a/src/col/vct/col/ast/lang/LLVMSpecFunctionImpl.scala b/src/col/vct/col/ast/lang/LLVMSpecFunctionImpl.scala new file mode 100644 index 0000000000..013e845c14 --- /dev/null +++ b/src/col/vct/col/ast/lang/LLVMSpecFunctionImpl.scala @@ -0,0 +1,31 @@ +package vct.col.ast.lang + +import vct.col.ast.LlvmSpecFunction +import vct.col.ast.declaration.category.AbstractFunctionImpl +import vct.col.ast.declaration.global.GlobalDeclarationImpl +import vct.col.print.{Ctx, Doc, Empty, Group, Show, Text} + +import scala.collection.immutable.ListMap + +trait LLVMSpecFunctionImpl[G] extends GlobalDeclarationImpl[G] with AbstractFunctionImpl[G] { + this: LlvmSpecFunction[G] => + + def layoutModifiers(implicit ctx: Ctx): Seq[Doc] = ListMap( + inline -> "inline", + threadLocal -> "thread_local", + ).filter(_._1).values.map(Text).map(Doc.inlineSpec).toSeq + + def layoutSpec(implicit ctx: Ctx): Doc = + Doc.stack(Seq( + contract, + Group( + Group(Doc.rspread(layoutModifiers) <> "pure" <+> returnType <+> ctx.name(this) <> + (if (typeArgs.nonEmpty) Text("<") <> Doc.args(typeArgs.map(ctx.name).map(Text)) <> ">" else Empty) <> + "(" <> Doc.args(args) <> ")") <> + body.map(Text(" =") <>> _ <> ";").getOrElse(Text(";")) + ), + )) + + override def layout(implicit ctx: Ctx): Doc = Doc.spec(Show.lazily(layoutSpec(_))) + +} diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index 753c423d67..9ce0b46647 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -10,7 +10,7 @@ import vct.col.origin._ import vct.col.resolve.ResolveReferences.scanScope import vct.col.ref.Ref import vct.col.resolve.ctx._ -import vct.col.resolve.lang.{C, Java, PVL, Spec} +import vct.col.resolve.lang.{C, Java, LLVM, PVL, Spec} import vct.col.resolve.Resolve.{MalformedBipAnnotation, SpecContractParser, SpecExprParser, getLit, isBip} import vct.col.resolve.lang.JavaAnnotationData.{BipComponent, BipData, BipGuard, BipInvariant, BipPort, BipPure, BipStatePredicate, BipTransition} import vct.col.rewrite.InitialGeneration @@ -31,13 +31,15 @@ case object Resolve { } trait SpecContractParser { - def parse[G](input: LlvmFunctionContract[G], o:Origin): ApplicableContract[G] + def parse[G](input: LlvmFunctionContract[G], o: Origin): ApplicableContract[G] + + def parse[G](input: LlvmGlobal[G], o: Origin): GlobalDeclaration[G] } def extractLiteral(e: Expr[_]): Option[String] = e match { case JavaStringValue(guardName, _) => Some(guardName) - case local @ JavaLocal(_) => + case local@JavaLocal(_) => local.ref match { case Some(RefJavaField(decls, id)) => decls.decls(id).init match { @@ -51,6 +53,7 @@ case object Resolve { case class UnexpectedComplicatedExpression(e: Expr[_]) extends UserError { override def code: String = "unexpectedComplicatedExpression" + override def text: String = e.o.messageInContext("This expression must either be a string literal or trivially resolve to one") } @@ -68,8 +71,10 @@ case object Resolve { case object ResolveTypes { sealed trait JavaClassPathEntry + case object JavaClassPathEntry { case object SourcePackageRoot extends JavaClassPathEntry + case class Path(root: java.nio.file.Path) extends JavaClassPathEntry } @@ -86,12 +91,12 @@ case object ResolveTypes { } def scanImport[G](imp: JavaImport[G], ctx: TypeResolutionContext[G]): Seq[Referrable[G]] /* importable? */ = imp match { - case imp @ JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName :+ staticMember), /* star = */ false) => + case imp@JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName :+ staticMember), /* star = */ false) => val staticType = Java.findJavaTypeName(fullyQualifiedTypeName, ctx) .getOrElse(throw NoSuchNameError("class", fullyQualifiedTypeName.mkString("."), imp)) Seq(Java.findStaticMember(staticType, staticMember) .getOrElse(throw NoSuchNameError("static member", (fullyQualifiedTypeName :+ staticMember).mkString("."), imp))) - case imp @ JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName), /* star = */ true) => + case imp@JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName), /* star = */ true) => val typeName = Java.findJavaTypeName(fullyQualifiedTypeName, ctx) .getOrElse(throw NoSuchNameError("class", fullyQualifiedTypeName.mkString("."), imp)) Java.getStaticMembers(typeName) @@ -101,42 +106,42 @@ case object ResolveTypes { def enterContext[G](node: Node[G], ctx: TypeResolutionContext[G]): TypeResolutionContext[G] = node match { case Program(decls) => - ctx.copy(stack=decls.flatMap(Referrable.from) +: ctx.stack) + ctx.copy(stack = decls.flatMap(Referrable.from) +: ctx.stack) case ns: JavaNamespace[G] => // Static imports need to be imported at this stage, because they influence how names are resolved. // E.g.: in the expression f.g, f is either a 1) variable, 2) parameter or 3) field. If none of those, it must be a // 4) statically imported field or typename, or 5) a non-static imported typename. If it's not that, it's a package name. // ctx.stack needs to be modified for this, and hence this importing is done in enterContext instead of in resolveOne. - val ctxWithNs = ctx.copy(namespace=Some(ns)) - ctxWithNs.copy(stack=(ns.declarations.flatMap(Referrable.from) ++ ns.imports.flatMap(scanImport(_, ctxWithNs))) +: ctx.stack) + val ctxWithNs = ctx.copy(namespace = Some(ns)) + ctxWithNs.copy(stack = (ns.declarations.flatMap(Referrable.from) ++ ns.imports.flatMap(scanImport(_, ctxWithNs))) +: ctx.stack) case Scope(locals, body) => ctx - .copy(stack = ((locals ++ scanScope(body, /* inGPUkernel = */false)).flatMap(Referrable.from)) +: ctx.stack) + .copy(stack = ((locals ++ scanScope(body, /* inGPUkernel = */ false)).flatMap(Referrable.from)) +: ctx.stack) case decl: Declarator[G] => - ctx.copy(stack=decl.declarations.flatMap(Referrable.from) +: ctx.stack) + ctx.copy(stack = decl.declarations.flatMap(Referrable.from) +: ctx.stack) case _ => ctx } def resolveOne[G](node: Node[G], ctx: TypeResolutionContext[G]): Unit = node match { - case javaClass @ JavaNamedType(genericNames) => + case javaClass@JavaNamedType(genericNames) => val names = genericNames.map(_._1) javaClass.ref = Some(Java.findJavaTypeName(names, ctx) .getOrElse(throw NoSuchNameError("class", names.mkString("."), javaClass))) - case t @ JavaTClass(ref, _) => + case t@JavaTClass(ref, _) => ref.tryResolve(name => ???) - case t @ CTypedefName(name) => + case t@CTypedefName(name) => t.ref = Some(C.findCTypeName(name, ctx).getOrElse( throw NoSuchNameError("struct", name, t) )) - case t @ PVLNamedType(name, typeArgs) => + case t@PVLNamedType(name, typeArgs) => t.ref = Some(PVL.findTypeName(name, ctx).getOrElse( throw NoSuchNameError("class", name, t))) - case t @ TModel(ref) => + case t@TModel(ref) => ref.tryResolve(name => Spec.findModel(name, ctx).getOrElse(throw NoSuchNameError("model", name, t))) - case t @ TClass(ref) => + case t@TClass(ref) => ref.tryResolve(name => Spec.findClass(name, ctx).getOrElse(throw NoSuchNameError("class", name, t))) - case t @ TAxiomatic(ref, _) => + case t@TAxiomatic(ref, _) => ref.tryResolve(name => Spec.findAdt(name, ctx).getOrElse(throw NoSuchNameError("adt", name, t))) - case t @ SilverPartialTAxiomatic(ref, partialTypeArgs) => + case t@SilverPartialTAxiomatic(ref, partialTypeArgs) => ref.tryResolve(name => Spec.findAdt(name, ctx).getOrElse(throw NoSuchNameError("adt", name, t))) partialTypeArgs.foreach(mapping => mapping._1.tryResolve(name => Spec.findAdtTypeArg(ref.decl, name).getOrElse(throw NoSuchNameError("type variable", name, t)))) case cls: Class[G] => @@ -158,8 +163,8 @@ case object ResolveTypes { } case deref: JavaDeref[G] => val ref = deref.obj match { - case javalocal : JavaLocal[G] => javalocal.ref - case javaderef : JavaDeref[G] => javaderef.ref + case javalocal: JavaLocal[G] => javalocal.ref + case javaderef: JavaDeref[G] => javaderef.ref case _ => None } ref match { @@ -181,18 +186,18 @@ case object ResolveReferences extends LazyLogging { resolve(program, ReferenceResolutionContext[G](jp, lsp)) } - def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean=false): Seq[CheckError] = { + def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean = false): Seq[CheckError] = { val inGPU = inGPUKernel || (node match { - case f: CFunctionDefinition[G] => f.specs.collectFirst{case _: CGpgpuKernelSpecifier[G] => ()}.isDefined + case f: CFunctionDefinition[G] => f.specs.collectFirst { case _: CGpgpuKernelSpecifier[G] => () }.isDefined case _ => false }) val childErrors = node match { - case l @ Let(binding, value, main) => + case l@Let(binding, value, main) => val innerCtx = enterContext(node, ctx, inGPU).copy(checkContext = l.enterCheckContext(ctx.checkContext)) resolve(binding, innerCtx) ++ - resolve(value, ctx) ++ - resolve(main, innerCtx) + resolve(value, ctx) ++ + resolve(main, innerCtx) case _ => val innerCtx = enterContext(node, ctx, inGPU) node.checkContextRecursor(ctx.checkContext, { (ctx, node) => @@ -200,7 +205,7 @@ case object ResolveReferences extends LazyLogging { }).flatten } - if(childErrors.nonEmpty) childErrors + if (childErrors.nonEmpty) childErrors else { resolveFlatly(node, ctx) node.check(ctx.checkContext) @@ -210,7 +215,7 @@ case object ResolveReferences extends LazyLogging { def scanScope[G](node: Node[G], inGPUKernel: Boolean): Seq[Declaration[G]] = node match { case _: Scope[G] => Nil // Remove shared memory locations from the body level of a GPU kernel, we want to reason about them at the top level - case CDeclarationStatement(decl) if !(inGPUKernel && decl.decl.specs.collectFirst{case GPULocal() => ()}.isDefined) + case CDeclarationStatement(decl) if !(inGPUKernel && decl.decl.specs.collectFirst { case GPULocal() => () }.isDefined) => Seq(decl) case JavaLocalDeclarationStatement(decl) => Seq(decl) case LocalDecl(v) => Seq(v) @@ -229,7 +234,7 @@ case object ResolveReferences extends LazyLogging { } def scanShared[G](node: Node[G]): Seq[Declaration[G]] = node.transSubnodes.collect { - case decl: CLocalDeclaration[G] if decl.decl.specs.collectFirst{case GPULocal() => ()}.isDefined => decl + case decl: CLocalDeclaration[G] if decl.decl.specs.collectFirst { case GPULocal() => () }.isDefined => decl } def scanJavaBipGuards[G](nodes: Seq[Declaration[G]]): Seq[(Expr[G], JavaMethod[G])] = nodes.collect { @@ -242,7 +247,7 @@ case object ResolveReferences extends LazyLogging { def enterContext[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean = false): ReferenceResolutionContext[G] = (node match { case ns: JavaNamespace[G] => ctx - .copy(currentJavaNamespace=Some(ns)) + .copy(currentJavaNamespace = Some(ns)) .copy(stack = ns.imports.flatMap(ResolveTypes.scanImport[G](_, ctx.asTypeResolutionContext)) +: ctx.stack) .declare(ns.declarations) case cls: JavaClassOrInterface[G] => { @@ -269,36 +274,36 @@ case object ResolveReferences extends LazyLogging { case deref: JavaDeref[G] => return ctx .copy(topLevelJavaDeref = ctx.topLevelJavaDeref.orElse(Some(deref))) case cls: Class[G] => ctx - .copy(currentThis=Some(RefClass(cls))) + .copy(currentThis = Some(RefClass(cls))) .declare(cls.declarations) case app: ContractApplicable[G] => ctx - .copy(currentResult=Some(Referrable.from(app).head.asInstanceOf[ResultTarget[G]] /* PB TODO: ew */)) + .copy(currentResult = Some(Referrable.from(app).head.asInstanceOf[ResultTarget[G]] /* PB TODO: ew */)) .declare(app.declarations ++ app.body.map(scanLabels).getOrElse(Nil)) case method: JavaMethod[G] => ctx - .copy(currentResult=Some(RefJavaMethod(method))) + .copy(currentResult = Some(RefJavaMethod(method))) .declare(method.declarations ++ method.body.map(scanLabels).getOrElse(Nil)) case fields: JavaFields[G] => ctx - .copy(currentInitializerType=Some(fields.t)) + .copy(currentInitializerType = Some(fields.t)) case locals: JavaLocalDeclaration[G] => ctx - .copy(currentInitializerType=Some(locals.t)) + .copy(currentInitializerType = Some(locals.t)) case decl: JavaVariableDeclaration[G] => ctx - .copy(currentInitializerType=ctx.currentInitializerType.map(t => FuncTools.repeat((t: Type[G]) => TArray(t), decl.moreDims, t))) + .copy(currentInitializerType = ctx.currentInitializerType.map(t => FuncTools.repeat((t: Type[G]) => TArray(t), decl.moreDims, t))) case arr: JavaNewLiteralArray[G] => ctx - .copy(currentInitializerType=Some(FuncTools.repeat((t: Type[G]) => TArray(t), arr.dims, arr.baseType))) + .copy(currentInitializerType = Some(FuncTools.repeat((t: Type[G]) => TArray(t), arr.dims, arr.baseType))) case init: JavaLiteralArray[G] => ctx - .copy(currentInitializerType=Some(ctx.currentInitializerType.get match { + .copy(currentInitializerType = Some(ctx.currentInitializerType.get match { case TArray(elem) => elem case _ => throw WrongArrayInitializer(init) })) case func: CFunctionDefinition[G] => var res = ctx - .copy(currentResult=Some(RefCFunctionDefinition(func))) + .copy(currentResult = Some(RefCFunctionDefinition(func))) .declare(C.paramsFromDeclarator(func.declarator) ++ scanLabels(func.body) ++ func.contract.givenArgs ++ func.contract.yieldsArgs) - if(func.specs.collectFirst{case _: CGpgpuKernelSpecifier[G] => ()}.isDefined) + if (func.specs.collectFirst { case _: CGpgpuKernelSpecifier[G] => () }.isDefined) res = res.declare(scanShared(func.body)) res case func: CGlobalDeclaration[G] => - if(func.decl.contract.nonEmpty && func.decl.inits.size > 1) { + if (func.decl.contract.nonEmpty && func.decl.inits.size > 1) { throw MultipleForwardDeclarationContractError(func) } @@ -309,10 +314,13 @@ case object ResolveReferences extends LazyLogging { val info = C.getDeclaratorInfo(init.decl) ctx .declare(info.params.getOrElse(Nil)) - .copy(currentResult=info.params.map(_ => RefCGlobalDeclaration(func, idx))) + .copy(currentResult = info.params.map(_ => RefCGlobalDeclaration(func, idx))) } case func: LlvmFunctionDefinition[G] => ctx .copy(currentResult = Some(RefLlvmFunctionDefinition(func))) + case func: LlvmSpecFunction[G] => ctx + .copy(currentResult = Some(RefLlvmSpecFunction(func))) + .declare(func.args) case par: ParStatement[G] => ctx .declare(scanBlocks(par.impl).map(_.decl)) case Scope(locals, body) => ctx @@ -327,20 +335,22 @@ case object ResolveReferences extends LazyLogging { def resolveFlatly[G](node: Node[G], ctx: ReferenceResolutionContext[G]): Unit = node match { case local@CLocal(name) => local.ref = Some(C.findCName(name, ctx).getOrElse(throw NoSuchNameError("local", name, local))) - case local @ JavaLocal(name) => + case local@JavaLocal(name) => val start: Option[JavaNameTarget[G]] = if (ctx.javaBipGuardsEnabled) { Java.findJavaBipGuard(ctx, name).map(RefJavaBipGuard(_)) - } else { None } + } else { + None + } local.ref = Some(start.orElse( Java.findJavaName(name, ctx.asTypeResolutionContext) .orElse(Java.findJavaTypeName(Seq(name), ctx.asTypeResolutionContext) match { case Some(target: JavaNameTarget[G]) => Some(target) case None => None })) - .getOrElse( - if (ctx.topLevelJavaDeref.isEmpty) throw NoSuchNameError("local", name, local) - else RefUnloadedJavaNamespace(Seq(name)))) - case local @ PVLLocal(name) => + .getOrElse( + if (ctx.topLevelJavaDeref.isEmpty) throw NoSuchNameError("local", name, local) + else RefUnloadedJavaNamespace(Seq(name)))) + case local@PVLLocal(name) => local.ref = Some(PVL.findName(name, ctx).getOrElse(throw NoSuchNameError("local", name, local))) case local@Local(ref) => ref.tryResolve(name => Spec.findLocal(name, ctx).getOrElse(throw NoSuchNameError("local", name, local))) @@ -360,7 +370,7 @@ case object ResolveReferences extends LazyLogging { case Some(RefUnloadedJavaNamespace(_)) => true case _ => false })) throw NoSuchNameError("field", field, deref) - case deref @ PVLDeref(obj, field) => + case deref@PVLDeref(obj, field) => deref.ref = Some(PVL.findDeref(obj, field, ctx, deref.blame).getOrElse(throw NoSuchNameError("field", field, deref))) case deref@Deref(obj, field) => field.tryResolve(name => Spec.findField(obj, name).getOrElse(throw NoSuchNameError("field", name, deref))) @@ -428,11 +438,11 @@ case object ResolveReferences extends LazyLogging { case inv@SilverPartialADTFunctionInvocation(name, args, partialTypeArgs) => inv.ref = Some(Spec.findAdtFunction(name, ctx).getOrElse(throw NoSuchNameError("function", name, inv))) partialTypeArgs.foreach(mapping => mapping._1.tryResolve(name => Spec.findAdtTypeArg(inv.adt, name).getOrElse(throw NoSuchNameError("type variable", name, inv)))) - case inv @ InvokeProcedure(ref, _, _, _, givenMap, yields) => + case inv@InvokeProcedure(ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findProcedure(name, ctx).getOrElse(throw NoSuchNameError("procedure", name, inv))) Spec.resolveGiven(givenMap, RefProcedure(ref.decl), inv) Spec.resolveYields(ctx, yields, RefProcedure(ref.decl), inv) - case inv @ ProcedureInvocation(ref, _, _, _, givenMap, yields) => + case inv@ProcedureInvocation(ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findProcedure(name, ctx).getOrElse(throw NoSuchNameError("procedure", name, inv))) Spec.resolveGiven(givenMap, RefProcedure(ref.decl), inv) Spec.resolveYields(ctx, yields, RefProcedure(ref.decl), inv) @@ -444,11 +454,11 @@ case object ResolveReferences extends LazyLogging { ref.tryResolve(name => Spec.findPredicate(name, ctx).getOrElse(throw NoSuchNameError("predicate", name, inv))) case inv@SilverCurPredPerm(ref, _) => ref.tryResolve(name => Spec.findPredicate(name, ctx).getOrElse(throw NoSuchNameError("predicate", name, inv))) - case inv @ InvokeMethod(obj, ref, _, _, _, givenMap, yields) => + case inv@InvokeMethod(obj, ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findMethod(obj, name).getOrElse(throw NoSuchNameError("method", name, inv))) Spec.resolveGiven(givenMap, RefInstanceMethod(ref.decl), inv) Spec.resolveYields(ctx, yields, RefInstanceMethod(ref.decl), inv) - case inv @ MethodInvocation(obj, ref, _, _, _, givenMap, yields) => + case inv@MethodInvocation(obj, ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findMethod(obj, name).getOrElse(throw NoSuchNameError("method", name, inv))) Spec.resolveGiven(givenMap, RefInstanceMethod(ref.decl), inv) Spec.resolveYields(ctx, yields, RefInstanceMethod(ref.decl), inv) @@ -456,7 +466,7 @@ case object ResolveReferences extends LazyLogging { ref.tryResolve(name => Spec.findInstanceFunction(obj, name).getOrElse(throw NoSuchNameError("function", name, inv))) case inv@InstancePredicateApply(obj, ref, _, _) => ref.tryResolve(name => Spec.findInstancePredicate(obj, name).getOrElse(throw NoSuchNameError("predicate", name, inv))) - case inv @ CoalesceInstancePredicateApply(obj, ref, _, _) => + case inv@CoalesceInstancePredicateApply(obj, ref, _, _) => ref.tryResolve(name => Spec.findInstancePredicate(obj, name).getOrElse(throw NoSuchNameError("predicate", name, inv))) case defn: CFunctionDefinition[G] => @@ -525,7 +535,7 @@ case object ResolveReferences extends LazyLogging { def extractExpr(s: Option[Expr[_]]): (String, Origin) = s match { case None => ("true", ann.o) - case Some(s @ JavaStringValue(data, _)) => (data, s.o) + case Some(s@JavaStringValue(data, _)) => (data, s.o) case Some(n) => throw MalformedBipAnnotation(n, "pre- and post-conditions must be string literals") } @@ -568,9 +578,9 @@ case object ResolveReferences extends LazyLogging { case ann: JavaAnnotation[G] if isBip(ann, "Port") => val portType: BipPortType[G] = ann.expect("type") match { - case p @ JavaDeref(_, "enforceable") => BipEnforceable[G]()(p.o) - case p @ JavaDeref(_, "spontaneous") => BipSpontaneous[G]()(p.o) - case p @ JavaDeref(_, "internal") => BipInternal[G]()(p.o) + case p@JavaDeref(_, "enforceable") => BipEnforceable[G]()(p.o) + case p@JavaDeref(_, "spontaneous") => BipSpontaneous[G]()(p.o) + case p@JavaDeref(_, "internal") => BipInternal[G]()(p.o) case e => throw MalformedBipAnnotation(e, "Can be either PortType.enforceable, spontaneous, or internal") } ann.data = Some(BipPort[G](getLit(ann.expect("name")), portType)(ann.o)) @@ -578,7 +588,7 @@ case object ResolveReferences extends LazyLogging { case ann: JavaAnnotation[G] if isBip(ann, "Pure") => ann.data = Some(BipPure[G]()) - case portName @ JavaBipGlueName(JavaTClass(Ref(cls: JavaClass[G]), Nil), name) => + case portName@JavaBipGlueName(JavaTClass(Ref(cls: JavaClass[G]), Nil), name) => portName.data = Some((cls, getLit(name))) case contract: LlvmFunctionContract[G] => @@ -592,15 +602,18 @@ case object ResolveReferences extends LazyLogging { case Some(ref) => Some(ref._2) case None => throw NoSuchNameError("local", local.name, local) } + case RefLlvmSpecFunction(_) => + Some(Spec.findLocal(local.name, ctx).getOrElse(throw NoSuchNameError("local", local.name, local)).ref) } case inv: LlvmAmbiguousFunctionInvocation[G] => - inv.ref = ctx.currentResult.get match { - case RefLlvmFunctionDefinition(decl) => - decl.contract.invokableRefs.find(ref => ref._1 == inv.name) match { - case Some(ref) => Some(ref._2) - case None => throw NoSuchNameError("function", inv.name, inv) - } + inv.ref = LLVM.findCallable(inv.name, ctx) match { + case Some(callable) => Some(callable.ref) + case None => throw NoSuchNameError("function", inv.name, inv) } + case glob: LlvmGlobal[G] => + val decl = ctx.llvmSpecParser.parse(glob, glob.o) + glob.data = Some(decl) + resolve(decl, ctx) case _ => } } diff --git a/src/col/vct/col/resolve/ctx/Referrable.scala b/src/col/vct/col/resolve/ctx/Referrable.scala index 8a9ee40091..830f9f22ec 100644 --- a/src/col/vct/col/resolve/ctx/Referrable.scala +++ b/src/col/vct/col/resolve/ctx/Referrable.scala @@ -128,6 +128,8 @@ case object Referrable { case decl: VeyMontThread[G] => RefVeyMontThread(decl) case decl: JavaBipGlueContainer[G] => RefJavaBipGlueContainer() case decl: LlvmFunctionDefinition[G] => RefLlvmFunctionDefinition(decl) + case decl: LlvmGlobal[G] => RefLlvmGlobal(decl) + case decl: LlvmSpecFunction[G] => RefLlvmSpecFunction(decl) case decl: ProverType[G] => RefProverType(decl) case decl: ProverFunction[G] => RefProverFunction(decl) }) @@ -171,7 +173,7 @@ sealed trait SpecInvocationTarget[G] extends JavaInvocationTarget[G] with CNameTarget[G] with CDerefTarget[G] with CInvocationTarget[G] - with PVLInvocationTarget[G] + with PVLInvocationTarget[G] with LlvmInvocationTarget[G] sealed trait ThisTarget[G] extends Referrable[G] @@ -229,6 +231,9 @@ case class RefJavaBipStatePredicate[G](state: String, decl: JavaAnnotation[G]) e case class RefJavaBipGuard[G](decl: JavaMethod[G]) extends Referrable[G] with JavaNameTarget[G] case class RefJavaBipGlueContainer[G]() extends Referrable[G] // Bip glue jobs are not actually referrable case class RefLlvmFunctionDefinition[G](decl: LlvmFunctionDefinition[G]) extends Referrable[G] with LlvmInvocationTarget[G] with ResultTarget[G] +case class RefLlvmGlobal[G](decl: LlvmGlobal[G]) extends Referrable[G] + +case class RefLlvmSpecFunction[G](decl: LlvmSpecFunction[G]) extends Referrable[G] with SpecInvocationTarget[G] with ResultTarget[G] case class RefSeqProg[G](decl: VeyMontSeqProg[G]) extends Referrable[G] case class RefVeyMontThread[G](decl: VeyMontThread[G]) extends Referrable[G] with PVLNameTarget[G] case class RefProverType[G](decl: ProverType[G]) extends Referrable[G] with SpecTypeNameTarget[G] diff --git a/src/col/vct/col/resolve/lang/LLVM.scala b/src/col/vct/col/resolve/lang/LLVM.scala new file mode 100644 index 0000000000..e450dc9993 --- /dev/null +++ b/src/col/vct/col/resolve/lang/LLVM.scala @@ -0,0 +1,28 @@ +package vct.col.resolve.lang + +import vct.col.ast._ +import vct.col.resolve.NoSuchNameError +import vct.col.resolve.ctx.ReferenceResolutionContext +import vct.col.resolve.ctx._ +object LLVM { + + def findCallable[G](name: String, ctx: ReferenceResolutionContext[G]): Option[LlvmCallable[G]] = { + val callable = ctx.stack.flatten.collectFirst { + case RefLlvmGlobal(decl) if decl.data.nonEmpty => decl.data.get.collectFirst { + case f: LlvmSpecFunction[G] if f.name == name => f + } + } + callable.get match { + case Some(callable) => Some(callable) + case None => ctx.currentResult.get match { + case RefLlvmFunctionDefinition(decl) => + decl.contract.invokableRefs.find(ref => ref._1 == name) match { + case Some(ref) => Some(ref._2.decl) + case None => None + } + } + } + } + + +} diff --git a/src/col/vct/col/typerules/CoercingRewriter.scala b/src/col/vct/col/typerules/CoercingRewriter.scala index a6b0912489..20424f8290 100644 --- a/src/col/vct/col/typerules/CoercingRewriter.scala +++ b/src/col/vct/col/typerules/CoercingRewriter.scala @@ -1724,7 +1724,10 @@ abstract class CoercingRewriter[Pre <: Generation]() extends AbstractRewriter[Pr case definition: LlvmFunctionDefinition[Pre] => definition case typ: ProverType[Pre] => typ case func: ProverFunction[Pre] => func - } + case function: LlvmSpecFunction[Pre] => + new LlvmSpecFunction[Pre](function.name, function.returnType, function.args, function.typeArgs, function.body.map(coerce(_, function.returnType)), function.contract, function.inline, function.threadLocal)(function.blame) + case glob: LlvmGlobal[Pre] => glob + } } def coerce(region: ParRegion[Pre]): ParRegion[Pre] = { diff --git a/src/colhelper/ColDefs.scala b/src/colhelper/ColDefs.scala index 85ce4d714a..34e38fca49 100644 --- a/src/colhelper/ColDefs.scala +++ b/src/colhelper/ColDefs.scala @@ -71,7 +71,8 @@ object ColDefs { "JavaConstructor", "JavaMethod", "CFunctionDefinition", "PVLConstructor", - "LlvmFunctionDefinition" + "LlvmFunctionDefinition", + "LlvmSpecFunction" // Potentially ParBlocks and other execution contexts (lambdas?) should be a scope too. ), "SendDecl" -> Seq("ParBlock", "Loop"), diff --git a/src/main/vct/main/stages/Resolution.scala b/src/main/vct/main/stages/Resolution.scala index 7ab5de1c95..49eb31c29e 100644 --- a/src/main/vct/main/stages/Resolution.scala +++ b/src/main/vct/main/stages/Resolution.scala @@ -2,7 +2,7 @@ package vct.main.stages import com.typesafe.scalalogging.LazyLogging import hre.stages.Stage -import vct.col.ast.{AddrOf, ApplicableContract, CGlobalDeclaration, Expr, LlvmFunctionContract, Program, Refute, Verification, VerificationContext} +import vct.col.ast.{AddrOf, ApplicableContract, CGlobalDeclaration, Expr, GlobalDeclaration, LlvmFunctionContract, LlvmGlobal, Program, Refute, Verification, VerificationContext} import org.antlr.v4.runtime.CharStreams import vct.col.check.CheckError import vct.col.rewrite.lang.{LangSpecificToCol, LangTypesToCol} @@ -86,6 +86,16 @@ case class MyLocalLLVMSpecParser(blameProvider: BlameProvider) extends Resolve.S ColLLVMParser(originProvider, blameProvider) .parseFunctionContract[G](charStream)._1 } + + override def parse[G](input: LlvmGlobal[G], o: Origin): GlobalDeclaration[G] = { + val originProvider = ReadableOriginProvider(input.o match { + case o: LLVMOrigin => StringReadable(input.value, o.fileName) + case _ => StringReadable(input.value) + }) + val charStream = CharStreams.fromString(input.value) + ColLLVMParser(originProvider, blameProvider) + .parseGlobal(charStream)._1 + } } case class Resolution[G <: Generation] diff --git a/src/parsers/vct/parsers/ColLLVMParser.scala b/src/parsers/vct/parsers/ColLLVMParser.scala index c4c41d2f36..c9046e07af 100644 --- a/src/parsers/vct/parsers/ColLLVMParser.scala +++ b/src/parsers/vct/parsers/ColLLVMParser.scala @@ -58,4 +58,21 @@ case class ColLLVMParser(override val originProvider: OriginProvider, override v val contract = LLVMContractToCol[G](originProvider, blameProvider, errors).convert(tree) (contract, errors.map(_._3)) } + + def parseGlobal[G](stream: CharStream): (vct.col.ast.GlobalDeclaration[G], Seq[ExpectedError]) = { + val lexer = new LangLLVMSpecLexer(stream) + val tokens = new CommonTokenStream(lexer) + originProvider.setTokenStream(tokens) + val parser = new LLVMSpecParser(tokens) + // we're parsing a contract so set the parser to specLevel == 1 + parser.specLevel = 1 + + val (errors, tree) = noErrorsOrThrow(parser, lexer, originProvider) { + val errors = expectedErrors(tokens, LangLLVMSpecLexer.EXPECTED_ERROR_CHANNEL, LangLLVMSpecLexer.VAL_EXPECT_ERROR_OPEN, LangLLVMSpecLexer.VAL_EXPECT_ERROR_CLOSE) + val tree = parser.valGlobalDeclaration() + (errors, tree) + } + val global = LLVMContractToCol[G](originProvider, blameProvider, errors).convert(tree) + (global, errors.map(_._3)) + } } \ No newline at end of file diff --git a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala index 04b16e3db0..4d521ac411 100644 --- a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala +++ b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala @@ -1,5 +1,6 @@ package vct.parsers.transform +import hre.data.BitString import org.antlr.v4.runtime.{ParserRuleContext, Token} import vct.antlr4.generated.LLVMSpecParser._ import vct.antlr4.generated.LLVMSpecParserPatterns @@ -10,6 +11,7 @@ import vct.col.ref.{Ref, UnresolvedRef} import vct.col.util.AstBuildHelpers.{ff, foldAnd, implies, tt} import scala.annotation.nowarn +import scala.collection.immutable.{AbstractSeq, LinearSeq} import scala.collection.mutable @nowarn("msg=match may not be exhaustive&msg=Some\\(") @@ -358,4 +360,41 @@ case class LLVMContractToCol[G](override val originProvider: OriginProvider, case ValExpressionList0(expr) => Seq(convert(expr)) case ValExpressionList1(head, _, tail) => convert(head) +: convert(tail) } + + def convert(implicit decl: ValGlobalDeclarationContext): GlobalDeclaration[G] = decl match { + case ValFunction(contract, modifiers, _, t, name, typeArgs, _, args, _, definition) => + val contractCollector = new ContractCollector[G]() + contract.foreach(convert(_, contractCollector)) + + val modifierCollector = new ModifierCollector() + modifiers.foreach(convert(_, modifierCollector)) + + val namedOrigin = SourceNameOrigin(convert(name), origin(decl)) + new LlvmSpecFunction( + convert(name), + convert(t), + args.map(convert(_)).getOrElse(Nil), + Nil, // TODO implement + convert(definition), + contractCollector.consumeApplicableContract(blame(decl)), + modifierCollector.consume(modifierCollector.inline))(blame(decl) + )(namedOrigin) + } + + def convert(implicit definition: ValPureDefContext): Option[Expr[G]] = definition match { + case ValPureAbstractBody(_) => None + case ValPureBody(_, expr, _) => Some(convert(expr)) + } + + def convert(mod: ValModifierContext, collector: ModifierCollector): Unit = mod match { + case ValModifier0(name) => name match { + case "pure" => collector.pure += mod + case "inline" => collector.inline += mod + case "thread_local" => collector.threadLocal += mod + case "bip_annotation" => collector.bipAnnotation += mod + } + case ValStatic(_) => collector.static += mod + } + + } \ No newline at end of file diff --git a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala index 06e11172c9..5b02e1f1c4 100644 --- a/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangLLVMToCol.scala @@ -13,7 +13,8 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends type Post = Rewritten[Pre] implicit val implicitRewriter: AbstractRewriter[Pre, Post] = rw - private val functionMap: SuccessionMap[LlvmFunctionDefinition[Pre], Procedure[Post]] = SuccessionMap() + private val llvmFunctionMap: SuccessionMap[LlvmFunctionDefinition[Pre], Procedure[Post]] = SuccessionMap() + private val specFunctionMap: SuccessionMap[LlvmSpecFunction[Pre], Function[Post]] = SuccessionMap() def rewriteLocal(local: LlvmLocal[Pre]): Expr[Post] = { @@ -33,31 +34,40 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends outArgs = Nil, typeArgs = Nil, body = if (func.pure) Some(GotoEliminator(func.functionBody match { case scope: Scope[Pre] => scope }).eliminate()) else Some(rw.dispatch(func.functionBody)), - //body = Some(GotoEliminator(func.functionBody match { case scope: Scope[Pre] => scope }).eliminate()), contract = rw.dispatch(func.contract.data.get), pure = func.pure )(func.blame) ) } - functionMap.update(func, procedure) + llvmFunctionMap.update(func, procedure) } - def rewriteAmbiguousFunctionInvocation(inv: LlvmAmbiguousFunctionInvocation[Pre]): ProcedureInvocation[Post] = { + def rewriteAmbiguousFunctionInvocation(inv: LlvmAmbiguousFunctionInvocation[Pre]): Invocation[Post] = { implicit val o: Origin = inv.o - new ProcedureInvocation[Post]( - ref = new LazyRef[Post, Procedure[Post]](functionMap(inv.ref.get.decl)), - args = inv.args.map(rw.dispatch), - givenMap = inv.givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, - yields = inv.yields.map { case (e, Ref(v)) => (rw.dispatch(e), rw.succ(v)) }, - outArgs = Seq.empty, - typeArgs = Seq.empty - )(inv.blame) + inv.ref.get.decl match { + case func: LlvmFunctionDefinition[Pre] => new ProcedureInvocation[Post]( + ref = new LazyRef[Post, Procedure[Post]](llvmFunctionMap(func)), + args = inv.args.map(rw.dispatch), + givenMap = inv.givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, + yields = inv.yields.map { case (e, Ref(v)) => (rw.dispatch(e), rw.succ(v)) }, + outArgs = Seq.empty, + typeArgs = Seq.empty + )(inv.blame) + case func: LlvmSpecFunction[Pre] => new FunctionInvocation[Post]( + ref = new LazyRef[Post, Function[Post]](specFunctionMap(func)), + args = inv.args.map(rw.dispatch), + givenMap = inv.givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, + yields = inv.yields.map { case (e, Ref(v)) => (rw.dispatch(e), rw.succ(v)) }, + typeArgs = Seq.empty + )(inv.blame) + } + } def rewriteFunctionInvocation(inv: LlvmFunctionInvocation[Pre]): ProcedureInvocation[Post] = { implicit val o: Origin = inv.o new ProcedureInvocation[Post]( - ref = new LazyRef[Post, Procedure[Post]](functionMap(inv.ref.decl)), + ref = new LazyRef[Post, Procedure[Post]](llvmFunctionMap(inv.ref.decl)), args = inv.args.map(rw.dispatch), givenMap = inv.givenMap.map { case (Ref(v), e) => (rw.succ(v), rw.dispatch(e)) }, yields = inv.yields.map { case (e, Ref(v)) => (rw.dispatch(e), rw.succ(v)) }, @@ -66,8 +76,36 @@ case class LangLLVMToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends )(inv.blame) } + def rewriteGlobal(decl: LlvmGlobal[Pre]): Unit = { + implicit val o: Origin = decl.o + rw.globalDeclarations.declare( + decl.data match { + case Some(data) => data match { + case function: LlvmSpecFunction[Pre] => + val rwFunction = new Function[Post]( + rw.dispatch(function.returnType), + rw.variables.collect { + function.args.foreach(rw.dispatch) + }._1, + rw.variables.collect { + function.typeArgs.foreach(rw.dispatch) + }._1, + function.body match { + case Some(body) => Some(rw.dispatch(body)) + case None => None + }, + rw.dispatch(function.contract), + function.inline, + function.threadLocal + )(function.blame) + specFunctionMap.update(function, rwFunction) + rwFunction + } + }) + } + def result(ref: RefLlvmFunctionDefinition[Pre])(implicit o: Origin): Expr[Post] = - Result[Post](functionMap.ref(ref.decl)) + Result[Post](llvmFunctionMap.ref(ref.decl)) /* Elimination works by replacing every goto with the block its referring too diff --git a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala index e62cf0bcfe..0b58c5f6c4 100644 --- a/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala +++ b/src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala @@ -68,6 +68,7 @@ case class LangSpecificToCol[Pre <: Generation]() extends Rewriter[Pre] with Laz case decl: CGlobalDeclaration[Pre] => c.rewriteGlobalDecl(decl) case decl: CLocalDeclaration[Pre] => ??? case func: LlvmFunctionDefinition[Pre] => llvm.rewriteFunctionDef(func) + case global: LlvmGlobal[Pre] => llvm.rewriteGlobal(global) case cls: Class[Pre] => currentClass.having(cls) { From 9537aea49fabcc6ef6005a30cf713a9a307ffc90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=C3=A9=20van=20Oorschot?= Date: Mon, 3 Jul 2023 15:32:55 +0200 Subject: [PATCH 2/4] Fix blames in Resolve.scala and fix bug in findCallable for LLVM resolver --- src/col/vct/col/ast/lang/LLVMGlobalImpl.scala | 3 + src/col/vct/col/resolve/Resolve.scala | 117 +++++++++--------- src/col/vct/col/resolve/lang/LLVM.scala | 15 ++- 3 files changed, 69 insertions(+), 66 deletions(-) diff --git a/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala b/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala index e8dd7407e6..0a95858f62 100644 --- a/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala +++ b/src/col/vct/col/ast/lang/LLVMGlobalImpl.scala @@ -1,7 +1,10 @@ package vct.col.ast.lang import vct.col.ast.LlvmGlobal +import vct.col.print.{Ctx, Doc, Text} trait LLVMGlobalImpl[G] { this: LlvmGlobal[G] => + override def layout(implicit ctx: Ctx): Doc = Text(value) + } diff --git a/src/col/vct/col/resolve/Resolve.scala b/src/col/vct/col/resolve/Resolve.scala index 9ce0b46647..91ba3b3083 100644 --- a/src/col/vct/col/resolve/Resolve.scala +++ b/src/col/vct/col/resolve/Resolve.scala @@ -39,7 +39,7 @@ case object Resolve { def extractLiteral(e: Expr[_]): Option[String] = e match { case JavaStringValue(guardName, _) => Some(guardName) - case local@JavaLocal(_) => + case local @ JavaLocal(_) => local.ref match { case Some(RefJavaField(decls, id)) => decls.decls(id).init match { @@ -53,7 +53,6 @@ case object Resolve { case class UnexpectedComplicatedExpression(e: Expr[_]) extends UserError { override def code: String = "unexpectedComplicatedExpression" - override def text: String = e.o.messageInContext("This expression must either be a string literal or trivially resolve to one") } @@ -71,10 +70,8 @@ case object Resolve { case object ResolveTypes { sealed trait JavaClassPathEntry - case object JavaClassPathEntry { case object SourcePackageRoot extends JavaClassPathEntry - case class Path(root: java.nio.file.Path) extends JavaClassPathEntry } @@ -91,12 +88,12 @@ case object ResolveTypes { } def scanImport[G](imp: JavaImport[G], ctx: TypeResolutionContext[G]): Seq[Referrable[G]] /* importable? */ = imp match { - case imp@JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName :+ staticMember), /* star = */ false) => + case imp @ JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName :+ staticMember), /* star = */ false) => val staticType = Java.findJavaTypeName(fullyQualifiedTypeName, ctx) .getOrElse(throw NoSuchNameError("class", fullyQualifiedTypeName.mkString("."), imp)) Seq(Java.findStaticMember(staticType, staticMember) .getOrElse(throw NoSuchNameError("static member", (fullyQualifiedTypeName :+ staticMember).mkString("."), imp))) - case imp@JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName), /* star = */ true) => + case imp @ JavaImport(/* static = */ true, JavaName(fullyQualifiedTypeName), /* star = */ true) => val typeName = Java.findJavaTypeName(fullyQualifiedTypeName, ctx) .getOrElse(throw NoSuchNameError("class", fullyQualifiedTypeName.mkString("."), imp)) Java.getStaticMembers(typeName) @@ -106,42 +103,42 @@ case object ResolveTypes { def enterContext[G](node: Node[G], ctx: TypeResolutionContext[G]): TypeResolutionContext[G] = node match { case Program(decls) => - ctx.copy(stack = decls.flatMap(Referrable.from) +: ctx.stack) + ctx.copy(stack=decls.flatMap(Referrable.from) +: ctx.stack) case ns: JavaNamespace[G] => // Static imports need to be imported at this stage, because they influence how names are resolved. // E.g.: in the expression f.g, f is either a 1) variable, 2) parameter or 3) field. If none of those, it must be a // 4) statically imported field or typename, or 5) a non-static imported typename. If it's not that, it's a package name. // ctx.stack needs to be modified for this, and hence this importing is done in enterContext instead of in resolveOne. - val ctxWithNs = ctx.copy(namespace = Some(ns)) - ctxWithNs.copy(stack = (ns.declarations.flatMap(Referrable.from) ++ ns.imports.flatMap(scanImport(_, ctxWithNs))) +: ctx.stack) + val ctxWithNs = ctx.copy(namespace=Some(ns)) + ctxWithNs.copy(stack=(ns.declarations.flatMap(Referrable.from) ++ ns.imports.flatMap(scanImport(_, ctxWithNs))) +: ctx.stack) case Scope(locals, body) => ctx - .copy(stack = ((locals ++ scanScope(body, /* inGPUkernel = */ false)).flatMap(Referrable.from)) +: ctx.stack) + .copy(stack = ((locals ++ scanScope(body, /* inGPUkernel = */false)).flatMap(Referrable.from)) +: ctx.stack) case decl: Declarator[G] => - ctx.copy(stack = decl.declarations.flatMap(Referrable.from) +: ctx.stack) + ctx.copy(stack=decl.declarations.flatMap(Referrable.from) +: ctx.stack) case _ => ctx } def resolveOne[G](node: Node[G], ctx: TypeResolutionContext[G]): Unit = node match { - case javaClass@JavaNamedType(genericNames) => + case javaClass @ JavaNamedType(genericNames) => val names = genericNames.map(_._1) javaClass.ref = Some(Java.findJavaTypeName(names, ctx) .getOrElse(throw NoSuchNameError("class", names.mkString("."), javaClass))) - case t@JavaTClass(ref, _) => + case t @ JavaTClass(ref, _) => ref.tryResolve(name => ???) - case t@CTypedefName(name) => + case t @ CTypedefName(name) => t.ref = Some(C.findCTypeName(name, ctx).getOrElse( throw NoSuchNameError("struct", name, t) )) - case t@PVLNamedType(name, typeArgs) => + case t @ PVLNamedType(name, typeArgs) => t.ref = Some(PVL.findTypeName(name, ctx).getOrElse( throw NoSuchNameError("class", name, t))) - case t@TModel(ref) => + case t @ TModel(ref) => ref.tryResolve(name => Spec.findModel(name, ctx).getOrElse(throw NoSuchNameError("model", name, t))) - case t@TClass(ref) => + case t @ TClass(ref) => ref.tryResolve(name => Spec.findClass(name, ctx).getOrElse(throw NoSuchNameError("class", name, t))) - case t@TAxiomatic(ref, _) => + case t @ TAxiomatic(ref, _) => ref.tryResolve(name => Spec.findAdt(name, ctx).getOrElse(throw NoSuchNameError("adt", name, t))) - case t@SilverPartialTAxiomatic(ref, partialTypeArgs) => + case t @ SilverPartialTAxiomatic(ref, partialTypeArgs) => ref.tryResolve(name => Spec.findAdt(name, ctx).getOrElse(throw NoSuchNameError("adt", name, t))) partialTypeArgs.foreach(mapping => mapping._1.tryResolve(name => Spec.findAdtTypeArg(ref.decl, name).getOrElse(throw NoSuchNameError("type variable", name, t)))) case cls: Class[G] => @@ -163,8 +160,8 @@ case object ResolveTypes { } case deref: JavaDeref[G] => val ref = deref.obj match { - case javalocal: JavaLocal[G] => javalocal.ref - case javaderef: JavaDeref[G] => javaderef.ref + case javalocal : JavaLocal[G] => javalocal.ref + case javaderef : JavaDeref[G] => javaderef.ref case _ => None } ref match { @@ -186,18 +183,18 @@ case object ResolveReferences extends LazyLogging { resolve(program, ReferenceResolutionContext[G](jp, lsp)) } - def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean = false): Seq[CheckError] = { + def resolve[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean=false): Seq[CheckError] = { val inGPU = inGPUKernel || (node match { - case f: CFunctionDefinition[G] => f.specs.collectFirst { case _: CGpgpuKernelSpecifier[G] => () }.isDefined + case f: CFunctionDefinition[G] => f.specs.collectFirst{case _: CGpgpuKernelSpecifier[G] => ()}.isDefined case _ => false }) val childErrors = node match { - case l@Let(binding, value, main) => + case l @ Let(binding, value, main) => val innerCtx = enterContext(node, ctx, inGPU).copy(checkContext = l.enterCheckContext(ctx.checkContext)) resolve(binding, innerCtx) ++ - resolve(value, ctx) ++ - resolve(main, innerCtx) + resolve(value, ctx) ++ + resolve(main, innerCtx) case _ => val innerCtx = enterContext(node, ctx, inGPU) node.checkContextRecursor(ctx.checkContext, { (ctx, node) => @@ -205,7 +202,7 @@ case object ResolveReferences extends LazyLogging { }).flatten } - if (childErrors.nonEmpty) childErrors + if(childErrors.nonEmpty) childErrors else { resolveFlatly(node, ctx) node.check(ctx.checkContext) @@ -215,7 +212,7 @@ case object ResolveReferences extends LazyLogging { def scanScope[G](node: Node[G], inGPUKernel: Boolean): Seq[Declaration[G]] = node match { case _: Scope[G] => Nil // Remove shared memory locations from the body level of a GPU kernel, we want to reason about them at the top level - case CDeclarationStatement(decl) if !(inGPUKernel && decl.decl.specs.collectFirst { case GPULocal() => () }.isDefined) + case CDeclarationStatement(decl) if !(inGPUKernel && decl.decl.specs.collectFirst{case GPULocal() => ()}.isDefined) => Seq(decl) case JavaLocalDeclarationStatement(decl) => Seq(decl) case LocalDecl(v) => Seq(v) @@ -234,7 +231,7 @@ case object ResolveReferences extends LazyLogging { } def scanShared[G](node: Node[G]): Seq[Declaration[G]] = node.transSubnodes.collect { - case decl: CLocalDeclaration[G] if decl.decl.specs.collectFirst { case GPULocal() => () }.isDefined => decl + case decl: CLocalDeclaration[G] if decl.decl.specs.collectFirst{case GPULocal() => ()}.isDefined => decl } def scanJavaBipGuards[G](nodes: Seq[Declaration[G]]): Seq[(Expr[G], JavaMethod[G])] = nodes.collect { @@ -247,7 +244,7 @@ case object ResolveReferences extends LazyLogging { def enterContext[G](node: Node[G], ctx: ReferenceResolutionContext[G], inGPUKernel: Boolean = false): ReferenceResolutionContext[G] = (node match { case ns: JavaNamespace[G] => ctx - .copy(currentJavaNamespace = Some(ns)) + .copy(currentJavaNamespace=Some(ns)) .copy(stack = ns.imports.flatMap(ResolveTypes.scanImport[G](_, ctx.asTypeResolutionContext)) +: ctx.stack) .declare(ns.declarations) case cls: JavaClassOrInterface[G] => { @@ -274,36 +271,36 @@ case object ResolveReferences extends LazyLogging { case deref: JavaDeref[G] => return ctx .copy(topLevelJavaDeref = ctx.topLevelJavaDeref.orElse(Some(deref))) case cls: Class[G] => ctx - .copy(currentThis = Some(RefClass(cls))) + .copy(currentThis=Some(RefClass(cls))) .declare(cls.declarations) case app: ContractApplicable[G] => ctx - .copy(currentResult = Some(Referrable.from(app).head.asInstanceOf[ResultTarget[G]] /* PB TODO: ew */)) + .copy(currentResult=Some(Referrable.from(app).head.asInstanceOf[ResultTarget[G]] /* PB TODO: ew */)) .declare(app.declarations ++ app.body.map(scanLabels).getOrElse(Nil)) case method: JavaMethod[G] => ctx - .copy(currentResult = Some(RefJavaMethod(method))) + .copy(currentResult=Some(RefJavaMethod(method))) .declare(method.declarations ++ method.body.map(scanLabels).getOrElse(Nil)) case fields: JavaFields[G] => ctx - .copy(currentInitializerType = Some(fields.t)) + .copy(currentInitializerType=Some(fields.t)) case locals: JavaLocalDeclaration[G] => ctx - .copy(currentInitializerType = Some(locals.t)) + .copy(currentInitializerType=Some(locals.t)) case decl: JavaVariableDeclaration[G] => ctx - .copy(currentInitializerType = ctx.currentInitializerType.map(t => FuncTools.repeat((t: Type[G]) => TArray(t), decl.moreDims, t))) + .copy(currentInitializerType=ctx.currentInitializerType.map(t => FuncTools.repeat((t: Type[G]) => TArray(t), decl.moreDims, t))) case arr: JavaNewLiteralArray[G] => ctx - .copy(currentInitializerType = Some(FuncTools.repeat((t: Type[G]) => TArray(t), arr.dims, arr.baseType))) + .copy(currentInitializerType=Some(FuncTools.repeat((t: Type[G]) => TArray(t), arr.dims, arr.baseType))) case init: JavaLiteralArray[G] => ctx - .copy(currentInitializerType = Some(ctx.currentInitializerType.get match { + .copy(currentInitializerType=Some(ctx.currentInitializerType.get match { case TArray(elem) => elem case _ => throw WrongArrayInitializer(init) })) case func: CFunctionDefinition[G] => var res = ctx - .copy(currentResult = Some(RefCFunctionDefinition(func))) + .copy(currentResult=Some(RefCFunctionDefinition(func))) .declare(C.paramsFromDeclarator(func.declarator) ++ scanLabels(func.body) ++ func.contract.givenArgs ++ func.contract.yieldsArgs) - if (func.specs.collectFirst { case _: CGpgpuKernelSpecifier[G] => () }.isDefined) + if(func.specs.collectFirst{case _: CGpgpuKernelSpecifier[G] => ()}.isDefined) res = res.declare(scanShared(func.body)) res case func: CGlobalDeclaration[G] => - if (func.decl.contract.nonEmpty && func.decl.inits.size > 1) { + if(func.decl.contract.nonEmpty && func.decl.inits.size > 1) { throw MultipleForwardDeclarationContractError(func) } @@ -314,7 +311,7 @@ case object ResolveReferences extends LazyLogging { val info = C.getDeclaratorInfo(init.decl) ctx .declare(info.params.getOrElse(Nil)) - .copy(currentResult = info.params.map(_ => RefCGlobalDeclaration(func, idx))) + .copy(currentResult=info.params.map(_ => RefCGlobalDeclaration(func, idx))) } case func: LlvmFunctionDefinition[G] => ctx .copy(currentResult = Some(RefLlvmFunctionDefinition(func))) @@ -335,22 +332,20 @@ case object ResolveReferences extends LazyLogging { def resolveFlatly[G](node: Node[G], ctx: ReferenceResolutionContext[G]): Unit = node match { case local@CLocal(name) => local.ref = Some(C.findCName(name, ctx).getOrElse(throw NoSuchNameError("local", name, local))) - case local@JavaLocal(name) => + case local @ JavaLocal(name) => val start: Option[JavaNameTarget[G]] = if (ctx.javaBipGuardsEnabled) { Java.findJavaBipGuard(ctx, name).map(RefJavaBipGuard(_)) - } else { - None - } + } else { None } local.ref = Some(start.orElse( Java.findJavaName(name, ctx.asTypeResolutionContext) .orElse(Java.findJavaTypeName(Seq(name), ctx.asTypeResolutionContext) match { case Some(target: JavaNameTarget[G]) => Some(target) case None => None })) - .getOrElse( - if (ctx.topLevelJavaDeref.isEmpty) throw NoSuchNameError("local", name, local) - else RefUnloadedJavaNamespace(Seq(name)))) - case local@PVLLocal(name) => + .getOrElse( + if (ctx.topLevelJavaDeref.isEmpty) throw NoSuchNameError("local", name, local) + else RefUnloadedJavaNamespace(Seq(name)))) + case local @ PVLLocal(name) => local.ref = Some(PVL.findName(name, ctx).getOrElse(throw NoSuchNameError("local", name, local))) case local@Local(ref) => ref.tryResolve(name => Spec.findLocal(name, ctx).getOrElse(throw NoSuchNameError("local", name, local))) @@ -370,7 +365,7 @@ case object ResolveReferences extends LazyLogging { case Some(RefUnloadedJavaNamespace(_)) => true case _ => false })) throw NoSuchNameError("field", field, deref) - case deref@PVLDeref(obj, field) => + case deref @ PVLDeref(obj, field) => deref.ref = Some(PVL.findDeref(obj, field, ctx, deref.blame).getOrElse(throw NoSuchNameError("field", field, deref))) case deref@Deref(obj, field) => field.tryResolve(name => Spec.findField(obj, name).getOrElse(throw NoSuchNameError("field", name, deref))) @@ -438,11 +433,11 @@ case object ResolveReferences extends LazyLogging { case inv@SilverPartialADTFunctionInvocation(name, args, partialTypeArgs) => inv.ref = Some(Spec.findAdtFunction(name, ctx).getOrElse(throw NoSuchNameError("function", name, inv))) partialTypeArgs.foreach(mapping => mapping._1.tryResolve(name => Spec.findAdtTypeArg(inv.adt, name).getOrElse(throw NoSuchNameError("type variable", name, inv)))) - case inv@InvokeProcedure(ref, _, _, _, givenMap, yields) => + case inv @ InvokeProcedure(ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findProcedure(name, ctx).getOrElse(throw NoSuchNameError("procedure", name, inv))) Spec.resolveGiven(givenMap, RefProcedure(ref.decl), inv) Spec.resolveYields(ctx, yields, RefProcedure(ref.decl), inv) - case inv@ProcedureInvocation(ref, _, _, _, givenMap, yields) => + case inv @ ProcedureInvocation(ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findProcedure(name, ctx).getOrElse(throw NoSuchNameError("procedure", name, inv))) Spec.resolveGiven(givenMap, RefProcedure(ref.decl), inv) Spec.resolveYields(ctx, yields, RefProcedure(ref.decl), inv) @@ -454,11 +449,11 @@ case object ResolveReferences extends LazyLogging { ref.tryResolve(name => Spec.findPredicate(name, ctx).getOrElse(throw NoSuchNameError("predicate", name, inv))) case inv@SilverCurPredPerm(ref, _) => ref.tryResolve(name => Spec.findPredicate(name, ctx).getOrElse(throw NoSuchNameError("predicate", name, inv))) - case inv@InvokeMethod(obj, ref, _, _, _, givenMap, yields) => + case inv @ InvokeMethod(obj, ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findMethod(obj, name).getOrElse(throw NoSuchNameError("method", name, inv))) Spec.resolveGiven(givenMap, RefInstanceMethod(ref.decl), inv) Spec.resolveYields(ctx, yields, RefInstanceMethod(ref.decl), inv) - case inv@MethodInvocation(obj, ref, _, _, _, givenMap, yields) => + case inv @ MethodInvocation(obj, ref, _, _, _, givenMap, yields) => ref.tryResolve(name => Spec.findMethod(obj, name).getOrElse(throw NoSuchNameError("method", name, inv))) Spec.resolveGiven(givenMap, RefInstanceMethod(ref.decl), inv) Spec.resolveYields(ctx, yields, RefInstanceMethod(ref.decl), inv) @@ -466,7 +461,7 @@ case object ResolveReferences extends LazyLogging { ref.tryResolve(name => Spec.findInstanceFunction(obj, name).getOrElse(throw NoSuchNameError("function", name, inv))) case inv@InstancePredicateApply(obj, ref, _, _) => ref.tryResolve(name => Spec.findInstancePredicate(obj, name).getOrElse(throw NoSuchNameError("predicate", name, inv))) - case inv@CoalesceInstancePredicateApply(obj, ref, _, _) => + case inv @ CoalesceInstancePredicateApply(obj, ref, _, _) => ref.tryResolve(name => Spec.findInstancePredicate(obj, name).getOrElse(throw NoSuchNameError("predicate", name, inv))) case defn: CFunctionDefinition[G] => @@ -535,7 +530,7 @@ case object ResolveReferences extends LazyLogging { def extractExpr(s: Option[Expr[_]]): (String, Origin) = s match { case None => ("true", ann.o) - case Some(s@JavaStringValue(data, _)) => (data, s.o) + case Some(s @ JavaStringValue(data, _)) => (data, s.o) case Some(n) => throw MalformedBipAnnotation(n, "pre- and post-conditions must be string literals") } @@ -578,9 +573,9 @@ case object ResolveReferences extends LazyLogging { case ann: JavaAnnotation[G] if isBip(ann, "Port") => val portType: BipPortType[G] = ann.expect("type") match { - case p@JavaDeref(_, "enforceable") => BipEnforceable[G]()(p.o) - case p@JavaDeref(_, "spontaneous") => BipSpontaneous[G]()(p.o) - case p@JavaDeref(_, "internal") => BipInternal[G]()(p.o) + case p @ JavaDeref(_, "enforceable") => BipEnforceable[G]()(p.o) + case p @ JavaDeref(_, "spontaneous") => BipSpontaneous[G]()(p.o) + case p @ JavaDeref(_, "internal") => BipInternal[G]()(p.o) case e => throw MalformedBipAnnotation(e, "Can be either PortType.enforceable, spontaneous, or internal") } ann.data = Some(BipPort[G](getLit(ann.expect("name")), portType)(ann.o)) @@ -588,7 +583,7 @@ case object ResolveReferences extends LazyLogging { case ann: JavaAnnotation[G] if isBip(ann, "Pure") => ann.data = Some(BipPure[G]()) - case portName@JavaBipGlueName(JavaTClass(Ref(cls: JavaClass[G]), Nil), name) => + case portName @ JavaBipGlueName(JavaTClass(Ref(cls: JavaClass[G]), Nil), name) => portName.data = Some((cls, getLit(name))) case contract: LlvmFunctionContract[G] => diff --git a/src/col/vct/col/resolve/lang/LLVM.scala b/src/col/vct/col/resolve/lang/LLVM.scala index e450dc9993..3a1f9298a2 100644 --- a/src/col/vct/col/resolve/lang/LLVM.scala +++ b/src/col/vct/col/resolve/lang/LLVM.scala @@ -4,15 +4,20 @@ import vct.col.ast._ import vct.col.resolve.NoSuchNameError import vct.col.resolve.ctx.ReferenceResolutionContext import vct.col.resolve.ctx._ + object LLVM { def findCallable[G](name: String, ctx: ReferenceResolutionContext[G]): Option[LlvmCallable[G]] = { - val callable = ctx.stack.flatten.collectFirst { - case RefLlvmGlobal(decl) if decl.data.nonEmpty => decl.data.get.collectFirst { - case f: LlvmSpecFunction[G] if f.name == name => f + // look in context + val callable = ctx.stack.flatten.map { + case RefLlvmGlobal(decl) => decl.data.get match { + case f: LlvmSpecFunction[G] if f.name == name => Some(f) + case _ => None } - } - callable.get match { + case _ => None + }.collectFirst { case Some(f) => f } + // if not present in context, might find it in the call site of the current function definition + callable match { case Some(callable) => Some(callable) case None => ctx.currentResult.get match { case RefLlvmFunctionDefinition(decl) => From 372ab12e11c73d8aaa9b4c84b2fa5fb4ff4b7a31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=C3=A9=20van=20Oorschot?= Date: Mon, 3 Jul 2023 17:21:43 +0200 Subject: [PATCH 3/4] Add branch instruction for spec language (acts as select) --- src/parsers/antlr4/LangLLVMSpecLexer.g4 | 3 +++ src/parsers/antlr4/LangLLVMSpecParser.g4 | 3 ++- .../vct/parsers/transform/LLVMContractToCol.scala | 10 +++++++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/parsers/antlr4/LangLLVMSpecLexer.g4 b/src/parsers/antlr4/LangLLVMSpecLexer.g4 index 8cf5e74cd0..1415f46b23 100644 --- a/src/parsers/antlr4/LangLLVMSpecLexer.g4 +++ b/src/parsers/antlr4/LangLLVMSpecLexer.g4 @@ -83,6 +83,9 @@ SDIV: 'sdiv'; ICMP: 'icmp'; CALL: 'call'; +// operators -> termops +BR: 'br'; + // compare predicates EQ_pred: 'eq'; NE_pred: 'ne'; diff --git a/src/parsers/antlr4/LangLLVMSpecParser.g4 b/src/parsers/antlr4/LangLLVMSpecParser.g4 index aff02a4852..878466e33d 100644 --- a/src/parsers/antlr4/LangLLVMSpecParser.g4 +++ b/src/parsers/antlr4/LangLLVMSpecParser.g4 @@ -11,6 +11,7 @@ instruction : binOpInstruction # binOpRule | compareInstruction # cmpOpRule | callInstruction # callOpRule + | branchInstruction #brOpRule ; constant @@ -43,7 +44,7 @@ compareInstruction: compOp Lparen compPred Comma expression Comma expression Rpa callInstruction: CALL Identifier Lparen expressionList Rparen; - +branchInstruction: BR Lparen expression Comma expression Comma expression Rparen; binOp : ADD # add diff --git a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala index 4d521ac411..f4f1c617ab 100644 --- a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala +++ b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala @@ -101,16 +101,20 @@ case class LLVMContractToCol[G](override val originProvider: OriginProvider, case BinOpRule(binOp) => convert(binOp) case CmpOpRule(cmpOp) => convert(cmpOp) case CallOpRule(callOp) => convert(callOp) + case BrOpRule(brOp) => convert(brOp) + } + + def convert(implicit brOp:BranchInstructionContext): Expr[G] = brOp match { + case BranchInstruction0(_, _, testExpr, _, trueExpr, _, falseExpr, _) => + Select(convert(testExpr), convert(trueExpr), convert(falseExpr)) } def convert(implicit callOp: CallInstructionContext): Expr[G] = callOp match { - case CallInstruction0(_, id, _, exprList, _) => { + case CallInstruction0(_, id, _, exprList, _) => val args: Seq[Expr[G]] = convert(exprList) LlvmAmbiguousFunctionInvocation(id, args, Nil, Nil)(blame(callOp)) - } } - def convert(implicit binOp: BinOpInstructionContext): Expr[G] = binOp match { case BinOpInstruction0(op, _, lhs, _, rhs, _) => convert(op, lhs, rhs) } From a7369808e6c791a80a1ba46b6e8feb12119b7fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dr=C3=A9=20van=20Oorschot?= Date: Mon, 3 Jul 2023 22:03:36 +0200 Subject: [PATCH 4/4] Add logical operators --- src/parsers/antlr4/LangLLVMSpecLexer.g4 | 4 ++++ src/parsers/antlr4/LangLLVMSpecParser.g4 | 4 ++++ .../parsers/transform/LLVMContractToCol.scala | 17 +++++++++++++++++ 3 files changed, 25 insertions(+) diff --git a/src/parsers/antlr4/LangLLVMSpecLexer.g4 b/src/parsers/antlr4/LangLLVMSpecLexer.g4 index 1415f46b23..d3ac5c1883 100644 --- a/src/parsers/antlr4/LangLLVMSpecLexer.g4 +++ b/src/parsers/antlr4/LangLLVMSpecLexer.g4 @@ -78,6 +78,10 @@ SUB: 'sub'; MUL: 'mul'; UDIV: 'udiv'; SDIV: 'sdiv'; +// bitwise +AND: 'and'; +OR: 'or'; +XOR: 'xor'; // operators -> other ICMP: 'icmp'; diff --git a/src/parsers/antlr4/LangLLVMSpecParser.g4 b/src/parsers/antlr4/LangLLVMSpecParser.g4 index 878466e33d..26d1249a10 100644 --- a/src/parsers/antlr4/LangLLVMSpecParser.g4 +++ b/src/parsers/antlr4/LangLLVMSpecParser.g4 @@ -5,6 +5,7 @@ expression | constant | identifier | valExpr + | expression valImpOp expression ; instruction @@ -52,6 +53,9 @@ binOp | MUL # mul | UDIV # udiv | SDIV # sdiv + | AND # and + | OR # or + | XOR # xor ; diff --git a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala index f4f1c617ab..955fd67ba8 100644 --- a/src/parsers/vct/parsers/transform/LLVMContractToCol.scala +++ b/src/parsers/vct/parsers/transform/LLVMContractToCol.scala @@ -95,6 +95,10 @@ case class LLVMContractToCol[G](override val originProvider: OriginProvider, case Expression1(constant) => convert(constant) case Expression2(identifier) => convert(identifier) case Expression3(valExpr) => convert(valExpr) + case Expression4(e1, impOp, e2) => impOp match { + case ValImpOp0(_) => ??? + case ValImpOp1(_) => Implies(convert(e1), convert(e2)) + } } def convert(implicit inst: InstructionContext): Expr[G] = inst match { @@ -128,6 +132,19 @@ case class LLVMContractToCol[G](override val originProvider: OriginProvider, case Sub(_) => Minus(left, right) case Mul(_) => Mult(left, right) case Udiv(_) | Sdiv(_) => FloorDiv(left, right)(blame(op)) + // bitwise/boolean + case bitOp => left.t match { + case TBool() => bitOp match { + case LLVMSpecParserPatterns.And(_) => vct.col.ast.And(left, right) + case LLVMSpecParserPatterns.Or(_) => vct.col.ast.Or(left, right) + case Xor(_) => Neq(left, right) + } + case TInt() => bitOp match { + case LLVMSpecParserPatterns.And(_) => BitAnd(left, right) + case LLVMSpecParserPatterns.Or(_) => BitOr(left, right) + case Xor(_) => BitXor(left, right) + } + } } }