Skip to content

Commit

Permalink
Require parens on dot apply (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnynek authored Sep 2, 2023
1 parent ac20cad commit cca3891
Show file tree
Hide file tree
Showing 17 changed files with 90 additions and 90 deletions.
14 changes: 7 additions & 7 deletions core/src/main/resources/bosatsu/predef.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def concat(front: List[a], back: List[a]) -> List[a]:
case _: reverse_concat(reverse(front), back)

def map_List(lst: List[a], fn: a -> b) -> List[b]:
lst.foldLeft([], \t, a -> [fn(a), *t]).reverse
lst.foldLeft([], \t, a -> [fn(a), *t]).reverse()

def flat_map_List(lst: List[a], fn: a -> List[b]) -> List[b]:
lst.foldLeft([], \t, a -> fn(a).reverse_concat(t)).reverse
lst.foldLeft([], \t, a -> fn(a).reverse_concat(t)).reverse()

#############
# Some utilities for dealing with functions
Expand Down Expand Up @@ -274,10 +274,10 @@ def add_item(ord: Order[a], tree: Tree[a], item: a) -> Tree[a]:
case EQ: Branch(s, h, item, left, right)
case LT:
left = loop(left)
branch(s.add(1), item0, left, right).balance
branch(s.add(1), item0, left, right).balance()
case GT:
right = loop(right)
branch(s.add(1), item0, left, right).balance
branch(s.add(1), item0, left, right).balance()

loop(tree)

Expand Down Expand Up @@ -308,13 +308,13 @@ def remove_item(ord: Order[a], tree: Tree[a], item: a) -> Tree[a]:
case Empty: left
case _:
right = loop(right)
branch(size.sub(1), key, left, right).balance
branch(size.sub(1), key, left, right).balance()
case LT:
left = loop(left)
branch(size.sub(1), key, left, right).balance
branch(size.sub(1), key, left, right).balance()
case GT:
right = loop(right)
branch(size.sub(1), key, left, right).balance
branch(size.sub(1), key, left, right).balance()

loop(tree)

Expand Down
18 changes: 8 additions & 10 deletions core/src/main/scala/org/bykn/bosatsu/Declaration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ sealed abstract class Declaration {
(args.head.toDoc + Doc.char('.') + fnDoc, args.tail)
}

body match {
case Nil => prefix
case notEmpty =>
prefix + Doc.char('(') + Doc.intercalate(Doc.text(", "), notEmpty.map(_.toDoc)) + Doc.char(')')
}
prefix + Doc.char('(') + Doc.intercalate(Doc.text(", "), body.map(_.toDoc)) + Doc.char(')')
case ApplyOp(left, Identifier.Operator(opStr), right) =>
left.toDoc space Doc.text(opStr) space right.toDoc
case Binding(b) =>
Expand Down Expand Up @@ -1169,23 +1165,25 @@ object Declaration {
* This is where we parse application, either direct, or dot-style
*/
val applied: P[NonBinding] = {
val params = recNonBind.parensLines1Cut
// here we are using . syntax foo.bar(1, 2)
// we also allow foo.(anyExpression)(1, 2)
val fn = varP.orElse(recNonBind.parensCut)
val slashcontinuation = ((maybeSpace ~ P.char('\\') ~ toEOL1).backtrack ~ Parser.maybeSpacesAndLines).?.void
// 0 or more args
val params0 = recNonBind.parensLines0Cut
val dotApply: P[NonBinding => NonBinding] =
(slashcontinuation.with1 *> P.char('.') *> (fn ~ params.?))
(slashcontinuation.with1 *> P.char('.') *> (fn ~ params0))
.region
.map { case (r2, (fn, argsOpt)) =>
val args = argsOpt.fold(List.empty[NonBinding])(_.toList)
.map { case (r2, (fn, args)) =>

{ (head: NonBinding) => Apply(fn, NonEmptyList(head, args), ApplyKind.Dot)(head.region + r2) }
}

// 1 or more args
val params1 = recNonBind.parensLines1Cut
// here we directly call a function foo(1, 2)
val applySuffix: P[NonBinding => NonBinding] =
params
params1
.region
.map { case (r, args) =>
{ (fn: NonBinding) => Apply(fn, args, ApplyKind.Parens)(fn.region + r) }
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/bykn/bosatsu/MainModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ abstract class MainModule[IO[_]](implicit
val (r, c) = pf.locations.toLineCol(pf.position).get
val ctx = pf.showContext(color)
List(
s"failed to parse $path at line ${r + 1}, column ${c + 1}",
s"failed to parse $path:${r + 1}:${c + 1}",
ctx.render(80)
)
case MainCommand.ParseError.FileError(path, err) =>
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/bykn/bosatsu/PackageError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object PackageError {
val candidates =
nearest(ex.name, candidateMap, 3)
.map { case (n, r) =>
val pos = lm.toLineCol(r.start).map { case (l, c) => s" at line: ${l + 1}, column: ${c + 1}" }.getOrElse("")
val pos = lm.toLineCol(r.start).map { case (l, c) => s":${l + 1}:${c + 1}" }.getOrElse("")
s"${n.asString}$pos"
}
val candstr = candidates.mkString("\n\t", "\n\t", "\n")
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ object Parser {
item.nonEmptyListOfWs(maybeSpacesAndLines)
.parensCut

def parensLines0Cut: P[List[T]] =
parens(nonEmptyListToList(item.nonEmptyListOfWs(maybeSpacesAndLines)))
/**
* either: a, b, c, ..
* or (a, b, c, ) where we allow newlines:
Expand Down
24 changes: 12 additions & 12 deletions core/src/test/scala/org/bykn/bosatsu/EvaluationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Value._
import LocationMap.Colorize
import org.scalatest.funsuite.AnyFunSuite

class EvalationTest extends AnyFunSuite with ParTest {
class EvaluationTest extends AnyFunSuite with ParTest {

import TestUtils._

Expand Down Expand Up @@ -1078,7 +1078,7 @@ package A
e = empty_Dict(string_Order)
e1 = e.clear_Dict.add_key("hello2", "world2")
e1 = e.clear_Dict().add_key("hello2", "world2")
main = e1.get_key("hello")
"""), "A", VOption.none)
Expand All @@ -1099,7 +1099,7 @@ package A
e1 = empty_Dict(string_Order)
e2 = e1.add_key("hello", "world").add_key("hello1", "world1")
lst = e2.items
lst = e2.items()
main = match lst:
case [("hello", "world"), ("hello1", "world1")]: "good"
Expand All @@ -1111,7 +1111,7 @@ package A
e1 = {}
e2 = e1.add_key("hello", "world").add_key("hello1", "world1")
lst = e2.items
lst = e2.items()
main = match lst:
case [("hello", "world"), ("hello1", "world1")]: "good"
Expand All @@ -1125,7 +1125,7 @@ e = {
"hello": "world",
"hello1":
"world1" }
lst = e.items
lst = e.items()
main = match lst:
case [("hello", "world"), ("hello1", "world1")]: "good"
Expand All @@ -1138,7 +1138,7 @@ package A
pairs = [("hello", "world"), ("hello1", "world1")]
e = { k: v for (k, v) in pairs }
lst = e.items
lst = e.items()
main = match lst:
case [("hello", "world"), ("hello1", "world1")]: "good"
Expand All @@ -1156,7 +1156,7 @@ def is_hello(s):
case _: False
e = { k: v for (k, v) in pairs if is_hello(k) }
lst = e.items
lst = e.items()
main = match lst:
case [("hello", res)]: res
Expand Down Expand Up @@ -1655,7 +1655,7 @@ struct RecordSet[shape](
)
def get[shape: (* -> *) -> *, t](sh: shape[RecordValue], RecordGetter(_, getter): RecordGetter[shape, t]) -> t:
RecordValue(result) = sh.getter
RecordValue(result) = sh.getter()
result
def create_field[shape: (* -> *) -> *, t](rf: RecordField[t], fn: shape[RecordValue] -> t):
Expand Down Expand Up @@ -1769,7 +1769,7 @@ rs0 = rs.restructure(\PS(a, PS(b, PS(c, _))) -> ps(c, ps(b, ps(a, ps("Plus 2".in
tests = TestSuite("reordering",
[
Assertion(equal_rows.equal_List(rs0.list_of_rows, [[REBool(RecordValue(False)), REInt(RecordValue(1)), REString(RecordValue("a")), REInt(RecordValue(3))]]), "swap")
Assertion(equal_rows.equal_List(rs0.list_of_rows(), [[REBool(RecordValue(False)), REInt(RecordValue(1)), REString(RecordValue("a")), REInt(RecordValue(3))]]), "swap")
]
)
"""), "RecordSet/Library", 1)
Expand Down Expand Up @@ -2427,7 +2427,7 @@ struct Queue[a](front: List[a], back: List[a])
def fold_Queue(Queue(f, b): Queue[a], binit: b, fold_fn: b -> a -> b) -> b:
front = f.foldLeft(binit, fold_fn)
b.reverse.foldLeft(front, fold_fn)
b.reverse().foldLeft(front, fold_fn)
test = Assertion(Queue([1], [2]).fold_Queue(0, add).eq_Int(3), "foldQueue")
"""), "QueueTest", 1)
Expand Down Expand Up @@ -2458,7 +2458,7 @@ test = Assertion(substitute.eq_Int(42), "basis substitution")
runBosatsuTest(List("""
package A
three = 2.(x -> add(x, 1))
three = 2.(x -> add(x, 1))()
test = Assertion(three.eq_Int(3), "let inside apply")
"""), "A", 1)
Expand Down Expand Up @@ -2709,7 +2709,7 @@ struct RecordGetter[shape, t](
# shape is (* -> *) -> *
def get[shape](sh: shape[RecordValue], RecordGetter(getter): RecordGetter[shape, t]) -> t:
RecordValue(result) = sh.getter
RecordValue(result) = sh.getter()
result
""")) { case PackageError.TypeErrorIn(_, _) => ()
}
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/scala/org/bykn/bosatsu/ParserTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -619,11 +619,11 @@ x""")
Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), AParens))

parseTestAll(parser(""),
"f.x",
"f.x()",
Apply(mkVar("x"), NonEmptyList.of(mkVar("f")), ADot))

parseTestAll(parser(""),
"f(foo).x",
"f(foo).x()",
Apply(mkVar("x"), NonEmptyList.of(Apply(mkVar("f"), NonEmptyList.of(mkVar("foo")), AParens)), ADot))

parseTestAll(parser(""),
Expand Down
12 changes: 6 additions & 6 deletions test_workspace/AvlTree.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def add_item(ord: Order[a], tree: Tree[a], item: a) -> Tree[a]:
EQ: Branch(s, h, item, left, right)
LT:
left = loop(left)
branch(s.add(1), item0, left, right).balance
branch(s.add(1), item0, left, right).balance()
GT:
right = loop(right)
branch(s.add(1), item0, left, right).balance
branch(s.add(1), item0, left, right).balance()

loop(tree)

Expand Down Expand Up @@ -133,13 +133,13 @@ def remove_item(ord: Order[a], tree: Tree[a], item: a) -> Tree[a]:
Empty: left
_:
right = loop(right)
branch(size.sub(1), key, left, right).balance
branch(size.sub(1), key, left, right).balance()
LT:
left = loop(left)
branch(size.sub(1), key, left, right).balance
branch(size.sub(1), key, left, right).balance()
GT:
right = loop(right)
branch(size.sub(1), key, left, right).balance
branch(size.sub(1), key, left, right).balance()

loop(tree)

Expand Down Expand Up @@ -265,7 +265,7 @@ size_tests = (
TestSuite('size tests', [
add_increases_size(Empty, 1, "Empty.add(1)"),
add_increases_size(single_i(1), 2, "single(1).add(2)"),
Assertion(single_i(1).size.eq_i(single_i(1).add_i(1).size), "single(1) + 1 has same size"),
Assertion(single_i(1).size().eq_i(single_i(1).add_i(1).size()), "single(1) + 1 has same size"),
rem_decreases_size(single_i(1), 1, "single(1) - 1"),
rem_decreases_size(single_i(2).add_i(3), 2, "single(2) + 3 - 2"),
])
Expand Down
58 changes: 29 additions & 29 deletions test_workspace/BinNat.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def toInt(b: BinNat) -> Int:
def toNat(b: BinNat) -> Nat:
recur b:
Zero: NatZero
Odd(n): NatSucc(toNat(n).times2_Nat)
Even(n): NatSucc(NatSucc(toNat(n).times2_Nat))
Odd(n): NatSucc(toNat(n).times2_Nat())
Even(n): NatSucc(NatSucc(toNat(n).times2_Nat()))

# Convert a built in integer to a BinNat. <= 0 is converted to 0
def toBinNat(n: Int) -> BinNat:
Expand Down Expand Up @@ -75,16 +75,16 @@ def add_BinNat(left: BinNat, right: BinNat) -> BinNat:
Even(add_BinNat(left, right))
Even(right):
# 2left + 1 + 2(right + 1) = 2((left + right) + 1) + 1
Odd(add_BinNat(left, right.next))
Odd(add_BinNat(left, right.next()))
Even(left) as even:
match right:
Zero: even
Odd(right):
# 2(left + 1) + 2right + 1 = 2((left + right) + 1) + 1
Odd(add_BinNat(left, right.next))
Odd(add_BinNat(left, right.next()))
Even(right):
# 2(left + 1) + 2(right + 1) = 2((left + right + 1) + 1)
Even(add_BinNat(left, right.next))
Even(add_BinNat(left, right.next()))

# multiply by 2
def times2(b: BinNat) -> BinNat:
Expand Down Expand Up @@ -144,27 +144,27 @@ def fib(b: BinNat) -> BinNat:
loop(toNat(b), one, one)

def round_trip_law(i, msg):
Assertion(i.toBinNat.toInt.eq_Int(i), msg)
Assertion(i.toBinNat().toInt().eq_Int(i), msg)

def next_law(i, msg):
Assertion(i.toBinNat.next.toInt.eq_Int(i.add(1)), msg)
Assertion(i.toBinNat().next().toInt().eq_Int(i.add(1)), msg)

def times2_law(i, msg):
Assertion(i.toBinNat.times2.toInt.eq_Int(i.times(2)), msg)
Assertion(i.toBinNat().times2().toInt().eq_Int(i.times(2)), msg)

one = Odd(Zero)
two = one.next
three = two.next
four = three.next
two = one.next()
three = two.next()
four = three.next()


test = TestSuite(
"BinNat tests", [
Assertion(Zero.toInt.eq_Int(0), "0.toBinNat"),
Assertion(Odd(Zero).toInt.eq_Int(1), "1.toBinNat"),
Assertion(Even(Zero).toInt.eq_Int(2), "2.toBinNat"),
Assertion(Odd(Odd(Zero)).toInt.eq_Int(3), "3.toBinNat"),
Assertion(Even(Odd(Zero)).toInt.eq_Int(4), "4.toBinNat"),
Assertion(Zero.toInt().eq_Int(0), "0.toBinNat"),
Assertion(Odd(Zero).toInt().eq_Int(1), "1.toBinNat"),
Assertion(Even(Zero).toInt().eq_Int(2), "2.toBinNat"),
Assertion(Odd(Odd(Zero)).toInt().eq_Int(3), "3.toBinNat"),
Assertion(Even(Odd(Zero)).toInt().eq_Int(4), "4.toBinNat"),
TestSuite("round trip laws", [ round_trip_law(i, m) for (i, m) in [
(0, "roundtrip 0"),
(1, "roundtrip 1"),
Expand All @@ -184,24 +184,24 @@ test = TestSuite(
(10, "10.next"),
(113, "113.next"),
]]),
Assertion(0.toBinNat.next.prev.toInt.eq_Int(0), "0.next.prev == 0"),
Assertion(5.toBinNat.next.prev.toInt.eq_Int(5), "5.next.prev == 5"),
Assertion(10.toBinNat.next.prev.toInt.eq_Int(10), "10.next.prev == 10"),
Assertion(10.toBinNat.add_BinNat(11.toBinNat).toInt.eq_Int(21), "add_BinNat(10, 11) == 21"),
Assertion(0.toBinNat().next().prev().toInt().eq_Int(0), "0.next().prev == 0"),
Assertion(5.toBinNat().next().prev().toInt().eq_Int(5), "5.next().prev == 5"),
Assertion(10.toBinNat().next().prev().toInt().eq_Int(10), "10.next().prev == 10"),
Assertion(10.toBinNat().add_BinNat(11.toBinNat()).toInt().eq_Int(21), "add_BinNat(10, 11) == 21"),
TestSuite("times2 law", [times2_law(i, msg) for (i, msg) in [
(0, "0 * 2"),
(1, "1 * 2"),
(2, "2 * 2"),
(5, "5 * 2"),
(10, "10 * 2"),
]]),
Assertion(10.toBinNat.times_BinNat(11.toBinNat).toInt.eq_Int(110), "10*11 = 110"),
Assertion(0.toBinNat.times_BinNat(11.toBinNat).toInt.eq_Int(0), "0*11 = 0"),
Assertion(fold_left_BinNat(\n, _ -> n.next, Zero, 10.toBinNat).toInt.eq_Int(10), "1 + ... + 1 = 10"),
Assertion(fold_left_BinNat(\n1, n2 -> n1.add_BinNat(n2), Zero, 4.toBinNat).toInt.eq_Int(6), "1+2+3=6"),
Assertion(fib(Zero).toInt.eq_Int(1), "fib(0) == 1"),
Assertion(fib(one).toInt.eq_Int(1), "fib(1) == 1"),
Assertion(fib(two).toInt.eq_Int(2), "fib(2) == 2"),
Assertion(fib(three).toInt.eq_Int(3), "fib(3) == 3"),
Assertion(fib(four).toInt.eq_Int(5), "fib(4) == 5"),
Assertion(10.toBinNat().times_BinNat(11.toBinNat()).toInt().eq_Int(110), "10*11 = 110"),
Assertion(0.toBinNat().times_BinNat(11.toBinNat()).toInt().eq_Int(0), "0*11 = 0"),
Assertion(fold_left_BinNat(\n, _ -> n.next(), Zero, 10.toBinNat()).toInt().eq_Int(10), "1 + ... + 1 = 10"),
Assertion(fold_left_BinNat(\n1, n2 -> n1.add_BinNat(n2), Zero, 4.toBinNat()).toInt().eq_Int(6), "1+2+3=6"),
Assertion(fib(Zero).toInt().eq_Int(1), "fib(0) == 1"),
Assertion(fib(one).toInt().eq_Int(1), "fib(1) == 1"),
Assertion(fib(two).toInt().eq_Int(2), "fib(2) == 2"),
Assertion(fib(three).toInt().eq_Int(3), "fib(3) == 3"),
Assertion(fib(four).toInt().eq_Int(5), "fib(4) == 5"),
])
2 changes: 1 addition & 1 deletion test_workspace/List.bosatsu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ sortTest = (
tests = TestSuite("List tests", [
Assertion([1, 2, 3] =*= [1, 2, 3], "list [1, 2, 3]"),
Assertion(not([1, 2, 3] =*= [1, 2]), "list [1, 2, 3] != [1, 2]"),
Assertion(range(6).sum.eq_Int(15), "range(6).sum == 1 + 2 + 3 + 4 + 5 = 15"),
Assertion(range(6).sum().eq_Int(15), "range(6).sum == 1 + 2 + 3 + 4 + 5 = 15"),
Assertion(range(6).exists(v -> v.eq_Int(5)), "range(6) does have 5"),
Assertion(not(range(6).exists(v -> v.eq_Int(6))), "range(6) does not have 6"),
headTest,
Expand Down
Loading

0 comments on commit cca3891

Please sign in to comment.