Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TODO] Nim now supports true lambdas; eg allows map(a~>a*localVar) (wo limitations of sugar.nim =>) #8679

Closed
wants to merge 7 commits into from
35 changes: 35 additions & 0 deletions lib/pure/collections/sequtils.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
include "system/inclrtl"

import macros
import pure/lambda

when not defined(nimhygiene):
{.pragma: dirty.}
Expand Down Expand Up @@ -638,6 +639,34 @@ template foldr*(sequence, operation: untyped): untyped =
result = operation
result

# TODO: make public after merging with `map`
template map2(s: typed, lambda: untyped): untyped =
## like ``mapIt`` but with cleaner syntax: ``[1,2].mapIt(a ~> a*10)``
# runnableExamples:
# doAssert [1,2].map2(a~>a*10) == @[10,20]
makeLambda(lambda, lambda2)
type outType = type((
block:
var it: type(items(s))
lambda2(it)
))

when compiles(s.len):
var result: seq[outType]
block:
evalOnceAs(s2, s, compiles((let _ = s)))
var i = 0
var result = newSeq[outType](s2.len)
for it in s2:
result[i] = lambda2(it)
i += 1
result
else:
var result: seq[outType] = @[]
for it in s:
result.add lambda2(it)
result

template mapIt*(s, typ, op: untyped): untyped =
## Convenience template around the ``map`` proc to reduce typing.
##
Expand Down Expand Up @@ -1137,3 +1166,9 @@ when isMainModule:

when not defined(testing):
echo "Finished doc tests"

block map2Test:
# PENDING https://github.com/nim-lang/Nim/issues/7280
discard [1].map2(a ~> a)
# once map2 is public, remove this (will be covered by runnableExamples)
doAssert [1,2].map2(a ~> a*10) == @[10,20]
70 changes: 70 additions & 0 deletions lib/pure/lambda.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import macros

# TODO: make this public in subsequent PR
macro varargsToTuple(args:varargs[untyped]):untyped=
result = newTree(nnkTupleConstr)
if args.len != 1:
for a in args:
result.add a
else:
# PENDING https://github.com/nim-lang/Nim/issues/8706
if args[0].kind != nnkHiddenStdConv:
result.add args[0]

macro lambdaEval*(lambda: untyped, arg:tuple):untyped=
## allows using zero-cost lambda expressions ``a~>b`` in templates. Type
## inference is done at evaluation site, so user doesn't need to specify
## types in the lambda expression.
runnableExamples:
template testCallFun[T](fun: untyped, a:T): auto =
lambdaEval(fun, lambdaEval(fun, a))
doAssert testCallFun(x ~> x * 3, 2) == (2 * 3) * 3

if $lambda[0] != "~>":
error("Expected `~>` got " & $lambda[0])

var ret = newStmtList()
expectKind(arg, nnkTupleConstr)
case lambda[1].kind
of nnkPar: # (a, b) ~> expr
if lambda[1].len != arg.len:
error("size mismatch: lambda:" & $lambda[1].len & " arg:" & $arg.len)
for i in 0..<lambda[1].len:
ret.add newLetStmt(lambda[1][i], arg[i])
of nnkIdent: # a ~> expr
if arg.len != 1:
error("size mismatch: " & $arg.len)
ret.add newLetStmt(lambda[1], arg[0])
else:
error("expected " & ${nnkPar,nnkIdent} & " got " & $lambda[1].kind)

ret.add lambda[2]
result = newBlockStmt(ret)

macro makeLambda*(lambdaFun: untyped, lambdaAlias:untyped): untyped =
## convenience macro allowing one to use ``lambda(arg)`` instead of
## ``lambdaEval(fun, arg)``
runnableExamples:
block:
template testCallFun[T](fun: untyped, a:T): auto =
makeLambda(fun, lambda)
lambda(lambda(a))
doAssert testCallFun(x ~> x * 3, 2) == (2 * 3) * 3

block:
template testCallFun2[T](fun: untyped, a:T, b:T): auto =
makeLambda(fun, lambda)
lambda(a, b)
doAssert testCallFun2((u,v) ~> u*v, 10, 11) == 10 * 11

expectKind(lambdaAlias, nnkIdent)
result = quote do:
template `lambdaAlias`(args:varargs[untyped]): untyped =
lambdaEval(`lambdaFun`, varargsToTuple(args))

when isMainModule:
# PENDING https://github.com/nim-lang/Nim/issues/7280
block lambdaEvalTest:
discard lambdaEval(a ~> a, (0,))
block makeLambdaTest:
makeLambda(a ~> a, _)
100 changes: 100 additions & 0 deletions tests/stdlib/tlambda.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
discard """
"""

import macros
import lambda

import typetraits

block: # 0-param lambda
template testLambda(fun: untyped): auto =
makeLambda(fun, lambda)
# doAssert: not compiles(lambda(10))
lambda()

doAssert testLambda(() ~> 100) == 100

block: # 1-param lambda
template testLambda[T](fun: untyped, a:T): auto =
makeLambda(fun, lambda)
lambda(a)

doAssert testLambda(a ~> a*10, 2) == 20

block: # 2-param lambda
template testLambda[T](fun: untyped, a:T, b:T): auto =
makeLambda(fun, lambda)
lambda(a,b)

doAssert testLambda((u1,u2) ~> u1*u2, 2, 3) == 2 * 3

block: # 3-param lambda
template testLambda[T](fun: untyped, a:T, b:T, c:T): auto =
makeLambda(fun, lambda)
when false:
# BUG: SIGSEGV: Illegal storage access
doAssert: not compiles(lambda(a,b))
# but `lambda(a,b)` correctly gives compile error
lambda(a,b,c)

doAssert testLambda((u1,u2,u3) ~> u1*u2*u3, 2, 3, 4) == 2 * 3 * 4

block: # multiple lambda application
template testLambda[T](fun: untyped, a:T): auto =
makeLambda(fun, lambda)
lambda(lambda(a))
doAssert testLambda(x ~> x * 3, 2) == (2 * 3) * 3

block: # lambda with local param
template testLambda[T](fun: untyped, a:T): auto =
makeLambda(fun, lambda)
lambda a
let x = 10
doAssert testLambda(u ~> u * x, 11) == 11 * x

block: # nested lambda
template testLambda1[T](fun: untyped, a:T): auto =
makeLambda(fun, lambda)
lambda(a)
template testLambda2[T](fun: untyped, a:T): auto =
makeLambda(fun, lambda)
lambda a
doAssert testLambda1(u ~> u + testLambda2(v ~> v*3, u), 100) == 100 + 100*3

block: # multiple lambdas
template testLambda[T](fun1: untyped, fun2: untyped, a:T, b:T): auto =
makeLambda(fun1, lambda1)
makeLambda(fun2, lambda2)
(lambda1(a,b), lambda2(a,b))

doAssert testLambda((u1,u2) ~> u1*u2, (u1,u2) ~> u1+u2, 2, 3) == (2 * 3, 2 + 3)

block:
template map2(s: typed, lambda: untyped): untyped =
## like ``mapIt`` but with cleaner syntax: ``[1,2].mapIt(a ~> a*10)``
makeLambda(lambda, lambda2)
type outType = type((
block:
var it: type(items(s))
lambda2(it)
))

when compiles(s.len):
block:
# Note: a more robust implementation would use `evalOnceAs`, see `mapIt`
let s2=s
var i = 0
var result = newSeq[outType](s2.len)
for it in s2:
result[i] = lambda2(it)
i += 1
result
else:
var result: seq[outType] = @[]
for it in s:
result.add lambda2(it)
result

doAssert [1,2].map2(a ~> a*10) == @[10, 20]
let foo=3
doAssert [1,2].map2(a ~> a*foo) == @[3, 6]