-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
autostruct.jl
244 lines (198 loc) · 7.82 KB
/
autostruct.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
"""
@autostruct function MyLayer(d); ...; MyLayer(f1, f2, ...); end
This is a macro for easily defining new layers.
Recall that Flux layer is a callable `struct` which may contain parameter arrays.
Usually, the steps to define a new one are:
1. Define a `struct MyLayer` with the desired fields,
and tell Flux to look inside with `@layer MyLayer` (or on earlier versions, `@functor`).
2. Define a constructor function like `MyLayer(d::Int)`,
which initialises the parameters (say to `randn32(d, d)`)
and returns an instance of the `struct`, some `m::MyLayer`.
3. Define the forward pass, by making the struct callable: `(m::MyLayer)(x) = ...`
Given the function in step 2, this macro handles step 1. You still do step 3.
If you change the name or types of the fields, then the `struct` definition is
automatically replaced, without re-starting Julia.
This works because this definition uses an auto-generated name, which is `== MyLayer`.
(But existing instances of the old `struct` are not changed in any way!)
Writing `@autostruct :expand function MyLayer(d)` will use `@layer :expand MyLayer`,
and result in container-style pretty-printing.
See [AutoStructs.jl](https://github.com/CarloLucibello/AutoStructs.jl) for
a version of this macro aimed at non-Flux uses.
## Examples
```julia
@autostruct function MyModel(d::Int)
alpha, beta = [Dense(d=>d, tanh) for _ in 1:2] # arbitrary code here, not just keyword-like
beta.bias[:] .= 1/d
return MyModel(alpha, beta) # this must be very simple, no = signs allowed (return optional)
end
function (m::MyModel)(x) # forward pass looks just like a normal struct
y = m.alpha(x)
z = m.beta(y)
(x .+ y .+ z)./3
end
Flux.trainable(m::MyModel) = (; m.alpha) # if necessary, restrict which fields are trainable
Base.show(io::IO, m::MyModel) = # if desired, replace default printing "MyModel(...)"
print(io, "MyModel(", size(m.alpha.weight, 1), ")")
MyModel(2) isa MyModel # true
```
The `struct` defined by the macro here is something like this:
```julia
struct MyModel001{T1, T2}
alpha::T1
beta::T2
end
```
This can hold any objects, even `MyModel("hello", "world")`.
As you can see by looking `methods(MyModel)`, there should never be an ambiguity
between the `struct`'s own constructor, and your `MyModel(d::Int)`.
You can also restrict the types allowed in the struct:
```
@autostruct :expand function MyOtherModel(d1, d2, act=identity)
gamma = Embedding(128 => d1)
delta = Dense(d1 => d2, act)
MyOtherModel(gamma::Embedding, delta::Dense) # struct will only hold these types
end
(m::MyOtherModel)(x) = softmax(m.delta(m.gamma(x))) # forward pass
methods(MyOtherModel) # will show 3 methods
```
Such restrictions change the struct like this:
```julia
struct MyOtherModel002{T1 <: Embedding, T2 <: Dense}
gamma::T1
delta::T2
end
```
If you need to add additional constructor methods, the obvious syntax will not work.
But you can add them to the type, like this:
```julia
MyModel(str::String) = MyModel(parse(Int, str))
# ERROR: cannot define function MyModel; it already has a value
(::Type{MyModel})(str::String) = MyModel(parse(Int, str))
MyModel("4") # this works
```
## Compared to `@compact`
For comparison, the use of `@compact` to do much the same thing looks like this -- shorter,
but further from being ordinary Julia code.
```julia
function MyModel2(d::Int)
alpha, beta = [Dense(d=>d, tanh) for _ in 1:2]
beta.bias[:] .= 1/d
@compact(; alpha, beta) do x
y = alpha(x)
z = beta(y)
(x .+ y .+ z)./3
end
end
MyModel2(2) isa Fluxperimental.CompactLayer # no easy struct type
MyOtherModel2(d1, d2, act=identity) =
@compact(gamma = Embedding(128 => d1), delta=Dense(d1 => d2, act)) do x
softmax(delta(gamma(x)))
end
```
"""
macro autostruct(ex)
esc(_autostruct(ex))
end
macro autostruct(ex1, ex2)
(ex1 isa QuoteNode && ex1.value == :expand) || throw("Expected either `@autostruct function` or `@autostruct :expand function`")
esc(_autostruct(ex2; expand=true))
end
"""
@autostruct MyType(field1, field2, ...)
Used like this, without a `function`, the macro creates a `struct` with the fields indicated.
`@autostruct MyType(field1, field2::Function)` expands to roughly this:
```julia
struct MyType003{T1, T2 <: Function}
field1::T1
fiedl2::T2
end
Flux.@layer :expand MyType002
MyType = MyType003 # this allows re-definition
```
To use this as a Flux layer, you will also need to make it callable,
by writing for instance:
```julia
(m::MyType)(x) = m.field1(x) .+ x .|> m.field2
m1 = MyType(Chain(vcat, Dense(1 => 5, relu)), cbrt)
m1(-2)
```
"""
var"@autostruct"
const DEFINE = Dict{UInt, Tuple}()
struct _NoCall
_NoCall() = error("this object is meant never to be created")
end
function _autostruct(expr; expand=nothing)
if Meta.isexpr(expr, :function) # original path, @autostruct function MyStruct(...); ...
expand = something(expand, false)
elseif Meta.isexpr(expr, :call) # one-line, @autostruct MyStruct(field)
fun = expr.args[1]
newex = :(function $fun(_::$_NoCall) # perhaps not the cleanest implementation
$expr
end)
return _autostruct(newex; expand = expand = something(expand, true))
else
throw("Expected a function definition, like `@autostruct function MyStruct(...); ...`, or a call like `@autostruct MyStruct(...)`")
end
fun = expr.args[1].args[1]
ret = expr.args[2].args[end]
if Meta.isexpr(ret, :return)
ret = only(ret.args)
end
# Check first & last line of the input expression:
Meta.isexpr(ret, :call) || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)`")
ret.args[1] === fun || throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)`")
for ex in ret.args
ex isa Symbol && continue
Meta.isexpr(ex, :(::)) && continue
throw("Last line of `@autostruct function $fun` must return `$fun(field1, field2, ...)` or `$fun(field1::T1, field2::T2, ...)`, but got $ex")
end
funargs = expr.args[1].args[2:end]
retargs = ret.args[2:end]
if length(retargs) == length(funargs) && all(ex -> ex isa Symbol, retargs) && all(ex -> ex isa Symbol, funargs)
# This check only catches cases like MyFun(a) -> MyFun(A), not MyFun(as...) or MyFun(a, b=1) or MyFun(a; b=1)
@warn "Function $(expr.args[1]) will be ambiguous with struct $ret. " *
"Please add some type restrictions to the function, or to the return line (which sets struct fields)"
end
# If the last line is new, construct struct definition:
name, defex = get!(DEFINE, hash(ret, UInt(expand))) do
name = gensym(fun)
fields = map(enumerate(ret.args[2:end])) do (i, ex)
field = ex isa Symbol ? ex : ex.args[1] # we allow `return MyModel(alpha, beta::Chain)`
type = Symbol("T#", i)
:($field::$type)
end
types = map(fields, ret.args[2:end]) do ft, ex
if ex isa Symbol # then no type spec on return line
ft.args[2]
else
Expr(:(<:), ft.args[2], ex.args[2])
end
end
layer = if !expand
:($Flux.@layer $name)
else
str = "$fun("
quote
$Flux.@layer :expand $name
Flux._show_pre_post(::$name) = $str, ")" # needs https://github.com/FluxML/Flux.jl/pull/2344
end
end
str = "$fun(...)"
ex = quote
struct $name{$(types...)}
$(fields...)
end
$layer
$Base.show(io::IO, _::$name) = $print(io, $str)
$fun = $name
end
(name, ex)
end
# Change first line to use the struct's name:
expr.args[1].args[1] = name
quote
$(defex.args...) # struct definition
$expr # constructor function
end
end