-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: compile training loop automatically using reactant #969
Conversation
Benchmark Results (ASV)
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
50ad0c2
to
744cc56
Compare
Example Usageusing Lux, Reactant, Enzyme, Random
using Optimisers
dev = xla_device()
model = Chain(
Dense(2 => 32, tanh),
Dense(32 => 32, tanh),
Dense(32 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
x = randn(Float32, 2, 1024) |> dev
y = x .^ 2 .+ 1
function sse(model, ps, st, (x, y))
z, stₙ = model(x, ps, st)
diff = z .- y
return sum(abs2, diff), stₙ, (;)
end
train_state = Training.TrainState(model, ps, st, Descent(0.01))
grads, loss, stats, train_state_compiled = Training.single_train_step!(
AutoEnzyme(),
sse,
(x, y),
train_state
)
grads, loss, stats, train_state_compiled = Training.single_train_step!(
AutoEnzyme(),
sse,
(x, y),
train_state_compiled
) Using any other ADType will cause an error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lux Benchmarks
Benchmark suite | Current: 5481cdc | Previous: 04deedf | Ratio |
---|---|---|---|
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) |
411958.5 ns |
411750 ns |
1.00 |
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) |
322042 ns |
322271 ns |
1.00 |
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) |
322354 ns |
323042 ns |
1.00 |
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) |
742333 ns |
749375 ns |
0.99 |
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA |
43425 ns |
43905 ns |
0.99 |
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) |
592042 ns |
1306583 ns |
0.45 |
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) |
2431084 ns |
465625 ns |
5.22 |
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) |
14232375 ns |
13617333 ns |
1.05 |
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) |
2275166 ns |
2245750 ns |
1.01 |
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA |
190617 ns |
192831 ns |
0.99 |
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) |
747750 ns |
1394875 ns |
0.54 |
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) |
2623417 ns |
634729.5 ns |
4.13 |
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) |
14186833 ns |
14050875 ns |
1.01 |
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) |
2233750 ns |
2238000 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) |
1548145.5 ns |
1661542 ns |
0.93 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) |
1025000 ns |
1196103.5 ns |
0.86 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) |
1528958 ns |
1534187.5 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) |
2988750 ns |
3005667 ns |
0.99 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA |
210240 ns |
209529 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) |
12297958 ns |
12111521 ns |
1.02 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) |
8809104.5 ns |
9554687 ns |
0.92 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) |
9213042 ns |
9247000 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) |
18575208 ns |
18626583 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA |
1920868 ns |
1910271 ns |
1.01 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) |
17326250 ns |
17307250 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) |
13955500 ns |
14377958 ns |
0.97 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) |
14512333 ns |
14526875 ns |
1.00 |
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) |
21826729.5 ns |
21836458.5 ns |
1.00 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) |
123177667 ns |
250439041.5 ns |
0.49 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) |
148327500 ns |
174592521 ns |
0.85 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) |
116102750 ns |
115955208.5 ns |
1.00 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) |
453770833 ns |
447243084 ns |
1.01 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA |
5476779 ns |
5470843 ns |
1.00 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) |
611248042 ns |
1228722500 ns |
0.50 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) |
929680709 ns |
543561875 ns |
1.71 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) |
831844250 ns |
830623396.5 ns |
1.00 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) |
1649532250 ns |
1628878000 ns |
1.01 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA |
35018194 ns |
38000637 ns |
0.92 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) |
667893521 ns |
1136994583 ns |
0.59 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) |
1000178229 ns |
679379084 ns |
1.47 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) |
1306998770.5 ns |
1328113771 ns |
0.98 |
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) |
1730971229 ns |
1733752146 ns |
1.00 |
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) |
825834 ns |
1103375 ns |
0.75 |
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) |
1604584 ns |
823209 ns |
1.95 |
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) |
3532833.5 ns |
3578479 ns |
0.99 |
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) |
777979 ns |
786500 ns |
0.99 |
lenet(28, 28, 1, 32)/forward/GPU/CUDA |
266898.5 ns |
266091.5 ns |
1.00 |
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) |
2700875 ns |
2986021 ns |
0.90 |
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) |
4132458.5 ns |
2426000 ns |
1.70 |
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) |
9412000 ns |
10461250 ns |
0.90 |
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) |
3133708 ns |
3150042 ns |
0.99 |
lenet(28, 28, 1, 32)/zygote/GPU/CUDA |
1053497 ns |
1055864 ns |
1.00 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) |
2193792 ns |
2335042 ns |
0.94 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) |
1449292 ns |
1537708 ns |
0.94 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) |
1705291 ns |
1740000 ns |
0.98 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) |
4334812.5 ns |
4348437.5 ns |
1.00 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA |
210237 ns |
212286 ns |
0.99 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) |
20482125 ns |
20266645.5 ns |
1.01 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) |
16980666 ns |
17701209 ns |
0.96 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) |
18146833 ns |
17495416 ns |
1.04 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) |
26730917 ns |
26797000 ns |
1.00 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA |
1983964 ns |
1973706 ns |
1.01 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) |
45022979.5 ns |
44317750 ns |
1.02 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) |
40982021.5 ns |
42027646 ns |
0.98 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) |
41316458 ns |
41325000 ns |
1.00 |
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) |
47738791 ns |
47734917 ns |
1.00 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) |
4280771 ns |
4664854 ns |
0.92 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) |
2853416.5 ns |
2868521.5 ns |
0.99 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) |
2998521 ns |
3015958 ns |
0.99 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) |
8638667 ns |
8658937.5 ns |
1.00 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA |
512477.5 ns |
516555 ns |
0.99 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) |
40269500 ns |
40579000.5 ns |
0.99 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) |
34068542 ns |
34830104 ns |
0.98 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) |
34146937.5 ns |
34148292 ns |
1.00 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) |
53546917 ns |
53661812 ns |
1.00 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA |
3090082.5 ns |
2969951 ns |
1.04 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) |
89992917 ns |
109640958 ns |
0.82 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) |
135652041.5 ns |
84133666 ns |
1.61 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) |
241274000 ns |
255828791 ns |
0.94 |
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) |
96381833 ns |
96388416 ns |
1.00 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) |
141558125 ns |
270215792 ns |
0.52 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) |
160423459 ns |
186630271 ns |
0.86 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) |
127894500 ns |
128172709 ns |
1.00 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) |
490577000 ns |
489605542 ns |
1.00 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA |
7105529 ns |
7104246 ns |
1.00 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) |
880297750.5 ns |
1502664042 ns |
0.59 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) |
1204185958 ns |
821183792 ns |
1.47 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) |
1096989708 ns |
1092397958.5 ns |
1.00 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) |
2043412687.5 ns |
2032173187.5 ns |
1.01 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA |
34000370.5 ns |
33798333 ns |
1.01 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) |
1668401875 ns |
2027767896 ns |
0.82 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) |
1844138125 ns |
1563910958 ns |
1.18 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) |
2075119917 ns |
2210346833.5 ns |
0.94 |
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) |
2586231167 ns |
2560629834 ns |
1.01 |
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) |
1539333 ns |
2006833 ns |
0.77 |
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) |
2612875 ns |
1257333 ns |
2.08 |
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) |
7608479.5 ns |
7451041.5 ns |
1.02 |
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) |
2429417 ns |
2470458 ns |
0.98 |
lenet(28, 28, 1, 128)/forward/GPU/CUDA |
277595 ns |
275531 ns |
1.01 |
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) |
7693771 ns |
9463416 ns |
0.81 |
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) |
11629042 ns |
6552500 ns |
1.77 |
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) |
25087146 ns |
25529541 ns |
0.98 |
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) |
11747771 ns |
11734125 ns |
1.00 |
lenet(28, 28, 1, 128)/zygote/GPU/CUDA |
1126864.5 ns |
1130415 ns |
1.00 |
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) |
185544521 ns |
380676854.5 ns |
0.49 |
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) |
286762104 ns |
145328000 ns |
1.97 |
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) |
238849459 ns |
243564083 ns |
0.98 |
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) |
453230104 ns |
452336354.5 ns |
1.00 |
vgg16(32, 32, 3, 32)/forward/GPU/CUDA |
4878004 ns |
4879283 ns |
1.00 |
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) |
646115875 ns |
1156932333 ns |
0.56 |
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) |
997850500 ns |
487570458 ns |
2.05 |
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) |
1020548042 ns |
973572458 ns |
1.05 |
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) |
1405903916 ns |
1399439834 ns |
1.00 |
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA |
16517069 ns |
16976929 ns |
0.97 |
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) |
1082416.5 ns |
1062687.5 ns |
1.02 |
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) |
1987771 ns |
971124.5 ns |
2.05 |
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) |
5710354 ns |
6269583 ns |
0.91 |
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) |
1292417 ns |
1393375 ns |
0.93 |
lenet(28, 28, 1, 64)/forward/GPU/CUDA |
271917 ns |
277704.5 ns |
0.98 |
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) |
6008167 ns |
6494541.5 ns |
0.93 |
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) |
12617937.5 ns |
4635437.5 ns |
2.72 |
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) |
19261125 ns |
19450479 ns |
0.99 |
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) |
6080125 ns |
6080229 ns |
1.00 |
lenet(28, 28, 1, 64)/zygote/GPU/CUDA |
1135120.5 ns |
1148981 ns |
0.99 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) |
23800833.5 ns |
70442208 ns |
0.34 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) |
43546020.5 ns |
35305229 ns |
1.23 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) |
39747791.5 ns |
39532604 ns |
1.01 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) |
134027792 ns |
132574604 ns |
1.01 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA |
1832530 ns |
1848251 ns |
0.99 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) |
185091125 ns |
356785937.5 ns |
0.52 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) |
270613834 ns |
159371854 ns |
1.70 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) |
254143208 ns |
254893688 ns |
1.00 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) |
534168750 ns |
535009020.5 ns |
1.00 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA |
16481853 ns |
16489529.5 ns |
1.00 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) |
296796938 ns |
395707667 ns |
0.75 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) |
396396750 ns |
245564417 ns |
1.61 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) |
713142792 ns |
652089584 ns |
1.09 |
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) |
711324584 ns |
712574333 ns |
1.00 |
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) |
654228834 ns |
1191762375 ns |
0.55 |
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) |
691374875 ns |
434009729.5 ns |
1.59 |
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) |
633147916 ns |
631038834 ns |
1.00 |
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) |
1779924875 ns |
1771033395.5 ns |
1.01 |
vgg16(32, 32, 3, 128)/forward/GPU/CUDA |
12474301 ns |
12471861 ns |
1.00 |
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) |
1890900125.5 ns |
3670803208.5 ns |
0.52 |
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) |
2844497583 ns |
1633483458 ns |
1.74 |
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) |
2698564000 ns |
2737701958 ns |
0.99 |
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) |
5083178292 ns |
5038709417 ns |
1.01 |
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA |
49878971 ns |
49641386 ns |
1.00 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) |
3061062 ns |
3412146 ns |
0.90 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) |
2050792 ns |
2094750 ns |
0.98 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) |
2518833 ns |
2533833.5 ns |
0.99 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) |
6007396 ns |
6034292 ns |
1.00 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA |
580914.5 ns |
586721 ns |
0.99 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) |
25707542 ns |
26096750.5 ns |
0.99 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) |
18780458 ns |
20315791.5 ns |
0.92 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) |
19414188 ns |
19312917 ns |
1.01 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) |
39173750 ns |
39366625 ns |
1.00 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA |
2996538 ns |
2989473.5 ns |
1.00 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) |
35182271 ns |
54095229 ns |
0.65 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) |
82233188 ns |
28393083 ns |
2.90 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) |
176579334 ns |
177757792 ns |
0.99 |
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) |
45418896.5 ns |
45278750 ns |
1.00 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) |
1659959 ns |
1778208 ns |
0.93 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) |
1093666 ns |
1204708 ns |
0.91 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) |
1583625 ns |
1564000 ns |
1.01 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) |
3020417 ns |
3038771 ns |
0.99 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA |
217039 ns |
217944 ns |
1.00 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) |
12723500 ns |
12531437.5 ns |
1.02 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) |
9202000 ns |
9964292 ns |
0.92 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) |
9639333 ns |
9707042 ns |
0.99 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) |
18926020.5 ns |
18974500 ns |
1.00 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA |
1947814 ns |
1963028.5 ns |
0.99 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) |
17690687.5 ns |
17644270.5 ns |
1.00 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) |
14314083 ns |
14745500 ns |
0.97 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) |
14550166.5 ns |
14639333 ns |
0.99 |
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) |
22168625 ns |
22173792 ns |
1.00 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) |
23646458.5 ns |
70409562 ns |
0.34 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) |
43551458.5 ns |
34786542 ns |
1.25 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) |
39640417 ns |
39571499.5 ns |
1.00 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) |
132442291.5 ns |
132610521 ns |
1.00 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA |
1820313 ns |
1837717 ns |
0.99 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) |
190914083 ns |
360588187.5 ns |
0.53 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) |
347341062.5 ns |
237608334 ns |
1.46 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) |
304813417 ns |
299913354 ns |
1.02 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) |
731122041 ns |
725805833 ns |
1.01 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA |
13916356.5 ns |
13956738 ns |
1.00 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) |
301421791.5 ns |
418949812.5 ns |
0.72 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) |
420410709 ns |
251360792 ns |
1.67 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) |
681250833 ns |
712732021 ns |
0.96 |
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) |
717801541 ns |
717284542 ns |
1.00 |
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) |
1917291 ns |
1912041.5 ns |
1.00 |
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) |
1507125 ns |
1579125 ns |
0.95 |
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) |
1552146 ns |
1549791.5 ns |
1.00 |
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) |
2650416 ns |
2657625 ns |
1.00 |
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA |
577487 ns |
573525 ns |
1.01 |
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) |
6162500 ns |
9220000 ns |
0.67 |
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) |
13031833 ns |
5936166 ns |
2.20 |
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) |
32447375 ns |
31895937.5 ns |
1.02 |
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) |
10165875 ns |
10214937.5 ns |
1.00 |
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA |
1391351 ns |
1399984.5 ns |
0.99 |
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) |
18787354.5 ns |
22182333.5 ns |
0.85 |
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) |
27626166 ns |
19138291.5 ns |
1.44 |
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) |
49929625 ns |
52527562.5 ns |
0.95 |
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) |
18807229 ns |
18888042 ns |
1.00 |
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) |
69167 ns |
791291.5 ns |
0.08741026536996796 |
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) |
507666 ns |
69958.5 ns |
7.26 |
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) |
1007292 ns |
997167 ns |
1.01 |
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) |
725271 ns |
724499.5 ns |
1.00 |
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA |
47866 ns |
48324 ns |
0.99 |
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) |
328458 ns |
1508042 ns |
0.22 |
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) |
1012854 ns |
320291 ns |
3.16 |
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) |
1408209 ns |
1445145.5 ns |
0.97 |
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) |
2278291.5 ns |
2258458.5 ns |
1.01 |
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA |
215931.5 ns |
216350 ns |
1.00 |
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) |
444791 ns |
1537083 ns |
0.29 |
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) |
1071292 ns |
428792 ns |
2.50 |
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) |
1254999.5 ns |
1444584 ns |
0.87 |
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) |
2228417 ns |
2250333 ns |
0.99 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) |
3046750 ns |
3421750 ns |
0.89 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) |
2038520.5 ns |
2084312.5 ns |
0.98 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) |
2500667 ns |
2519375.5 ns |
0.99 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) |
5987375 ns |
6015021 ns |
1.00 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA |
587477 ns |
584297 ns |
1.01 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) |
23590333 ns |
24071521.5 ns |
0.98 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) |
17304396 ns |
18050833 ns |
0.96 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) |
17177583 ns |
17227375 ns |
1.00 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) |
37426916.5 ns |
37583145.5 ns |
1.00 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA |
2913314 ns |
2895440 ns |
1.01 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) |
33402499.5 ns |
52599188 ns |
0.64 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) |
83336958.5 ns |
27644250 ns |
3.01 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) |
139112583 ns |
170611917 ns |
0.82 |
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) |
44466063 ns |
44514250 ns |
1.00 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) |
122042083.5 ns |
250102292 ns |
0.49 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) |
147943958 ns |
174510104 ns |
0.85 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) |
115633729 ns |
115645729 ns |
1.00 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) |
454541833.5 ns |
448140124.5 ns |
1.01 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA |
5461242 ns |
5446378 ns |
1.00 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) |
473498792 ns |
1105120833 ns |
0.43 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) |
855268708 ns |
467780729.5 ns |
1.83 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) |
827066562.5 ns |
825455520.5 ns |
1.00 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) |
1748960000 ns |
1753431125 ns |
1.00 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA |
32293722 ns |
35149612 ns |
0.92 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) |
642440500 ns |
1021983312.5 ns |
0.63 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) |
923772709 ns |
662517187.5 ns |
1.39 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) |
1230540000 ns |
1286071167 ns |
0.96 |
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) |
1717797333 ns |
1721665437.5 ns |
1.00 |
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) |
1228792 ns |
1312041 ns |
0.94 |
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) |
960687.5 ns |
928625 ns |
1.03 |
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) |
962104 ns |
903208 ns |
1.07 |
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) |
1941146 ns |
2032416 ns |
0.96 |
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA |
572882.5 ns |
575428 ns |
1.00 |
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) |
2970083 ns |
5922771 ns |
0.50 |
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) |
6001021 ns |
2615500 ns |
2.29 |
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) |
24213750 ns |
24427083.5 ns |
0.99 |
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) |
7075625 ns |
7104916.5 ns |
1.00 |
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA |
1331747 ns |
1363516 ns |
0.98 |
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) |
6643438 ns |
9705958.5 ns |
0.68 |
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) |
12968500 ns |
6499000 ns |
2.00 |
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) |
30681146 ns |
31929750 ns |
0.96 |
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) |
7599417 ns |
7614042 ns |
1.00 |
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) |
39000 ns |
483291 ns |
0.08069672309229843 |
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) |
423750 ns |
31750 ns |
13.35 |
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) |
1728854 ns |
1795375 ns |
0.96 |
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) |
91666.5 ns |
91542 ns |
1.00 |
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA |
28093 ns |
28996 ns |
0.97 |
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) |
175708.5 ns |
392958 ns |
0.45 |
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) |
418125 ns |
175542 ns |
2.38 |
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) |
4363375 ns |
4708417 ns |
0.93 |
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) |
273375 ns |
273000 ns |
1.00 |
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA |
217037.5 ns |
224707.5 ns |
0.97 |
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) |
442000 ns |
666333 ns |
0.66 |
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) |
689625 ns |
442250 ns |
1.56 |
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) |
4473333 ns |
4499167 ns |
0.99 |
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) |
510374.5 ns |
510979.5 ns |
1.00 |
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) |
13458.5 ns |
430437.5 ns |
0.0312670248293887 |
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) |
361958 ns |
13583 ns |
26.65 |
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) |
662333 ns |
709208 ns |
0.93 |
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) |
53333.5 ns |
52584 ns |
1.01 |
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA |
28091 ns |
29296 ns |
0.96 |
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) |
25750 ns |
337250 ns |
0.07635285396590066 |
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) |
277604.5 ns |
26375 ns |
10.53 |
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) |
398771 ns |
484812.5 ns |
0.82 |
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) |
151541.5 ns |
151333 ns |
1.00 |
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA |
206752 ns |
213308.5 ns |
0.97 |
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) |
46042 ns |
352521 ns |
0.13 |
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) |
293562.5 ns |
45792 ns |
6.41 |
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) |
583874.5 ns |
487125 ns |
1.20 |
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) |
151104.5 ns |
151000 ns |
1.00 |
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) |
319286916 ns |
603223875 ns |
0.53 |
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) |
430851208 ns |
239241354 ns |
1.80 |
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) |
374433583.5 ns |
377713896 ns |
0.99 |
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) |
872306875 ns |
872019458 ns |
1.00 |
vgg16(32, 32, 3, 64)/forward/GPU/CUDA |
7676094 ns |
7676104.5 ns |
1.00 |
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) |
1098654125 ns |
2005520125 ns |
0.55 |
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) |
1617580312 ns |
947653916.5 ns |
1.71 |
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) |
1598725833 ns |
1551514604.5 ns |
1.03 |
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) |
2649602250 ns |
2653038416 ns |
1.00 |
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA |
27121930 ns |
27180094 ns |
1.00 |
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) |
195250 ns |
525604 ns |
0.37 |
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) |
449792 ns |
168333 ns |
2.67 |
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) |
1696833 ns |
1740625 ns |
0.97 |
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) |
871667 ns |
875541 ns |
1.00 |
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA |
47277.5 ns |
47837 ns |
0.99 |
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) |
1207770.5 ns |
1943750 ns |
0.62 |
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) |
2647125 ns |
1100208 ns |
2.41 |
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) |
14579895.5 ns |
14661875 ns |
0.99 |
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) |
2726959 ns |
2836709 ns |
0.96 |
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA |
223041.5 ns |
232330 ns |
0.96 |
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) |
2305792 ns |
2974229 ns |
0.78 |
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) |
5778000 ns |
2208583.5 ns |
2.62 |
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) |
15078604.5 ns |
15024229.5 ns |
1.00 |
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) |
3705500 ns |
3751750 ns |
0.99 |
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) |
1576583.5 ns |
1602291.5 ns |
0.98 |
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) |
1175270.5 ns |
1221084 ns |
0.96 |
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) |
1231208 ns |
1264750 ns |
0.97 |
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) |
2209917 ns |
2362750 ns |
0.94 |
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA |
572444.5 ns |
576709 ns |
0.99 |
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) |
3188250 ns |
5931125 ns |
0.54 |
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) |
4761542 ns |
2866334 ns |
1.66 |
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) |
24964209 ns |
25035834 ns |
1.00 |
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) |
7279103.5 ns |
6650208 ns |
1.09 |
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA |
1325716 ns |
1379411 ns |
0.96 |
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) |
8829125 ns |
11605146 ns |
0.76 |
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) |
14248542 ns |
8767458 ns |
1.63 |
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) |
35161250 ns |
35255000 ns |
1.00 |
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) |
9525937.5 ns |
9570000.5 ns |
1.00 |
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) |
2500 ns |
2541 ns |
0.98 |
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) |
2458 ns |
2292 ns |
1.07 |
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) |
2833 ns |
3000 ns |
0.94 |
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) |
2458 ns |
2333 ns |
1.05 |
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA |
24765 ns |
25379.5 ns |
0.98 |
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) |
7250 ns |
7125 ns |
1.02 |
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) |
7250 ns |
7083 ns |
1.02 |
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) |
7417 ns |
7375 ns |
1.01 |
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) |
7292 ns |
7270.5 ns |
1.00 |
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA |
187906 ns |
193729.5 ns |
0.97 |
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) |
8542 ns |
8334 ns |
1.02 |
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) |
8291 ns |
8500 ns |
0.98 |
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) |
8625 ns |
8417 ns |
1.02 |
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) |
6083 ns |
6084 ns |
1.00 |
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) |
11583 ns |
10375.5 ns |
1.12 |
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) |
15479 ns |
14916 ns |
1.04 |
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) |
10625 ns |
11854 ns |
0.90 |
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) |
7334 ns |
7625 ns |
0.96 |
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA |
24957 ns |
25646 ns |
0.97 |
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) |
21750 ns |
21708 ns |
1.00 |
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) |
21667 ns |
21500 ns |
1.01 |
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) |
21708 ns |
21750 ns |
1.00 |
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) |
22000 ns |
21875 ns |
1.01 |
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA |
196771.5 ns |
203851 ns |
0.97 |
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) |
56667 ns |
53417 ns |
1.06 |
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) |
53584 ns |
56583.5 ns |
0.95 |
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) |
53917 ns |
53583.5 ns |
1.01 |
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) |
51209 ns |
51333 ns |
1.00 |
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) |
32375 ns |
26895.5 ns |
1.20 |
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) |
28583 ns |
28333.5 ns |
1.01 |
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) |
28875 ns |
29000 ns |
1.00 |
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) |
46167 ns |
48291 ns |
0.96 |
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA |
25960 ns |
26739 ns |
0.97 |
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) |
44333 ns |
220875 ns |
0.20 |
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) |
276291 ns |
44583 ns |
6.20 |
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) |
4200084 ns |
4132667 ns |
1.02 |
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) |
145500 ns |
145458 ns |
1.00 |
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA |
167846 ns |
172310 ns |
0.97 |
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) |
70542 ns |
237312.5 ns |
0.30 |
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) |
294250 ns |
68625 ns |
4.29 |
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) |
4092292 ns |
4360708 ns |
0.94 |
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) |
145542 ns |
145917 ns |
1.00 |
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) |
1959 ns |
2292 ns |
0.85 |
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) |
1916 ns |
1750 ns |
1.09 |
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) |
2417 ns |
2166 ns |
1.12 |
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) |
1834 ns |
1520.5 ns |
1.21 |
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA |
23087 ns |
23935 ns |
0.96 |
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) |
5167 ns |
5125 ns |
1.01 |
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) |
4959 ns |
5042 ns |
0.98 |
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) |
5375 ns |
5458 ns |
0.98 |
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) |
5459 ns |
5084 ns |
1.07 |
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA |
171785 ns |
176841 ns |
0.97 |
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) |
8250 ns |
7292 ns |
1.13 |
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) |
7292 ns |
8166 ns |
0.89 |
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) |
7750 ns |
7541 ns |
1.03 |
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) |
5167 ns |
5167 ns |
1 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) |
34011375 ns |
80940833 ns |
0.42 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) |
49736250 ns |
41092709 ns |
1.21 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) |
45704375 ns |
45570541 ns |
1.00 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) |
153460084 ns |
153559792 ns |
1.00 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA |
2633927 ns |
2660311 ns |
0.99 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) |
451506229.5 ns |
621714834 ns |
0.73 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) |
429059958 ns |
421739375 ns |
1.02 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) |
414617250.5 ns |
414510667 ns |
1.00 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) |
705577125 ns |
697568292 ns |
1.01 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA |
15166479 ns |
15148414 ns |
1.00 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) |
747601083 ns |
872377937.5 ns |
0.86 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) |
842678875 ns |
706482291.5 ns |
1.19 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) |
1152094271 ns |
1162546146 ns |
0.99 |
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) |
1173282708 ns |
1175739375 ns |
1.00 |
This comment was automatically generated by workflow using github-action-benchmark.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #969 +/- ##
==========================================
- Coverage 92.61% 91.90% -0.72%
==========================================
Files 58 60 +2
Lines 2886 2952 +66
==========================================
+ Hits 2673 2713 +40
- Misses 213 239 +26 ☔ View full report in Codecov by Sentry. |
92cc36b
to
a06b571
Compare
EnzymeAD/Reactant.jl#161 is needed for the loss functions tests to pass |
e4d6507
to
ac05b19
Compare
72f5f07
to
b6a3b35
Compare
ff2aa65
to
9d1d81d
Compare
b6a3b35
to
a8fcd1f
Compare
9d1d81d
to
93856cd
Compare
93856cd
to
38ed312
Compare
Much less intrusive than #673. We automatically check for
get_device_type
and use Reactant if paired with Enzyme.TODOs
xla_device
andAutoEnzyme
single_train_step!
single_train_step
compute_gradients
We hack around a proxy for the optimizer similar to what I was doing in Compile training loop with Reactant #673.TracedRNumber
EnzymeAD/Reactant.jl#161Future PRs