Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VCLLVM] Add support for pure functions #1049

Merged
merged 5 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/col/vct/col/ast/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1144,36 +1144,35 @@ final case class BipInternal[G]()(implicit val o: Origin = DiagnosticOrigin) ext
final case class BipPortSynchronization[G](ports: Seq[Ref[G, BipPort[G]]], wires: Seq[BipGlueDataWire[G]])(val blame: Blame[BipSynchronizationFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with BipPortSynchronizationImpl[G]
final case class BipTransitionSynchronization[G](transitions: Seq[Ref[G, BipTransition[G]]], wires: Seq[BipGlueDataWire[G]])(val blame: Blame[BipSynchronizationFailure])(implicit val o: Origin) extends GlobalDeclaration[G] with BipTransitionSynchronizationImpl[G]

final class LlvmFunctionContract[G](val value:String, val variableRefs:Seq[(String, Ref[G, Variable[G]])], val invokableRefs:Seq[(String, Ref[G, LlvmFunctionDefinition[G]])])
final class LlvmFunctionContract[G](val value:String, val variableRefs:Seq[(String, Ref[G, Variable[G]])], val invokableRefs:Seq[(String, Ref[G, LlvmCallable[G]])])
(val blame: Blame[NontrivialUnsatisfiable])
(implicit val o: Origin) extends NodeFamily[G] with LLVMFunctionContractImpl[G] {
var data: Option[ApplicableContract[G]] = None
}

sealed trait LlvmCallable[G] extends GlobalDeclaration[G]
final class LlvmFunctionDefinition[G](val returnType: Type[G],
val args: Seq[Variable[G]],
val functionBody: Statement[G],
val contract: LlvmFunctionContract[G],
val pure: Boolean = false)
(val blame: Blame[CallableFailure])(implicit val o: Origin)
extends GlobalDeclaration[G] with Applicable[G] with LLVMFunctionDefinitionImpl[G]

extends LlvmCallable[G] with Applicable[G] with LLVMFunctionDefinitionImpl[G]
final class LlvmSpecFunction[G](val name: String, val returnType: Type[G], val args: Seq[Variable[G]], val typeArgs: Seq[Variable[G]],
val body: Option[Expr[G]], val contract: ApplicableContract[G], val inline: Boolean = false, val threadLocal: Boolean = false)
(val blame: Blame[ContractedFailure])(implicit val o: Origin)
extends LlvmCallable[G] with AbstractFunction[G] with LLVMSpecFunctionImpl[G]
final case class LlvmFunctionInvocation[G](ref: Ref[G, LlvmFunctionDefinition[G]],
args: Seq[Expr[G]],
givenMap: Seq[(Ref[G, Variable[G]], Expr[G])],
yields: Seq[(Expr[G], Ref[G, Variable[G]])])
(val blame: Blame[InvocationFailure])(implicit val o: Origin) extends Apply[G] with LLVMFunctionInvocationImpl[G]

final case class LlvmLoop[G](cond:Expr[G], contract:LlvmLoopContract[G], body:Statement[G])
(implicit val o: Origin) extends CompositeStatement[G] with LLVMLoopImpl[G]

sealed trait LlvmLoopContract[G] extends NodeFamily[G] with LLVMLoopContractImpl[G]
final case class LlvmLoopInvariant[G](value:String, references:Seq[(String, Ref[G, Declaration[G]])])
(val blame: Blame[LoopInvariantFailure])
(implicit val o: Origin) extends LlvmLoopContract[G] with LLVMLoopInvariantImpl[G]

sealed trait LlvmExpr[G] extends Expr[G] with LLVMExprImpl[G]

final case class LlvmLocal[G](name: String)(val blame: Blame[DerefInsufficientPermission])(implicit val o: Origin) extends LlvmExpr[G] with LLVMLocalImpl[G] {
var ref: Option[Ref[G, Variable[G]]] = None
}
Expand All @@ -1182,7 +1181,11 @@ final case class LlvmAmbiguousFunctionInvocation[G](name: String,
givenMap: Seq[(Ref[G, Variable[G]], Expr[G])],
yields: Seq[(Expr[G], Ref[G, Variable[G]])])
(val blame: Blame[InvocationFailure])(implicit val o: Origin) extends LlvmExpr[G] with LLVMAmbiguousFunctionInvocationImpl[G] {
var ref: Option[Ref[G, LlvmFunctionDefinition[G]]] = None
var ref: Option[Ref[G, LlvmCallable[G]]] = None
}

final class LlvmGlobal[G](val value: String)(implicit val o: Origin) extends GlobalDeclaration[G] with LLVMGlobalImpl[G] {
var data: Option[GlobalDeclaration[G]] = None
}
sealed trait PVLType[G] extends Type[G] with PVLTypeImpl[G]
final case class PVLNamedType[G](name: String, typeArgs: Seq[Type[G]])(implicit val o: Origin = DiagnosticOrigin) extends PVLType[G] with PVLNamedTypeImpl[G] {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package vct.col.ast.lang

import vct.col.ast.{LlvmAmbiguousFunctionInvocation, Type}
import vct.col.ast.{LlvmAmbiguousFunctionInvocation, LlvmFunctionDefinition, LlvmSpecFunction, Type}
import vct.col.print.{Ctx, Doc, DocUtil, Group, Precedence, Text}

trait LLVMAmbiguousFunctionInvocationImpl[G] { this: LlvmAmbiguousFunctionInvocation[G] =>
override lazy val t: Type[G] = ref match {
case Some(ref) => ref.decl.returnType
override lazy val t: Type[G] = ref.get.decl match {
case func: LlvmFunctionDefinition[G] => func.returnType
case func: LlvmSpecFunction[G] => func.returnType
}

override def precedence: Int = Precedence.POSTFIX
Expand Down
10 changes: 10 additions & 0 deletions src/col/vct/col/ast/lang/LLVMGlobalImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package vct.col.ast.lang

import vct.col.ast.LlvmGlobal
import vct.col.print.{Ctx, Doc, Text}

trait LLVMGlobalImpl[G] { this: LlvmGlobal[G] =>

override def layout(implicit ctx: Ctx): Doc = Text(value)

}
31 changes: 31 additions & 0 deletions src/col/vct/col/ast/lang/LLVMSpecFunctionImpl.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package vct.col.ast.lang

import vct.col.ast.LlvmSpecFunction
import vct.col.ast.declaration.category.AbstractFunctionImpl
import vct.col.ast.declaration.global.GlobalDeclarationImpl
import vct.col.print.{Ctx, Doc, Empty, Group, Show, Text}

import scala.collection.immutable.ListMap

trait LLVMSpecFunctionImpl[G] extends GlobalDeclarationImpl[G] with AbstractFunctionImpl[G] {
this: LlvmSpecFunction[G] =>

def layoutModifiers(implicit ctx: Ctx): Seq[Doc] = ListMap(
inline -> "inline",
threadLocal -> "thread_local",
).filter(_._1).values.map(Text).map(Doc.inlineSpec).toSeq

def layoutSpec(implicit ctx: Ctx): Doc =
Doc.stack(Seq(
contract,
Group(
Group(Doc.rspread(layoutModifiers) <> "pure" <+> returnType <+> ctx.name(this) <>
(if (typeArgs.nonEmpty) Text("<") <> Doc.args(typeArgs.map(ctx.name).map(Text)) <> ">" else Empty) <>
"(" <> Doc.args(args) <> ")") <>
body.map(Text(" =") <>> _ <> ";").getOrElse(Text(";"))
),
))

override def layout(implicit ctx: Ctx): Doc = Doc.spec(Show.lazily(layoutSpec(_)))

}
24 changes: 16 additions & 8 deletions src/col/vct/col/resolve/Resolve.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import vct.col.origin._
import vct.col.resolve.ResolveReferences.scanScope
import vct.col.ref.Ref
import vct.col.resolve.ctx._
import vct.col.resolve.lang.{C, CPP, Java, PVL, Spec}
import vct.col.resolve.lang.{C, CPP, Java, LLVM, PVL, Spec}
import vct.col.resolve.Resolve.{MalformedBipAnnotation, SpecContractParser, SpecExprParser, getLit, isBip}
import vct.col.resolve.lang.JavaAnnotationData.{BipComponent, BipData, BipGuard, BipInvariant, BipPort, BipPure, BipStatePredicate, BipTransition}
import vct.col.rewrite.InitialGeneration
Expand All @@ -31,7 +31,9 @@ case object Resolve {
}

trait SpecContractParser {
def parse[G](input: LlvmFunctionContract[G], o:Origin): ApplicableContract[G]
def parse[G](input: LlvmFunctionContract[G], o: Origin): ApplicableContract[G]

def parse[G](input: LlvmGlobal[G], o: Origin): GlobalDeclaration[G]
}

def extractLiteral(e: Expr[_]): Option[String] = e match {
Expand Down Expand Up @@ -335,6 +337,9 @@ case object ResolveReferences extends LazyLogging {
}
case func: LlvmFunctionDefinition[G] => ctx
.copy(currentResult = Some(RefLlvmFunctionDefinition(func)))
case func: LlvmSpecFunction[G] => ctx
.copy(currentResult = Some(RefLlvmSpecFunction(func)))
.declare(func.args)
case par: ParStatement[G] => ctx
.declare(scanBlocks(par.impl).map(_.decl))
case Scope(locals, body) => ctx
Expand Down Expand Up @@ -625,15 +630,18 @@ case object ResolveReferences extends LazyLogging {
case Some(ref) => Some(ref._2)
case None => throw NoSuchNameError("local", local.name, local)
}
case RefLlvmSpecFunction(_) =>
Some(Spec.findLocal(local.name, ctx).getOrElse(throw NoSuchNameError("local", local.name, local)).ref)
}
case inv: LlvmAmbiguousFunctionInvocation[G] =>
inv.ref = ctx.currentResult.get match {
case RefLlvmFunctionDefinition(decl) =>
decl.contract.invokableRefs.find(ref => ref._1 == inv.name) match {
case Some(ref) => Some(ref._2)
case None => throw NoSuchNameError("function", inv.name, inv)
}
inv.ref = LLVM.findCallable(inv.name, ctx) match {
case Some(callable) => Some(callable.ref)
case None => throw NoSuchNameError("function", inv.name, inv)
}
case glob: LlvmGlobal[G] =>
val decl = ctx.llvmSpecParser.parse(glob, glob.o)
glob.data = Some(decl)
resolve(decl, ctx)
case _ =>
}
}
12 changes: 8 additions & 4 deletions src/col/vct/col/resolve/ctx/Referrable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ case object Referrable {
case decl: VeyMontThread[G] => RefVeyMontThread(decl)
case decl: JavaBipGlueContainer[G] => RefJavaBipGlueContainer()
case decl: LlvmFunctionDefinition[G] => RefLlvmFunctionDefinition(decl)
case decl: LlvmGlobal[G] => RefLlvmGlobal(decl)
case decl: LlvmSpecFunction[G] => RefLlvmSpecFunction(decl)
case decl: ProverType[G] => RefProverType(decl)
case decl: ProverFunction[G] => RefProverFunction(decl)
})
Expand Down Expand Up @@ -185,11 +187,10 @@ sealed trait LlvmInvocationTarget[G] extends Referrable[G]
sealed trait SpecInvocationTarget[G]
extends JavaInvocationTarget[G]
with CNameTarget[G]
with CDerefTarget[G]
with CInvocationTarget[G]
with CPPNameTarget[G]
with CPPInvocationTarget[G]
with CDerefTarget[G] with CInvocationTarget[G]
with CPPNameTarget[G] with CPPInvocationTarget[G]
with PVLInvocationTarget[G]
with LlvmInvocationTarget[G]

sealed trait ThisTarget[G] extends Referrable[G]

Expand Down Expand Up @@ -253,6 +254,9 @@ case class RefJavaBipStatePredicate[G](state: String, decl: JavaAnnotation[G]) e
case class RefJavaBipGuard[G](decl: JavaMethod[G]) extends Referrable[G] with JavaNameTarget[G]
case class RefJavaBipGlueContainer[G]() extends Referrable[G] // Bip glue jobs are not actually referrable
case class RefLlvmFunctionDefinition[G](decl: LlvmFunctionDefinition[G]) extends Referrable[G] with LlvmInvocationTarget[G] with ResultTarget[G]
case class RefLlvmGlobal[G](decl: LlvmGlobal[G]) extends Referrable[G]

case class RefLlvmSpecFunction[G](decl: LlvmSpecFunction[G]) extends Referrable[G] with SpecInvocationTarget[G] with ResultTarget[G]
case class RefSeqProg[G](decl: VeyMontSeqProg[G]) extends Referrable[G]
case class RefVeyMontThread[G](decl: VeyMontThread[G]) extends Referrable[G] with PVLNameTarget[G]
case class RefProverType[G](decl: ProverType[G]) extends Referrable[G] with SpecTypeNameTarget[G]
Expand Down
33 changes: 33 additions & 0 deletions src/col/vct/col/resolve/lang/LLVM.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package vct.col.resolve.lang

import vct.col.ast._
import vct.col.resolve.NoSuchNameError
import vct.col.resolve.ctx.ReferenceResolutionContext
import vct.col.resolve.ctx._

object LLVM {

def findCallable[G](name: String, ctx: ReferenceResolutionContext[G]): Option[LlvmCallable[G]] = {
// look in context
val callable = ctx.stack.flatten.map {
case RefLlvmGlobal(decl) => decl.data.get match {
case f: LlvmSpecFunction[G] if f.name == name => Some(f)
case _ => None
}
case _ => None
}.collectFirst { case Some(f) => f }
// if not present in context, might find it in the call site of the current function definition
callable match {
case Some(callable) => Some(callable)
case None => ctx.currentResult.get match {
case RefLlvmFunctionDefinition(decl) =>
decl.contract.invokableRefs.find(ref => ref._1 == name) match {
case Some(ref) => Some(ref._2.decl)
case None => None
}
}
}
}


}
5 changes: 4 additions & 1 deletion src/col/vct/col/typerules/CoercingRewriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,10 @@ abstract class CoercingRewriter[Pre <: Generation]() extends AbstractRewriter[Pr
case definition: LlvmFunctionDefinition[Pre] => definition
case typ: ProverType[Pre] => typ
case func: ProverFunction[Pre] => func
}
case function: LlvmSpecFunction[Pre] =>
new LlvmSpecFunction[Pre](function.name, function.returnType, function.args, function.typeArgs, function.body.map(coerce(_, function.returnType)), function.contract, function.inline, function.threadLocal)(function.blame)
case glob: LlvmGlobal[Pre] => glob
}
}

def coerce(region: ParRegion[Pre]): ParRegion[Pre] = {
Expand Down
3 changes: 2 additions & 1 deletion src/colhelper/ColDefs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ object ColDefs {
"CFunctionDefinition",
"CPPFunctionDefinition",
"PVLConstructor",
"LlvmFunctionDefinition"
"LlvmFunctionDefinition",
"LlvmSpecFunction"
// Potentially ParBlocks and other execution contexts (lambdas?) should be a scope too.
),
"SendDecl" -> Seq("ParBlock", "Loop"),
Expand Down
11 changes: 11 additions & 0 deletions src/main/vct/main/stages/Resolution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package vct.main.stages

import com.typesafe.scalalogging.LazyLogging
import hre.stages.Stage
import vct.col.ast.{AddrOf, ApplicableContract, CGlobalDeclaration, Expr, GlobalDeclaration, LlvmFunctionContract, LlvmGlobal, Program, Refute, Verification, VerificationContext}
import org.antlr.v4.runtime.CharStreams
import vct.col.ast._
import vct.col.check.CheckError
Expand Down Expand Up @@ -79,6 +80,16 @@ case class MyLocalLLVMSpecParser(blameProvider: BlameProvider) extends Resolve.S
ColLLVMParser(originProvider, blameProvider)
.parseFunctionContract[G](charStream)._1
}

override def parse[G](input: LlvmGlobal[G], o: Origin): GlobalDeclaration[G] = {
val originProvider = ReadableOriginProvider(input.o match {
case o: LLVMOrigin => StringReadable(input.value, o.fileName)
case _ => StringReadable(input.value)
})
val charStream = CharStreams.fromString(input.value)
ColLLVMParser(originProvider, blameProvider)
.parseGlobal(charStream)._1
}
}

case class Resolution[G <: Generation]
Expand Down
7 changes: 7 additions & 0 deletions src/parsers/antlr4/LangLLVMSpecLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,18 @@ SUB: 'sub';
MUL: 'mul';
UDIV: 'udiv';
SDIV: 'sdiv';
// bitwise
AND: 'and';
OR: 'or';
XOR: 'xor';

// operators -> other
ICMP: 'icmp';
CALL: 'call';

// operators -> termops
BR: 'br';

// compare predicates
EQ_pred: 'eq';
NE_pred: 'ne';
Expand Down
7 changes: 6 additions & 1 deletion src/parsers/antlr4/LangLLVMSpecParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ expression
| constant
| identifier
| valExpr
| <assoc=right> expression valImpOp expression
;

instruction
: binOpInstruction # binOpRule
| compareInstruction # cmpOpRule
| callInstruction # callOpRule
| branchInstruction #brOpRule
;

constant
Expand Down Expand Up @@ -43,14 +45,17 @@ compareInstruction: compOp Lparen compPred Comma expression Comma expression Rpa

callInstruction: CALL Identifier Lparen expressionList Rparen;


branchInstruction: BR Lparen expression Comma expression Comma expression Rparen;

binOp
: ADD # add
| SUB # sub
| MUL # mul
| UDIV # udiv
| SDIV # sdiv
| AND # and
| OR # or
| XOR # xor
;


Expand Down
17 changes: 17 additions & 0 deletions src/parsers/vct/parsers/ColLLVMParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,21 @@ case class ColLLVMParser(override val originProvider: OriginProvider, override v
val contract = LLVMContractToCol[G](originProvider, blameProvider, errors).convert(tree)
(contract, errors.map(_._3))
}

def parseGlobal[G](stream: CharStream): (vct.col.ast.GlobalDeclaration[G], Seq[ExpectedError]) = {
val lexer = new LangLLVMSpecLexer(stream)
val tokens = new CommonTokenStream(lexer)
originProvider.setTokenStream(tokens)
val parser = new LLVMSpecParser(tokens)
// we're parsing a contract so set the parser to specLevel == 1
parser.specLevel = 1

val (errors, tree) = noErrorsOrThrow(parser, lexer, originProvider) {
val errors = expectedErrors(tokens, LangLLVMSpecLexer.EXPECTED_ERROR_CHANNEL, LangLLVMSpecLexer.VAL_EXPECT_ERROR_OPEN, LangLLVMSpecLexer.VAL_EXPECT_ERROR_CLOSE)
val tree = parser.valGlobalDeclaration()
(errors, tree)
}
val global = LLVMContractToCol[G](originProvider, blameProvider, errors).convert(tree)
(global, errors.map(_._3))
}
}
Loading