From 251f71a9bb9a963937e379cdb0a7ed80fc54f50c Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Fri, 29 Nov 2024 17:08:52 +0900 Subject: [PATCH] feat: support generating trace on gpu --- Cargo.lock | 870 ++++++++++++++++-- Cargo.toml | 4 + crates/prover/Cargo.toml | 4 + .../prover/src/core/backend/gpu/gen_trace.rs | 312 +++++++ .../src/core/backend/gpu/gen_trace.wgsl | 279 ++++++ crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/mod.rs | 1 + crates/prover/src/examples/poseidon/mod.rs | 7 + 8 files changed, 1420 insertions(+), 58 deletions(-) create mode 100644 crates/prover/src/core/backend/gpu/gen_trace.rs create mode 100644 crates/prover/src/core/backend/gpu/gen_trace.wgsl create mode 100644 crates/prover/src/core/backend/gpu/mod.rs diff --git a/Cargo.lock b/Cargo.lock index c14183025..0686249f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,18 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -20,6 +32,21 @@ dependencies = [ "as-slice", ] +[[package]] +name = "allocator-api2" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anes" version = "0.1.6" @@ -117,6 +144,15 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "ash" +version = "0.38.0+1.3.281" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bb44936d800fea8f016d7f2311c6a4f97aebd5dc86f09906139ec848cf3a46f" +dependencies = [ + "libloading", +] + [[package]] name = "autocfg" version = "1.2.0" @@ -135,6 +171,33 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + [[package]] name = "blake2" version = "0.10.6" @@ -157,6 +220,12 @@ dependencies = [ "constant_time_eq", ] +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "block-buffer" version = "0.10.4" @@ -174,9 +243,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.15.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" dependencies = [ "bytemuck_derive", ] @@ -189,7 +258,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -210,6 +279,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + [[package]] name = "ciborium" version = "0.2.2" @@ -262,6 +337,16 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "codespan-reporting" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e" +dependencies = [ + "termcolor", + "unicode-width", +] + [[package]] name = "console_error_panic_hook" version = "0.1.7" @@ -278,6 +363,33 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -397,6 +509,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "document-features" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb6969eaabd2421f8a2775cfd2471a2b634372b4a25d41e3bd647b79912850a0" +dependencies = [ + "litrs", +] + [[package]] name = "downcast-rs" version = "1.2.1" @@ -412,7 +533,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -438,7 +559,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -460,6 +581,63 @@ dependencies = [ "log", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "nanorand", + "spin", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + [[package]] name = "generic-array" version = "0.14.7" @@ -483,6 +661,89 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gl_generator" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a95dfc23a2b4a9a2f5ab41d194f8bfda3cabec42af4e39f08c339eb2a0c124d" +dependencies = [ + "khronos_api", + "log", + "xml-rs", +] + +[[package]] +name = "glow" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d51fa363f025f5c111e03f13eda21162faeacb6911fe8caa0c0349f9cf0c4483" +dependencies = [ + "js-sys", + "slotmap", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "glutin_wgl_sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a4e1951bbd9434a81aa496fe59ccc2235af3820d27b85f9314e279609211e2c" +dependencies = [ + "gl_generator", +] + +[[package]] +name = "gpu-alloc" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171" +dependencies = [ + "bitflags 2.6.0", + "gpu-alloc-types", +] + +[[package]] +name = "gpu-alloc-types" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4" +dependencies = [ + "bitflags 2.6.0", +] + +[[package]] +name = "gpu-allocator" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c151a2a5ef800297b4e79efa4f4bec035c5f51d5ae587287c9b952bdf734cacd" +dependencies = [ + "log", + "presser", + "thiserror", + "windows", +] + +[[package]] +name = "gpu-descriptor" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c08c1f623a8d0b722b8b99f821eb0ba672a1618f0d3b16ddbee1cedd2dd8557" +dependencies = [ + "bitflags 2.6.0", + "gpu-descriptor-types", + "hashbrown 0.14.5", +] + +[[package]] +name = "gpu-descriptor-types" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "half" version = "2.4.1" @@ -493,6 +754,22 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" + [[package]] name = "hermit-abi" version = "0.3.9" @@ -505,6 +782,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hexf-parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfa686283ad6dd069f105e5ab091b04c62850d3e4cf5d67debad1933f55023df" + [[package]] name = "hmac" version = "0.12.1" @@ -514,6 +797,16 @@ dependencies = [ "digest", ] +[[package]] +name = "indexmap" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" +dependencies = [ + "equivalent", + "hashbrown 0.15.2", +] + [[package]] name = "is-terminal" version = "0.4.12" @@ -549,15 +842,38 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jni-sys" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" + [[package]] name = "js-sys" -version = "0.3.70" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] +[[package]] +name = "khronos-egl" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6aae1df220ece3c0ada96b8153459b67eebe9ae9212258bb0134ae60416fdf76" +dependencies = [ + "libc", + "libloading", + "pkg-config", +] + +[[package]] +name = "khronos_api" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2db585e1d738fc771bf08a151420d3ed193d9d895a36df7f6f8a9456b911ddc" + [[package]] name = "lazy_static" version = "1.4.0" @@ -570,12 +886,47 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libloading" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +dependencies = [ + "cfg-if", + "windows-targets", +] + +[[package]] +name = "litrs" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ce301924b7887e9d637144fdade93f9dfff9b60981d4ac161db09720d39aa5" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matchers" version = "0.1.0" @@ -591,6 +942,21 @@ version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +[[package]] +name = "metal" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" +dependencies = [ + "bitflags 2.6.0", + "block", + "core-graphics-types", + "foreign-types", + "log", + "objc", + "paste", +] + [[package]] name = "minicov" version = "0.3.5" @@ -601,6 +967,45 @@ dependencies = [ "walkdir", ] +[[package]] +name = "naga" +version = "23.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d5941e45a15b53aad4375eedf02033adb7a28931eedc31117faffa52e6a857e" +dependencies = [ + "arrayvec", + "bit-set", + "bitflags 2.6.0", + "cfg_aliases", + "codespan-reporting", + "hexf-parse", + "indexmap", + "log", + "rustc-hash", + "spirv", + "termcolor", + "thiserror", + "unicode-xid", +] + +[[package]] +name = "nanorand" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" +dependencies = [ + "getrandom", +] + +[[package]] +name = "ndk-sys" +version = "0.5.0+25.2.9519653" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c196769dd60fd4f363e11d948139556a344e79d451aeb2fa2fd040738ef7691" +dependencies = [ + "jni-sys", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -639,11 +1044,20 @@ dependencies = [ "autocfg", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", +] + [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "oorandom" @@ -657,6 +1071,29 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + [[package]] name = "paste" version = "1.0.15" @@ -669,6 +1106,12 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + [[package]] name = "plotters" version = "0.3.5" @@ -697,21 +1140,39 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "pollster" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22686f4785f02a4fcc856d3b3bb19bf6c8160d103f7a99cc258bddd0251dc7f2" + [[package]] name = "ppv-lite86" version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "presser" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" + [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] +[[package]] +name = "profiling" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afbdc74edc00b6f6a218ca6a5364d6226a259d4b8ea1af4a0ea063f27e179f4d" + [[package]] name = "quote" version = "1.0.36" @@ -747,6 +1208,18 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +[[package]] +name = "range-alloc" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" + +[[package]] +name = "raw-window-handle" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20675572f6f24e9e76ef639bc5552774ed45f1c30e2951e1e99c59888861c539" + [[package]] name = "rayon" version = "1.10.0" @@ -767,6 +1240,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "regex" version = "1.10.4" @@ -811,6 +1293,12 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +[[package]] +name = "renderdoc-sys" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" + [[package]] name = "rfc6979" version = "0.4.0" @@ -821,6 +1309,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -851,6 +1345,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "semver" version = "1.0.23" @@ -874,7 +1374,7 @@ checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -908,12 +1408,39 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "slotmap" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbff4acf519f630b3a3ddcfaea6c06b42174d9a44bc70c620e9ed1649d58b82a" +dependencies = [ + "version_check", +] + [[package]] name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "spirv" +version = "0.3.0+sdk-1.3.268.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -948,7 +1475,7 @@ checksum = "bbc159a1934c7be9761c237333a57febe060ace2bc9e3b337a59a37af206d19f" dependencies = [ "starknet-curve", "starknet-ff", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -974,6 +1501,12 @@ dependencies = [ "serde", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "stwo-prover" version = "0.1.1" @@ -986,9 +1519,12 @@ dependencies = [ "criterion", "downcast-rs", "educe", + "flume", "hex", "itertools 0.12.1", "num-traits", + "once_cell", + "pollster", "rand", "rayon", "serde", @@ -999,6 +1535,7 @@ dependencies = [ "tracing", "tracing-subscriber", "wasm-bindgen-test", + "wgpu", ] [[package]] @@ -1020,15 +1557,24 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.60" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "test-log" version = "0.2.15" @@ -1048,27 +1594,27 @@ checksum = "c8f546451eaa38373f549093fe9fd05e7d2bade739e2ddf834b9968621d60107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -1110,7 +1656,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] @@ -1164,6 +1710,18 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-width" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "valuable" version = "0.1.0" @@ -1194,9 +1752,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -1205,24 +1763,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.43" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e9300f63a621e96ed275155c108eb6f843b6a26d053f122ab69724559dc8ed" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -1232,9 +1790,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1242,22 +1800,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.93" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-bindgen-test" @@ -1282,17 +1840,123 @@ checksum = "4b8220be1fa9e4c889b30fd207d4906657e7e90b12e0e6b0c8b8d8709f5de021" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wgpu" +version = "23.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80f70000db37c469ea9d67defdc13024ddf9a5f1b89cb2941b812ad7cde1735a" +dependencies = [ + "arrayvec", + "cfg_aliases", + "document-features", + "js-sys", + "log", + "naga", + "parking_lot", + "profiling", + "raw-window-handle", + "smallvec", + "static_assertions", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "wgpu-core", + "wgpu-hal", + "wgpu-types", +] + +[[package]] +name = "wgpu-core" +version = "23.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" +dependencies = [ + "arrayvec", + "bit-vec", + "bitflags 2.6.0", + "cfg_aliases", + "document-features", + "indexmap", + "log", + "naga", + "once_cell", + "parking_lot", + "profiling", + "raw-window-handle", + "rustc-hash", + "smallvec", + "thiserror", + "wgpu-hal", + "wgpu-types", +] + +[[package]] +name = "wgpu-hal" +version = "23.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "89364b8a0b211adc7b16aeaf1bd5ad4a919c1154b44c9ce27838213ba05fd821" dependencies = [ + "android_system_properties", + "arrayvec", + "ash", + "bit-set", + "bitflags 2.6.0", + "block", + "bytemuck", + "cfg_aliases", + "core-graphics-types", + "glow", + "glutin_wgl_sys", + "gpu-alloc", + "gpu-allocator", + "gpu-descriptor", "js-sys", + "khronos-egl", + "libc", + "libloading", + "log", + "metal", + "naga", + "ndk-sys", + "objc", + "once_cell", + "parking_lot", + "profiling", + "range-alloc", + "raw-window-handle", + "renderdoc-sys", + "rustc-hash", + "smallvec", + "thiserror", "wasm-bindgen", + "web-sys", + "wgpu-types", + "windows", + "windows-core", +] + +[[package]] +name = "wgpu-types" +version = "23.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" +dependencies = [ + "bitflags 2.6.0", + "js-sys", + "web-sys", ] [[package]] @@ -1326,6 +1990,70 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core", + "windows-targets", +] + +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-result", + "windows-strings", + "windows-targets", +] + +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -1337,9 +2065,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -1353,51 +2081,77 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "xml-rs" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af310deaae937e48a26602b730250b4949e125f468f11e6990be3e5304ddd96f" + +[[package]] +name = "zerocopy" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] [[package]] name = "zeroize" @@ -1416,5 +2170,5 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.89", ] diff --git a/Cargo.toml b/Cargo.toml index 0f314a496..f94ed1660 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,10 @@ num-traits = "0.2.17" thiserror = "1.0.56" bytemuck = "1.14.3" tracing = "0.1.40" +wgpu = "23.0.0" +flume = "0.11.0" +pollster = "0.3" +once_cell = "1.20.2" [profile.bench] codegen-units = 1 diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index a9b80e9e4..fea6f1e14 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -26,6 +26,10 @@ thiserror.workspace = true tracing.workspace = true rayon = { version = "1.10.0", optional = true } serde = { version = "1.0", features = ["derive"] } +wgpu.workspace = true +flume.workspace = true +pollster.workspace = true +once_cell.workspace = true [dev-dependencies] aligned = "0.4.2" diff --git a/crates/prover/src/core/backend/gpu/gen_trace.rs b/crates/prover/src/core/backend/gpu/gen_trace.rs new file mode 100644 index 000000000..5b3782f3d --- /dev/null +++ b/crates/prover/src/core/backend/gpu/gen_trace.rs @@ -0,0 +1,312 @@ +use std::time::Instant; + +use bytemuck::{Pod, Zeroable}; +use wgpu::util::DeviceExt; + +// use crate::constraint_framework::EvalAtRow; +// use crate::examples::poseidon::PoseidonElements; + +const N_STATE: u32 = 16; +#[allow(dead_code)] +const N_INSTANCES_PER_ROW: u32 = 8; +#[allow(dead_code)] +const N_COLUMNS: u32 = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +#[allow(dead_code)] +const N_HALF_FULL_ROUNDS: u32 = 4; +#[allow(dead_code)] +const FULL_ROUNDS: u32 = 2 * N_HALF_FULL_ROUNDS; +#[allow(dead_code)] +const N_PARTIAL_ROUNDS: u32 = 14; +const N_LANES: u32 = 16; +#[allow(dead_code)] +const N_COLUMNS_PER_REP: u32 = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +#[allow(dead_code)] +const LOG_N_LANES: u32 = 4; + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct Complex { + real: f32, + imag: f32, +} + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct GenTraceInput { + log_size: u32, +} + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct BaseColumn { + data: [PackedM31; N_STATE as usize], + length: u32, +} + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct PackedM31 { + data: [u32; N_LANES as usize], +} + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +struct LookupData { + initial_state: [[BaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize], + final_state: [[BaseColumn; N_STATE as usize]; N_INSTANCES_PER_ROW as usize], +} + +#[derive(Debug, Clone, Copy, Pod, Zeroable)] +#[repr(C)] +pub struct DebugData { + index: [u32; 16], + values: [u32; 16], + counter: u32, +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +struct GenTraceOutput { + data: [PackedM31; N_STATE as usize], + trace: [BaseColumn; N_COLUMNS as usize], + lookup_data: LookupData, +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for GenTraceInput {} +impl ByteSerialize for BaseColumn {} +impl ByteSerialize for DebugData {} +impl ByteSerialize for GenTraceOutput {} + +pub async fn gen_trace() { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + let input_data: GenTraceInput = GenTraceInput { log_size: 7 }; + + // Create buffers + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Input Buffer"), + contents: bytemuck::cast_slice(&[input_data]), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Output Buffer"), + size: (N_STATE as usize * std::mem::size_of::()) as wgpu::BufferAddress, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let debug_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Debug Buffer"), + size: (std::mem::size_of::()) as wgpu::BufferAddress, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Load shader + let shader_source = include_str!("gen_trace.wgsl"); + let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Gen Trace Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + // Bind group layout + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + // Binding 0: Input buffer + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Binding 1: Output buffer + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Binding 2: Debug buffer + wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Gen Trace Bind Group Layout"), + }); + + // Create bind group + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: output_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: debug_buffer.as_entire_binding(), + }, + ], + label: Some("Gen Trace Bind Group"), + }); + + // Pipeline layout + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + label: Some("Gen Trace Pipeline Layout"), + }); + + // Compute pipeline + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Gen Trace Compute Pipeline"), + layout: Some(&pipeline_layout), + module: &shader_module, + entry_point: Some("gen_trace"), + cache: None, + compilation_options: Default::default(), + }); + + // Create encoder + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Gen Trace Command Encoder"), + }); + + // === GPU FFT Timing Start === + let gpu_start = Instant::now(); + + // Dispatch the compute shader + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Gen Trace Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(&compute_pipeline); + compute_pass.set_bind_group(0, &bind_group, &[]); + + // Workgroup size defined in shader + let workgroup_size = 256u32; + + compute_pass.dispatch_workgroups(workgroup_size, 1, 1); + } + + // Copy output to staging buffer for read access + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Staging Buffer"), + size: (N_STATE as usize * std::mem::size_of::()) as wgpu::BufferAddress, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size()); + + // create storage buffer for debug data + let debug_staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Debug Staging Buffer"), + size: (std::mem::size_of::()) as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + encoder.copy_buffer_to_buffer( + &debug_buffer, + 0, + &debug_staging_buffer, + 0, + debug_staging_buffer.size(), + ); + + // Submit the commands + queue.submit(Some(encoder.finish())); + + { + let buffer_slice = debug_staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + + if let Ok(Ok(())) = receiver.recv_async().await { + let data = buffer_slice.get_mapped_range(); + let result = *DebugData::from_bytes(&data); + drop(data); + debug_staging_buffer.unmap(); + + println!("Debug data: {:?}", result); + } + } + + // Wait for the GPU to finish and map the staging buffer + let buffer_slice = staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + device.poll(wgpu::Maintain::wait()).panic_on_timeout(); + + if let Ok(Ok(())) = receiver.recv_async().await { + let data = buffer_slice.get_mapped_range(); + let result = *GenTraceOutput::from_bytes(&data); + + drop(data); + staging_buffer.unmap(); + + println!("Output: {:?}", result); + } + + println!("Poseidon generate trace time: {:?}", gpu_start.elapsed()); +} diff --git a/crates/prover/src/core/backend/gpu/gen_trace.wgsl b/crates/prover/src/core/backend/gpu/gen_trace.wgsl new file mode 100644 index 000000000..a066fe03a --- /dev/null +++ b/crates/prover/src/core/backend/gpu/gen_trace.wgsl @@ -0,0 +1,279 @@ +const MODULUS_BITS: u32 = 31u; +const P: u32 = 2147483647u; + +// Define constants +const N_STATE: u32 = 16; +const N_INSTANCES_PER_ROW: u32 = 8; +const N_COLUMNS: u32 = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const N_HALF_FULL_ROUNDS: u32 = 4; +const FULL_ROUNDS: u32 = 2u * N_HALF_FULL_ROUNDS; +const N_PARTIAL_ROUNDS: u32 = 14; +const N_LANES: u32 = 16; +const N_COLUMNS_PER_REP: u32 = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const LOG_N_LANES: u32 = 4; + +// Initialize EXTERNAL_ROUND_CONSTS with explicit values +var EXTERNAL_ROUND_CONSTS: array, FULL_ROUNDS> = array, FULL_ROUNDS>( + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), +); + +// Initialize INTERNAL_ROUND_CONSTS with explicit values +var INTERNAL_ROUND_CONSTS: array = array( + 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234 +); + +// Create ColumnVec struct +struct ColumnVec { + data: array, N_INSTANCES_PER_ROW>, + length: u32, +} + +struct BaseColumn { + data: array, + length: u32, +} + +struct PackedM31 { + data: array, +} + +struct GenTraceInput { + log_size: u32, +} + +struct DebugData { + index: array, + values: array, + counter: atomic, +} + +struct LookupData { + initial_state: array, N_INSTANCES_PER_ROW>, + final_state: array, N_INSTANCES_PER_ROW>, +} + +struct GenTraceOutput { + data: array, + trace: array, + lookup_data: LookupData, +} + +@group(0) @binding(0) +var input: GenTraceInput; + +// Output buffer +@group(0) @binding(1) +var output: GenTraceOutput; + +@group(0) @binding(2) +var debug_buffer: DebugData; + +fn from_u32(value: u32) -> PackedM31 { + var packedM31 = PackedM31(); + for (var i = 0u; i < N_LANES; i++) { + packedM31.data[i] = value; + } + return packedM31; +} + +fn add(a: PackedM31, b: PackedM31) -> PackedM31 { + var packedM31 = PackedM31(); + for (var i = 0u; i < N_LANES; i++) { + packedM31.data[i] = partial_reduce(a.data[i] + b.data[i]); + } + return packedM31; +} + +fn mul(a: PackedM31, b: PackedM31) -> PackedM31 { + var packedM31 = PackedM31(); + for (var i = 0u; i < N_LANES; i++) { + var temp: u64 = u64(a.data[i]); + temp = temp * u64(b.data[i]); + packedM31.data[i] = full_reduce(temp); + } + return packedM31; +} + +// Partial reduce for values in [0, 2P) +fn partial_reduce(val: u32) -> u32 { + let reduced = val - P; + return select(val, reduced, reduced < val); +} + +fn full_reduce(val: u64) -> u32 { + let first_shift = val >> MODULUS_BITS; + let first_sum = first_shift + val + 1; + let second_shift = first_sum >> MODULUS_BITS; + let final_sum = second_shift + val; + return u32(final_sum & u64(P)); +} + +// Function to apply pow5 operation +fn pow5(x: PackedM31) -> PackedM31 { + return mul(mul(mul(x, x), mul(x, x)), x); +} + +/// Applies the external round matrix. +/// See 5.1 and Appendix B. +fn apply_external_round_matrix(state: array) -> array { + // Applies circ(2M4, M4, M4, M4). + var modified_state = state; + for (var i = 0u; i < 4u; i++) { + let partial_state = array( + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ); + let modified_partial_state = apply_m4(partial_state); + modified_state[4 * i] = modified_partial_state[0]; + modified_state[4 * i + 1] = modified_partial_state[1]; + modified_state[4 * i + 2] = modified_partial_state[2]; + modified_state[4 * i + 3] = modified_partial_state[3]; + } + for (var j = 0u; j < 4u; j++) { + let s = add(add(modified_state[j], modified_state[j + 4]), add(modified_state[j + 8], modified_state[j + 12])); + for (var i = 0u; i < 4u; i++) { + modified_state[4 * i + j] = add(modified_state[4 * i + j], s); + } + } + return modified_state; +} + +// Applies the internal round matrix. +// mu_i = 2^{i+1} + 1. +// See 5.2. +fn apply_internal_round_matrix(state: array) -> array { + var sum = state[0]; + for (var i = 1u; i < N_STATE; i++) { + sum = add(sum, state[i]); + } + + var result = array(); + for (var i = 0u; i < N_STATE; i++) { + let factor = partial_reduce(1u << (i + 1)); + result[i] = add(mul(from_u32(factor), state[i]), sum); + } + + return result; +} + +/// Applies the M4 MDS matrix described in 5.1. +fn apply_m4(x: array) -> array { + let t0 = add(x[0], x[1]); + let t02 = add(t0, t0); + let t1 = add(x[2], x[3]); + let t12 = add(t1, t1); + let t2 = add(add(x[1], x[1]), t1); + let t3 = add(add(x[3], x[3]), t0); + let t4 = add(add(t12, t12), t3); + let t5 = add(add(t02, t02), t2); + let t6 = add(t3, t5); + let t7 = add(t2, t4); + return array(t6, t5, t7, t4); +} + +fn store_debug_value(index: u32, value: u32) { + let debug_idx = atomicAdd(&debug_buffer.counter, 1u); + debug_buffer.index[debug_idx] = index; + debug_buffer.values[debug_idx] = value; +} + +@compute @workgroup_size(256) +fn gen_trace(@builtin(global_invocation_id) GlobalInvocationID: vec3) { + if (GlobalInvocationID.x != 0u) { + return; + } + + let log_size = input.log_size; + + if (log_size < LOG_N_LANES) { + return; + } + + for (var vec_index = 0u; vec_index < (1u << (log_size - LOG_N_LANES)); vec_index++) { + var col_index = 0u; + + for (var rep_i = 0u; rep_i < N_INSTANCES_PER_ROW; rep_i++) { + var state: array = initialize_state(vec_index, rep_i); + + for (var i = 0u; i < N_STATE; i++) { + output.trace[col_index].data[vec_index] = state[i]; + col_index += 1u; + } + + for (var i = 0u; i < N_STATE; i++) { + output.lookup_data.initial_state[rep_i][i].data[vec_index] = state[i]; + } + + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state[j] = add(state[j], from_u32(EXTERNAL_ROUND_CONSTS[i][j])); + } + state = apply_external_round_matrix(state); + for (var j = 0u; j < N_STATE; j++) { + state[j] = pow5(state[j]); + } + for (var j = 0u; j < N_STATE; j++) { + output.trace[col_index].data[vec_index] = state[j]; + col_index += 1u; + } + } + // Partial rounds + for (var i = 0u; i < N_PARTIAL_ROUNDS; i++) { + state[0] = add(state[0], from_u32(INTERNAL_ROUND_CONSTS[i])); + state = apply_internal_round_matrix(state); + state[0] = pow5(state[0]); + output.trace[col_index].data[vec_index] = state[0]; + col_index += 1u; + } + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state[j] = add(state[j], from_u32(EXTERNAL_ROUND_CONSTS[i + N_HALF_FULL_ROUNDS][j])); + } + state = apply_external_round_matrix(state); + for (var j = 0u; j < N_STATE; j++) { + state[j] = pow5(state[j]); + } + for (var j = 0u; j < N_STATE; j++) { + output.trace[col_index].data[vec_index] = state[j]; + col_index += 1u; + } + } + + for (var j = 0u; j < N_STATE; j++) { + output.lookup_data.final_state[rep_i][j].data[vec_index] = state[j]; + } + } + } +} + +// Function to initialize the state array +fn initialize_state(vec_index: u32, rep_i: u32) -> array { + var state: array; + + for (var state_i = 0u; state_i < N_STATE; state_i++) { + // Initialize each element of the state array + var packed_value = PackedM31(); + + for (var i = 0u; i < N_LANES; i++) { + // Calculate the value based on vec_index, state_i, and rep_i + let value: u32 = vec_index * 16u + i + state_i + rep_i; + // Here, you would typically pack this value into a PackedBaseField equivalent + // For simplicity, we'll just assign it directly + packed_value.data[i] = value; // Replace with actual packing logic if needed + } + state[state_i] = packed_value; + } + + return state; +} diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs new file mode 100644 index 000000000..7a0c69372 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -0,0 +1 @@ +pub mod gen_trace; diff --git a/crates/prover/src/core/backend/mod.rs b/crates/prover/src/core/backend/mod.rs index 288ecc123..f8e84738c 100644 --- a/crates/prover/src/core/backend/mod.rs +++ b/crates/prover/src/core/backend/mod.rs @@ -15,6 +15,7 @@ use super::proof_of_work::GrindOps; use super::vcs::ops::MerkleOps; pub mod cpu; +pub mod gpu; pub mod simd; pub trait Backend: diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 51b671580..76184621c 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -483,6 +483,13 @@ mod tests { ); } + #[test] + fn test_gpu_poseidon_constraints() { + use crate::core::backend::gpu::gen_trace::gen_trace as gen_trace_gpu; + + pollster::block_on(gen_trace_gpu()); + } + #[test_log::test] fn test_simd_poseidon_prove() { // Note: To see time measurement, run test with