Skip to content

Commit

Permalink
Add C operators codegen (#1249)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Nov 9, 2024
1 parent 48ae426 commit b08bc29
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 23 deletions.
153 changes: 132 additions & 21 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,73 @@ object Code {

def select(i: Ident): Select = Select(this, i)

def deref: Expression = Deref(this)
def deref: Expression = PrefixExpr(PrefixUnary.Deref, this)

def bracket(arg: Expression): Expression = Bracket(this, arg)

def addr: Expression = AddrOf(this)
def addr: Expression = PrefixExpr(PrefixUnary.Addr, this)

def stmt: Statement = Effect(this)

def bin(op: BinOp, rhs: Expression): Expression =
BinExpr(this, op, rhs)

def +(that: Expression): Expression = bin(BinOp.Add, that)
def -(that: Expression): Expression = bin(BinOp.Sub, that)
def *(that: Expression): Expression = bin(BinOp.Mult, that)
def /(that: Expression): Expression = bin(BinOp.Div, that)

def postInc: Expression = PostfixExpr(this, PostfixUnary.Inc)
def postDec: Expression = PostfixExpr(this, PostfixUnary.Dec)
}

sealed abstract class BinOp(repr: String) {
val toDoc: Doc = Doc.text(repr)
}
object BinOp {
case object Add extends BinOp("+")
case object Sub extends BinOp("-")
case object Mult extends BinOp("*")
case object Div extends BinOp("/")
case object Mod extends BinOp("%")

case object Eq extends BinOp("==")
case object NotEq extends BinOp("!=")
case object Lt extends BinOp("<")
case object LtEq extends BinOp("<=")
case object Gt extends BinOp(">")
case object GtEq extends BinOp(">=")

case object Or extends BinOp("||")
case object And extends BinOp("&&")

case object BitOr extends BinOp("|")
case object BitAnd extends BinOp("&")
case object BitXor extends BinOp("^")
case object LeftShift extends BinOp("<<")
case object RightShift extends BinOp(">>")
}

sealed abstract class PrefixUnary(val repr: String) {
val toDoc: Doc = Doc.text(repr)
}
object PrefixUnary {
// + - ! ~ ++ -- (type)* & sizeof
case object Neg extends PrefixUnary("-")
case object Not extends PrefixUnary("!")
case object BitNot extends PrefixUnary("~")
case object Inc extends PrefixUnary("++")
case object Dec extends PrefixUnary("--")
case object Addr extends PrefixUnary("&")
case object Deref extends PrefixUnary("*")
}

sealed abstract class PostfixUnary(val repr: String) {
val toDoc: Doc = Doc.text(repr)
}
object PostfixUnary {
case object Inc extends PostfixUnary("++")
case object Dec extends PostfixUnary("--")
}

case class Ident(name: String) extends Expression
Expand All @@ -70,12 +132,13 @@ object Code {
}
case class IntLiteral(value: BigInt) extends Expression
case class Cast(tpe: TypeIdent, expr: Expression) extends Expression
case class Apply(fn: Expression, args: List[Expression]) extends Expression with Statement
case class Apply(fn: Expression, args: List[Expression]) extends Expression
case class Select(target: Expression, name: Ident) extends Expression
case class Deref(targets: Expression) extends Expression
case class BinExpr(left: Expression, op: BinOp, right: Expression) extends Expression
case class PrefixExpr(op: PrefixUnary, target: Expression) extends Expression
case class PostfixExpr(target: Expression, op: PostfixUnary) extends Expression
case class Bracket(target: Expression, item: Expression) extends Expression
case class AddrOf(targets: Expression) extends Expression

case class Ternary(cond: Expression, whenTrue: Expression, whenFalse: Expression) extends Expression

case class Param(tpe: TypeIdent, name: Ident) {
def toDoc: Doc = TypeIdent.toDoc(tpe) + Doc.space + Doc.text(name.name)
Expand All @@ -94,6 +157,9 @@ object Code {
}
case class IfElse(ifs: NonEmptyList[(Expression, Block)], elseCond: Option[Block]) extends Statement
case class DoWhile(block: Block, whileCond: Expression) extends Statement
case class Effect(expr: Expression) extends Statement
case class While(cond: Expression, body: Block) extends Statement
case class Include(quote: Boolean, filename: String) extends Statement

val returnVoid: Statement = Return(None)

Expand All @@ -119,7 +185,12 @@ object Code {
private val doDoc = Doc.text("do ")
private val whileDoc = Doc.text("while")
private val arrow = Doc.text("->")
private val ampDoc = Doc.char('&')
private val questionDoc = Doc.text(" ? ")
private val colonDoc = Doc.text(" : ")
private val quoteDoc = Doc.char('"')
private def leftAngleDoc = BinOp.Lt.toDoc
private def rightAngleDoc = BinOp.Gt.toDoc
private val includeDoc = Doc.text("#include")

private def par(d: Doc): Doc = leftPar + d + rightPar

Expand All @@ -134,6 +205,15 @@ object Code {
leftCurly + inner + Doc.line + rightCurly
}

object Tight {
// These are the highest priority, so safe to not use a parens
def unapply(e: Expression): Option[Expression] =
e match {
case noPar @ (Ident(_) | Apply(_, _) | Select(_, _) | Bracket(_, _)) => Some(noPar)
case _ => None
}
}

def toDoc(c: Code): Doc =
c match {
case Ident(n) => Doc.text(n)
Expand All @@ -150,29 +230,48 @@ object Code {
case notIdent => par(toDoc(notIdent))
}
fnDoc + par(Doc.intercalate(commaLine, args.map(expr => toDoc(expr))).grouped.nested(4))
case Deref(expr) =>
asterisk + toDoc(expr)
case PostfixExpr(expr, op) =>
val left = expr match {
case Ident(n) => Doc.text(n)
case notIdent => par(toDoc(notIdent))
}
left + op.toDoc
case PrefixExpr(op, expr) =>
val right = expr match {
case Tight(n) => toDoc(n)
case usePar => par(toDoc(usePar))
}
op.toDoc + right
case BinExpr(left, op, right) =>
val leftD = left match {
case Tight(n) => toDoc(n)
case usePar => par(toDoc(usePar))
}
val rightD = right match {
case Tight(n) => toDoc(n)
case usePar => par(toDoc(usePar))
}
leftD + Doc.space + op.toDoc + Doc.space + rightD
case Select(target, Ident(nm)) =>
target match {
case Deref(noPar @ (Ident(_) | Apply(_, _) | Select(_, _))) => toDoc(noPar) + arrow + Doc.text(nm)
case Deref(notIdent) => par(toDoc(notIdent)) + arrow + Doc.text(nm)
case noPar @ (Ident(_) | Apply(_, _) | Select(_, _)) => toDoc(noPar) + dot + Doc.text(nm)
case PrefixExpr(PrefixUnary.Deref, Tight(noPar)) => toDoc(noPar) + arrow + Doc.text(nm)
case PrefixExpr(PrefixUnary.Deref, notIdent) => par(toDoc(notIdent)) + arrow + Doc.text(nm)
case Tight(noPar) => toDoc(noPar) + dot + Doc.text(nm)
case notIdent => par(toDoc(notIdent)) + dot + Doc.text(nm)
}
case Bracket(target, item) =>
val left = target match {
case noPar @ (Ident(_) | Apply(_, _) | Select(_, _)) =>
toDoc(noPar)
case Tight(noPar) => toDoc(noPar)
case yesPar => par(toDoc(yesPar))
}
left + leftBracket + toDoc(item) + rightBracket
case AddrOf(expr) =>
val e = expr match {
case noPar @ (Ident(_) | Select(_, _)) =>
toDoc(noPar)
case yesPar => par(toDoc(yesPar))
}
ampDoc + e
case Ternary(cond, t, f) =>
def d(e: Expression): Doc =
e match {
case noPar @ (Tight(_) | PrefixExpr(_, _) | BinExpr(_, _, _)) => toDoc(noPar)
case yesPar => par(toDoc(yesPar))
}
d(cond) + questionDoc + d(t) + colonDoc + d(f)
// Statements
case Assignment(t, v) => toDoc(t) + (equalsDoc + (toDoc(v) + semiDoc))
case DeclareArray(tpe, nm, values) =>
Expand Down Expand Up @@ -262,5 +361,17 @@ object Code {
first + middle + end
case DoWhile(block, cond) =>
doDoc + toDoc(block) + Doc.space + whileDoc + par(toDoc(cond)) + semiDoc
case Effect(expr) =>
toDoc(expr) + semiDoc
case While(expr, block) =>
whileDoc + Doc.space + par(toDoc(expr)) + Doc.space + toDoc(block)
case Include(useQuote, file) =>
val inc = if (useQuote) {
quoteDoc + Doc.text(file) + quoteDoc
}
else {
leftAngleDoc + Doc.text(file) + rightAngleDoc
}
includeDoc + Doc.space + inc
}
}
22 changes: 20 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/codegen/clang/CodeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class CodeTest extends munit.FunSuite {
checkCode(Ident("foo")(), "foo()")
checkCode(TypeIdent.Named("Foo").cast(Ident("foo"))(), "((Foo)foo)()")

checkCode(block(Ident("foo")()).doWhile(Ident("bar")), "do {\n" +
" foo()\n" +
checkCode(block(Ident("foo")().stmt).doWhile(Ident("bar")), "do {\n" +
" foo();\n" +
"} while(bar);")

checkCode(Ident("foo").deref, "*foo")
Expand All @@ -87,5 +87,23 @@ class CodeTest extends munit.FunSuite {
)
checkCode(Ident("foo").addr, "&foo")
checkCode(Ident("foo").select("bar").addr, "&foo.bar")
checkCode(Ternary(Ident("foo"), Ident("bar"), Ident("baz")), "foo ? bar : baz")
checkCode(Ternary(Ident("foo").select("q"), Ident("bar"), Ident("baz")), "foo.q ? bar : baz")
checkCode(Ternary(Ident("foo").deref, Ident("bar"), Ident("baz")), "*foo ? bar : baz")

checkCode(While(Ident("foo"), block(Ident("bar").stmt)),
"while (foo) {\n" +
" bar;\n" +
"}")

checkCode(Ident("i").postInc, "i++")
checkCode(Ident("i").postDec, "i--")
checkCode(Ident("i") + Ident("j"), "i + j")
checkCode(Ident("i") * Ident("j"), "i * j")
checkCode(Ident("i") - Ident("j"), "i - j")
checkCode(Ident("i") / Ident("j"), "i / j")

checkCode(Include(true, "foo"), "#include \"foo\"")
checkCode(Include(false, "foo"), "#include <foo>")
}
}

0 comments on commit b08bc29

Please sign in to comment.