diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 00000000..42465404 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,1280 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.10.4" +manifest_format = "2.0" +project_hash = "855a6b353814d602d218bf0de92bfa65f664d3c0" + +[[deps.AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] + + [deps.AbstractFFTs.extensions] + AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "4.0.4" +weakdeps = ["StaticArrays"] + + [deps.Adapt.extensions] + AdaptStaticArraysExt = "StaticArrays" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + +[[deps.BFloat16s]] +deps = ["LinearAlgebra", "Printf", "Random", "Test"] +git-tree-sha1 = "2c7cc21e8678eff479978a0a2ef5ce2f51b63dff" +uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" +version = "0.5.0" + +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] +git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.40" + + [deps.BangBang.extensions] + BangBangChainRulesCoreExt = "ChainRulesCore" + BangBangDataFramesExt = "DataFrames" + BangBangStaticArraysExt = "StaticArrays" + BangBangStructArraysExt = "StructArrays" + BangBangTypedTablesExt = "TypedTables" + + [deps.BangBang.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "2dc09997850d68179b69dafb58ae806167a32b1b" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.8" + +[[deps.BytePairEncoding]] +deps = ["Artifacts", "Base64", "DataStructures", "DoubleArrayTries", "LazyArtifacts", "StructWalk", "TextEncodeBase", "Unicode"] +git-tree-sha1 = "295253961b9bcb1020bfd8711c7b51311dbfa102" +uuid = "a4280ba5-8788-555a-8ca8-4a8c3d966a71" +version = "0.4.1" + +[[deps.CEnum]] +git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" +uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" +version = "0.5.0" + +[[deps.CUDA]] +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "Crayons", "DataFrames", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LLVMLoopInfo", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "NVTX", "Preferences", "PrettyTables", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "StaticArrays", "Statistics"] +git-tree-sha1 = "6e945e876652f2003e6ca74e19a3c45017d3e9f6" +uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" +version = "5.4.2" + + [deps.CUDA.extensions] + ChainRulesCoreExt = "ChainRulesCore" + EnzymeCoreExt = "EnzymeCore" + SpecialFunctionsExt = "SpecialFunctions" + + [deps.CUDA.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "c48f9da18efd43b6b7adb7ee1f93fe5f2926c339" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.9.0+0" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "5db9da5fdeaa708c22ba86b82c49528f402497f2" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.3.3" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "bcba305388e16aa5c879e896726db9e71b4942c6" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.14.0+1" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "cbf7d75f8c58b147bdf6acea2e5bc96cececa6d4" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "9.0.0+1" + +[[deps.ChainRules]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] +git-tree-sha1 = "450ba466228b4dd2f620dc50d36cef7eb23ff8a7" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "1.67.0" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra"] +git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.23.0" +weakdeps = ["SparseArrays"] + + [deps.ChainRulesCore.extensions] + ChainRulesCoreSparseArraysExt = "SparseArrays" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.4" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.5" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.11" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["TOML", "UUIDs"] +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.15.0" +weakdeps = ["Dates", "LinearAlgebra"] + + [deps.Compat.extensions] + CompatLinearAlgebraExt = "LinearAlgebra" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.1.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" + + [deps.CompositionsBase.extensions] + CompositionsBaseInverseFunctionsExt = "InverseFunctions" + + [deps.CompositionsBase.weakdeps] + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "6cbbd4d241d7e6579ab354737f4dd95ca43946e1" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.4.1" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.5" + + [deps.ConstructionBase.extensions] + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" + + [deps.ConstructionBase.weakdeps] + IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[[deps.ContextVariablesX]] +deps = ["Compat", "Logging", "UUIDs"] +git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" +uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" +version = "0.1.3" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.16.0" + +[[deps.DataDeps]] +deps = ["HTTP", "Libdl", "Reexport", "SHA", "Scratch", "p7zip_jll"] +git-tree-sha1 = "8ae085b71c462c2cb1cfedcb10c3c877ec6cf03f" +uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +version = "0.7.13" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.20" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" +version = "1.9.1" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.DoubleArrayTries]] +deps = ["OffsetArrays", "Preferences", "StringViews"] +git-tree-sha1 = "9667af23bda5ce51bad3dd759812c398a58d8b9d" +uuid = "abbaa0e5-f788-499c-92af-c35ff4258c82" +version = "0.1.0" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "dcb08a0d93ec0b1cdc4af184b26b591e9695423a" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.10" + +[[deps.ExprTools]] +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" +uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +version = "0.1.10" + +[[deps.FLoops]] +deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] +git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +uuid = "cc61a311-1640-44b5-9fba-1b764f453329" +version = "0.2.1" + +[[deps.FLoopsBase]] +deps = ["ContextVariablesX"] +git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" +uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" +version = "0.1.1" + +[[deps.Fetch]] +deps = ["Base64", "HTTP", "JSON3", "Random", "StructTypes", "p7zip_jll"] +git-tree-sha1 = "781292162fd5bfe8d001210f9dddbb6baa509bf4" +uuid = "bb354801-46f6-40b6-9c3d-d42d7a74c775" +version = "0.1.4" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "0653c0a2396a6da5bc4766c43041ef5fd3efbe57" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.11.0" + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" + + [deps.FillArrays.weakdeps] + PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.5" + +[[deps.Flux]] +deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "a5475163b611812d073171583982c42ea48d22b0" +uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" +version = "0.14.15" + + [deps.Flux.extensions] + FluxAMDGPUExt = "AMDGPU" + FluxCUDAExt = "CUDA" + FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] + FluxMetalExt = "Metal" + + [deps.Flux.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" +weakdeps = ["StaticArrays"] + + [deps.ForwardDiff.extensions] + ForwardDiffStaticArraysExt = "StaticArrays" + +[[deps.FuncPipelines]] +git-tree-sha1 = "6484a27c35ecc680948c7dc7435c97f12c2bfaf7" +uuid = "9ed96fbb-10b6-44d4-99a6-7e2a3dc8861b" +version = "0.2.3" + +[[deps.Functors]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "d3e63d9fa13f8eaa2f06f64949e2afc593ff52c2" +uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +version = "0.4.10" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.GPUArrays]] +deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "38cb19b8a3e600e509dc36a6396ac74266d108c1" +uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +version = "10.1.1" + +[[deps.GPUArraysCore]] +deps = ["Adapt"] +git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +uuid = "46192b85-c4d5-4398-a991-12ede77f4527" +version = "0.1.6" + +[[deps.GPUCompiler]] +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "518ebd058c9895de468a8c255797b0c53fdb44dd" +uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" +version = "0.26.5" + +[[deps.HTML_Entities]] +deps = ["StrTables"] +git-tree-sha1 = "c4144ed3bc5f67f595622ad03c0e39fa6c70ccc7" +uuid = "7693890a-d069-55fe-a829-b4a6d304f0ee" +version = "1.0.1" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "d1d712be3164d61d1fb98e7ce9bcbc6cc06b45ed" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.8" + +[[deps.HuggingFaceApi]] +deps = ["Dates", "Downloads", "JSON3", "LibGit2", "OhMyArtifacts", "Pkg", "SHA"] +git-tree-sha1 = "bcf9b0ee12839d9bbee389ec13cd926845a2d39f" +uuid = "3cc741c3-0c9d-4fbe-84fa-cdec264173de" +version = "0.1.0" + +[[deps.IRTools]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.14" + +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +deps = ["Parsers"] +git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InternedStrings]] +deps = ["Random", "Test"] +git-tree-sha1 = "eb05b5625bc5d821b8075a77e4c421933e20c76b" +uuid = "7d512f48-7fb1-5a58-b986-67e6dc259f01" +version = "0.7.0" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON3]] +deps = ["Dates", "Mmap", "Parsers", "PrecompileTools", "StructTypes", "UUIDs"] +git-tree-sha1 = "eb3edce0ed4fa32f75a0a11217433c31d56bd48b" +uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" +version = "1.14.0" + + [deps.JSON3.extensions] + JSON3ArrowExt = ["ArrowTypes"] + + [deps.JSON3.weakdeps] + ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" + +[[deps.JuliaNVTXCallbacks_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "af433a10f3942e882d3c671aacb203e006a5808f" +uuid = "9c1d0b0a-7046-5b2e-a33f-ea22f176ac7e" +version = "0.2.1+0" + +[[deps.JuliaVariables]] +deps = ["MLStyle", "NameResolution"] +git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" +uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" +version = "0.2.4" + +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "db02395e4c374030c53dc28f3c1d33dec35f7272" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.19" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[[deps.LLVM]] +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] +git-tree-sha1 = "065c36f95709dd4a676dc6839a35d6fa6f192f24" +uuid = "929cbde3-209d-540e-8aea-75f648917ca0" +version = "7.1.0" +weakdeps = ["BFloat16s"] + + [deps.LLVM.extensions] + BFloat16sExt = "BFloat16s" + +[[deps.LLVMExtra_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" +version = "0.0.29+0" + +[[deps.LLVMLoopInfo]] +git-tree-sha1 = "2e5c102cfc41f48ae4740c7eca7743cc7e7b75ea" +uuid = "8b046642-f1f6-4319-8d3c-209ddc03c586" +version = "1.0.0" + +[[deps.LRUCache]] +git-tree-sha1 = "b3cc6698599b10e652832c2f23db3cab99d51b59" +uuid = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" +version = "1.6.1" +weakdeps = ["Serialization"] + + [deps.LRUCache.extensions] + SerializationExt = ["Serialization"] + +[[deps.LaTeXStrings]] +git-tree-sha1 = "50901ebc375ed41dbf8058da26f9de442febbbec" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.1" + +[[deps.LazyArtifacts]] +deps = ["Artifacts", "Pkg"] +uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.4" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "8.4.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibGit2_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] +uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" +version = "1.6.4+0" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.11.0+1" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.Libiconv_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" +version = "1.17.0+0" + +[[deps.LightXML]] +deps = ["Libdl", "XML2_jll"] +git-tree-sha1 = "3a994404d3f6709610701c7dabfc03fed87a81f8" +uuid = "9c8b4983-aa76-5018-a973-4c85ecc9e179" +version = "0.9.1" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LogExpFunctions]] +deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.27" + + [deps.LogExpFunctions.extensions] + LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" + LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" + LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.3" + +[[deps.MLStyle]] +git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" +uuid = "d8e11817-5142-5d16-987a-aa16d5891078" +version = "0.4.17" + +[[deps.MLUtils]] +deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] +git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" +uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" +version = "0.4.4" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.13" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "NetworkOptions", "Random", "Sockets"] +git-tree-sha1 = "c067a280ddc25f196b5e7df3877c6b226d390aaf" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.9" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.2+1" + +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.4" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.2.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2023.1.10" + +[[deps.NNlib]] +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "3d4617f943afe6410206a5294a95948c8d1b35bd" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.9.17" + + [deps.NNlib.extensions] + NNlibAMDGPUExt = "AMDGPU" + NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] + NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" + + [deps.NNlib.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + +[[deps.NVTX]] +deps = ["Colors", "JuliaNVTXCallbacks_jll", "Libdl", "NVTX_jll"] +git-tree-sha1 = "53046f0483375e3ed78e49190f1154fa0a4083a1" +uuid = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +version = "0.3.4" + +[[deps.NVTX_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "ce3269ed42816bf18d500c9f63418d4b0d9f5a3b" +uuid = "e98f9f5b-d649-5603-91fd-7774390e6439" +version = "3.1.0+2" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NameResolution]] +deps = ["PrettyPrint"] +git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" +uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" +version = "0.1.5" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.NeuralAttentionlib]] +deps = ["Adapt", "CUDA", "ChainRulesCore", "GPUArrays", "GPUArraysCore", "LinearAlgebra", "NNlib", "Requires", "Static", "cuDNN"] +git-tree-sha1 = "2fa97194916782f95683cb3a6c3b2cb841a07325" +repo-rev = "master" +repo-url = "https://github.com/jarbus/NeuralAttentionlib.jl" +uuid = "12afc1b8-fad6-47e1-9132-84abc478905f" +version = "0.2.12" + +[[deps.OffsetArrays]] +git-tree-sha1 = "e64b4f5ea6b7389f6f046d13d4896a8f9c1ba71e" +uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" +version = "1.14.0" +weakdeps = ["Adapt"] + + [deps.OffsetArrays.extensions] + OffsetArraysAdaptExt = "Adapt" + +[[deps.OhMyArtifacts]] +deps = ["Dates", "Downloads", "Pidfile", "Pkg", "Printf", "SHA", "Scratch", "TOML"] +git-tree-sha1 = "1ae208c3919548b9e7e6783ba294289cd204b4cb" +uuid = "cf8be1f4-309d-442e-839d-29d2a0af6cb7" +version = "0.3.1" + +[[deps.OneHotArrays]] +deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] +git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +version = "0.2.5" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.23+4" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+2" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "38cb508d080d21dc1128f7fb04f20387ed4c0af4" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.3" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.13+1" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optimisers]] +deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" +version = "0.3.3" + +[[deps.OrderedCollections]] +git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.3" + +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.8.1" + +[[deps.PartialFunctions]] +deps = ["MacroTools"] +git-tree-sha1 = "47b49a4dbc23b76682205c646252c0f9e1eb75af" +uuid = "570af359-4316-4cb7-8c74-252c00c2016b" +version = "1.2.0" + +[[deps.Pickle]] +deps = ["BFloat16s", "DataStructures", "InternedStrings", "Mmap", "Serialization", "SparseArrays", "StridedViews", "StringEncodings", "ZipFile"] +git-tree-sha1 = "e99da19b86b7e1547b423fc1721b260cfbe83acb" +uuid = "fbb45041-c46e-462f-888f-7c521cafbc2c" +version = "0.3.5" + +[[deps.Pidfile]] +deps = ["FileWatching", "Test"] +git-tree-sha1 = "2d8aaf8ee10df53d0dfb9b8ee44ae7c04ced2b03" +uuid = "fa939f87-e72e-5be4-a000-7fc836dbe307" +version = "1.3.0" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.10.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.1" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.3" + +[[deps.PrettyPrint]] +git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" +uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" +version = "0.2.0" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "PrecompileTools", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "66b20dd35966a748321d3b2537c4584cf40387c7" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.3.2" + +[[deps.PrimitiveOneHot]] +deps = ["Adapt", "ChainRulesCore", "NNlib", "Requires"] +git-tree-sha1 = "679f66ad280909b4ab590e382b4da6eb45c5955f" +uuid = "13d12f88-f12b-451e-9b9f-13b97e01cc85" +version = "0.1.4" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressLogging]] +deps = ["Logging", "SHA", "UUIDs"] +git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" +uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" +version = "0.1.4" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.Random123]] +deps = ["Random", "RandomNumbers"] +git-tree-sha1 = "4743b43e5a9c4a2ede372de7061eed81795b12e7" +uuid = "74087812-796a-5b5d-8853-05524746bad3" +version = "1.7.0" + +[[deps.RandomNumbers]] +deps = ["Random", "Requires"] +git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" +uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" +version = "1.5.3" + +[[deps.RealDot]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" +uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" +version = "0.1.0" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.RustRegex]] +deps = ["rure_jll"] +git-tree-sha1 = "16be5e710d7b980678ec0d8c61d4c00e9a5591e3" +uuid = "cdf36688-0c6d-42c6-a883-5d2df16e9e88" +version = "0.1.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "3bac05bc7e74a75fd9cba4295cde4045d9fe2386" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.1" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "90b4f68892337554d31cdcdbe19e48989f26c7e6" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.3" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.2.1" + +[[deps.SparseArrays]] +deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +version = "1.10.0" + +[[deps.SparseInverseSubset]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" +uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" +version = "0.1.2" + +[[deps.SpecialFunctions]] +deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.4.0" +weakdeps = ["ChainRulesCore"] + + [deps.SpecialFunctions.extensions] + SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.8.10" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] +git-tree-sha1 = "9ae599cd7529cfce7fea36cf00a62cfc56f0f37c" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.9.4" +weakdeps = ["ChainRulesCore", "Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysChainRulesCoreExt = "ChainRulesCore" + StaticArraysStatisticsExt = "Statistics" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +version = "1.10.0" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + +[[deps.StrTables]] +deps = ["Dates"] +git-tree-sha1 = "5998faae8c6308acc25c25896562a1e66a3bb038" +uuid = "9700d1a9-a7c8-5760-9816-a99fda30bb8f" +version = "1.0.1" + +[[deps.StridedViews]] +deps = ["LinearAlgebra", "PackageExtensionCompat"] +git-tree-sha1 = "5b765c4e401693ab08981989f74a36a010aa1d8e" +uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" +version = "0.2.2" +weakdeps = ["CUDA"] + + [deps.StridedViews.extensions] + StridedViewsCUDAExt = "CUDA" + +[[deps.StringEncodings]] +deps = ["Libiconv_jll"] +git-tree-sha1 = "b765e46ba27ecf6b44faf70df40c57aa3a547dcb" +uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" +version = "0.3.7" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.StringViews]] +git-tree-sha1 = "f7b06677eae2571c888fd686ba88047d8738b0e3" +uuid = "354b36f9-a18e-4713-926e-db85100087ba" +version = "1.3.3" + +[[deps.StructArrays]] +deps = ["ConstructionBase", "DataAPI", "Tables"] +git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +version = "0.6.18" +weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] + + [deps.StructArrays.extensions] + StructArraysAdaptExt = "Adapt" + StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysSparseArraysExt = "SparseArrays" + StructArraysStaticArraysExt = "StaticArrays" + +[[deps.StructTypes]] +deps = ["Dates", "UUIDs"] +git-tree-sha1 = "ca4bccb03acf9faaf4137a9abc1881ed1841aa70" +uuid = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" +version = "1.10.0" + +[[deps.StructWalk]] +deps = ["ConstructionBase"] +git-tree-sha1 = "ef626534f40a9d99b3dafdbd54cfe411ad86e3b8" +uuid = "31cdf514-beb7-4750-89db-dda9d2eb8d3d" +version = "0.2.1" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.SuiteSparse_jll]] +deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] +uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" +version = "7.2.1+1" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.3" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.11.1" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TextEncodeBase]] +deps = ["FuncPipelines", "PartialFunctions", "PrimitiveOneHot", "RustRegex", "StaticArrays", "StructWalk", "Unicode", "WordTokenizers"] +git-tree-sha1 = "4753ea70646cb276a4db65952e59103fbc2b0576" +uuid = "f92c20c0-9f2a-4705-8116-881385faba05" +version = "0.7.0" + +[[deps.TimerOutputs]] +deps = ["ExprTools", "Printf"] +git-tree-sha1 = "5a13ae8a41237cff5ecf34f73eb1b8f42fff6531" +uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" +version = "0.5.24" + +[[deps.TranscodingStreams]] +git-tree-sha1 = "a947ea21087caba0a798c5e494d0bb78e3a1a3a0" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.10.9" +weakdeps = ["Random", "Test"] + + [deps.TranscodingStreams.extensions] + TestExt = ["Test", "Random"] + +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.80" + + [deps.Transducers.extensions] + TransducersBlockArraysExt = "BlockArrays" + TransducersDataFramesExt = "DataFrames" + TransducersLazyArraysExt = "LazyArrays" + TransducersOnlineStatsBaseExt = "OnlineStatsBase" + TransducersReferenceablesExt = "Referenceables" + + [deps.Transducers.weakdeps] + BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" + DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" + OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" + Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" + +[[deps.Tricks]] +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.8" + +[[deps.URIs]] +git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "d9f5962fecd5ccece07db1ff006fb0b5271bdfdd" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.4" + +[[deps.ValSplit]] +deps = ["ExprTools", "Tricks"] +git-tree-sha1 = "3e1d94627f9276c40034c80dc5ab29ac1a3b06c0" +uuid = "0625e100-946b-11ec-09cd-6328dd093154" +version = "0.1.1" + +[[deps.WordTokenizers]] +deps = ["DataDeps", "HTML_Entities", "StrTables", "Unicode"] +git-tree-sha1 = "01dd4068c638da2431269f49a5964bf42ff6c9d2" +uuid = "796a5d58-b03d-544a-977e-18100b691f6e" +version = "0.5.6" + +[[deps.XML2_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] +git-tree-sha1 = "52ff2af32e591541550bd753c0da8b9bc92bb9d9" +uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" +version = "2.12.7+0" + +[[deps.ZipFile]] +deps = ["Libdl", "Printf", "Zlib_jll"] +git-tree-sha1 = "f492b7fe1698e623024e873244f10d89c95c340a" +uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" +version = "0.10.1" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.13+1" + +[[deps.Zygote]] +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" +version = "0.6.70" + + [deps.Zygote.extensions] + ZygoteColorsExt = "Colors" + ZygoteDistancesExt = "Distances" + ZygoteTrackerExt = "Tracker" + + [deps.Zygote.weakdeps] + Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" + Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[[deps.ZygoteRules]] +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.5" + +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDA_Runtime_Discovery", "CUDNN_jll"] +git-tree-sha1 = "4909e87d6d62c29a897d54d9001c63932e41cb0e" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.3.2" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.8.0+1" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.52.0+1" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+2" + +[[deps.rure_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "a24449573502225e7833277f99a8e2c19801f5a7" +uuid = "2a13b4fb-3cbe-5d55-9db2-86fcb16976f1" +version = "0.2.2+0" diff --git a/Project.toml b/Project.toml index 673125c7..07554030 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,6 @@ LightXML = "9c8b4983-aa76-5018-a973-4c85ecc9e179" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f" Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -42,6 +41,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" ValSplit = "0625e100-946b-11ec-09cd-6328dd093154" WordTokenizers = "796a5d58-b03d-544a-977e-18100b691f6e" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [compat] BytePairEncoding = "0.4" @@ -52,7 +52,7 @@ DataStructures = "0.18" DoubleArrayTries = "0.1" Fetch = "0.1.3" FillArrays = "0.13, 1" -Flux = "0.13.4" +Flux = "0.13, 0.14" FuncPipelines = "0.2.3" Functors = "0.2, 0.3, 0.4" HTTP = "0.9, 1" @@ -60,8 +60,7 @@ HuggingFaceApi = "0.1" JSON3 = "1.12" LRUCache = "1.5" LightXML = "0.9" -NNlib = "0.8" -NNlibCUDA = "0.2" +NNlib = "0.8, 0.9" NeuralAttentionlib = "0.2.12" Pickle = "0.3" PrimitiveOneHot = "0.1" diff --git a/README.md b/README.md index 13940354..a8966c4b 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +# THIS FORK IS CUSTOMIZED FOR DECODER_ONLY MODELS. +
Transformers.jl
[![Build status](https://github.com/chengchingwen/Transformers.jl/workflows/CI/badge.svg)](https://github.com/chengchingwen/Transformers.jl/actions) diff --git a/src/cuda.jl b/src/cuda.jl index a8f6cdcc..75728e84 100644 --- a/src/cuda.jl +++ b/src/cuda.jl @@ -4,7 +4,7 @@ using NeuralAttentionlib function _togpudevice(x, cache) # https://github.com/FluxML/Flux.jl/blob/79971741ed8454cdf6a66515799a0c4b864f564a/src/functor.jl#L206-L209 - Flux.check_use_cuda() + # Flux.check_use_cuda() return Flux.fmap( x -> Flux.adapt(Flux.FluxCUDAAdaptor(), x), x; exclude = Flux._isleaf, cache) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 75e8ab56..8f9dd110 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -7,7 +7,8 @@ export Seq2Seq, Transformer, TransformerBlock, TransformerDecoderBlock, PreNormTransformerBlock, PostNormTransformerBlock, PreNormTransformerDecoderBlock, PostNormTransformerDecoderBlock, - Embed, EmbedDecoder, FixedLenPositionEmbed, SinCosPositionEmbed + Embed, EmbedDecoder, FixedLenPositionEmbed, SinCosPositionEmbed, + RotaryPositionEmbed include("./utils.jl") include("./architecture.jl") @@ -15,6 +16,7 @@ include("./base.jl") include("./embed.jl") include("./layer.jl") include("./attention_op.jl") +include("./causal_flash_op.jl") include("./structwalk.jl") include("./testmode.jl") diff --git a/src/layers/causal_flash_op.jl b/src/layers/causal_flash_op.jl new file mode 100644 index 00000000..39f1b626 --- /dev/null +++ b/src/layers/causal_flash_op.jl @@ -0,0 +1,201 @@ +using CUDA +@inline function compute_shmem_size(d, Bs) + return (Bs * d * 3 + 4 * d + Bs * Bs) * sizeof(Float32) +end + +""" + setMaxShmem(shmem) + +Set the maximum shared memory size for the current device to `shmem` KB. +""" +function setMaxShmem(shmem) + kernel = cufunction(flash_attention_kernel, NTuple{4, CuDeviceArray{Float16, 4, 1}}) + return CUDA.cuFuncSetAttribute(kernel.fun, + CUDA.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shmem * 1024) +end + +function _checkbounds(Q, K, V) + sQ, sK, sV = size(Q), size(K), size(V) + sK != sV && throw(DimensionMismatch("K and V must have the same shape")) + sQ[3:4] != sK[3:4] != sV[3:4] && + throw(DimensionMismatch("Q, K and V must have the same batch size and head size")) + return sQ[1] != sK[2] != sV[2] && + throw(DimensionMismatch("Q, K and V must have the same hidden dimension")) +end + +@inline function mod1_pow2(x, y) + r = x & (y - 1) + return ifelse(r == 0, y, r) +end + + +function causal_flash_attention_kernel(Q, K, V, O) + d = size(K, 1) + power = trailing_zeros(d) + tx = threadIdx().x + Bs = blockDim().x # assume Br == Bc + col = (blockIdx().x - 1) * Bs + tx + # skip computation if col < row + + # acllocate shared memory + T = eltype(Q) + shmem_offset = 0 + q = CuDynamicSharedArray(T, (Bs + 2, d), shmem_offset) # pad 2 rows to avoid bank conflicts + shmem_offset += sizeof(q) + o = CuDynamicSharedArray(T, (Bs + 2, d), shmem_offset) # pad 2 row to avoid bank conflicts + shmem_offset += sizeof(o) + k = CuDynamicSharedArray(T, (d, Bs), shmem_offset) # pad 2 rows to avoid bank conflicts + shmem_offset += sizeof(k) + s = CuDynamicSharedArray(T, (Bs, Bs), shmem_offset) + + # load Q to shared memory, note that this is done only once + Q_offset = d * Bs * (blockIdx().x - 1) + + stride(Q, 3) * (blockIdx().y - 1) + + stride(Q, 4) * (blockIdx().z - 1) + K_offset = stride(K, 3) * (blockIdx().y - 1) + stride(K, 4) * (blockIdx().z - 1) + + for i in 0:(d - 1) + idx = i * Bs + tx + row = mod1_pow2(idx, d) + col = (idx - row) >> power + 1 + @inbounds q[col, row] = Q[idx + Q_offset] + @inbounds o[idx] = zero(T) + @inbounds k[idx] = K[idx + K_offset] + end + + sync_threads() + + # initialize lseᵢ and mᵢ + lseᵢ = -typemax(T) + mᵢ = -typemax(T) + + # the inner loop is serial + for _ in 1:cld(size(K, 2), Bs) # iterate over Bs elements in sequence + # initialize mᵢⱼ + mᵢⱼ = lseᵢ + + # compute s=Q^TK + # s = (Bs, Bs) + #inf_block = true + for n in 1:Bs + if Q_offset + tx < K_offset + n + s[tx, n] = -Inf + continue + end + #inf_block = false + + tmp = zero(T) + for m in 1:d + @inbounds tmp = CUDA.fma(q[tx, m], k[m, n], tmp) + end + s[tx, n] = tmp + @inbounds mᵢⱼ = max(mᵢⱼ, s[tx, n]) + end + #inf_block && return nothing + + sync_threads() + + # compute P̃ᵢⱼ and lᵢⱼ + lᵢⱼ = zero(T) + for n in 1:Bs + @inbounds tmp = exp(s[tx, n] - mᵢⱼ) + @inbounds s[tx, n] = tmp + lᵢⱼ += tmp + end + + # Load V to shared memory, which shares the same memory with k + for i in 0:(d - 1) + idx = i * Bs + tx + row = mod1_pow2(idx, d) + col = (idx - row) >> power + 1 + @inbounds k[row, col] = V[idx + K_offset] + end + + sync_threads() + + # update o + for m in 1:d + tmp = o[tx, m] * exp(mᵢ - mᵢⱼ) + for n in 1:Bs + @inbounds tmp = CUDA.fma(s[tx, n], k[m, n], tmp) # k[m, n] * s[n, tx] + end + @inbounds o[tx, m] = tmp + end + + mᵢ = mᵢⱼ + lseᵢ = mᵢⱼ + log(exp(lseᵢ - mᵢⱼ) + lᵢⱼ) + + K_offset += Bs * d + + # update k + for i in 0:(d - 1) + idx = i * Bs + tx + @inbounds k[idx] = K[idx + K_offset] + end + sync_threads() + end + + for m in 1:d + @inbounds o[tx, m] = o[tx, m] * exp(mᵢ - lseᵢ) + end + sync_threads() + + # write to O + for i in 0:(d - 1) + idx = i * Bs + tx + row = mod1_pow2(idx, d) + col = (idx - row) >> power + 1 + @inbounds O[idx + Q_offset] = o[col, row] + end + + return nothing +end + +function causal_flash_attention(Q::CuArray{T, 4}, K::CuArray{T, 4}, V::CuArray{T, 4}) where {T} + _checkbounds(Q, K, V) + O = similar(Q) + kernel = @cuda launch=false causal_flash_attention_kernel(Q, K, V, O) + d, N, H, B = size(Q) + get_shmem = Base.Fix1(compute_shmem_size, d) + config = launch_configuration(kernel.fun; shmem=get_shmem, max_threads=256) + + Bs = min(N, config.threads) + threads = (Bs, 1, 1) + blocks = (cld(N, Bs), H, B) + shmem = get_shmem(Bs) + + kernel(Q, K, V, O; threads=threads, blocks=blocks, shmem=shmem) + return O +end + +function causal_flash_attention(n_heads::Int, Q, K, V) + @assert ndims(Q) == ndims(K) == ndims(V) == 3 "Q, K, and V should be of size (d*h, n, b)" + Q_fa, K_fa, V_fa = Transformers_to_Flash(n_heads, Q), Transformers_to_Flash(n_heads, K), Transformers_to_Flash(n_heads, V) + O_fa = causal_flash_attention(Q_fa, K_fa, V_fa) + Flash_to_Transformers(O_fa) +end +function Transformers_to_Flash(n_heads::Int, arr) + d = Int(size(arr, 1) / n_heads) + N, B = size(arr, 2), size(arr, 3) + arr_4d = reshape(arr, d, n_heads, N, B) + perm =(1, 3, 2, 4) + permutedims(arr_4d, perm) +end + +function Flash_to_Transformers(arr) + arr = permutedims(arr, (1,3,2,4)) + N, B = size(arr, 3), size(arr, 4) + (hidden_state=reshape(arr, :, N, B),) +end + + +struct CausalFlashMultiheadQKVAttenOp{F} <: AbstractAttenOp + head::Int + p::F +end +CausalFlashMultiheadQKVAttenOp(head) = CausalFlashMultiheadQKVAttenOp(head, nothing) +NeuralAttentionlib.get_attention_func(::CausalFlashMultiheadQKVAttenOp) = causal_flash_attention +NeuralAttentionlib.get_attention_func_args(op::CausalFlashMultiheadQKVAttenOp, q, k, v, mask = nothing) = (op.head, q, k, v) +argument_names(::CausalFlashMultiheadQKVAttenOp) = (:hidden_state, :attention_mask) +apply_on_namedtuple(op::CausalFlashMultiheadQKVAttenOp, nt::NamedTuple) = apply_attention_op(op, nt) diff --git a/src/layers/embed.jl b/src/layers/embed.jl index e3ab8a60..c193f57c 100644 --- a/src/layers/embed.jl +++ b/src/layers/embed.jl @@ -227,6 +227,10 @@ function Base.show(io::IO, embed::SinCosPositionEmbed) end @fluxlayershow SinCosPositionEmbed false +struct RotaryPositionEmbed <: AbstractEmbedding end +(embed::RotaryPositionEmbed)(x) = NeuralAttentionlib.with_rotary_position_embedding(x) +@fluxlayershow RotaryPositionEmbed false + """ ApplyEmbed([apply = .+,] embed) diff --git a/src/layers/layer.jl b/src/layers/layer.jl index 9bbbc2ce..4e90ce15 100644 --- a/src/layers/layer.jl +++ b/src/layers/layer.jl @@ -43,20 +43,24 @@ end (b::TransformerBlock)(nt::NamedTuple) = apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.attention, nt)) -struct TransformerDecoderBlock{A, C, F} <: AbstractTransformerBlock +struct TransformerDecoderBlock{A, F} <: AbstractTransformerBlock attention::A - crossattention::C feedforward::F end @functor TransformerDecoderBlock argument_names(b::TransformerDecoderBlock) = Base.merge_names( - Base.merge_names(argument_names(b.crossattention), argument_names(b.attention)), + argument_names(b.attention), argument_names(b.feedforward) ) +# performs attention on nt, returns the result as an NamedTuple +# then performs crossattention on the result, returns the result as an NamedTuple +# then performs feedforward on the result, returns the result as an NamedTuple +# (b::TransformerDecoderBlock)(nt::NamedTuple) = +# apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.crossattention, apply_on_namedtuple(b.attention, nt))) (b::TransformerDecoderBlock)(nt::NamedTuple) = - apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.crossattention, apply_on_namedtuple(b.attention, nt))) + apply_on_namedtuple(b.feedforward, apply_on_namedtuple(b.attention, nt)) struct Residual{L} <: LayerStruct layer::L @@ -97,10 +101,10 @@ end const PreNormTransformerBlock{A, LN1, F, LN2} = TransformerBlock{ PreNormResidual{A, LN1}, PreNormResidual{F, LN2}} const PostNormTransformerBlock{A, LN1, F, LN2} = TransformerBlock{PostNormResidual{A, LN1}, PostNormResidual{F, LN2}} -const PreNormTransformerDecoderBlock{A, LN1, C, LN2, F, LN3} = - TransformerDecoderBlock{ PreNormResidual{A, LN1}, PreNormResidual{C, LN2}, PreNormResidual{F, LN3}} -const PostNormTransformerDecoderBlock{A, LN1, C, LN2, F, LN3} = - TransformerDecoderBlock{PostNormResidual{A, LN1}, PostNormResidual{C, LN2}, PostNormResidual{F, LN3}} +const PreNormTransformerDecoderBlock{A, LN1, #=C, LN2,=# F, LN3} = + TransformerDecoderBlock{ PreNormResidual{A, LN1}, #=PreNormResidual{C, LN2},=# PreNormResidual{F, LN3}} +const PostNormTransformerDecoderBlock{A, LN1, #=C, LN2,=# F, LN3} = + TransformerDecoderBlock{PostNormResidual{A, LN1}, #=PostNormResidual{C, LN2},=# PostNormResidual{F, LN3}} function Base.show(io::IO, t::PreNormTransformerBlock) print(io, "PreNormTransformerBlock("); @@ -115,13 +119,13 @@ end function Base.show(io::IO, t::PreNormTransformerDecoderBlock) print(io, "PreNormTransformerDecoderBlock(") show(io, t.attention.layer); print(io, ", "); show(io, t.attention.norm); print(io, ", "); - show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", "); + # show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", "); show(io, t.feedforward.layer); print(io, ", "); show(io, t.feedforward.norm); print(io, ')') end function Base.show(io::IO, t::PostNormTransformerDecoderBlock) print(io, "PostNormTransformerDecoderBlock(") show(io, t.attention.layer); print(io, ", "); show(io, t.attention.norm); print(io, ", "); - show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", "); + #show(io, t.crossattention.layer); print(io, ", "); show(io, t.crossattention.norm); print(io, ", "); show(io, t.feedforward.layer); print(io, ", "); show(io, t.feedforward.norm); print(io, ')') end _show_name(t::PreNormTransformerBlock) = "PreNormTransformerBlock" @@ -131,8 +135,8 @@ _show_name(t::PostNormTransformerDecoderBlock) = "PostNormTransformerDecoderBloc Flux._show_children(t::PreNormTransformerBlock) = (t.attention.layer, t.attention.norm, t.feedforward.layer, t.feedforward.norm) Flux._show_children(t::PostNormTransformerBlock) = (t.attention.layer, t.attention.norm, t.feedforward.layer, t.feedforward.norm) -Flux._show_children(t::PreNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, t.crossattention.layer, t.crossattention.norm, t.feedforward.layer, t.feedforward.norm) -Flux._show_children(t::PostNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, t.crossattention.layer, t.crossattention.norm, t.feedforward.layer, t.feedforward.norm) +Flux._show_children(t::PreNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, #=t.crossattention.layer, t.crossattention.norm,=# t.feedforward.layer, t.feedforward.norm) +Flux._show_children(t::PostNormTransformerDecoderBlock) = (t.attention.layer, t.attention.norm, #= t.crossattention.layer, t.crossattention.norm, =# t.feedforward.layer, t.feedforward.norm) ############################################# @@ -147,7 +151,13 @@ function (sa::SelfAttention)(nt::NamedTuple) qkv = apply_on_namedtuple(sa.qkv_proj, nt) a = apply_on_namedtuple(sa.attention_op, qkv) y = apply_on_namedtuple(sa.o_proj, a) - return y + return y + # NOTE: instead of returning y, we return a copy of y, because + # there is some sort of memory leak when using distributed for Jevo specifically, + # I suspect related to gradients. This cuts off gradient flow. + #hidden_state = zeros(Float32, size(y.hidden_state)) |> Flux.gpu + #hidden_state .= y.hidden_state + #return (hidden_state = hidden_state, attention_mask = y.attention_mask) end struct CrossAttention{A, Q, KV, O} <: LayerStruct @@ -515,18 +525,14 @@ function PostNormTransformerDecoderBlock( ) sa = SelfAttention(head, hidden_size, head_hidden_size; dropout = attention_dropout, causal = true, return_score = return_self_attention_score) - ca = CrossAttention(head, hidden_size, head_hidden_size; dropout = cross_attention_dropout, return_score) ff1 = Dense(act, hidden_size, intermediate_size) ff2 = Dense(intermediate_size, hidden_size) return TransformerDecoderBlock( PostNormResidual( - DropoutLayer(sa, dropout), + sa, LayerNorm(hidden_size)), PostNormResidual( - DropoutLayer(ca, dropout), - LayerNorm(hidden_size)), - PostNormResidual( - DropoutLayer(Chain(ff1, ff2), dropout), + Chain(ff1, ff2), LayerNorm(hidden_size))) end