Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compile training loop automatically using reactant #969

Merged
merged 23 commits into from
Oct 9, 2024

Conversation

avik-pal
Copy link
Member

@avik-pal avik-pal commented Oct 4, 2024

Much less intrusive than #673. We automatically check for get_device_type and use Reactant if paired with Enzyme.

TODOs

Future PRs

@avik-pal avik-pal mentioned this pull request Oct 4, 2024
18 tasks
Copy link
Contributor

github-actions bot commented Oct 4, 2024

Benchmark Results (ASV)

main 77f1048... main/77f1048de54cb7...
basics/overhead 0.0544 ± 0.0012 μs 0.0551 ± 0.0013 μs 0.987
time_to_load 1.26 ± 0.0072 s 1.27 ± 0.022 s 0.992

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal avik-pal force-pushed the ap/reactant_training2 branch 2 times, most recently from 50ad0c2 to 744cc56 Compare October 4, 2024 02:27
@avik-pal
Copy link
Member Author

avik-pal commented Oct 4, 2024

Example Usage

using Lux, Reactant, Enzyme, Random
using Optimisers

dev = xla_device()

model = Chain(
    Dense(2 => 32, tanh),
    Dense(32 => 32, tanh),
    Dense(32 => 2)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev

x = randn(Float32, 2, 1024) |> dev
y = x .^ 2 .+ 1

function sse(model, ps, st, (x, y))
    z, stₙ = model(x, ps, st)
    diff = z .- y
    return sum(abs2, diff), stₙ, (;)
end

train_state = Training.TrainState(model, ps, st, Descent(0.01))

grads, loss, stats, train_state_compiled = Training.single_train_step!(
    AutoEnzyme(),
    sse,
    (x, y),
    train_state
)

grads, loss, stats, train_state_compiled = Training.single_train_step!(
    AutoEnzyme(),
    sse,
    (x, y),
    train_state_compiled
)

Using any other ADType will cause an error.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lux Benchmarks

Benchmark suite Current: 5481cdc Previous: 04deedf Ratio
Dense(512 => 512, identity)(512 x 128)/forward/CPU/2 thread(s) 411958.5 ns 411750 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/4 thread(s) 322042 ns 322271 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/8 thread(s) 322354 ns 323042 ns 1.00
Dense(512 => 512, identity)(512 x 128)/forward/CPU/1 thread(s) 742333 ns 749375 ns 0.99
Dense(512 => 512, identity)(512 x 128)/forward/GPU/CUDA 43425 ns 43905 ns 0.99
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/2 thread(s) 592042 ns 1306583 ns 0.45
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/4 thread(s) 2431084 ns 465625 ns 5.22
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/8 thread(s) 14232375 ns 13617333 ns 1.05
Dense(512 => 512, identity)(512 x 128)/zygote/CPU/1 thread(s) 2275166 ns 2245750 ns 1.01
Dense(512 => 512, identity)(512 x 128)/zygote/GPU/CUDA 190617 ns 192831 ns 0.99
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/2 thread(s) 747750 ns 1394875 ns 0.54
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/4 thread(s) 2623417 ns 634729.5 ns 4.13
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/8 thread(s) 14186833 ns 14050875 ns 1.01
Dense(512 => 512, identity)(512 x 128)/enzyme/CPU/1 thread(s) 2233750 ns 2238000 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1548145.5 ns 1661542 ns 0.93
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1025000 ns 1196103.5 ns 0.86
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1528958 ns 1534187.5 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 2988750 ns 3005667 ns 0.99
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/forward/GPU/CUDA 210240 ns 209529 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12297958 ns 12111521 ns 1.02
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 8809104.5 ns 9554687 ns 0.92
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9213042 ns 9247000 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18575208 ns 18626583 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1920868 ns 1910271 ns 1.01
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17326250 ns 17307250 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 13955500 ns 14377958 ns 0.97
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14512333 ns 14526875 ns 1.00
Conv((3, 3), 2 => 2, identity)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 21826729.5 ns 21836458.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 123177667 ns 250439041.5 ns 0.49
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 148327500 ns 174592521 ns 0.85
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 116102750 ns 115955208.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 453770833 ns 447243084 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5476779 ns 5470843 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 611248042 ns 1228722500 ns 0.50
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 929680709 ns 543561875 ns 1.71
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 831844250 ns 830623396.5 ns 1.00
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1649532250 ns 1628878000 ns 1.01
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 35018194 ns 38000637 ns 0.92
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 667893521 ns 1136994583 ns 0.59
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1000178229 ns 679379084 ns 1.47
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1306998770.5 ns 1328113771 ns 0.98
Conv((3, 3), 64 => 64, relu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1730971229 ns 1733752146 ns 1.00
lenet(28, 28, 1, 32)/forward/CPU/2 thread(s) 825834 ns 1103375 ns 0.75
lenet(28, 28, 1, 32)/forward/CPU/4 thread(s) 1604584 ns 823209 ns 1.95
lenet(28, 28, 1, 32)/forward/CPU/8 thread(s) 3532833.5 ns 3578479 ns 0.99
lenet(28, 28, 1, 32)/forward/CPU/1 thread(s) 777979 ns 786500 ns 0.99
lenet(28, 28, 1, 32)/forward/GPU/CUDA 266898.5 ns 266091.5 ns 1.00
lenet(28, 28, 1, 32)/zygote/CPU/2 thread(s) 2700875 ns 2986021 ns 0.90
lenet(28, 28, 1, 32)/zygote/CPU/4 thread(s) 4132458.5 ns 2426000 ns 1.70
lenet(28, 28, 1, 32)/zygote/CPU/8 thread(s) 9412000 ns 10461250 ns 0.90
lenet(28, 28, 1, 32)/zygote/CPU/1 thread(s) 3133708 ns 3150042 ns 0.99
lenet(28, 28, 1, 32)/zygote/GPU/CUDA 1053497 ns 1055864 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 2193792 ns 2335042 ns 0.94
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1449292 ns 1537708 ns 0.94
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1705291 ns 1740000 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 4334812.5 ns 4348437.5 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 210237 ns 212286 ns 0.99
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 20482125 ns 20266645.5 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 16980666 ns 17701209 ns 0.96
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 18146833 ns 17495416 ns 1.04
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 26730917 ns 26797000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1983964 ns 1973706 ns 1.01
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 45022979.5 ns 44317750 ns 1.02
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 40982021.5 ns 42027646 ns 0.98
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 41316458 ns 41325000 ns 1.00
Conv((3, 3), 2 => 2, gelu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 47738791 ns 47734917 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 4280771 ns 4664854 ns 0.92
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2853416.5 ns 2868521.5 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2998521 ns 3015958 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 8638667 ns 8658937.5 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 512477.5 ns 516555 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 40269500 ns 40579000.5 ns 0.99
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 34068542 ns 34830104 ns 0.98
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 34146937.5 ns 34148292 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 53546917 ns 53661812 ns 1.00
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 3090082.5 ns 2969951 ns 1.04
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 89992917 ns 109640958 ns 0.82
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 135652041.5 ns 84133666 ns 1.61
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 241274000 ns 255828791 ns 0.94
Conv((3, 3), 4 => 4, gelu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 96381833 ns 96388416 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 141558125 ns 270215792 ns 0.52
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 160423459 ns 186630271 ns 0.86
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 127894500 ns 128172709 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 490577000 ns 489605542 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/forward/GPU/CUDA 7105529 ns 7104246 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 880297750.5 ns 1502664042 ns 0.59
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 1204185958 ns 821183792 ns 1.47
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 1096989708 ns 1092397958.5 ns 1.00
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 2043412687.5 ns 2032173187.5 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 34000370.5 ns 33798333 ns 1.01
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 1668401875 ns 2027767896 ns 0.82
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 1844138125 ns 1563910958 ns 1.18
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 2075119917 ns 2210346833.5 ns 0.94
Conv((3, 3), 64 => 64, gelu)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 2586231167 ns 2560629834 ns 1.01
lenet(28, 28, 1, 128)/forward/CPU/2 thread(s) 1539333 ns 2006833 ns 0.77
lenet(28, 28, 1, 128)/forward/CPU/4 thread(s) 2612875 ns 1257333 ns 2.08
lenet(28, 28, 1, 128)/forward/CPU/8 thread(s) 7608479.5 ns 7451041.5 ns 1.02
lenet(28, 28, 1, 128)/forward/CPU/1 thread(s) 2429417 ns 2470458 ns 0.98
lenet(28, 28, 1, 128)/forward/GPU/CUDA 277595 ns 275531 ns 1.01
lenet(28, 28, 1, 128)/zygote/CPU/2 thread(s) 7693771 ns 9463416 ns 0.81
lenet(28, 28, 1, 128)/zygote/CPU/4 thread(s) 11629042 ns 6552500 ns 1.77
lenet(28, 28, 1, 128)/zygote/CPU/8 thread(s) 25087146 ns 25529541 ns 0.98
lenet(28, 28, 1, 128)/zygote/CPU/1 thread(s) 11747771 ns 11734125 ns 1.00
lenet(28, 28, 1, 128)/zygote/GPU/CUDA 1126864.5 ns 1130415 ns 1.00
vgg16(32, 32, 3, 32)/forward/CPU/2 thread(s) 185544521 ns 380676854.5 ns 0.49
vgg16(32, 32, 3, 32)/forward/CPU/4 thread(s) 286762104 ns 145328000 ns 1.97
vgg16(32, 32, 3, 32)/forward/CPU/8 thread(s) 238849459 ns 243564083 ns 0.98
vgg16(32, 32, 3, 32)/forward/CPU/1 thread(s) 453230104 ns 452336354.5 ns 1.00
vgg16(32, 32, 3, 32)/forward/GPU/CUDA 4878004 ns 4879283 ns 1.00
vgg16(32, 32, 3, 32)/zygote/CPU/2 thread(s) 646115875 ns 1156932333 ns 0.56
vgg16(32, 32, 3, 32)/zygote/CPU/4 thread(s) 997850500 ns 487570458 ns 2.05
vgg16(32, 32, 3, 32)/zygote/CPU/8 thread(s) 1020548042 ns 973572458 ns 1.05
vgg16(32, 32, 3, 32)/zygote/CPU/1 thread(s) 1405903916 ns 1399439834 ns 1.00
vgg16(32, 32, 3, 32)/zygote/GPU/CUDA 16517069 ns 16976929 ns 0.97
lenet(28, 28, 1, 64)/forward/CPU/2 thread(s) 1082416.5 ns 1062687.5 ns 1.02
lenet(28, 28, 1, 64)/forward/CPU/4 thread(s) 1987771 ns 971124.5 ns 2.05
lenet(28, 28, 1, 64)/forward/CPU/8 thread(s) 5710354 ns 6269583 ns 0.91
lenet(28, 28, 1, 64)/forward/CPU/1 thread(s) 1292417 ns 1393375 ns 0.93
lenet(28, 28, 1, 64)/forward/GPU/CUDA 271917 ns 277704.5 ns 0.98
lenet(28, 28, 1, 64)/zygote/CPU/2 thread(s) 6008167 ns 6494541.5 ns 0.93
lenet(28, 28, 1, 64)/zygote/CPU/4 thread(s) 12617937.5 ns 4635437.5 ns 2.72
lenet(28, 28, 1, 64)/zygote/CPU/8 thread(s) 19261125 ns 19450479 ns 0.99
lenet(28, 28, 1, 64)/zygote/CPU/1 thread(s) 6080125 ns 6080229 ns 1.00
lenet(28, 28, 1, 64)/zygote/GPU/CUDA 1135120.5 ns 1148981 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 23800833.5 ns 70442208 ns 0.34
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43546020.5 ns 35305229 ns 1.23
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39747791.5 ns 39532604 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 134027792 ns 132574604 ns 1.01
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1832530 ns 1848251 ns 0.99
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 185091125 ns 356785937.5 ns 0.52
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 270613834 ns 159371854 ns 1.70
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 254143208 ns 254893688 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 534168750 ns 535009020.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 16481853 ns 16489529.5 ns 1.00
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 296796938 ns 395707667 ns 0.75
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 396396750 ns 245564417 ns 1.61
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 713142792 ns 652089584 ns 1.09
Conv((3, 3), 32 => 32, identity)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 711324584 ns 712574333 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/2 thread(s) 654228834 ns 1191762375 ns 0.55
vgg16(32, 32, 3, 128)/forward/CPU/4 thread(s) 691374875 ns 434009729.5 ns 1.59
vgg16(32, 32, 3, 128)/forward/CPU/8 thread(s) 633147916 ns 631038834 ns 1.00
vgg16(32, 32, 3, 128)/forward/CPU/1 thread(s) 1779924875 ns 1771033395.5 ns 1.01
vgg16(32, 32, 3, 128)/forward/GPU/CUDA 12474301 ns 12471861 ns 1.00
vgg16(32, 32, 3, 128)/zygote/CPU/2 thread(s) 1890900125.5 ns 3670803208.5 ns 0.52
vgg16(32, 32, 3, 128)/zygote/CPU/4 thread(s) 2844497583 ns 1633483458 ns 1.74
vgg16(32, 32, 3, 128)/zygote/CPU/8 thread(s) 2698564000 ns 2737701958 ns 0.99
vgg16(32, 32, 3, 128)/zygote/CPU/1 thread(s) 5083178292 ns 5038709417 ns 1.01
vgg16(32, 32, 3, 128)/zygote/GPU/CUDA 49878971 ns 49641386 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3061062 ns 3412146 ns 0.90
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2050792 ns 2094750 ns 0.98
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2518833 ns 2533833.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 6007396 ns 6034292 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/forward/GPU/CUDA 580914.5 ns 586721 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 25707542 ns 26096750.5 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 18780458 ns 20315791.5 ns 0.92
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 19414188 ns 19312917 ns 1.01
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 39173750 ns 39366625 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2996538 ns 2989473.5 ns 1.00
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 35182271 ns 54095229 ns 0.65
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 82233188 ns 28393083 ns 2.90
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 176579334 ns 177757792 ns 0.99
Conv((3, 3), 4 => 4, relu)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 45418896.5 ns 45278750 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/2 thread(s) 1659959 ns 1778208 ns 0.93
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/4 thread(s) 1093666 ns 1204708 ns 0.91
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/8 thread(s) 1583625 ns 1564000 ns 1.01
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/CPU/1 thread(s) 3020417 ns 3038771 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/forward/GPU/CUDA 217039 ns 217944 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/2 thread(s) 12723500 ns 12531437.5 ns 1.02
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/4 thread(s) 9202000 ns 9964292 ns 0.92
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/8 thread(s) 9639333 ns 9707042 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/CPU/1 thread(s) 18926020.5 ns 18974500 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/zygote/GPU/CUDA 1947814 ns 1963028.5 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/2 thread(s) 17690687.5 ns 17644270.5 ns 1.00
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/4 thread(s) 14314083 ns 14745500 ns 0.97
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/8 thread(s) 14550166.5 ns 14639333 ns 0.99
Conv((3, 3), 2 => 2, relu)(64 x 64 x 2 x 128)/enzyme/CPU/1 thread(s) 22168625 ns 22173792 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 23646458.5 ns 70409562 ns 0.34
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 43551458.5 ns 34786542 ns 1.25
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 39640417 ns 39571499.5 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 132442291.5 ns 132610521 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 1820313 ns 1837717 ns 0.99
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 190914083 ns 360588187.5 ns 0.53
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 347341062.5 ns 237608334 ns 1.46
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 304813417 ns 299913354 ns 1.02
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 731122041 ns 725805833 ns 1.01
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 13916356.5 ns 13956738 ns 1.00
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 301421791.5 ns 418949812.5 ns 0.72
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 420410709 ns 251360792 ns 1.67
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 681250833 ns 712732021 ns 0.96
Conv((3, 3), 32 => 32, relu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 717801541 ns 717284542 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/2 thread(s) 1917291 ns 1912041.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/4 thread(s) 1507125 ns 1579125 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/8 thread(s) 1552146 ns 1549791.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/CPU/1 thread(s) 2650416 ns 2657625 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/forward/GPU/CUDA 577487 ns 573525 ns 1.01
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/2 thread(s) 6162500 ns 9220000 ns 0.67
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/4 thread(s) 13031833 ns 5936166 ns 2.20
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/8 thread(s) 32447375 ns 31895937.5 ns 1.02
mlp7layer_bn(gelu)(32 x 256)/zygote/CPU/1 thread(s) 10165875 ns 10214937.5 ns 1.00
mlp7layer_bn(gelu)(32 x 256)/zygote/GPU/CUDA 1391351 ns 1399984.5 ns 0.99
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/2 thread(s) 18787354.5 ns 22182333.5 ns 0.85
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/4 thread(s) 27626166 ns 19138291.5 ns 1.44
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/8 thread(s) 49929625 ns 52527562.5 ns 0.95
mlp7layer_bn(gelu)(32 x 256)/enzyme/CPU/1 thread(s) 18807229 ns 18888042 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/CPU/2 thread(s) 69167 ns 791291.5 ns 0.08741026536996796
Dense(512 => 512, relu)(512 x 128)/forward/CPU/4 thread(s) 507666 ns 69958.5 ns 7.26
Dense(512 => 512, relu)(512 x 128)/forward/CPU/8 thread(s) 1007292 ns 997167 ns 1.01
Dense(512 => 512, relu)(512 x 128)/forward/CPU/1 thread(s) 725271 ns 724499.5 ns 1.00
Dense(512 => 512, relu)(512 x 128)/forward/GPU/CUDA 47866 ns 48324 ns 0.99
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/2 thread(s) 328458 ns 1508042 ns 0.22
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/4 thread(s) 1012854 ns 320291 ns 3.16
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/8 thread(s) 1408209 ns 1445145.5 ns 0.97
Dense(512 => 512, relu)(512 x 128)/zygote/CPU/1 thread(s) 2278291.5 ns 2258458.5 ns 1.01
Dense(512 => 512, relu)(512 x 128)/zygote/GPU/CUDA 215931.5 ns 216350 ns 1.00
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/2 thread(s) 444791 ns 1537083 ns 0.29
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/4 thread(s) 1071292 ns 428792 ns 2.50
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/8 thread(s) 1254999.5 ns 1444584 ns 0.87
Dense(512 => 512, relu)(512 x 128)/enzyme/CPU/1 thread(s) 2228417 ns 2250333 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/2 thread(s) 3046750 ns 3421750 ns 0.89
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/4 thread(s) 2038520.5 ns 2084312.5 ns 0.98
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/8 thread(s) 2500667 ns 2519375.5 ns 0.99
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/CPU/1 thread(s) 5987375 ns 6015021 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/forward/GPU/CUDA 587477 ns 584297 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/2 thread(s) 23590333 ns 24071521.5 ns 0.98
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/4 thread(s) 17304396 ns 18050833 ns 0.96
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/8 thread(s) 17177583 ns 17227375 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/CPU/1 thread(s) 37426916.5 ns 37583145.5 ns 1.00
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/zygote/GPU/CUDA 2913314 ns 2895440 ns 1.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/2 thread(s) 33402499.5 ns 52599188 ns 0.64
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/4 thread(s) 83336958.5 ns 27644250 ns 3.01
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/8 thread(s) 139112583 ns 170611917 ns 0.82
Conv((3, 3), 4 => 4, identity)(64 x 64 x 4 x 128)/enzyme/CPU/1 thread(s) 44466063 ns 44514250 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/2 thread(s) 122042083.5 ns 250102292 ns 0.49
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/4 thread(s) 147943958 ns 174510104 ns 0.85
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/8 thread(s) 115633729 ns 115645729 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/CPU/1 thread(s) 454541833.5 ns 448140124.5 ns 1.01
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/forward/GPU/CUDA 5461242 ns 5446378 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/2 thread(s) 473498792 ns 1105120833 ns 0.43
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/4 thread(s) 855268708 ns 467780729.5 ns 1.83
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/8 thread(s) 827066562.5 ns 825455520.5 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/CPU/1 thread(s) 1748960000 ns 1753431125 ns 1.00
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/zygote/GPU/CUDA 32293722 ns 35149612 ns 0.92
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/2 thread(s) 642440500 ns 1021983312.5 ns 0.63
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/4 thread(s) 923772709 ns 662517187.5 ns 1.39
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/8 thread(s) 1230540000 ns 1286071167 ns 0.96
Conv((3, 3), 64 => 64, identity)(64 x 64 x 64 x 128)/enzyme/CPU/1 thread(s) 1717797333 ns 1721665437.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/forward/CPU/2 thread(s) 1228792 ns 1312041 ns 0.94
mlp7layer_bn(relu)(32 x 256)/forward/CPU/4 thread(s) 960687.5 ns 928625 ns 1.03
mlp7layer_bn(relu)(32 x 256)/forward/CPU/8 thread(s) 962104 ns 903208 ns 1.07
mlp7layer_bn(relu)(32 x 256)/forward/CPU/1 thread(s) 1941146 ns 2032416 ns 0.96
mlp7layer_bn(relu)(32 x 256)/forward/GPU/CUDA 572882.5 ns 575428 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/2 thread(s) 2970083 ns 5922771 ns 0.50
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/4 thread(s) 6001021 ns 2615500 ns 2.29
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/8 thread(s) 24213750 ns 24427083.5 ns 0.99
mlp7layer_bn(relu)(32 x 256)/zygote/CPU/1 thread(s) 7075625 ns 7104916.5 ns 1.00
mlp7layer_bn(relu)(32 x 256)/zygote/GPU/CUDA 1331747 ns 1363516 ns 0.98
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/2 thread(s) 6643438 ns 9705958.5 ns 0.68
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/4 thread(s) 12968500 ns 6499000 ns 2.00
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/8 thread(s) 30681146 ns 31929750 ns 0.96
mlp7layer_bn(relu)(32 x 256)/enzyme/CPU/1 thread(s) 7599417 ns 7614042 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/2 thread(s) 39000 ns 483291 ns 0.08069672309229843
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/4 thread(s) 423750 ns 31750 ns 13.35
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/8 thread(s) 1728854 ns 1795375 ns 0.96
Dense(128 => 128, gelu)(128 x 128)/forward/CPU/1 thread(s) 91666.5 ns 91542 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/forward/GPU/CUDA 28093 ns 28996 ns 0.97
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/2 thread(s) 175708.5 ns 392958 ns 0.45
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/4 thread(s) 418125 ns 175542 ns 2.38
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/8 thread(s) 4363375 ns 4708417 ns 0.93
Dense(128 => 128, gelu)(128 x 128)/zygote/CPU/1 thread(s) 273375 ns 273000 ns 1.00
Dense(128 => 128, gelu)(128 x 128)/zygote/GPU/CUDA 217037.5 ns 224707.5 ns 0.97
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/2 thread(s) 442000 ns 666333 ns 0.66
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/4 thread(s) 689625 ns 442250 ns 1.56
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/8 thread(s) 4473333 ns 4499167 ns 0.99
Dense(128 => 128, gelu)(128 x 128)/enzyme/CPU/1 thread(s) 510374.5 ns 510979.5 ns 1.00
Dense(128 => 128, relu)(128 x 128)/forward/CPU/2 thread(s) 13458.5 ns 430437.5 ns 0.0312670248293887
Dense(128 => 128, relu)(128 x 128)/forward/CPU/4 thread(s) 361958 ns 13583 ns 26.65
Dense(128 => 128, relu)(128 x 128)/forward/CPU/8 thread(s) 662333 ns 709208 ns 0.93
Dense(128 => 128, relu)(128 x 128)/forward/CPU/1 thread(s) 53333.5 ns 52584 ns 1.01
Dense(128 => 128, relu)(128 x 128)/forward/GPU/CUDA 28091 ns 29296 ns 0.96
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/2 thread(s) 25750 ns 337250 ns 0.07635285396590066
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/4 thread(s) 277604.5 ns 26375 ns 10.53
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/8 thread(s) 398771 ns 484812.5 ns 0.82
Dense(128 => 128, relu)(128 x 128)/zygote/CPU/1 thread(s) 151541.5 ns 151333 ns 1.00
Dense(128 => 128, relu)(128 x 128)/zygote/GPU/CUDA 206752 ns 213308.5 ns 0.97
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/2 thread(s) 46042 ns 352521 ns 0.13
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/4 thread(s) 293562.5 ns 45792 ns 6.41
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/8 thread(s) 583874.5 ns 487125 ns 1.20
Dense(128 => 128, relu)(128 x 128)/enzyme/CPU/1 thread(s) 151104.5 ns 151000 ns 1.00
vgg16(32, 32, 3, 64)/forward/CPU/2 thread(s) 319286916 ns 603223875 ns 0.53
vgg16(32, 32, 3, 64)/forward/CPU/4 thread(s) 430851208 ns 239241354 ns 1.80
vgg16(32, 32, 3, 64)/forward/CPU/8 thread(s) 374433583.5 ns 377713896 ns 0.99
vgg16(32, 32, 3, 64)/forward/CPU/1 thread(s) 872306875 ns 872019458 ns 1.00
vgg16(32, 32, 3, 64)/forward/GPU/CUDA 7676094 ns 7676104.5 ns 1.00
vgg16(32, 32, 3, 64)/zygote/CPU/2 thread(s) 1098654125 ns 2005520125 ns 0.55
vgg16(32, 32, 3, 64)/zygote/CPU/4 thread(s) 1617580312 ns 947653916.5 ns 1.71
vgg16(32, 32, 3, 64)/zygote/CPU/8 thread(s) 1598725833 ns 1551514604.5 ns 1.03
vgg16(32, 32, 3, 64)/zygote/CPU/1 thread(s) 2649602250 ns 2653038416 ns 1.00
vgg16(32, 32, 3, 64)/zygote/GPU/CUDA 27121930 ns 27180094 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/2 thread(s) 195250 ns 525604 ns 0.37
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/4 thread(s) 449792 ns 168333 ns 2.67
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/8 thread(s) 1696833 ns 1740625 ns 0.97
Dense(512 => 512, gelu)(512 x 128)/forward/CPU/1 thread(s) 871667 ns 875541 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/forward/GPU/CUDA 47277.5 ns 47837 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/2 thread(s) 1207770.5 ns 1943750 ns 0.62
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/4 thread(s) 2647125 ns 1100208 ns 2.41
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/8 thread(s) 14579895.5 ns 14661875 ns 0.99
Dense(512 => 512, gelu)(512 x 128)/zygote/CPU/1 thread(s) 2726959 ns 2836709 ns 0.96
Dense(512 => 512, gelu)(512 x 128)/zygote/GPU/CUDA 223041.5 ns 232330 ns 0.96
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/2 thread(s) 2305792 ns 2974229 ns 0.78
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/4 thread(s) 5778000 ns 2208583.5 ns 2.62
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/8 thread(s) 15078604.5 ns 15024229.5 ns 1.00
Dense(512 => 512, gelu)(512 x 128)/enzyme/CPU/1 thread(s) 3705500 ns 3751750 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/2 thread(s) 1576583.5 ns 1602291.5 ns 0.98
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/4 thread(s) 1175270.5 ns 1221084 ns 0.96
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/8 thread(s) 1231208 ns 1264750 ns 0.97
mlp7layer_bn(tanh)(32 x 256)/forward/CPU/1 thread(s) 2209917 ns 2362750 ns 0.94
mlp7layer_bn(tanh)(32 x 256)/forward/GPU/CUDA 572444.5 ns 576709 ns 0.99
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/2 thread(s) 3188250 ns 5931125 ns 0.54
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/4 thread(s) 4761542 ns 2866334 ns 1.66
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/8 thread(s) 24964209 ns 25035834 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/zygote/CPU/1 thread(s) 7279103.5 ns 6650208 ns 1.09
mlp7layer_bn(tanh)(32 x 256)/zygote/GPU/CUDA 1325716 ns 1379411 ns 0.96
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/2 thread(s) 8829125 ns 11605146 ns 0.76
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/4 thread(s) 14248542 ns 8767458 ns 1.63
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/8 thread(s) 35161250 ns 35255000 ns 1.00
mlp7layer_bn(tanh)(32 x 256)/enzyme/CPU/1 thread(s) 9525937.5 ns 9570000.5 ns 1.00
Dense(16 => 16, relu)(16 x 128)/forward/CPU/2 thread(s) 2500 ns 2541 ns 0.98
Dense(16 => 16, relu)(16 x 128)/forward/CPU/4 thread(s) 2458 ns 2292 ns 1.07
Dense(16 => 16, relu)(16 x 128)/forward/CPU/8 thread(s) 2833 ns 3000 ns 0.94
Dense(16 => 16, relu)(16 x 128)/forward/CPU/1 thread(s) 2458 ns 2333 ns 1.05
Dense(16 => 16, relu)(16 x 128)/forward/GPU/CUDA 24765 ns 25379.5 ns 0.98
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/2 thread(s) 7250 ns 7125 ns 1.02
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/4 thread(s) 7250 ns 7083 ns 1.02
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/8 thread(s) 7417 ns 7375 ns 1.01
Dense(16 => 16, relu)(16 x 128)/zygote/CPU/1 thread(s) 7292 ns 7270.5 ns 1.00
Dense(16 => 16, relu)(16 x 128)/zygote/GPU/CUDA 187906 ns 193729.5 ns 0.97
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/2 thread(s) 8542 ns 8334 ns 1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/4 thread(s) 8291 ns 8500 ns 0.98
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/8 thread(s) 8625 ns 8417 ns 1.02
Dense(16 => 16, relu)(16 x 128)/enzyme/CPU/1 thread(s) 6083 ns 6084 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/2 thread(s) 11583 ns 10375.5 ns 1.12
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/4 thread(s) 15479 ns 14916 ns 1.04
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/8 thread(s) 10625 ns 11854 ns 0.90
Dense(16 => 16, gelu)(16 x 128)/forward/CPU/1 thread(s) 7334 ns 7625 ns 0.96
Dense(16 => 16, gelu)(16 x 128)/forward/GPU/CUDA 24957 ns 25646 ns 0.97
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/2 thread(s) 21750 ns 21708 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/4 thread(s) 21667 ns 21500 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/8 thread(s) 21708 ns 21750 ns 1.00
Dense(16 => 16, gelu)(16 x 128)/zygote/CPU/1 thread(s) 22000 ns 21875 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/zygote/GPU/CUDA 196771.5 ns 203851 ns 0.97
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/2 thread(s) 56667 ns 53417 ns 1.06
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/4 thread(s) 53584 ns 56583.5 ns 0.95
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/8 thread(s) 53917 ns 53583.5 ns 1.01
Dense(16 => 16, gelu)(16 x 128)/enzyme/CPU/1 thread(s) 51209 ns 51333 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/2 thread(s) 32375 ns 26895.5 ns 1.20
Dense(128 => 128, identity)(128 x 128)/forward/CPU/4 thread(s) 28583 ns 28333.5 ns 1.01
Dense(128 => 128, identity)(128 x 128)/forward/CPU/8 thread(s) 28875 ns 29000 ns 1.00
Dense(128 => 128, identity)(128 x 128)/forward/CPU/1 thread(s) 46167 ns 48291 ns 0.96
Dense(128 => 128, identity)(128 x 128)/forward/GPU/CUDA 25960 ns 26739 ns 0.97
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/2 thread(s) 44333 ns 220875 ns 0.20
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/4 thread(s) 276291 ns 44583 ns 6.20
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/8 thread(s) 4200084 ns 4132667 ns 1.02
Dense(128 => 128, identity)(128 x 128)/zygote/CPU/1 thread(s) 145500 ns 145458 ns 1.00
Dense(128 => 128, identity)(128 x 128)/zygote/GPU/CUDA 167846 ns 172310 ns 0.97
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/2 thread(s) 70542 ns 237312.5 ns 0.30
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/4 thread(s) 294250 ns 68625 ns 4.29
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/8 thread(s) 4092292 ns 4360708 ns 0.94
Dense(128 => 128, identity)(128 x 128)/enzyme/CPU/1 thread(s) 145542 ns 145917 ns 1.00
Dense(16 => 16, identity)(16 x 128)/forward/CPU/2 thread(s) 1959 ns 2292 ns 0.85
Dense(16 => 16, identity)(16 x 128)/forward/CPU/4 thread(s) 1916 ns 1750 ns 1.09
Dense(16 => 16, identity)(16 x 128)/forward/CPU/8 thread(s) 2417 ns 2166 ns 1.12
Dense(16 => 16, identity)(16 x 128)/forward/CPU/1 thread(s) 1834 ns 1520.5 ns 1.21
Dense(16 => 16, identity)(16 x 128)/forward/GPU/CUDA 23087 ns 23935 ns 0.96
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/2 thread(s) 5167 ns 5125 ns 1.01
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/4 thread(s) 4959 ns 5042 ns 0.98
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/8 thread(s) 5375 ns 5458 ns 0.98
Dense(16 => 16, identity)(16 x 128)/zygote/CPU/1 thread(s) 5459 ns 5084 ns 1.07
Dense(16 => 16, identity)(16 x 128)/zygote/GPU/CUDA 171785 ns 176841 ns 0.97
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/2 thread(s) 8250 ns 7292 ns 1.13
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/4 thread(s) 7292 ns 8166 ns 0.89
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/8 thread(s) 7750 ns 7541 ns 1.03
Dense(16 => 16, identity)(16 x 128)/enzyme/CPU/1 thread(s) 5167 ns 5167 ns 1
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/2 thread(s) 34011375 ns 80940833 ns 0.42
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/4 thread(s) 49736250 ns 41092709 ns 1.21
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/8 thread(s) 45704375 ns 45570541 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/CPU/1 thread(s) 153460084 ns 153559792 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/forward/GPU/CUDA 2633927 ns 2660311 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/2 thread(s) 451506229.5 ns 621714834 ns 0.73
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/4 thread(s) 429059958 ns 421739375 ns 1.02
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/8 thread(s) 414617250.5 ns 414510667 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/CPU/1 thread(s) 705577125 ns 697568292 ns 1.01
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/zygote/GPU/CUDA 15166479 ns 15148414 ns 1.00
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/2 thread(s) 747601083 ns 872377937.5 ns 0.86
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/4 thread(s) 842678875 ns 706482291.5 ns 1.19
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/8 thread(s) 1152094271 ns 1162546146 ns 0.99
Conv((3, 3), 32 => 32, gelu)(64 x 64 x 32 x 128)/enzyme/CPU/1 thread(s) 1173282708 ns 1175739375 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Copy link

codecov bot commented Oct 4, 2024

Codecov Report

Attention: Patch coverage is 73.17073% with 22 lines in your changes missing coverage. Please review.

Project coverage is 91.90%. Comparing base (77eb5fb) to head (5481cdc).

Files with missing lines Patch % Lines
ext/LuxReactantExt/training.jl 59.09% 18 Missing ⚠️
src/helpers/training.jl 82.60% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #969      +/-   ##
==========================================
- Coverage   92.61%   91.90%   -0.72%     
==========================================
  Files          58       60       +2     
  Lines        2886     2952      +66     
==========================================
+ Hits         2673     2713      +40     
- Misses        213      239      +26     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@avik-pal avik-pal force-pushed the ap/reactant_training2 branch 2 times, most recently from 92cc36b to a06b571 Compare October 5, 2024 00:27
@avik-pal
Copy link
Member Author

avik-pal commented Oct 5, 2024

EnzymeAD/Reactant.jl#161 is needed for the loss functions tests to pass

@avik-pal avik-pal force-pushed the ap/reactant_training2 branch 2 times, most recently from e4d6507 to ac05b19 Compare October 9, 2024 19:23
@avik-pal avik-pal changed the base branch from main to ap/remove_lossfunctions October 9, 2024 19:24
@avik-pal avik-pal force-pushed the ap/remove_lossfunctions branch 2 times, most recently from 72f5f07 to b6a3b35 Compare October 9, 2024 19:56
@avik-pal avik-pal force-pushed the ap/reactant_training2 branch 2 times, most recently from ff2aa65 to 9d1d81d Compare October 9, 2024 19:58
Base automatically changed from ap/remove_lossfunctions to main October 9, 2024 21:06
An error occurred while trying to automatically change base from ap/remove_lossfunctions to main October 9, 2024 21:06
@avik-pal avik-pal marked this pull request as ready for review October 9, 2024 21:17
@avik-pal avik-pal merged commit 1b0d6f8 into main Oct 9, 2024
18 of 23 checks passed
@avik-pal avik-pal deleted the ap/reactant_training2 branch October 9, 2024 22:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant