Skip to content

Commit

Permalink
Better Struct support
Browse files Browse the repository at this point in the history
  • Loading branch information
sakehl committed Nov 2, 2023
1 parent 10d4e7e commit e368794
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 21 deletions.
127 changes: 122 additions & 5 deletions examples/concepts/c/structs.c
Original file line number Diff line number Diff line change
@@ -1,20 +1,48 @@
#include <assert.h>
#include <stddef.h>

struct point {
int x;
int y;
};

struct rect{
struct point p1, p2, p3;
};

struct polygon{
struct point* ps;
};

struct linked_list{
struct linked_list *p1;
int x;
};

/*@
context p != NULL ** Perm(p, write);
context Perm(p->x, write);
context Perm(p->y, write);
ensures p->x == 0;
ensures p->y == 0;
ensures \old(*p) == *p;
@*/
void alter_struct(struct point *p){
p->x = 0;
p->y = 0;
}

/*@
context p != NULL ** Perm(p, write) ** Perm(*p, write);
ensures p->x == \old(p->x + 1);
ensures p->y == \old(p->y + 1);
ensures \old(*p) == *p;
@*/
void alter_struct_1(struct point *p){
p->x = p->x+1;
p->y = p->y+1;
}

/*@
context Perm(p.x, 1\1);
context Perm(p.y, 1\1);
Expand All @@ -24,23 +52,112 @@ void alter_copy_struct(struct point p){
p.y = 0;
}

/*@
context Perm(p, 1\1);
@*/
void alter_copy_struct_2(struct point p){
p.x = 0;
p.y = 0;
}

/*@
context r != NULL ** Perm(r, 1\2) ** Perm(*r, 1\2);
ensures \result == (r->p1.x + r->p2.x + r->p3.x)/3;
@*/
int avr_x(struct rect *r){
return (r->p1.x + r->p2.x + r->p3.x)/3;
}
/*@
requires n >= 0;
requires inp != NULL && \pointer_length(inp) >= n;
requires (\forall* int i; 0 <= i && i < n; Perm(&inp[i], 1\10));
requires (\forall int i, int j; 0<=i && i<n && 0<=j && j<n; i != j ==> inp[i] != inp[j]);
requires (\forall* int i; 0 <= i && i < n; Perm(inp[i].x, 1\10));
ensures |\result| == n;
ensures (\forall int i; 0 <= i && i < n; \result[i] == inp[i].x);
ensures n>0 ==> \result == inp_to_seq(inp, n-1) + [inp[n-1].x];
pure seq<int> inp_to_seq(struct point *inp, int n) = n == 0 ? [t: int ] : inp_to_seq(inp, n-1) + [inp[n-1].x];
decreases |xs|;
ensures |xs| == 0 ==> \result == 0;
ensures |xs| > 0 ==> \result == sum_seq(xs[.. (|xs|-1)]) + xs[ |xs|-1 ];
pure int sum_seq(seq<int> xs) = |xs| == 0 ? 0 : sum_seq(xs[.. (|xs|-1)]) + xs[ |xs|-1 ];
@*/


/*@
requires len > 0;
context p != NULL ** Perm(p, 1\2) ** Perm(*p, 1\2);
context p->ps != NULL && \pointer_length(p->ps) >= len;
context (\forall* int i; 0<=i && i<len; Perm(&p->ps[i], 1\2));
context (\forall int i, int j; 0<=i && i<len && 0<=j && j<len; i != j ==> p->ps[i] != p->ps[j]);
context (\forall* int i; 0<=i && i<len; Perm(p->ps[i], 1\2));
// No clue why, but it hangs if we try for bigger numbers
ensures len == 3 ==> \result == sum_seq(inp_to_seq(p->ps, len))/len;
@*/
int avr_x_pol(struct polygon *p, int len){
int sum = 0;
//@ ghost seq<int> xs = inp_to_seq(p->ps, len);
/*@
loop_invariant 0<=i && i<=len;
loop_invariant p != NULL ** Perm(p, 1\2) ** Perm(*p, 1\2);
loop_invariant p->ps != NULL && \pointer_length(p->ps) >= len;
loop_invariant (\forall* int i; 0<=i && i<len; Perm(&p->ps[i], 1\2));
loop_invariant (\forall int i, int j; 0<=i && i<len && 0<=j && j<len; i != j ==> p->ps[i] != p->ps[j]);
loop_invariant (\forall* int i; 0<=i && i<len; Perm(p->ps[i], 1\2));
loop_invariant (\forall int i; 0<=i && i<len; p->ps[i].x == xs[i]);
loop_invariant sum == sum_seq(xs[..i]);
@*/
for(int i=0; i<len; i++){
sum += p->ps[i].x;
//@ assert xs[.. i+1][.. i] == xs[.. i];
}

//@ assert xs[..len] == xs;
return sum/len;
}


int main(){
struct point p;
// struct point pp[1];
struct point *pp;
pp = &p;

// assert (pp[0] != NULL );
//@ assert (pp[0] != NULL );
assert (pp != NULL );

p.x = 1;
p.y = 2;

// assert(p->x == 1);
// assert(p->y == 1);
assert(pp->x == 1);
assert(pp->y == 2);
alter_copy_struct(p);
assert(p.x == 1);
assert(p.y == 2);

alter_struct(pp);
// assert(p.x == 0);
assert(pp->x == 0);
assert(p.x == 0);
alter_struct_1(pp); //alter_struct_1(&p) is not supported yet
assert(p.x == 1 && p.y == 1);

struct point p1, p2, p3;
p1.x = 1; p1.y = 1;
p2.x = 2; p1.y = 2;
p3.x = 3; p1.y = 3;
struct rect r, *rr;
rr = &r;
r.p1 = p1;
r.p2 = p2;
r.p3 = p3;
assert(avr_x(rr) == 2);
struct point ps[3] = {p1, p2, p3};
struct polygon pol, *ppols;
ppols = &pol;
pol.ps = ps;
int avr_pol = avr_x_pol(ppols, 3);
//@ assert sum_seq(inp_to_seq(ppols->ps, 3)) == 6;
assert(avr_pol == 2);

return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@ trait ProverFunctionInvocationImpl[G] { this: ProverFunctionInvocation[G] =>
override def precedence: Int = Precedence.POSTFIX
override def layout(implicit ctx: Ctx): Doc =
Group(
Group(
Text(ctx.name(ref)) <>
"("
) <> Doc.args(args) <> ")"
"(" <> Doc.args(args) <> ")"
)

}
2 changes: 2 additions & 0 deletions src/col/vct/col/ast/lang/smt/SmtlibPowImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import vct.col.ast.{SmtlibPow, TBool, Type}
import vct.col.print._

trait SmtlibPowImpl[G] { this: SmtlibPow[G] =>
override def precedence: Int = Precedence.PVL_POW

override def t: Type[G] = left.t
override def layout(implicit ctx: Ctx): Doc = rassoc(left, "^", right)
}
54 changes: 41 additions & 13 deletions src/rewrite/vct/rewrite/lang/LangCToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import vct.col.ast.`type`.TFloats.ieee754_32bit
import vct.col.ast.util.ExpressionEqualityCheck.isConstantInt
import vct.col.rewrite.lang.LangSpecificToCol.NotAValue
import vct.col.origin.{AbstractApplicable, ArraySizeError, AssignLocalOk, Blame, CallableFailure, FrontendInvocationError, InterpretedOriginVariable, KernelBarrierInconsistent, KernelBarrierInvariantBroken, KernelBarrierNotEstablished, KernelPostconditionFailed, KernelPredicateNotInjective, Origin, PanicBlame, ParBarrierFailure, ParBarrierInconsistent, ParBarrierInvariantBroken, ParBarrierMayNotThrow, ParBarrierNotEstablished, ParBlockContractFailure, ParBlockFailure, ParBlockMayNotThrow, ParBlockPostconditionFailed, ParPreconditionFailed, ParPredicateNotInjective, PointerInsufficientPermission, ReceiverNotInjective, TrueSatisfiable, VerificationFailure}
import vct.col.ref.Ref
import vct.col.ref.{LazyRef, Ref}
import vct.col.resolve.lang.C
import vct.col.resolve.ctx.{BuiltinField, BuiltinInstanceMethod, CNameTarget, CStructTarget, RefADTFunction, RefAxiomaticDataType, RefCFunctionDefinition, RefCGlobalDeclaration, RefCLocalDeclaration, RefCParam, RefCStruct, RefCStructField, RefCudaBlockDim, RefCudaBlockIdx, RefCudaGridDim, RefCudaThreadIdx, RefCudaVec, RefCudaVecDim, RefCudaVecX, RefCudaVecY, RefCudaVecZ, RefFunction, RefInstanceFunction, RefInstanceMethod, RefInstancePredicate, RefModelAction, RefModelField, RefModelProcess, RefPredicate, RefProcedure, RefProverFunction, RefVariable, SpecInvocationTarget}
import vct.col.resolve.lang.C.nameFromDeclarator
Expand Down Expand Up @@ -145,9 +145,13 @@ case object LangCToCol {

case class UnsupportedMalloc(c: Expr[_]) extends UserError {
override def code: String = "unsupportedMalloc"

override def text: String = c.o.messageInContext("Only 'malloc' of the format '(t *) malloc(x*typeof(t)' is supported.")
}

case class UnsupportedStructPerm(o: Origin) extends UserError {
override def code: String = "unsupportedStructPerm"
override def text: String = o.messageInContext("Shorthand for Permissions for structs not possible, since the struct has a cyclic reference")
}
}

case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends LazyLogging {
Expand Down Expand Up @@ -1078,6 +1082,32 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz
}
}

// Allow a user to write `Perm(p, write)` instead of `Perm(p.x, write) ** Perm(p.y, write)` for `struct p {int x, y}`
def unwrapStructPerm(struct: AmbiguousLocation[Post], perm: Expr[Pre], structType: CTStruct[Pre], origin: Origin, visited: Seq[CTStruct[Pre]]= Seq()): Expr[Post] = {
if(visited.contains(structType)) throw UnsupportedStructPerm(origin) // We do not allow this notation for recursive structs
implicit val o: Origin = origin
val blame = PanicBlame("???")
val Seq(CStructDeclaration(_, fields)) = structType.ref.decl.decl.specs
val newPerm = rw.dispatch(perm)
val AmbiguousLocation(newExpr) = struct
val newFieldPerms = fields.map(member => {
val loc = AmbiguousLocation(
Deref[Post](
newExpr,
cStructFieldsSuccessor.ref((structType.ref.decl, member))
)(blame)
)(struct.blame)
member.specs.collectFirst {
case CSpecificationType(newStruct: CTStruct[Pre]) =>
// We recurse, since a field is another struct
Perm(loc, newPerm) &* unwrapStructPerm(loc, perm, newStruct, origin, structType +: visited)
}.getOrElse(Perm(loc, newPerm))
})

foldStar(newFieldPerms)
}


def createStructCopy(a: Expr[Post], target: CGlobalDeclaration[Pre]): Expr[Post] = {
implicit val o: Origin = a.o
val targetClass: Class[Post] = cStructSuccessor(target)
Expand Down Expand Up @@ -1345,13 +1375,6 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz
val RefCGlobalDeclaration(decls, initIdx) = e
implicit val o: Origin = inv.o

val arg = if(args.size == 1){
args.head match {
case IntegerValue(i) if i >= 0 && i < 3 => Some(i.toInt)
case _ => None
}
} else None

(e.name, args, givenMap, yields) match {
case (_, _, g, y) if g.nonEmpty || y.nonEmpty =>
case("free", Seq(xs), _, _) if isCPointer(xs.t) =>
Expand All @@ -1362,6 +1385,13 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz
case _ => ()
}

val arg = if (args.size == 1) {
args.head match {
case IntegerValue(i) if i >= 0 && i < 3 => Some(i.toInt)
case _ => None
}
} else None

(e.name, arg) match {
case ("get_local_id", Some(i)) => getCudaLocalThread(i, o)
case ("get_group_id", Some(i)) => getCudaGroupThread(i, o)
Expand All @@ -1385,9 +1415,7 @@ case class LangCToCol[Pre <: Generation](rw: LangSpecificToCol[Pre]) extends Laz
}

def structType(t: CTStruct[Pre]): Type[Post] = {
val target = t.ref.decl
implicit val o: Origin = t.o
val targetClass: Class[Post] = cStructSuccessor(target)
TClass[Post](targetClass.ref)
val targetClass = new LazyRef[Post, Class[Post]](cStructSuccessor(t.ref.decl))
TClass[Post](targetClass)(t.o)
}
}
6 changes: 6 additions & 0 deletions src/rewrite/vct/rewrite/lang/LangSpecificToCol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ case class LangSpecificToCol[Pre <: Generation]() extends Rewriter[Pre] with Laz
case global: GlobalThreadId[Pre] => c.cudaGlobalThreadId(global)
case cast: CCast[Pre] => c.cast(cast)

case Perm(a@AmbiguousLocation(expr), perm)
if c.getBaseType(expr.t).isInstanceOf[CTStruct[Pre]] =>
c.getBaseType(expr.t) match {
case structType: CTStruct[Pre] => c.unwrapStructPerm(dispatch(a).asInstanceOf[AmbiguousLocation[Post]], perm, structType, e.o)
}

case local: CPPLocal[Pre] => cpp.local(local)
case inv: CPPInvocation[Pre] => cpp.invocation(inv)

Expand Down

0 comments on commit e368794

Please sign in to comment.