Skip to content

Commit

Permalink
continue C code generator
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek committed Nov 8, 2024
1 parent 793e447 commit 1ced6f9
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 4 deletions.
106 changes: 102 additions & 4 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.bykn.bosatsu.codegen.clang

import org.typelevel.paiges.Doc
import cats.data.NonEmptyList
import scala.language.implicitConversions

sealed trait Code

Expand All @@ -25,6 +26,8 @@ object Code {

def cast(expr: Expression): Expression = Cast(this, expr)
}
private val asterisk = Doc.char('*')

object TypeIdent {
sealed trait ComplexType extends TypeIdent
case class StructType(name: String) extends ComplexType
Expand All @@ -34,7 +37,6 @@ object Code {

private val structDoc = Doc.text("struct ")
private val unionDoc = Doc.text("union ")
private val asterisk = Doc.char('*')

def toDoc(te: TypeIdent): Doc =
te match {
Expand All @@ -52,21 +54,46 @@ object Code {
def ret: Statement = Return(Some(this))

def apply(args: Expression*): Apply = Apply(this, args.toList)

def select(i: Ident): Select = Select(this, i)

def deref: Expression = Deref(this)

def bracket(arg: Expression): Expression = Bracket(this, arg)

def addr: Expression = AddrOf(this)
}

case class Ident(name: String) extends Expression
object Ident {
implicit def fromString(str: String): Ident = Ident(str)
}
case class IntLiteral(value: BigInt) extends Expression
case class Cast(tpe: TypeIdent, expr: Expression) extends Expression
case class Apply(fn: Expression, args: List[Expression]) extends Expression
case class Apply(fn: Expression, args: List[Expression]) extends Expression with Statement
case class Select(target: Expression, name: Ident) extends Expression
case class Deref(targets: Expression) extends Expression
case class Bracket(target: Expression, item: Expression) extends Expression
case class AddrOf(targets: Expression) extends Expression

sealed trait Statement extends Code

case class Param(tpe: TypeIdent, name: Ident) {
def toDoc: Doc = TypeIdent.toDoc(tpe) + Doc.space + Doc.text(name.name)
}

sealed trait Statement extends Code
case class Assignment(target: Expression, value: Expression) extends Statement
case class DeclareArray(tpe: TypeIdent, ident: Ident, values: Either[Int, List[Expression]]) extends Statement
case class DeclareVar(attrs: List[Attr], tpe: TypeIdent, ident: Ident, value: Option[Expression]) extends Statement
case class DeclareFn(attrs: List[Attr], returnTpe: TypeIdent, ident: Ident, args: List[Param], value: Option[Block]) extends Statement
case class Typedef(tpe: TypeIdent, name: Ident) extends Statement
case class DefineComplex(tpe: TypeIdent.ComplexType, elements: List[(TypeIdent, Ident)]) extends Statement
case class Return(expr: Option[Expression]) extends Statement
case class Block(items: NonEmptyList[Statement]) extends Statement
case class Block(items: NonEmptyList[Statement]) extends Statement {
def doWhile(cond: Expression): Statement = DoWhile(this, cond)
}
case class IfElse(ifs: NonEmptyList[(Expression, Block)], elseCond: Option[Block]) extends Statement
case class DoWhile(block: Block, whileCond: Expression) extends Statement

val returnVoid: Statement = Return(None)

Expand All @@ -80,12 +107,19 @@ object Code {
private val rightCurly = Doc.char('}')
private val leftPar = Doc.char('(')
private val rightPar = Doc.char(')')
private val leftBracket = Doc.char('[')
private val rightBracket = Doc.char(']')
private val dot = Doc.char('.')
private val returnSemi = Doc.text("return;")
private val returnSpace = Doc.text("return ")
private val ifDoc = Doc.text("if ")
private val elseIfDoc = Doc.text("else if ")
private val elseDoc = Doc.text("else ")
private val commaLine = Doc.char(',') + Doc.line
private val doDoc = Doc.text("do ")
private val whileDoc = Doc.text("while")
private val arrow = Doc.text("->")
private val ampDoc = Doc.char('&')

private def par(d: Doc): Doc = leftPar + d + rightPar

Expand All @@ -103,6 +137,7 @@ object Code {
def toDoc(c: Code): Doc =
c match {
case Ident(n) => Doc.text(n)
case IntLiteral(bi) => Doc.str(bi)
case Cast(tpe, expr) =>
val edoc = expr match {
case Ident(n) => Doc.text(n)
Expand All @@ -115,7 +150,50 @@ object Code {
case notIdent => par(toDoc(notIdent))
}
fnDoc + par(Doc.intercalate(commaLine, args.map(expr => toDoc(expr))).grouped.nested(4))
case Deref(expr) =>
asterisk + toDoc(expr)
case Select(target, Ident(nm)) =>
target match {
case Deref(noPar @ (Ident(_) | Apply(_, _) | Select(_, _))) => toDoc(noPar) + arrow + Doc.text(nm)
case Deref(notIdent) => par(toDoc(notIdent)) + arrow + Doc.text(nm)
case noPar @ (Ident(_) | Apply(_, _) | Select(_, _)) => toDoc(noPar) + dot + Doc.text(nm)
case notIdent => par(toDoc(notIdent)) + dot + Doc.text(nm)
}
case Bracket(target, item) =>
val left = target match {
case noPar @ (Ident(_) | Apply(_, _) | Select(_, _)) =>
toDoc(noPar)
case yesPar => par(toDoc(yesPar))
}
left + leftBracket + toDoc(item) + rightBracket
case AddrOf(expr) =>
val e = expr match {
case noPar @ (Ident(_) | Select(_, _)) =>
toDoc(noPar)
case yesPar => par(toDoc(yesPar))
}
ampDoc + e
// Statements
case Assignment(t, v) => toDoc(t) + (equalsDoc + (toDoc(v) + semiDoc))
case DeclareArray(tpe, nm, values) =>
// Foo bar[size] = {v(0), v(1), ...};
// or
// Foo bar[size];
val tpeName = TypeIdent.toDoc(tpe) + Doc.space + toDoc(nm)
values match {
case Right(init) =>
val len = init.size
val begin = tpeName + leftBracket + Doc.str(len) + rightBracket + equalsDoc + leftCurly;
val items =
if (init.isEmpty) Doc.empty
else {
((Doc.line + Doc.intercalate(commaLine, init.map(e => toDoc(e)))).nested(4) + Doc.line).grouped
}

begin + items + rightCurly + semiDoc
case Left(len) =>
tpeName + leftBracket + Doc.str(len) + rightBracket + semiDoc
}
case DeclareVar(attrs, tpe, ident, v) =>
val attrDoc =
if (attrs.isEmpty) Doc.empty
Expand All @@ -132,6 +210,24 @@ object Code {
case Some(rhs) => prefix + equalsDoc + toDoc(rhs) + semiDoc
case None => prefix + semiDoc
}
case DeclareFn(attrs, tpe, ident, args, v) =>
val attrDoc =
if (attrs.isEmpty) Doc.empty
else {
Doc.intercalate(Doc.space, attrs.map(a => Attr.toDoc(a))) + Doc.space
}

val paramDoc = Doc.intercalate(Doc.line, args.map(_.toDoc)).nested(4).grouped

val prefix = Doc.intercalate(Doc.space,
(attrDoc + TypeIdent.toDoc(tpe)) ::
(toDoc(ident) + par(paramDoc)) ::
Nil)

v match {
case Some(rhs) => prefix + Doc.space + toDoc(rhs)
case None => prefix + semiDoc
}
case Typedef(td, n) =>
typeDefDoc + TypeIdent.toDoc(td) + Doc.space + toDoc(n) + semiDoc
case DefineComplex(tpe, els) =>
Expand Down Expand Up @@ -164,5 +260,7 @@ object Code {
}

first + middle + end
case DoWhile(block, cond) =>
doDoc + toDoc(block) + Doc.space + whileDoc + par(toDoc(cond)) + semiDoc
}
}
27 changes: 27 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/codegen/clang/CodeTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ class CodeTest extends munit.FunSuite {
checkCode(DeclareVar(Nil, TypeIdent.Named("Foo").ptr, Ident("bar"), None), "Foo* bar;")
checkCode(DeclareVar(Attr.Static :: Nil, TypeIdent.Named("Foo").ptr, Ident("bar"), None), "static Foo* bar;")

checkCode(DeclareFn(Nil, TypeIdent.Named("Foo").ptr, Ident("bar"), Nil, None), "Foo* bar();")
checkCode(DeclareFn(Nil, TypeIdent.Named("Foo").ptr, Ident("bar"),
List(Param(TypeIdent.Named("bar"), Ident("baz"))), Some(block(Ident("baz").ret))), "Foo* bar(bar baz) {\n" +
" return baz;\n" +
"}")

checkCode(TypeIdent.Named("Foo").typedefAs(Ident("Baz")), "typedef Foo Baz;")
checkCode(TypeIdent.Named("Foo").cast(Ident("Baz")), "(Foo)Baz")

Expand Down Expand Up @@ -60,5 +66,26 @@ class CodeTest extends munit.FunSuite {
checkCode(Ident("foo")(Ident("bar"), Ident("baz")), "foo(bar, baz)")
checkCode(Ident("foo")(), "foo()")
checkCode(TypeIdent.Named("Foo").cast(Ident("foo"))(), "((Foo)foo)()")

checkCode(block(Ident("foo")()).doWhile(Ident("bar")), "do {\n" +
" foo()\n" +
"} while(bar);")

checkCode(Ident("foo").deref, "*foo")
checkCode(Ident("foo").select(Ident("bar")), "foo.bar")
checkCode(Ident("foo").select(Ident("bar")).select("baz"), "foo.bar.baz")
checkCode(Ident("foo").deref.select(Ident("bar")), "foo->bar")
checkCode(Ident("foo")().select(Ident("bar")), "foo().bar")
checkCode(Ident("foo")().deref.select(Ident("bar")), "foo()->bar")
checkCode(Ident("foo").bracket(Ident("bar")), "foo[bar]")
checkCode(Ident("foo").bracket(IntLiteral(BigInt(42))), "foo[42]")
checkCode(DeclareArray(TypeIdent.Named("Foo"), "bar", Right(List(Ident("baz"), IntLiteral(BigInt(42))))),
"Foo bar[2] = { baz, 42 };"
)
checkCode(DeclareArray(TypeIdent.Named("Foo"), "bar", Left(42)),
"Foo bar[42];"
)
checkCode(Ident("foo").addr, "&foo")
checkCode(Ident("foo").select("bar").addr, "&foo.bar")
}
}

0 comments on commit 1ced6f9

Please sign in to comment.