-
-
Notifications
You must be signed in to change notification settings - Fork 15
/
functor.jl
114 lines (91 loc) · 3.3 KB
/
functor.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
functor(T, x) = (), _ -> x
functor(x) = functor(typeof(x), x)
functor(::Type{<:Tuple}, x) = x, y -> y
functor(::Type{<:NamedTuple{L}}, x) where L = NamedTuple{L}(map(s -> getproperty(x, s), L)), identity
functor(::Type{<:AbstractArray}, x) = x, y -> y
functor(::Type{<:AbstractArray{<:Number}}, x) = (), _ -> x
@static if VERSION >= v"1.6"
functor(::Type{<:Base.ComposedFunction}, x) = (outer = x.outer, inner = x.inner), y -> Base.ComposedFunction(y.outer, y.inner)
end
function makefunctor(m::Module, T, fs = fieldnames(T))
yᵢ = 0
escargs = map(fieldnames(T)) do f
f in fs ? :(y[$(yᵢ += 1)]) : :(x.$f)
end
escfs = [:($f=x.$f) for f in fs]
@eval m begin
$Functors.functor(::Type{<:$T}, x) = ($(escfs...),), y -> $T($(escargs...))
end
end
function functorm(T, fs = nothing)
fs === nothing || Meta.isexpr(fs, :tuple) || error("@functor T (a, b)")
fs = fs === nothing ? [] : [:($(map(QuoteNode, fs.args)...),)]
:(makefunctor(@__MODULE__, $(esc(T)), $(fs...)))
end
macro functor(args...)
functorm(args...)
end
isleaf(x) = children(x) === ()
children(x) = functor(x)[1]
function _default_walk(f, x)
func, re = functor(x)
re(map(f, func))
end
struct NoKeyword end
function fmap(f, x; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x) : walk(x -> fmap(f, x; exclude=exclude, walk=walk, cache=cache, prune=prune), x)
end
###
### Extras
###
fmapstructure(f, x; kwargs...) = fmap(f, x; walk = (f, x) -> map(f, children(x)), kwargs...)
function fcollect(x; output = [], cache = Base.IdSet(), exclude = v -> false)
# note: we don't have an `OrderedIdSet`, so we use an `IdSet` for the cache
# (to ensure we get exactly 1 copy of each distinct array), and a usual `Vector`
# for the results, to preserve traversal order (important downstream!).
x in cache && return output
if !exclude(x)
push!(cache, x)
push!(output, x)
foreach(y -> fcollect(y; cache=cache, output=output, exclude=exclude), children(x))
end
return output
end
###
### Vararg forms
###
function fmap(f, x, ys...; exclude = isleaf, walk = _default_walk, cache = IdDict(), prune = NoKeyword())
haskey(cache, x) && return prune isa NoKeyword ? cache[x] : prune
cache[x] = exclude(x) ? f(x, ys...) : walk((xy...,) -> fmap(f, xy...; exclude=exclude, walk=walk, cache=cache, prune=prune), x, ys...)
end
function _default_walk(f, x, ys...)
func, re = functor(x)
yfuncs = map(y -> functor(typeof(x), y)[1], ys)
re(map(f, func, yfuncs...))
end
###
### FlexibleFunctors.jl
###
function makeflexiblefunctor(m::Module, T, pfield)
pfield = QuoteNode(pfield)
@eval m begin
function $Functors.functor(::Type{<:$T}, x)
pfields = getproperty(x, $pfield)
function re(y)
all_args = map(fn -> getproperty(fn in pfields ? y : x, fn), fieldnames($T))
return $T(all_args...)
end
func = NamedTuple{pfields}(map(p -> getproperty(x, p), pfields))
return func, re
end
end
end
function flexiblefunctorm(T, pfield = :params)
pfield isa Symbol || error("@flexiblefunctor T param_field")
pfield = QuoteNode(pfield)
:(makeflexiblefunctor(@__MODULE__, $(esc(T)), $(esc(pfield))))
end
macro flexiblefunctor(args...)
flexiblefunctorm(args...)
end