diff --git a/research/pallas/aot.ipynb b/research/pallas/aot.ipynb new file mode 100644 index 0000000..cd10e99 --- /dev/null +++ b/research/pallas/aot.ipynb @@ -0,0 +1,1771 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import spyx\n", + "import haiku as hk\n", + "import optax" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "k = 25\n", + "\n", + "@jax.jit\n", + "def grad_superspike(x):\n", + " return 1 / (1 + k*jnp.abs(x))**2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.,\n", + " 13., 14., 15.], dtype=float32)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V = jnp.arange(16, dtype=jnp.float32)\n", + "V" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([1.0000000e+00, 1.4792900e-03, 3.8446751e-04, 1.7313019e-04,\n", + " 9.8029610e-05, 6.2988162e-05, 4.3857726e-05, 3.2283056e-05,\n", + " 2.4751862e-05, 1.9578667e-05, 1.5872763e-05, 1.3127495e-05,\n", + " 1.1037406e-05, 9.4094621e-06, 8.1168173e-06, 7.0733363e-06], dtype=float32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grad_superspike(V)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[16]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[16]\u001b[39m = pjit[\n", + " name=grad_superspike\n", + " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; c\u001b[35m:f32[16]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[16]\u001b[39m = abs c\n", + " e\u001b[35m:f32[16]\u001b[39m = mul 25.0 d\n", + " f\u001b[35m:f32[16]\u001b[39m = add 1.0 e\n", + " g\u001b[35m:f32[16]\u001b[39m = integer_pow[y=2] f\n", + " h\u001b[35m:f32[16]\u001b[39m = div 1.0 g\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(h,) }\n", + " ] a\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(b,) }" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.make_jaxpr(grad_superspike)(V)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "lowered = jax.jit(grad_superspike).lower(V)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit_grad_superspike attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<16xf32> {jax.result_info = \"\", mhlo.layout_mode = \"default\"}) {\n", + " %0 = call @grad_superspike(%arg0) : (tensor<16xf32>) -> tensor<16xf32>\n", + " return %0 : tensor<16xf32>\n", + " }\n", + " func.func private @grad_superspike(%arg0: tensor<16xf32>) -> tensor<16xf32> {\n", + " %0 = stablehlo.abs %arg0 : tensor<16xf32>\n", + " %1 = stablehlo.constant dense<2.500000e+01> : tensor\n", + " %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<16xf32>\n", + " %3 = stablehlo.multiply %2, %0 : tensor<16xf32>\n", + " %4 = stablehlo.constant dense<1.000000e+00> : tensor\n", + " %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<16xf32>\n", + " %6 = stablehlo.add %5, %3 : tensor<16xf32>\n", + " %7 = stablehlo.multiply %6, %6 : tensor<16xf32>\n", + " %8 = stablehlo.constant dense<1.000000e+00> : tensor\n", + " %9 = stablehlo.broadcast_in_dim %8, dims = [] : (tensor) -> tensor<16xf32>\n", + " %10 = stablehlo.divide %9, %7 : tensor<16xf32>\n", + " return %10 : tensor<16xf32>\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "print(lowered.as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "compiled = lowered.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_grad_superspike, is_scheduled=true, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"97c98a0ca68aa6e801146d4c1b162054\"}\n", + "\n", + "%fused_divide (param_0.3: f32[16]) -> f32[16] {\n", + " %constant_0_1 = f32[] constant(1)\n", + " %broadcast.4.1 = f32[16]{0} broadcast(f32[] %constant_0_1), dimensions={}\n", + " %param_0.3 = f32[16]{0} parameter(0)\n", + " %abs.2.1 = f32[16]{0} abs(f32[16]{0} %param_0.3), metadata={op_name=\"jit(grad_superspike)/jit(main)/jit(grad_superspike)/abs\" source_file=\"/tmp/ipykernel_73553/1167682913.py\" source_line=5}\n", + " %constant_1_1 = f32[] constant(25)\n", + " %broadcast.6.1 = f32[16]{0} broadcast(f32[] %constant_1_1), dimensions={}\n", + " %multiply.4.1 = f32[16]{0} multiply(f32[16]{0} %abs.2.1, f32[16]{0} %broadcast.6.1), metadata={op_name=\"jit(grad_superspike)/jit(main)/jit(grad_superspike)/mul\" source_file=\"/tmp/ipykernel_73553/1167682913.py\" source_line=5}\n", + " %add.2.1 = f32[16]{0} add(f32[16]{0} %multiply.4.1, f32[16]{0} %broadcast.4.1), metadata={op_name=\"jit(grad_superspike)/jit(main)/jit(grad_superspike)/add\" source_file=\"/tmp/ipykernel_73553/1167682913.py\" source_line=5}\n", + " %multiply.5.1 = f32[16]{0} multiply(f32[16]{0} %add.2.1, f32[16]{0} %add.2.1), metadata={op_name=\"jit(grad_superspike)/jit(main)/jit(grad_superspike)/integer_pow[y=2]\" source_file=\"/tmp/ipykernel_73553/1167682913.py\" source_line=5}\n", + " ROOT %divide.2.1 = f32[16]{0} divide(f32[16]{0} %broadcast.4.1, f32[16]{0} %multiply.5.1), metadata={op_name=\"jit(grad_superspike)/jit(main)/jit(grad_superspike)/div\" source_file=\"/tmp/ipykernel_73553/1167682913.py\" source_line=5}\n", + "}\n", + "\n", + "ENTRY %main.14 (Arg_0.1.0: f32[16]) -> f32[16] {\n", + " %Arg_0.1.0 = f32[16]{0} parameter(0)\n", + " ROOT %loop_divide_fusion = f32[16]{0} fusion(f32[16]{0} %Arg_0.1.0), kind=kLoop, calls=%fused_divide, metadata={op_name=\"jit(grad_superspike)/jit(main)/jit(grad_superspike)/div\" source_file=\"/tmp/ipykernel_73553/1167682913.py\" source_line=5}\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(compiled.as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'flops': 80.0,\n", + " 'bytes accessed': 128.0,\n", + " 'utilization0{}': 1.0,\n", + " 'utilization1{}': 4.0,\n", + " 'bytes accessed1{}': 256.0,\n", + " 'bytes accessedout{}': 64.0,\n", + " 'bytes accessed0{}': 64.0}]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compiled.cost_analysis()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pallas" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental import pallas as pl\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import spyx\n", + "import haiku as hk\n", + "import optax" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def add_vectors_kernel(x_ref, y_ref, o_ref):\n", + " x, y = x_ref[...], y_ref[...]\n", + " o_ref[...] = x + y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0, 2, 4, 6], dtype=int32)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@jax.jit\n", + "def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n", + " return pl.pallas_call(\n", + " add_vectors_kernel,\n", + " out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", + " )(x, y)\n", + "add_vectors(jnp.arange(4), jnp.arange(4))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.,\n", + " 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21.,\n", + " 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,\n", + " 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43.,\n", + " 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54.,\n", + " 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65.,\n", + " 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76.,\n", + " 77., 78., 79., 80., 81., 82., 83., 84., 85., 86., 87.,\n", + " 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98.,\n", + " 99., 100., 101., 102., 103., 104., 105., 106., 107., 108., 109.,\n", + " 110., 111., 112., 113., 114., 115., 116., 117., 118., 119., 120.,\n", + " 121., 122., 123., 124., 125., 126., 127., 128., 129., 130., 131.,\n", + " 132., 133., 134., 135., 136., 137., 138., 139., 140., 141., 142.,\n", + " 143., 144., 145., 146., 147., 148., 149., 150., 151., 152., 153.,\n", + " 154., 155., 156., 157., 158., 159., 160., 161., 162., 163., 164.,\n", + " 165., 166., 167., 168., 169., 170., 171., 172., 173., 174., 175.,\n", + " 176., 177., 178., 179., 180., 181., 182., 183., 184., 185., 186.,\n", + " 187., 188., 189., 190., 191., 192., 193., 194., 195., 196., 197.,\n", + " 198., 199., 200., 201., 202., 203., 204., 205., 206., 207., 208.,\n", + " 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219.,\n", + " 220., 221., 222., 223., 224., 225., 226., 227., 228., 229., 230.,\n", + " 231., 232., 233., 234., 235., 236., 237., 238., 239., 240., 241.,\n", + " 242., 243., 244., 245., 246., 247., 248., 249., 250., 251., 252.,\n", + " 253., 254., 255., 256., 257., 258., 259., 260., 261., 262., 263.,\n", + " 264., 265., 266., 267., 268., 269., 270., 271., 272., 273., 274.,\n", + " 275., 276., 277., 278., 279., 280., 281., 282., 283., 284., 285.,\n", + " 286., 287., 288., 289., 290., 291., 292., 293., 294., 295., 296.,\n", + " 297., 298., 299., 300., 301., 302., 303., 304., 305., 306., 307.,\n", + " 308., 309., 310., 311., 312., 313., 314., 315., 316., 317., 318.,\n", + " 319., 320., 321., 322., 323., 324., 325., 326., 327., 328., 329.,\n", + " 330., 331., 332., 333., 334., 335., 336., 337., 338., 339., 340.,\n", + " 341., 342., 343., 344., 345., 346., 347., 348., 349., 350., 351.,\n", + " 352., 353., 354., 355., 356., 357., 358., 359., 360., 361., 362.,\n", + " 363., 364., 365., 366., 367., 368., 369., 370., 371., 372., 373.,\n", + " 374., 375., 376., 377., 378., 379., 380., 381., 382., 383., 384.,\n", + " 385., 386., 387., 388., 389., 390., 391., 392., 393., 394., 395.,\n", + " 396., 397., 398., 399., 400., 401., 402., 403., 404., 405., 406.,\n", + " 407., 408., 409., 410., 411., 412., 413., 414., 415., 416., 417.,\n", + " 418., 419., 420., 421., 422., 423., 424., 425., 426., 427., 428.,\n", + " 429., 430., 431., 432., 433., 434., 435., 436., 437., 438., 439.,\n", + " 440., 441., 442., 443., 444., 445., 446., 447., 448., 449., 450.,\n", + " 451., 452., 453., 454., 455., 456., 457., 458., 459., 460., 461.,\n", + " 462., 463., 464., 465., 466., 467., 468., 469., 470., 471., 472.,\n", + " 473., 474., 475., 476., 477., 478., 479., 480., 481., 482., 483.,\n", + " 484., 485., 486., 487., 488., 489., 490., 491., 492., 493., 494.,\n", + " 495., 496., 497., 498., 499., 500., 501., 502., 503., 504., 505.,\n", + " 506., 507., 508., 509., 510., 511.], dtype=float32)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V = jnp.arange(512, dtype=jnp.float32)\n", + "V" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(16,)" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "V.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def superspike_kernel(x_ref, o_ref):\n", + " x = x_ref[...]\n", + " o_ref[...] = 1 / (1 + 25*jnp.abs(x))**2" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def pallas_superspike(x: jax.Array) -> jax.Array:\n", + " bspec = pl.BlockSpec(block_shape=(32,), index_map=lambda i: i)\n", + " return pl.pallas_call(\n", + " superspike_kernel,\n", + " out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " grid=(16,),\n", + " in_specs=[pl.BlockSpec(lambda i: i, (32,))],\n", + " out_specs=pl.BlockSpec(lambda i: i, (32,))\n", + " )(x,)" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([1.00000000e+00, 1.47928996e-03, 3.84467508e-04, 1.73130189e-04,\n", + " 9.80296099e-05, 6.29881615e-05, 4.38577263e-05, 3.22830565e-05,\n", + " 2.47518619e-05, 1.95786670e-05, 1.58727635e-05, 1.31274946e-05,\n", + " 1.10374058e-05, 9.40946211e-06, 8.11681730e-06, 7.07333629e-06,\n", + " 6.21886693e-06, 5.51037056e-06, 4.91639685e-06, 4.41352995e-06,\n", + " 3.98404791e-06, 3.61433581e-06, 3.29379668e-06, 3.01408181e-06,\n", + " 2.76854166e-06, 2.55182772e-06, 2.35959806e-06, 2.18829882e-06,\n", + " 2.03499781e-06, 1.89725961e-06, 1.77304651e-06, 1.66064399e-06,\n", + " 1.55860107e-06, 1.46568254e-06, 1.38083215e-06, 1.30314208e-06,\n", + " 1.23182895e-06, 1.16621345e-06, 1.10570420e-06, 1.04978506e-06,\n", + " 9.98002974e-07, 9.49959883e-07, 9.05304262e-07, 8.63724949e-07,\n", + " 8.24945744e-07, 7.88720627e-07, 7.54830353e-07, 7.23078358e-07,\n", + " 6.93288484e-07, 6.65302366e-07, 6.38977212e-07, 6.14184216e-07,\n", + " 5.90806678e-07, 5.68738926e-07, 5.47884838e-07, 5.28157102e-07,\n", + " 5.09476024e-07, 4.91768787e-07, 4.74968886e-07, 4.59015439e-07,\n", + " 4.43852429e-07, 4.29428582e-07, 4.15696519e-07, 4.02612784e-07,\n", + " 3.90137188e-07, 3.78232556e-07, 3.66864640e-07, 3.56001607e-07,\n", + " 3.45614040e-07, 3.35674542e-07, 3.26157760e-07, 3.17040019e-07,\n", + " 3.08299320e-07, 2.99915172e-07, 2.91868474e-07, 2.84141294e-07,\n", + " 2.76716946e-07, 2.69579857e-07, 2.62715361e-07, 2.56109757e-07,\n", + " 2.49750201e-07, 2.43624584e-07, 2.37721622e-07, 2.32030615e-07,\n", + " 2.26541559e-07, 2.21245003e-07, 2.16132065e-07, 2.11194319e-07,\n", + " 2.06423877e-07, 2.01813251e-07, 1.97355391e-07, 1.93043633e-07,\n", + " 1.88871638e-07, 1.84833453e-07, 1.80923394e-07, 1.77136116e-07,\n", + " 1.73466532e-07, 1.69909796e-07, 1.66461334e-07, 1.63116809e-07,\n", + " 1.59872073e-07, 1.56723203e-07, 1.53666463e-07, 1.50698270e-07,\n", + " 1.47815271e-07, 1.45014212e-07, 1.42292023e-07, 1.39645763e-07,\n", + " 1.37072661e-07, 1.34570016e-07, 1.32135298e-07, 1.29766050e-07,\n", + " 1.27459970e-07, 1.25214811e-07, 1.23028457e-07, 1.20898875e-07,\n", + " 1.18824104e-07, 1.16802291e-07, 1.14831643e-07, 1.12910456e-07,\n", + " 1.11037075e-07, 1.09209935e-07, 1.07427532e-07, 1.05688407e-07,\n", + " 1.03991169e-07, 1.02334496e-07, 1.00717095e-07, 9.91377362e-08,\n", + " 9.75952474e-08, 9.60884705e-08, 9.46163254e-08, 9.31777464e-08,\n", + " 9.17717387e-08, 9.03973074e-08, 8.90535290e-08, 8.77394939e-08,\n", + " 8.64543281e-08, 8.51971933e-08, 8.39672794e-08, 8.27638118e-08,\n", + " 8.15860233e-08, 8.04332032e-08, 7.93046482e-08, 7.81996832e-08,\n", + " 7.71176474e-08, 7.60579155e-08, 7.50198765e-08, 7.40029478e-08,\n", + " 7.30065537e-08, 7.20301472e-08, 7.10732024e-08, 7.01352008e-08,\n", + " 6.92156448e-08, 6.83140513e-08, 6.74299727e-08, 6.65629329e-08,\n", + " 6.57125199e-08, 6.48782930e-08, 6.40598543e-08, 6.32568060e-08,\n", + " 6.24687644e-08, 6.16953528e-08, 6.09362232e-08, 6.01910131e-08,\n", + " 5.94593956e-08, 5.87410298e-08, 5.80356101e-08, 5.73428167e-08,\n", + " 5.66623584e-08, 5.59939366e-08, 5.53372814e-08, 5.46921015e-08,\n", + " 5.40581446e-08, 5.34351408e-08, 5.28228554e-08, 5.22210222e-08,\n", + " 5.16294243e-08, 5.10478166e-08, 5.04759896e-08, 4.99137087e-08,\n", + " 4.93607786e-08, 4.88169789e-08, 4.82821214e-08, 4.77560036e-08,\n", + " 4.72384407e-08, 4.67292445e-08, 4.62282372e-08, 4.57352449e-08,\n", + " 4.52500970e-08, 4.47726265e-08, 4.43026771e-08, 4.38400818e-08,\n", + " 4.33847021e-08, 4.29363745e-08, 4.24949640e-08, 4.20603250e-08,\n", + " 4.16323189e-08, 4.12108143e-08, 4.07956797e-08, 4.03867837e-08,\n", + " 3.99840054e-08, 3.95872206e-08, 3.91963191e-08, 3.88111729e-08,\n", + " 3.84316792e-08, 3.80577241e-08, 3.76892011e-08, 3.73260001e-08,\n", + " 3.69680322e-08, 3.66151802e-08, 3.62673624e-08, 3.59244758e-08,\n", + " 3.55864280e-08, 3.52531302e-08, 3.49244935e-08, 3.46004292e-08,\n", + " 3.42808555e-08, 3.39656872e-08, 3.36548496e-08, 3.33482575e-08,\n", + " 3.30458363e-08, 3.27475078e-08, 3.24532010e-08, 3.21628448e-08,\n", + " 3.18763718e-08, 3.15937037e-08, 3.13147837e-08, 3.10395372e-08,\n", + " 3.07679073e-08, 3.04998231e-08, 3.02352312e-08, 2.99740641e-08,\n", + " 2.97162686e-08, 2.94617841e-08, 2.92105558e-08, 2.89625248e-08,\n", + " 2.87176416e-08, 2.84758510e-08, 2.82371015e-08, 2.80013399e-08,\n", + " 2.77685217e-08, 2.75385919e-08, 2.73115095e-08, 2.70872214e-08,\n", + " 2.68656883e-08, 2.66468589e-08, 2.64306941e-08, 2.62171476e-08,\n", + " 2.60061785e-08, 2.57977462e-08, 2.55918113e-08, 2.53883297e-08,\n", + " 2.51872674e-08, 2.49885819e-08, 2.47922411e-08, 2.45982026e-08,\n", + " 2.44064360e-08, 2.42169005e-08, 2.40295659e-08, 2.38443967e-08,\n", + " 2.36613591e-08, 2.34804194e-08, 2.33015491e-08, 2.31247146e-08,\n", + " 2.29498873e-08, 2.27770318e-08, 2.26061250e-08, 2.24371330e-08,\n", + " 2.22700294e-08, 2.21047838e-08, 2.19413732e-08, 2.17797655e-08,\n", + " 2.16199396e-08, 2.14618634e-08, 2.13055174e-08, 2.11508713e-08,\n", + " 2.09979039e-08, 2.08465902e-08, 2.06969073e-08, 2.05488284e-08,\n", + " 2.04023340e-08, 2.02574011e-08, 2.01140065e-08, 1.99721288e-08,\n", + " 1.98317469e-08, 1.96928411e-08, 1.95553884e-08, 1.94193710e-08,\n", + " 1.92847658e-08, 1.91515568e-08, 1.90197245e-08, 1.88892457e-08,\n", + " 1.87601064e-08, 1.86322868e-08, 1.85057711e-08, 1.83805362e-08,\n", + " 1.82565714e-08, 1.81338535e-08, 1.80123703e-08, 1.78921038e-08,\n", + " 1.77730382e-08, 1.76551573e-08, 1.75384454e-08, 1.74228845e-08,\n", + " 1.73084658e-08, 1.71951662e-08, 1.70829786e-08, 1.69718852e-08,\n", + " 1.68618719e-08, 1.67529244e-08, 1.66450285e-08, 1.65381717e-08,\n", + " 1.64323417e-08, 1.63275224e-08, 1.62237050e-08, 1.61208735e-08,\n", + " 1.60190172e-08, 1.59181237e-08, 1.58181788e-08, 1.57191735e-08,\n", + " 1.56210955e-08, 1.55239306e-08, 1.54276698e-08, 1.53323025e-08,\n", + " 1.52378163e-08, 1.51442006e-08, 1.50514463e-08, 1.49595394e-08,\n", + " 1.48684736e-08, 1.47782364e-08, 1.46888173e-08, 1.46002073e-08,\n", + " 1.45123975e-08, 1.44253782e-08, 1.43391379e-08, 1.42536685e-08,\n", + " 1.41689620e-08, 1.40850087e-08, 1.40017979e-08, 1.39193226e-08,\n", + " 1.38375746e-08, 1.37565443e-08, 1.36762237e-08, 1.35966047e-08,\n", + " 1.35176785e-08, 1.34394380e-08, 1.33618743e-08, 1.32849802e-08,\n", + " 1.32087488e-08, 1.31331719e-08, 1.30582398e-08, 1.29839481e-08,\n", + " 1.29102888e-08, 1.28372548e-08, 1.27648390e-08, 1.26930333e-08,\n", + " 1.26218325e-08, 1.25512294e-08, 1.24812152e-08, 1.24117872e-08,\n", + " 1.23429365e-08, 1.22746568e-08, 1.22069421e-08, 1.21397861e-08,\n", + " 1.20731825e-08, 1.20071268e-08, 1.19416104e-08, 1.18766295e-08,\n", + " 1.18121770e-08, 1.17482486e-08, 1.16848371e-08, 1.16219372e-08,\n", + " 1.15595444e-08, 1.14976526e-08, 1.14362564e-08, 1.13753513e-08,\n", + " 1.13149303e-08, 1.12549907e-08, 1.11955254e-08, 1.11365299e-08,\n", + " 1.10780007e-08, 1.10199307e-08, 1.09623164e-08, 1.09051523e-08,\n", + " 1.08484342e-08, 1.07921583e-08, 1.07363185e-08, 1.06809104e-08,\n", + " 1.06259312e-08, 1.05713749e-08, 1.05172377e-08, 1.04635154e-08,\n", + " 1.04102034e-08, 1.03572990e-08, 1.03047952e-08, 1.02526903e-08,\n", + " 1.02009796e-08, 1.01496598e-08, 1.00987254e-08, 1.00481730e-08,\n", + " 9.99800065e-09, 9.94820226e-09, 9.89877513e-09, 9.84971571e-09,\n", + " 9.80101955e-09, 9.75268488e-09, 9.70470548e-09, 9.65707958e-09,\n", + " 9.60980362e-09, 9.56287405e-09, 9.51628731e-09, 9.47003986e-09,\n", + " 9.42412903e-09, 9.37855127e-09, 9.33330391e-09, 9.28838251e-09,\n", + " 9.24378440e-09, 9.19950782e-09, 9.15554743e-09, 9.11190146e-09,\n", + " 9.06856723e-09, 9.02554209e-09, 8.98282160e-09, 8.94040308e-09,\n", + " 8.89828566e-09, 8.85646489e-09, 8.81493722e-09, 8.77370177e-09,\n", + " 8.73275496e-09, 8.69209416e-09, 8.65171668e-09, 8.61161986e-09,\n", + " 8.57180105e-09, 8.53225757e-09, 8.49298765e-09, 8.45398773e-09,\n", + " 8.41525605e-09, 8.37678993e-09, 8.33858671e-09, 8.30064462e-09,\n", + " 8.26296098e-09, 8.22553314e-09, 8.18835844e-09, 8.15143597e-09,\n", + " 8.11476220e-09, 8.07833622e-09, 8.04215361e-09, 8.00621436e-09,\n", + " 7.97051580e-09, 7.93505528e-09, 7.89983012e-09, 7.86483945e-09,\n", + " 7.83008147e-09, 7.79555265e-09, 7.76125120e-09, 7.72717623e-09,\n", + " 7.69332598e-09, 7.65969688e-09, 7.62628805e-09, 7.59309682e-09,\n", + " 7.56012231e-09, 7.52736184e-09, 7.49481366e-09, 7.46247597e-09,\n", + " 7.43034789e-09, 7.39842676e-09, 7.36670991e-09, 7.33519778e-09,\n", + " 7.30388638e-09, 7.27277660e-09, 7.24186355e-09, 7.21114812e-09,\n", + " 7.18062720e-09, 7.15030035e-09, 7.12016357e-09, 7.09021863e-09,\n", + " 7.06046110e-09, 7.03089142e-09, 7.00150649e-09, 6.97230584e-09,\n", + " 6.94328728e-09, 6.91444946e-09, 6.88579060e-09, 6.85731028e-09,\n", + " 6.82900536e-09, 6.80087675e-09, 6.77292045e-09, 6.74513645e-09,\n", + " 6.71752298e-09, 6.69007871e-09, 6.66280187e-09, 6.63569244e-09,\n", + " 6.60874733e-09, 6.58196653e-09, 6.55534826e-09, 6.52889076e-09,\n", + " 6.50259313e-09, 6.47645448e-09, 6.45047216e-09, 6.42464704e-09,\n", + " 6.39897602e-09, 6.37345909e-09, 6.34809361e-09, 6.32288000e-09,\n", + " 6.29781605e-09, 6.27290087e-09, 6.24813312e-09, 6.22351237e-09,\n", + " 6.19903595e-09, 6.17470430e-09, 6.15051521e-09, 6.12646822e-09], dtype=float32)" + ] + }, + "execution_count": 110, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pallas_superspike(V)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([1.00000000e+00, 1.47928996e-03, 3.84467508e-04, 1.73130189e-04,\n", + " 9.80296099e-05, 6.29881615e-05, 4.38577263e-05, 3.22830565e-05,\n", + " 2.47518619e-05, 1.95786670e-05, 1.58727635e-05, 1.31274946e-05,\n", + " 1.10374058e-05, 9.40946211e-06, 8.11681730e-06, 7.07333629e-06,\n", + " 6.21886693e-06, 5.51037056e-06, 4.91639685e-06, 4.41352995e-06,\n", + " 3.98404791e-06, 3.61433581e-06, 3.29379668e-06, 3.01408181e-06,\n", + " 2.76854166e-06, 2.55182772e-06, 2.35959806e-06, 2.18829882e-06,\n", + " 2.03499781e-06, 1.89725961e-06, 1.77304651e-06, 1.66064399e-06,\n", + " 1.55860107e-06, 1.46568254e-06, 1.38083215e-06, 1.30314208e-06,\n", + " 1.23182895e-06, 1.16621345e-06, 1.10570420e-06, 1.04978506e-06,\n", + " 9.98002974e-07, 9.49959883e-07, 9.05304262e-07, 8.63724949e-07,\n", + " 8.24945744e-07, 7.88720627e-07, 7.54830353e-07, 7.23078358e-07,\n", + " 6.93288484e-07, 6.65302366e-07, 6.38977212e-07, 6.14184216e-07,\n", + " 5.90806678e-07, 5.68738926e-07, 5.47884838e-07, 5.28157102e-07,\n", + " 5.09476024e-07, 4.91768787e-07, 4.74968886e-07, 4.59015439e-07,\n", + " 4.43852429e-07, 4.29428582e-07, 4.15696519e-07, 4.02612784e-07,\n", + " 3.90137188e-07, 3.78232556e-07, 3.66864640e-07, 3.56001607e-07,\n", + " 3.45614040e-07, 3.35674542e-07, 3.26157760e-07, 3.17040019e-07,\n", + " 3.08299320e-07, 2.99915172e-07, 2.91868474e-07, 2.84141294e-07,\n", + " 2.76716946e-07, 2.69579857e-07, 2.62715361e-07, 2.56109757e-07,\n", + " 2.49750201e-07, 2.43624584e-07, 2.37721622e-07, 2.32030615e-07,\n", + " 2.26541559e-07, 2.21245003e-07, 2.16132065e-07, 2.11194319e-07,\n", + " 2.06423877e-07, 2.01813251e-07, 1.97355391e-07, 1.93043633e-07,\n", + " 1.88871638e-07, 1.84833453e-07, 1.80923394e-07, 1.77136116e-07,\n", + " 1.73466532e-07, 1.69909796e-07, 1.66461334e-07, 1.63116809e-07,\n", + " 1.59872073e-07, 1.56723203e-07, 1.53666463e-07, 1.50698270e-07,\n", + " 1.47815271e-07, 1.45014212e-07, 1.42292023e-07, 1.39645763e-07,\n", + " 1.37072661e-07, 1.34570016e-07, 1.32135298e-07, 1.29766050e-07,\n", + " 1.27459970e-07, 1.25214811e-07, 1.23028457e-07, 1.20898875e-07,\n", + " 1.18824104e-07, 1.16802291e-07, 1.14831643e-07, 1.12910456e-07,\n", + " 1.11037075e-07, 1.09209935e-07, 1.07427532e-07, 1.05688407e-07,\n", + " 1.03991169e-07, 1.02334496e-07, 1.00717095e-07, 9.91377362e-08,\n", + " 9.75952474e-08, 9.60884705e-08, 9.46163254e-08, 9.31777464e-08,\n", + " 9.17717387e-08, 9.03973074e-08, 8.90535290e-08, 8.77394939e-08,\n", + " 8.64543281e-08, 8.51971933e-08, 8.39672794e-08, 8.27638118e-08,\n", + " 8.15860233e-08, 8.04332032e-08, 7.93046482e-08, 7.81996832e-08,\n", + " 7.71176474e-08, 7.60579155e-08, 7.50198765e-08, 7.40029478e-08,\n", + " 7.30065537e-08, 7.20301472e-08, 7.10732024e-08, 7.01352008e-08,\n", + " 6.92156448e-08, 6.83140513e-08, 6.74299727e-08, 6.65629329e-08,\n", + " 6.57125199e-08, 6.48782930e-08, 6.40598543e-08, 6.32568060e-08,\n", + " 6.24687644e-08, 6.16953528e-08, 6.09362232e-08, 6.01910131e-08,\n", + " 5.94593956e-08, 5.87410298e-08, 5.80356101e-08, 5.73428167e-08,\n", + " 5.66623584e-08, 5.59939366e-08, 5.53372814e-08, 5.46921015e-08,\n", + " 5.40581446e-08, 5.34351408e-08, 5.28228554e-08, 5.22210222e-08,\n", + " 5.16294243e-08, 5.10478166e-08, 5.04759896e-08, 4.99137087e-08,\n", + " 4.93607786e-08, 4.88169789e-08, 4.82821214e-08, 4.77560036e-08,\n", + " 4.72384407e-08, 4.67292445e-08, 4.62282372e-08, 4.57352449e-08,\n", + " 4.52500970e-08, 4.47726265e-08, 4.43026771e-08, 4.38400818e-08,\n", + " 4.33847021e-08, 4.29363745e-08, 4.24949640e-08, 4.20603250e-08,\n", + " 4.16323189e-08, 4.12108143e-08, 4.07956797e-08, 4.03867837e-08,\n", + " 3.99840054e-08, 3.95872206e-08, 3.91963191e-08, 3.88111729e-08,\n", + " 3.84316792e-08, 3.80577241e-08, 3.76892011e-08, 3.73260001e-08,\n", + " 3.69680322e-08, 3.66151802e-08, 3.62673624e-08, 3.59244758e-08,\n", + " 3.55864280e-08, 3.52531302e-08, 3.49244935e-08, 3.46004292e-08,\n", + " 3.42808555e-08, 3.39656872e-08, 3.36548496e-08, 3.33482575e-08,\n", + " 3.30458363e-08, 3.27475078e-08, 3.24532010e-08, 3.21628448e-08,\n", + " 3.18763718e-08, 3.15937037e-08, 3.13147837e-08, 3.10395372e-08,\n", + " 3.07679073e-08, 3.04998231e-08, 3.02352312e-08, 2.99740641e-08,\n", + " 2.97162686e-08, 2.94617841e-08, 2.92105558e-08, 2.89625248e-08,\n", + " 2.87176416e-08, 2.84758510e-08, 2.82371015e-08, 2.80013399e-08,\n", + " 2.77685217e-08, 2.75385919e-08, 2.73115095e-08, 2.70872214e-08,\n", + " 2.68656883e-08, 2.66468589e-08, 2.64306941e-08, 2.62171476e-08,\n", + " 2.60061785e-08, 2.57977462e-08, 2.55918113e-08, 2.53883297e-08,\n", + " 2.51872674e-08, 2.49885819e-08, 2.47922411e-08, 2.45982026e-08,\n", + " 2.44064360e-08, 2.42169005e-08, 2.40295659e-08, 2.38443967e-08,\n", + " 2.36613591e-08, 2.34804194e-08, 2.33015491e-08, 2.31247146e-08,\n", + " 2.29498873e-08, 2.27770318e-08, 2.26061250e-08, 2.24371330e-08,\n", + " 2.22700294e-08, 2.21047838e-08, 2.19413732e-08, 2.17797655e-08,\n", + " 2.16199396e-08, 2.14618634e-08, 2.13055174e-08, 2.11508713e-08,\n", + " 2.09979039e-08, 2.08465902e-08, 2.06969073e-08, 2.05488284e-08,\n", + " 2.04023340e-08, 2.02574011e-08, 2.01140065e-08, 1.99721288e-08,\n", + " 1.98317469e-08, 1.96928411e-08, 1.95553884e-08, 1.94193710e-08,\n", + " 1.92847658e-08, 1.91515568e-08, 1.90197245e-08, 1.88892457e-08,\n", + " 1.87601064e-08, 1.86322868e-08, 1.85057711e-08, 1.83805362e-08,\n", + " 1.82565714e-08, 1.81338535e-08, 1.80123703e-08, 1.78921038e-08,\n", + " 1.77730382e-08, 1.76551573e-08, 1.75384454e-08, 1.74228845e-08,\n", + " 1.73084658e-08, 1.71951662e-08, 1.70829786e-08, 1.69718852e-08,\n", + " 1.68618719e-08, 1.67529244e-08, 1.66450285e-08, 1.65381717e-08,\n", + " 1.64323417e-08, 1.63275224e-08, 1.62237050e-08, 1.61208735e-08,\n", + " 1.60190172e-08, 1.59181237e-08, 1.58181788e-08, 1.57191735e-08,\n", + " 1.56210955e-08, 1.55239306e-08, 1.54276698e-08, 1.53323025e-08,\n", + " 1.52378163e-08, 1.51442006e-08, 1.50514463e-08, 1.49595394e-08,\n", + " 1.48684736e-08, 1.47782364e-08, 1.46888173e-08, 1.46002073e-08,\n", + " 1.45123975e-08, 1.44253782e-08, 1.43391379e-08, 1.42536685e-08,\n", + " 1.41689620e-08, 1.40850087e-08, 1.40017979e-08, 1.39193226e-08,\n", + " 1.38375746e-08, 1.37565443e-08, 1.36762237e-08, 1.35966047e-08,\n", + " 1.35176785e-08, 1.34394380e-08, 1.33618743e-08, 1.32849802e-08,\n", + " 1.32087488e-08, 1.31331719e-08, 1.30582398e-08, 1.29839481e-08,\n", + " 1.29102888e-08, 1.28372548e-08, 1.27648390e-08, 1.26930333e-08,\n", + " 1.26218325e-08, 1.25512294e-08, 1.24812152e-08, 1.24117872e-08,\n", + " 1.23429365e-08, 1.22746568e-08, 1.22069421e-08, 1.21397861e-08,\n", + " 1.20731825e-08, 1.20071268e-08, 1.19416104e-08, 1.18766295e-08,\n", + " 1.18121770e-08, 1.17482486e-08, 1.16848371e-08, 1.16219372e-08,\n", + " 1.15595444e-08, 1.14976526e-08, 1.14362564e-08, 1.13753513e-08,\n", + " 1.13149303e-08, 1.12549907e-08, 1.11955254e-08, 1.11365299e-08,\n", + " 1.10780007e-08, 1.10199307e-08, 1.09623164e-08, 1.09051523e-08,\n", + " 1.08484342e-08, 1.07921583e-08, 1.07363185e-08, 1.06809104e-08,\n", + " 1.06259312e-08, 1.05713749e-08, 1.05172377e-08, 1.04635154e-08,\n", + " 1.04102034e-08, 1.03572990e-08, 1.03047952e-08, 1.02526903e-08,\n", + " 1.02009796e-08, 1.01496598e-08, 1.00987254e-08, 1.00481730e-08,\n", + " 9.99800065e-09, 9.94820226e-09, 9.89877513e-09, 9.84971571e-09,\n", + " 9.80101955e-09, 9.75268488e-09, 9.70470548e-09, 9.65707958e-09,\n", + " 9.60980362e-09, 9.56287405e-09, 9.51628731e-09, 9.47003986e-09,\n", + " 9.42412903e-09, 9.37855127e-09, 9.33330391e-09, 9.28838251e-09,\n", + " 9.24378440e-09, 9.19950782e-09, 9.15554743e-09, 9.11190146e-09,\n", + " 9.06856723e-09, 9.02554209e-09, 8.98282160e-09, 8.94040308e-09,\n", + " 8.89828566e-09, 8.85646489e-09, 8.81493722e-09, 8.77370177e-09,\n", + " 8.73275496e-09, 8.69209416e-09, 8.65171668e-09, 8.61161986e-09,\n", + " 8.57180105e-09, 8.53225757e-09, 8.49298765e-09, 8.45398773e-09,\n", + " 8.41525605e-09, 8.37678993e-09, 8.33858671e-09, 8.30064462e-09,\n", + " 8.26296098e-09, 8.22553314e-09, 8.18835844e-09, 8.15143597e-09,\n", + " 8.11476220e-09, 8.07833622e-09, 8.04215361e-09, 8.00621436e-09,\n", + " 7.97051580e-09, 7.93505528e-09, 7.89983012e-09, 7.86483945e-09,\n", + " 7.83008147e-09, 7.79555265e-09, 7.76125120e-09, 7.72717623e-09,\n", + " 7.69332598e-09, 7.65969688e-09, 7.62628805e-09, 7.59309682e-09,\n", + " 7.56012231e-09, 7.52736184e-09, 7.49481366e-09, 7.46247597e-09,\n", + " 7.43034789e-09, 7.39842676e-09, 7.36670991e-09, 7.33519778e-09,\n", + " 7.30388638e-09, 7.27277660e-09, 7.24186355e-09, 7.21114812e-09,\n", + " 7.18062720e-09, 7.15030035e-09, 7.12016357e-09, 7.09021863e-09,\n", + " 7.06046110e-09, 7.03089142e-09, 7.00150649e-09, 6.97230584e-09,\n", + " 6.94328728e-09, 6.91444946e-09, 6.88579060e-09, 6.85731028e-09,\n", + " 6.82900536e-09, 6.80087675e-09, 6.77292045e-09, 6.74513645e-09,\n", + " 6.71752298e-09, 6.69007871e-09, 6.66280187e-09, 6.63569244e-09,\n", + " 6.60874733e-09, 6.58196653e-09, 6.55534826e-09, 6.52889076e-09,\n", + " 6.50259313e-09, 6.47645448e-09, 6.45047216e-09, 6.42464704e-09,\n", + " 6.39897602e-09, 6.37345909e-09, 6.34809361e-09, 6.32288000e-09,\n", + " 6.29781605e-09, 6.27290087e-09, 6.24813312e-09, 6.22351237e-09,\n", + " 6.19903595e-09, 6.17470430e-09, 6.15051521e-09, 6.12646822e-09], dtype=float32)" + ] + }, + "execution_count": 111, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grad_superspike(V)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "p_lowered = pallas_superspike.lower(V)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit_pallas_superspike attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor<512xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<512xf32> {jax.result_info = \"\", mhlo.layout_mode = \"default\"}) {\n", + " %0 = call @wrapped(%arg0) : (tensor<512xf32>) -> tensor<512xf32>\n", + " return %0 : tensor<512xf32>\n", + " }\n", + " func.func private @wrapped(%arg0: tensor<512xf32>) -> tensor<512xf32> {\n", + " %0 = stablehlo.custom_call @__gpu$xla.gpu.triton(%arg0) {mhlo.backend_config = {debug = false, grid_x = 16 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = \"ML\\EFR\\0DMLIR19.0.0git\\00\\017\\07\\01\\05\\09!\\01\\03\\0F\\03\\13\\13\\17\\1B\\1F#'+/3\\05\\0D7;?CGK\\03\\BD\\99\\11\\01\\95\\07\\0F\\0F\\0F\\0F\\0B\\0F\\0F\\0B\\0F\\0F\\0F\\13\\0B\\0F\\0B\\0F\\0B\\1F\\13\\0B\\0B\\0B\\0B\\0F\\0F\\13\\0F\\0B\\13\\0B\\0F\\0F\\0B\\13\\0B\\0F\\0F\\0B\\17\\0F\\0F\\0B\\17\\0F\\0F\\0B\\17\\0F\\0F\\0B\\13\\0B\\0F\\0F\\0B\\17\\0F\\0F\\17\\0F\\17\\0B\\0F\\0B\\0F\\0B\\0F\\13\\1F\\0B\\0B\\0B\\0B\\05\\05YY\\01\\0F\\0F\\13\\13\\07\\13\\0B\\17\\03\\03A\\02\\C2\\04\\1F\\1D\\93\\0D\\1D/1\\11\\01\\05\\11\\01\\81\\05'\\15\\87\\1D\\1D\\8D\\0D\\05)\\11\\01\\01\\1D\\8F\\0D\\1D\\91\\0D\\03\\03)\\09\\05+\\157?\\05-\\11\\0B\\00\\05/\\13\\07\\10\\00\\00\\E0\\0F\\01\\05\\19\\19\\051\\0D\\0D\\053\\055\\153\\1D\\1D\\115\\17\\1B\\07\\01\\1D9;\\057\\17=\\09\\01\\059\\15AI\\1DCE\\05;\\17G\\03\\01\\05=\\15KQ\\1DMO\\05?\\17\\0B\\066\\01\\15SY\\1DUW\\05A\\17\\0B\\C64\\01\\15[a\\1D]_\\05C\\17\\0B\\C61\\01\\15ck\\1Deg\\05E\\17i\\9D\\01\\05G\\15ms\\1Doq\\05I\\17\\0BB.\\01\\15uy\\1D\\1Fw\\17\\0B\\8A-\\01\\1D\\1F{\\17}r\\08\\01\\05K\\11\\0B\\01\\05M\\1D\\85\\0D\\05O\\1D\\11\\89\\17\\1B\\09\\01\\13\\07\\10\\00\\00r\\10\\05Q\\05S\\05U\\05W#arith.overflow\\00#arith.fastmath\\00\\01\\02\\02\\1B\\03\\81\\01\\1B\\03\\81\\07\\0B\\1B\\03\\81\\0F\\01\\09\\05\\05\\0F\\0F\\01!tt.ptr\\00\\04\\8A\\04\\05\\01P\\01\\01\\07\\04f\\04\\03\\01\\05\\0BP\\01\\03\\07\\04:\\04\\03I\\93\\05\\1D\\1D\\00\\0DB\\01\\05\\03\\01\\15B\\01\\07\\03\\01\\17F\\01\\09\\03\\01\\05\\05\\07\\15B\\01\\07\\03\\01\\17F\\01\\09\\03\\01\\05\\05\\0B\\05B\\05\\0B\\03\\03\\03\\06\\05\\03\\03\\03\\09\\19F\\05\\09\\03\\03\\05\\0F\\11\\15B\\05\\0D\\03\\01\\03\\06\\05\\03\\03\\03\\15\\17F\\05\\09\\03\\03\\05\\13\\17\\03\\06\\05\\03\\09\\03\\01\\07\\06\\05\\03\\09\\05\\1B\\19\\09F\\05\\0F\\03\\05\\03\\1D\\0FF\\83\\11\\03\\05\\03\\1F\\15B\\0F\\13\\03\\07\\03\\06\\0F\\03\\05\\03#\\1BF\\0F\\15\\03\\05\\05%!\\15B\\15\\17\\03\\07\\03\\06\\15\\03\\05\\03)\\1DF\\15\\15\\03\\05\\05+'\\1BF\\0F\\15\\03\\05\\05--\\15B\\17\\17\\03\\07\\03\\06\\17\\03\\05\\031\\1FF\\17\\15\\03\\05\\053/\\05B\\03\\0B\\03\\03\\03\\06\\03\\03\\03\\03\\0D\\19F\\03\\09\\03\\03\\0579\\15B\\03\\0D\\03\\01\\03\\06\\03\\03\\03\\03=\\17F\\03\\09\\03\\03\\05;?\\03\\06\\03\\03\\09\\03\\03\\07\\06\\03\\03\\09\\05CA\\09F\\03\\0F\\03\\05\\03E\\11D\\03\\19\\05E5\\13\\00\\01\\06\\03\\01\\05\\01\\006\\0EY\\F9\\0B\\0B\\0B\\0B\\17\\8F\\15{)\\1F\\1D\\13G\\13G%\\F7\\0F!\\03\\13G%\\81\\0B\\0B\\0B\\0B\\0B\\13\\0F\\0D'\\1F\\0B\\0B\\0F\\17\\0D\\0F\\0D\\07\\11builtin\\00tt\\00arith\\00module\\00splat\\00make_range\\00addptr\\00load\\00func\\00get_program_id\\00extern_elementwise\\00store\\00return\\00constant\\00muli\\00addi\\00mulf\\00addf\\00divf\\00/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\\00superspike_kernel\\00/tmp/ipykernel_73553/3067264020.py\\00run_cell\\00\\00tt.divisibility\\00public\\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 32, 1)], [None, None]),)), (32,), ())], []),))]\\00pallas_superspike\\00/tmp/ipykernel_73553/2790088040.py\\00\\00/tmp/ipykernel_73553/2528866667.py\\00run_code\\00run_ast_nodes\\00run_cell_async\\00_pseudo_sync_runner\\00/usr/lib/python3/dist-packages/IPython/core/async_helpers.py\\00_run_cell\\00/home/legion/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py\\00__nv_fabsf\\00/abs\\00/mul\\00/add\\00/div\\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 32, 1)], [None, None]),)), (32,), ())], []),))]\\00\\08_\\1B\\05\\01\\01\\0BO+\\01\\11[\\03\\13\\03\\09\\05V\\02\\05\\09\\13\\03\\07\\11\\01\\07\\07!\\01\\07\\01\\03\\09##\\7F\\81\\03\\8B\\05^\\02\\03%\\07\\01\\0F\\0F\", name = \"superspike_kernel\", num_stages = 3 : i32, num_warps = 4 : i32}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<512xf32>) -> tensor<512xf32>\n", + " return %0 : tensor<512xf32>\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "print(p_lowered.as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "p_compiled = p_lowered.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_pallas_superspike, is_scheduled=true, entry_computation_layout={(f32[512]{0})->f32[512]{0}}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}, frontend_attributes={fingerprint_before_lhs=\"6632de43e561b95501cdc7d88cb9d595\"}\n", + "\n", + "ENTRY %main.6 (Arg_0.1.0: f32[512]) -> f32[512] {\n", + " %Arg_0.1.0 = f32[512]{0} parameter(0)\n", + " ROOT %custom-call.0.0 = f32[512]{0} custom-call(f32[512]{0} %Arg_0.1.0), custom_call_target=\"__gpu$xla.gpu.triton\", operand_layout_constraints={f32[512]{0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name=\"jit(pallas_superspike)/jit(main)/jit(wrapped)/pallas_call[name=superspike_kernel which_linear=(False,) in_shapes=(ShapeDtypeStruct(shape=(512,), dtype=float32),) out_shapes=(ShapeDtypeStruct(shape=(512,), dtype=float32),) debug=False interpret=False grid_mapping=GridMapping(grid=(16,), block_mappings=(BlockMapping(block_shape=(32,), index_map_jaxpr={ \\033[34m\\033[22m\\033[1mlambda \\033[39m\\033[22m\\033[22m; a\\033[35m:i32[]\\033[39m. \\033[34m\\033[22m\\033[1mlet\\033[39m\\033[22m\\033[22m \\033[34m\\033[22m\\033[1min \\033[39m\\033[22m\\033[22m(a,) }, indexing_mode=), BlockMapping(block_shape=(32,), index_map_jaxpr={ \\033[34m\\033[22m\\033[1mlambda \\033[39m\\033[22m\\033[22m; a\\033[35m:i32[]\\033[39m. \\033[34m\\033[22m\\033[1mlet\\033[39m\\033[22m\\033[22m \\033[34m\\033[22m\\033[1min \\033[39m\\033[22m\\033[22m(a,) }, indexing_mode=)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0) input_output_aliases=() compiler_params={}]\" source_file=\"/tmp/ipykernel_73553/2790088040.py\" source_line=4}, backend_config={debug = false, grid_x = 16 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = \"ML\\EFR\\0DMLIR19.0.0git\\00\\017\\07\\01\\05\\09!\\01\\03\\0F\\03\\13\\13\\17\\1B\\1F#'+/3\\05\\0D7;?CGK\\03\\BD\\99\\11\\01\\95\\07\\0F\\0F\\0F\\0F\\0B\\0F\\0F\\0B\\0F\\0F\\0F\\13\\0B\\0F\\0B\\0F\\0B\\1F\\13\\0B\\0B\\0B\\0B\\0F\\0F\\13\\0F\\0B\\13\\0B\\0F\\0F\\0B\\13\\0B\\0F\\0F\\0B\\17\\0F\\0F\\0B\\17\\0F\\0F\\0B\\17\\0F\\0F\\0B\\13\\0B\\0F\\0F\\0B\\17\\0F\\0F\\17\\0F\\17\\0B\\0F\\0B\\0F\\0B\\0F\\13\\1F\\0B\\0B\\0B\\0B\\05\\05YY\\01\\0F\\0F\\13\\13\\07\\13\\0B\\17\\03\\03A\\02\\C2\\04\\1F\\1D\\93\\0D\\1D/1\\11\\01\\05\\11\\01\\81\\05'\\15\\87\\1D\\1D\\8D\\0D\\05)\\11\\01\\01\\1D\\8F\\0D\\1D\\91\\0D\\03\\03)\\09\\05+\\157?\\05-\\11\\0B\\00\\05/\\13\\07\\10\\00\\00\\E0\\0F\\01\\05\\19\\19\\051\\0D\\0D\\053\\055\\153\\1D\\1D\\115\\17\\1B\\07\\01\\1D9;\\057\\17=\\09\\01\\059\\15AI\\1DCE\\05;\\17G\\03\\01\\05=\\15KQ\\1DMO\\05?\\17\\0B\\066\\01\\15SY\\1DUW\\05A\\17\\0B\\C64\\01\\15[a\\1D]_\\05C\\17\\0B\\C61\\01\\15ck\\1Deg\\05E\\17i\\9D\\01\\05G\\15ms\\1Doq\\05I\\17\\0BB.\\01\\15uy\\1D\\1Fw\\17\\0B\\8A-\\01\\1D\\1F{\\17}r\\08\\01\\05K\\11\\0B\\01\\05M\\1D\\85\\0D\\05O\\1D\\11\\89\\17\\1B\\09\\01\\13\\07\\10\\00\\00r\\10\\05Q\\05S\\05U\\05W#arith.overflow\\00#arith.fastmath\\00\\01\\02\\02\\1B\\03\\81\\01\\1B\\03\\81\\07\\0B\\1B\\03\\81\\0F\\01\\09\\05\\05\\0F\\0F\\01!tt.ptr\\00\\04\\8A\\04\\05\\01P\\01\\01\\07\\04f\\04\\03\\01\\05\\0BP\\01\\03\\07\\04:\\04\\03I\\93\\05\\1D\\1D\\00\\0DB\\01\\05\\03\\01\\15B\\01\\07\\03\\01\\17F\\01\\09\\03\\01\\05\\05\\07\\15B\\01\\07\\03\\01\\17F\\01\\09\\03\\01\\05\\05\\0B\\05B\\05\\0B\\03\\03\\03\\06\\05\\03\\03\\03\\09\\19F\\05\\09\\03\\03\\05\\0F\\11\\15B\\05\\0D\\03\\01\\03\\06\\05\\03\\03\\03\\15\\17F\\05\\09\\03\\03\\05\\13\\17\\03\\06\\05\\03\\09\\03\\01\\07\\06\\05\\03\\09\\05\\1B\\19\\09F\\05\\0F\\03\\05\\03\\1D\\0FF\\83\\11\\03\\05\\03\\1F\\15B\\0F\\13\\03\\07\\03\\06\\0F\\03\\05\\03#\\1BF\\0F\\15\\03\\05\\05%!\\15B\\15\\17\\03\\07\\03\\06\\15\\03\\05\\03)\\1DF\\15\\15\\03\\05\\05+'\\1BF\\0F\\15\\03\\05\\05--\\15B\\17\\17\\03\\07\\03\\06\\17\\03\\05\\031\\1FF\\17\\15\\03\\05\\053/\\05B\\03\\0B\\03\\03\\03\\06\\03\\03\\03\\03\\0D\\19F\\03\\09\\03\\03\\0579\\15B\\03\\0D\\03\\01\\03\\06\\03\\03\\03\\03=\\17F\\03\\09\\03\\03\\05;?\\03\\06\\03\\03\\09\\03\\03\\07\\06\\03\\03\\09\\05CA\\09F\\03\\0F\\03\\05\\03E\\11D\\03\\19\\05E5\\13\\00\\01\\06\\03\\01\\05\\01\\006\\0EY\\F9\\0B\\0B\\0B\\0B\\17\\8F\\15{)\\1F\\1D\\13G\\13G%\\F7\\0F!\\03\\13G%\\81\\0B\\0B\\0B\\0B\\0B\\13\\0F\\0D'\\1F\\0B\\0B\\0F\\17\\0D\\0F\\0D\\07\\11builtin\\00tt\\00arith\\00module\\00splat\\00make_range\\00addptr\\00load\\00func\\00get_program_id\\00extern_elementwise\\00store\\00return\\00constant\\00muli\\00addi\\00mulf\\00addf\\00divf\\00/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\\00superspike_kernel\\00/tmp/ipykernel_73553/3067264020.py\\00run_cell\\00\\00tt.divisibility\\00public\\00/get[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 32, 1)], [None, None]),)), (32,), ())], []),))]\\00pallas_superspike\\00/tmp/ipykernel_73553/2790088040.py\\00\\00/tmp/ipykernel_73553/2528866667.py\\00run_code\\00run_ast_nodes\\00run_cell_async\\00_pseudo_sync_runner\\00/usr/lib/python3/dist-packages/IPython/core/async_helpers.py\\00_run_cell\\00/home/legion/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py\\00__nv_fabsf\\00/abs\\00/mul\\00/add\\00/div\\00/swap[tree=PyTreeDef((CustomNode(NDIndexer[(PyTreeDef((CustomNode(Slice[(0, 32, 1)], [None, None]),)), (32,), ())], []),))]\\00\\08_\\1B\\05\\01\\01\\0BO+\\01\\11[\\03\\13\\03\\09\\05V\\02\\05\\09\\13\\03\\07\\11\\01\\07\\07!\\01\\07\\01\\03\\09##\\7F\\81\\03\\8B\\05^\\02\\03%\\07\\01\\0F\\0F\", name = \"superspike_kernel\", num_stages = 3 : i32, num_warps = 4 : i32}\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(p_compiled.as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "def grad_superspike2(x):\n", + " return 1 / (1 + k*jnp.abs(x))**2" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "273 µs ± 24.8 µs per loop (mean ± std. dev. of 20 runs, 1000 loops each)\n" + ] + } + ], + "source": [ + "%timeit -r20 grad_superspike2(V).block_until_ready()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "33.5 µs ± 4.39 µs per loop (mean ± std. dev. of 20 runs, 10000 loops each)\n" + ] + } + ], + "source": [ + "%timeit -r20 grad_superspike(V).block_until_ready()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "27.9 µs ± 2.68 µs per loop (mean ± std. dev. of 20 runs, 10000 loops each)\n" + ] + } + ], + "source": [ + "%timeit -r20 pallas_superspike(V).block_until_ready()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LIF neuron" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class LIF(hk.RNNCore):\n", + " \"\"\"\n", + " Leaky Integrate and Fire neuron model inspired by the implementation in\n", + " snnTorch:\n", + "\n", + " https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html\n", + " \n", + " \"\"\"\n", + "\n", + " def __init__(self, \n", + " hidden_shape: tuple, \n", + " beta=None,\n", + " threshold = 1.,\n", + " activation = superspike(),\n", + " name=\"LIF\"):\n", + "\n", + " \"\"\"\n", + " \n", + " :hidden_size: Size of preceding layer's outputs\n", + " :beta: decay rate. Set to float in range (0,1] for uniform decay across layer, otherwise it will be a normal\n", + " distribution centered on 0.5 with stddev of 0.25\n", + " :threshold: threshold for reset. Defaults to 1.\n", + " :activation: spyx.axn.Axon object, default is Heaviside with Straight-Through-Estimation.\n", + " \"\"\"\n", + " super().__init__(name=name)\n", + " self.hidden_shape = hidden_shape\n", + " self.beta = beta\n", + " self.threshold = threshold\n", + " self.spike = activation\n", + " \n", + " def __call__(self, x, V):\n", + " \"\"\"\n", + " :x: input vector coming from previous layer.\n", + " :V: neuron state tensor.\n", + "\n", + " \"\"\"\n", + " beta = self.beta # this line can probably be deleted, and the check changed to self.beta\n", + " if not beta:\n", + " beta = hk.get_parameter(\"beta\", self.hidden_shape,\n", + " init=hk.initializers.TruncatedNormal(0.25, 0.5))\n", + " beta = jnp.clip(beta, 0, 1)\n", + " else:\n", + " beta = hk.get_parameter(\"beta\", [],\n", + " init=hk.initializers.Constant(beta))\n", + " beta = jnp.clip(beta, 0, 1)\n", + " \n", + " # calculate whether spike is generated, and update membrane potential\n", + " spikes = self.spike(V-self.threshold)\n", + " V = beta*V + x - spikes * self.threshold\n", + " \n", + " return spikes, V\n", + "\n", + " def initial_state(self, batch_size): \n", + " return jnp.zeros((batch_size,) + self.hidden_shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "from spyx.axn import superspike" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "activation_func = superspike()" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "def lif_neuron(V, x): # carry, x\n", + " spikes = activation_func(V-1.1)\n", + " V = 0.9*V + x - spikes * 1.1\n", + "\n", + " return V, spikes # carry, y\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n", + " 0.5, 0.5, 0.5], dtype=float32)" + ] + }, + "execution_count": 59, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_in = jnp.ones(16) / 2\n", + "x_in" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([0.], dtype=float32)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "v0 = jnp.zeros(1)\n", + "v0" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def run_lif(x_in):\n", + " return jax.lax.scan(\n", + " lif_neuron,\n", + " v0,\n", + " x_in\n", + " )\n", + "\n", + "@jax.jit\n", + "def run_lif_unrolled(x_in):\n", + " return jax.lax.scan(\n", + " lif_neuron,\n", + " v0,\n", + " x_in,\n", + " unroll=True\n", + " ) " + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[16]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[1]\u001b[39m c\u001b[35m:f32[16,1]\u001b[39m = pjit[\n", + " name=run_lif\n", + " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[1]\u001b[39m; e\u001b[35m:f32[16]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mf\u001b[35m:f32[1]\u001b[39m g\u001b[35m:f32[16,1]\u001b[39m = scan[\n", + " _split_transpose=False\n", + " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; h\u001b[35m:f32[1]\u001b[39m i\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mj\u001b[35m:f32[1]\u001b[39m = sub h 1.100000023841858\n", + " k\u001b[35m:f32[1]\u001b[39m = pjit[\n", + " name=wrapped_fun\n", + " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; l\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mm\u001b[35m:f32[1]\u001b[39m = custom_vjp_call_jaxpr[\n", + " bwd=. at 0x7c885a4d8310>\n", + " fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; n\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mo\u001b[35m:bool[1]\u001b[39m = gt n 0.0\n", + " p\u001b[35m:i32[1]\u001b[39m = pjit[\n", + " name=_where\n", + " jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; q\u001b[35m:bool[1]\u001b[39m r\u001b[35m:i32[]\u001b[39m s\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", + " \u001b[39m\u001b[22m\u001b[22mt\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n", + " broadcast_dimensions=()\n", + " shape=(1,)\n", + " ] r\n", + " u\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n", + " broadcast_dimensions=()\n", + " shape=(1,)\n", + " ] s\n", + " v\u001b[35m:i32[1]\u001b[39m = select_n q u t\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(v,) }\n", + " ] o 1 0\n", + " w\u001b[35m:f32[1]\u001b[39m = convert_element_type[\n", + " new_dtype=float32\n", + " weak_type=False\n", + " ] p\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(w,) }\n", + " fwd_jaxpr_thunk=.memoized at 0x7c885a4d8c10>\n", + " num_consts=0\n", + " out_trees=. at 0x7c885a4d8160>\n", + " symbolic_zeros=False\n", + " ] l\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(m,) }\n", + " ] j\n", + " x\u001b[35m:f32[1]\u001b[39m = mul 0.8999999761581421 h\n", + " y\u001b[35m:f32[1]\u001b[39m = add x i\n", + " z\u001b[35m:f32[1]\u001b[39m = mul k 1.100000023841858\n", + " ba\u001b[35m:f32[1]\u001b[39m = sub y z\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(ba, k) }\n", + " length=16\n", + " linear=(False, False)\n", + " num_carry=1\n", + " num_consts=0\n", + " reverse=False\n", + " unroll=1\n", + " ] d e\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f, g) }\n", + " ] a\n", + " \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(b, c) }\n" + ] + } + ], + "source": [ + "print(jax.make_jaxpr(run_lif)(x_in))" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "module @jit_run_lif attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n", + " func.func public @main(%arg0: tensor<16xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<1xf32> {jax.result_info = \"[0]\", mhlo.layout_mode = \"default\"}, tensor<16x1xf32> {jax.result_info = \"[1]\", mhlo.layout_mode = \"default\"}) {\n", + " %0 = stablehlo.constant dense<0.000000e+00> : tensor<1xf32>\n", + " %1 = stablehlo.constant dense<0.000000e+00> : tensor\n", + " %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor) -> tensor<16x1xf32>\n", + " %3 = stablehlo.constant dense<0> : tensor\n", + " %4:4 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %3, %iterArg_1 = %0, %iterArg_2 = %2) : tensor<16xf32>, tensor, tensor<1xf32>, tensor<16x1xf32>\n", + " cond {\n", + " %5 = stablehlo.constant dense<16> : tensor\n", + " %6 = stablehlo.compare LT, %iterArg_0, %5, SIGNED : (tensor, tensor) -> tensor\n", + " stablehlo.return %6 : tensor\n", + " } do {\n", + " %5 = stablehlo.constant dense<0> : tensor\n", + " %6 = stablehlo.compare LT, %iterArg_0, %5, SIGNED : (tensor, tensor) -> tensor\n", + " %7 = stablehlo.convert %iterArg_0 : tensor\n", + " %8 = stablehlo.constant dense<16> : tensor\n", + " %9 = stablehlo.add %7, %8 : tensor\n", + " %10 = stablehlo.select %6, %9, %iterArg_0 : tensor, tensor\n", + " %11 = stablehlo.dynamic_slice %iterArg, %10, sizes = [1] : (tensor<16xf32>, tensor) -> tensor<1xf32>\n", + " %12 = stablehlo.reshape %11 : (tensor<1xf32>) -> tensor\n", + " %13 = stablehlo.constant dense<1.100000e+00> : tensor\n", + " %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xf32>\n", + " %15 = stablehlo.subtract %iterArg_1, %14 : tensor<1xf32>\n", + " %16 = func.call @wrapped_fun(%15) : (tensor<1xf32>) -> tensor<1xf32>\n", + " %17 = stablehlo.constant dense<0.899999976> : tensor\n", + " %18 = stablehlo.broadcast_in_dim %17, dims = [] : (tensor) -> tensor<1xf32>\n", + " %19 = stablehlo.multiply %18, %iterArg_1 : tensor<1xf32>\n", + " %20 = stablehlo.broadcast_in_dim %12, dims = [] : (tensor) -> tensor<1xf32>\n", + " %21 = stablehlo.add %19, %20 : tensor<1xf32>\n", + " %22 = stablehlo.constant dense<1.100000e+00> : tensor\n", + " %23 = stablehlo.broadcast_in_dim %22, dims = [] : (tensor) -> tensor<1xf32>\n", + " %24 = stablehlo.multiply %16, %23 : tensor<1xf32>\n", + " %25 = stablehlo.subtract %21, %24 : tensor<1xf32>\n", + " %26 = stablehlo.broadcast_in_dim %16, dims = [1] : (tensor<1xf32>) -> tensor<1x1xf32>\n", + " %27 = stablehlo.constant dense<0> : tensor\n", + " %28 = stablehlo.compare LT, %iterArg_0, %27, SIGNED : (tensor, tensor) -> tensor\n", + " %29 = stablehlo.convert %iterArg_0 : tensor\n", + " %30 = stablehlo.constant dense<16> : tensor\n", + " %31 = stablehlo.add %29, %30 : tensor\n", + " %32 = stablehlo.select %28, %31, %iterArg_0 : tensor, tensor\n", + " %33 = stablehlo.constant dense<0> : tensor\n", + " %34 = stablehlo.dynamic_update_slice %iterArg_2, %26, %32, %33 : (tensor<16x1xf32>, tensor<1x1xf32>, tensor, tensor) -> tensor<16x1xf32>\n", + " %35 = stablehlo.constant dense<1> : tensor\n", + " %36 = stablehlo.add %iterArg_0, %35 : tensor\n", + " stablehlo.return %iterArg, %36, %25, %34 : tensor<16xf32>, tensor, tensor<1xf32>, tensor<16x1xf32>\n", + " }\n", + " return %4#2, %4#3 : tensor<1xf32>, tensor<16x1xf32>\n", + " }\n", + " func.func private @wrapped_fun(%arg0: tensor<1xf32>) -> tensor<1xf32> {\n", + " %0 = stablehlo.constant dense<0.000000e+00> : tensor\n", + " %1 = stablehlo.broadcast_in_dim %0, dims = [] : (tensor) -> tensor<1xf32>\n", + " %2 = stablehlo.compare GT, %arg0, %1, FLOAT : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>\n", + " %3 = stablehlo.constant dense<1> : tensor\n", + " %4 = stablehlo.constant dense<0> : tensor\n", + " %5 = call @_where(%2, %3, %4) : (tensor<1xi1>, tensor, tensor) -> tensor<1xi32>\n", + " %6 = stablehlo.convert %5 : (tensor<1xi32>) -> tensor<1xf32>\n", + " return %6 : tensor<1xf32>\n", + " }\n", + " func.func private @_where(%arg0: tensor<1xi1>, %arg1: tensor, %arg2: tensor) -> tensor<1xi32> {\n", + " %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<1xi32>\n", + " %1 = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<1xi32>\n", + " %2 = stablehlo.select %arg0, %0, %1 : tensor<1xi1>, tensor<1xi32>\n", + " return %2 : tensor<1xi32>\n", + " }\n", + "}\n", + "\n" + ] + } + ], + "source": [ + "print(run_lif.lower(x_in).as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_run_lif, is_scheduled=true, entry_computation_layout={(f32[16]{0})->(f32[1]{0}, f32[16,1]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"ccb804c474066cf7086fa9ad8bb7d353\"}\n", + "\n", + "%fused_dynamic_update_slice (param_0: f32[16,1], param_1.18: s32[], param_2.21: f32[1]) -> f32[16,1] {\n", + " %param_0 = f32[16,1]{1,0} parameter(0)\n", + " %param_2.21 = f32[1]{0} parameter(2)\n", + " %constant_6_2 = f32[1]{0} constant({-1.1})\n", + " %add.8.3 = f32[1]{0} add(f32[1]{0} %param_2.21, f32[1]{0} %constant_6_2), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %constant_0_2 = f32[1]{0} constant({0})\n", + " %compare.5.3 = pred[1]{0} compare(f32[1]{0} %add.8.3, f32[1]{0} %constant_0_2), direction=GT, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_7_2 = s32[1]{0} constant({1})\n", + " %constant_8_2 = s32[1]{0} constant({0})\n", + " %select.5.3 = s32[1]{0} select(pred[1]{0} %compare.5.3, s32[1]{0} %constant_7_2, s32[1]{0} %constant_8_2), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.2.3 = f32[1]{0} convert(s32[1]{0} %select.5.3), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.99.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.2.3)\n", + " %param_1.18 = s32[] parameter(1)\n", + " %constant_32_1 = s32[] constant(0)\n", + " %compare.4.5 = pred[] compare(s32[] %param_1.18, s32[] %constant_32_1), direction=LT, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/lt\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %constant_31_1 = s32[] constant(16)\n", + " %add.6.5 = s32[] add(s32[] %param_1.18, s32[] %constant_31_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %select.4.5 = s32[] select(pred[] %compare.4.5, s32[] %add.6.5, s32[] %param_1.18), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/select_n\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " ROOT %dynamic-update-slice.1.1 = f32[16,1]{1,0} dynamic-update-slice(f32[16,1]{1,0} %param_0, f32[1,1]{1,0} %bitcast.99.1, s32[] %select.4.5, s32[] %constant_32_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/dynamic_update_slice\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + "}\n", + "\n", + "%fused_subtract (param_0.14: f32[16], param_1.17: f32[1], param_2.18: s32[]) -> f32[1] {\n", + " %param_1.17 = f32[1]{0} parameter(1)\n", + " %constant_29_1 = f32[1]{0} constant({0.9})\n", + " %multiply.2.1 = f32[1]{0} multiply(f32[1]{0} %param_1.17, f32[1]{0} %constant_29_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %param_0.14 = f32[16]{0} parameter(0)\n", + " %param_2.18 = s32[] parameter(2)\n", + " %constant_32_2 = s32[] constant(0)\n", + " %compare.4.3 = pred[] compare(s32[] %param_2.18, s32[] %constant_32_2), direction=LT, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/lt\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %constant_31_2 = s32[] constant(16)\n", + " %add.6.3 = s32[] add(s32[] %param_2.18, s32[] %constant_31_2), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %select.4.3 = s32[] select(pred[] %compare.4.3, s32[] %add.6.3, s32[] %param_2.18), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/select_n\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %dynamic-slice.1.1 = f32[1]{0} dynamic-slice(f32[16]{0} %param_0.14, s32[] %select.4.3), dynamic_slice_sizes={1}, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/dynamic_slice[slice_sizes=(1,)]\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %add.7.1 = f32[1]{0} add(f32[1]{0} %multiply.2.1, f32[1]{0} %dynamic-slice.1.1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_6_1 = f32[1]{0} constant({-1.1})\n", + " %add.8.5 = f32[1]{0} add(f32[1]{0} %param_1.17, f32[1]{0} %constant_6_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %constant_0_1 = f32[1]{0} constant({0})\n", + " %compare.5.5 = pred[1]{0} compare(f32[1]{0} %add.8.5, f32[1]{0} %constant_0_1), direction=GT, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_7_1 = s32[1]{0} constant({1})\n", + " %constant_8_1 = s32[1]{0} constant({0})\n", + " %select.5.5 = s32[1]{0} select(pred[1]{0} %compare.5.5, s32[1]{0} %constant_7_1, s32[1]{0} %constant_8_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.2.5 = f32[1]{0} convert(s32[1]{0} %select.5.5), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_30_1 = f32[1]{0} constant({1.1})\n", + " %multiply.3.1 = f32[1]{0} multiply(f32[1]{0} %convert.2.5, f32[1]{0} %constant_30_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " ROOT %subtract.1.1 = f32[1]{0} subtract(f32[1]{0} %add.7.1, f32[1]{0} %multiply.3.1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + "}\n", + "\n", + "%fused_add (param_0.13: s32[]) -> s32[] {\n", + " %param_0.13 = s32[] parameter(0)\n", + " %constant_28_1 = s32[] constant(1)\n", + " ROOT %add.5.1 = s32[] add(s32[] %param_0.13, s32[] %constant_28_1), metadata={op_name=\"jit(run_lif)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + "}\n", + "\n", + "%region_0.22 (arg_tuple.23.0: (s32[], f32[1], f32[16,1], f32[16])) -> (s32[], f32[1], f32[16,1], f32[16]) {\n", + " %arg_tuple.23.0 = (s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) parameter(0)\n", + " %get-tuple-element.4 = s32[] get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %arg_tuple.23.0), index=0\n", + " %get-tuple-element.5 = f32[1]{0} get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %arg_tuple.23.0), index=1\n", + " %get-tuple-element.11 = f32[16]{0} get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %arg_tuple.23.0), index=3\n", + " %get-tuple-element.6 = f32[16,1]{1,0} get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %arg_tuple.23.0), index=2\n", + " %loop_dynamic_update_slice_fusion = f32[16,1]{1,0} fusion(f32[16,1]{1,0} %get-tuple-element.6, s32[] %get-tuple-element.4, f32[1]{0} %get-tuple-element.5), kind=kLoop, calls=%fused_dynamic_update_slice, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/dynamic_update_slice\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %loop_subtract_fusion = f32[1]{0} fusion(f32[16]{0} %get-tuple-element.11, f32[1]{0} %get-tuple-element.5, s32[] %get-tuple-element.4), kind=kLoop, calls=%fused_subtract, control-predecessors={%loop_dynamic_update_slice_fusion}, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %loop_add_fusion = s32[] fusion(s32[] %get-tuple-element.4), kind=kLoop, calls=%fused_add, control-predecessors={%loop_dynamic_update_slice_fusion, %loop_subtract_fusion}, metadata={op_name=\"jit(run_lif)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " ROOT %tuple.2 = (s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) tuple(s32[] %loop_add_fusion, f32[1]{0} %loop_subtract_fusion, f32[16,1]{1,0} %loop_dynamic_update_slice_fusion, f32[16]{0} %get-tuple-element.11)\n", + "}\n", + "\n", + "%fused_compare (param_0.15: s32[]) -> pred[] {\n", + " %param_0.15 = s32[] parameter(0)\n", + " %constant_56_1 = s32[] constant(16)\n", + " ROOT %compare.6.1 = pred[] compare(s32[] %param_0.15, s32[] %constant_56_1), direction=LT, metadata={op_name=\"jit(run_lif)/jit(main)/while/cond/lt\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + "}\n", + "\n", + "%region_1.50 (arg_tuple.51.0: (s32[], f32[1], f32[16,1], f32[16])) -> pred[] {\n", + " %arg_tuple.51.0 = (s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) parameter(0)\n", + " %get-tuple-element.52.0 = s32[] get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %arg_tuple.51.0), index=0\n", + " ROOT %loop_compare_fusion = pred[] fusion(s32[] %get-tuple-element.52.0), kind=kLoop, calls=%fused_compare, metadata={op_name=\"jit(run_lif)/jit(main)/while/cond/lt\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + "}\n", + "\n", + "%fused_broadcast () -> f32[16,1] {\n", + " %constant_4_1 = f32[] constant(0)\n", + " ROOT %broadcast.1.1 = f32[16,1]{1,0} broadcast(f32[] %constant_4_1), dimensions={}\n", + "}\n", + "\n", + "%wrapped_copy_computation (param_0.17: s32[]) -> s32[] {\n", + " %param_0.17 = s32[] parameter(0)\n", + " ROOT %copy.14 = s32[] copy(s32[] %param_0.17)\n", + "}\n", + "\n", + "%wrapped_copy_computation.1 (param_0.18: f32[1]) -> f32[1] {\n", + " %param_0.18 = f32[1]{0} parameter(0)\n", + " ROOT %copy.15 = f32[1]{0} copy(f32[1]{0} %param_0.18)\n", + "}\n", + "\n", + "ENTRY %main.64 (Arg_0.1.0: f32[16]) -> (f32[1], f32[16,1]) {\n", + " %constant_3_0 = f32[1]{0} constant({0})\n", + " %constant_2_0 = s32[] constant(0)\n", + " %Arg_0.1.0 = f32[16]{0} parameter(0)\n", + " %wrapped_copy.1 = f32[1]{0} fusion(f32[1]{0} %constant_3_0), kind=kLoop, calls=%wrapped_copy_computation.1\n", + " %wrapped_copy = s32[] fusion(s32[] %constant_2_0), kind=kLoop, calls=%wrapped_copy_computation\n", + " %loop_broadcast_fusion = f32[16,1]{1,0} fusion(), kind=kLoop, calls=%fused_broadcast\n", + " %tuple = (s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) tuple(s32[] %wrapped_copy, f32[1]{0} %wrapped_copy.1, f32[16,1]{1,0} %loop_broadcast_fusion, f32[16]{0} %Arg_0.1.0)\n", + " %while.58.0 = (s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) while((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %tuple), condition=%region_1.50, body=%region_0.22, metadata={op_name=\"jit(run_lif)/jit(main)/while[cond_nconsts=0 body_nconsts=1]\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}, backend_config={\"known_trip_count\":{\"n\":\"16\"}}\n", + " %get-tuple-element.60.0 = f32[1]{0} get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %while.58.0), index=1, metadata={op_name=\"jit(run_lif)/jit(main)/while[cond_nconsts=0 body_nconsts=1]\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " %get-tuple-element.61.0 = f32[16,1]{1,0} get-tuple-element((s32[], f32[1]{0}, f32[16,1]{1,0}, f32[16]{0}) %while.58.0), index=2, metadata={op_name=\"jit(run_lif)/jit(main)/while[cond_nconsts=0 body_nconsts=1]\" source_file=\"/tmp/ipykernel_73553/458784223.py\" source_line=3}\n", + " ROOT %tuple.63.0 = (f32[1]{0}, f32[16,1]{1,0}) tuple(f32[1]{0} %get-tuple-element.60.0, f32[16,1]{1,0} %get-tuple-element.61.0)\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(run_lif.lower(x_in).compile().as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HloModule jit_run_lif_unrolled, is_scheduled=true, entry_computation_layout={(f32[16]{0})->(f32[1]{0}, f32[16,1]{1,0})}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true,true}, frontend_attributes={fingerprint_before_lhs=\"cb90a9e0f226453786315bab202e6172\"}\n", + "\n", + "%fused_concatenate (param_0.166: f32[16]) -> f32[16,1] {\n", + " %constant_191_1 = f32[1,1]{1,0} constant({ {0} })\n", + " %param_0.166 = f32[16]{0} parameter(0)\n", + " %bitcast.443.17 = f32[1,16]{1,0} bitcast(f32[16]{0} %param_0.166)\n", + " %slice.33.17 = f32[1,1]{1,0} slice(f32[1,16]{1,0} %bitcast.443.17), slice={[0:1], [0:1]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %constant_193_2 = f32[1,1]{1,0} constant({ {-1.1} })\n", + " %add.81.9 = f32[1,1]{1,0} add(f32[1,1]{1,0} %slice.33.17, f32[1,1]{1,0} %constant_193_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.51.7 = pred[1,1]{1,0} compare(f32[1,1]{1,0} %add.81.9, f32[1,1]{1,0} %constant_191_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_195_2 = s32[1,1]{1,0} constant({ {1} })\n", + " %constant_196_2 = s32[1,1]{1,0} constant({ {0} })\n", + " %select.68.5 = s32[1,1]{1,0} select(pred[1,1]{1,0} %compare.51.7, s32[1,1]{1,0} %constant_195_2, s32[1,1]{1,0} %constant_196_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.50.3 = f32[1,1]{1,0} convert(s32[1,1]{1,0} %select.68.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_192_2 = f32[1,1]{1,0} constant({ {0.9} })\n", + " %multiply.64.7 = f32[1,1]{1,0} multiply(f32[1,1]{1,0} %slice.33.17, f32[1,1]{1,0} %constant_192_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %bitcast.448.5 = f32[1]{0} bitcast(f32[1,1]{1,0} %multiply.64.7)\n", + " %slice.34.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[1:2]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.80.5 = f32[1]{0} add(f32[1]{0} %bitcast.448.5, f32[1]{0} %slice.34.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_200_2 = f32[1,1]{1,0} constant({ {1.1} })\n", + " %multiply.65.3 = f32[1,1]{1,0} multiply(f32[1,1]{1,0} %convert.50.3, f32[1,1]{1,0} %constant_200_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %bitcast.466.3 = f32[1]{0} bitcast(f32[1,1]{1,0} %multiply.65.3)\n", + " %subtract.31.3 = f32[1]{0} subtract(f32[1]{0} %add.80.5, f32[1]{0} %bitcast.466.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_126_2 = f32[1]{0} constant({-1.1}), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %add.83.3 = f32[1]{0} add(f32[1]{0} %subtract.31.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %constant_5_2 = f32[1]{0} constant({0})\n", + " %compare.52.3 = pred[1]{0} compare(f32[1]{0} %add.83.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_127_1 = s32[1]{0} constant({1})\n", + " %constant_128_1 = s32[1]{0} constant({0})\n", + " %select.69.3 = s32[1]{0} select(pred[1]{0} %compare.52.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.51.3 = f32[1]{0} convert(s32[1]{0} %select.69.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.793.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.51.3)\n", + " %constant_61_2 = f32[1]{0} constant({0.9})\n", + " %multiply.66.3 = f32[1]{0} multiply(f32[1]{0} %subtract.31.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.35.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[2:3]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(2,) limit_indices=(3,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.82.3 = f32[1]{0} add(f32[1]{0} %multiply.66.3, f32[1]{0} %slice.35.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_65_2 = f32[1]{0} constant({1.1})\n", + " %multiply.67.3 = f32[1]{0} multiply(f32[1]{0} %convert.51.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.32.3 = f32[1]{0} subtract(f32[1]{0} %add.82.3, f32[1]{0} %multiply.67.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.85.3 = f32[1]{0} add(f32[1]{0} %subtract.32.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.53.3 = pred[1]{0} compare(f32[1]{0} %add.85.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.70.3 = s32[1]{0} select(pred[1]{0} %compare.53.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.53.3 = f32[1]{0} convert(s32[1]{0} %select.70.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.797.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.53.3)\n", + " %multiply.68.3 = f32[1]{0} multiply(f32[1]{0} %subtract.32.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.36.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[3:4]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(3,) limit_indices=(4,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.84.3 = f32[1]{0} add(f32[1]{0} %multiply.68.3, f32[1]{0} %slice.36.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.69.3 = f32[1]{0} multiply(f32[1]{0} %convert.53.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.33.3 = f32[1]{0} subtract(f32[1]{0} %add.84.3, f32[1]{0} %multiply.69.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.87.3 = f32[1]{0} add(f32[1]{0} %subtract.33.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.54.3 = pred[1]{0} compare(f32[1]{0} %add.87.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.71.3 = s32[1]{0} select(pred[1]{0} %compare.54.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.54.3 = f32[1]{0} convert(s32[1]{0} %select.71.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.801.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.54.3)\n", + " %multiply.70.3 = f32[1]{0} multiply(f32[1]{0} %subtract.33.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.37.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[4:5]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(4,) limit_indices=(5,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.86.3 = f32[1]{0} add(f32[1]{0} %multiply.70.3, f32[1]{0} %slice.37.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.71.3 = f32[1]{0} multiply(f32[1]{0} %convert.54.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.34.3 = f32[1]{0} subtract(f32[1]{0} %add.86.3, f32[1]{0} %multiply.71.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.89.3 = f32[1]{0} add(f32[1]{0} %subtract.34.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.55.3 = pred[1]{0} compare(f32[1]{0} %add.89.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.72.3 = s32[1]{0} select(pred[1]{0} %compare.55.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.55.3 = f32[1]{0} convert(s32[1]{0} %select.72.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.805.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.55.3)\n", + " %multiply.72.3 = f32[1]{0} multiply(f32[1]{0} %subtract.34.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.38.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[5:6]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(5,) limit_indices=(6,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.88.3 = f32[1]{0} add(f32[1]{0} %multiply.72.3, f32[1]{0} %slice.38.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.73.3 = f32[1]{0} multiply(f32[1]{0} %convert.55.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.35.3 = f32[1]{0} subtract(f32[1]{0} %add.88.3, f32[1]{0} %multiply.73.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.91.3 = f32[1]{0} add(f32[1]{0} %subtract.35.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.56.3 = pred[1]{0} compare(f32[1]{0} %add.91.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.73.3 = s32[1]{0} select(pred[1]{0} %compare.56.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.56.3 = f32[1]{0} convert(s32[1]{0} %select.73.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.809.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.56.3)\n", + " %multiply.74.3 = f32[1]{0} multiply(f32[1]{0} %subtract.35.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.39.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[6:7]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(6,) limit_indices=(7,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.90.3 = f32[1]{0} add(f32[1]{0} %multiply.74.3, f32[1]{0} %slice.39.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.75.3 = f32[1]{0} multiply(f32[1]{0} %convert.56.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.36.3 = f32[1]{0} subtract(f32[1]{0} %add.90.3, f32[1]{0} %multiply.75.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.93.3 = f32[1]{0} add(f32[1]{0} %subtract.36.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.57.3 = pred[1]{0} compare(f32[1]{0} %add.93.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.75.3 = s32[1]{0} select(pred[1]{0} %compare.57.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.57.3 = f32[1]{0} convert(s32[1]{0} %select.75.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.813.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.57.3)\n", + " %multiply.76.3 = f32[1]{0} multiply(f32[1]{0} %subtract.36.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.40.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[7:8]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(7,) limit_indices=(8,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.92.3 = f32[1]{0} add(f32[1]{0} %multiply.76.3, f32[1]{0} %slice.40.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.77.3 = f32[1]{0} multiply(f32[1]{0} %convert.57.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.37.3 = f32[1]{0} subtract(f32[1]{0} %add.92.3, f32[1]{0} %multiply.77.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.95.3 = f32[1]{0} add(f32[1]{0} %subtract.37.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.58.3 = pred[1]{0} compare(f32[1]{0} %add.95.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.76.3 = s32[1]{0} select(pred[1]{0} %compare.58.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.58.3 = f32[1]{0} convert(s32[1]{0} %select.76.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.817.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.58.3)\n", + " %multiply.78.3 = f32[1]{0} multiply(f32[1]{0} %subtract.37.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.41.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[8:9]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(8,) limit_indices=(9,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.94.3 = f32[1]{0} add(f32[1]{0} %multiply.78.3, f32[1]{0} %slice.41.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.79.3 = f32[1]{0} multiply(f32[1]{0} %convert.58.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.38.3 = f32[1]{0} subtract(f32[1]{0} %add.94.3, f32[1]{0} %multiply.79.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.97.3 = f32[1]{0} add(f32[1]{0} %subtract.38.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.59.3 = pred[1]{0} compare(f32[1]{0} %add.97.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.77.3 = s32[1]{0} select(pred[1]{0} %compare.59.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.59.3 = f32[1]{0} convert(s32[1]{0} %select.77.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.821.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.59.3)\n", + " %multiply.80.3 = f32[1]{0} multiply(f32[1]{0} %subtract.38.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.42.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[9:10]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(9,) limit_indices=(10,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.96.3 = f32[1]{0} add(f32[1]{0} %multiply.80.3, f32[1]{0} %slice.42.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.81.3 = f32[1]{0} multiply(f32[1]{0} %convert.59.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.39.3 = f32[1]{0} subtract(f32[1]{0} %add.96.3, f32[1]{0} %multiply.81.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.99.3 = f32[1]{0} add(f32[1]{0} %subtract.39.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.60.3 = pred[1]{0} compare(f32[1]{0} %add.99.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.78.3 = s32[1]{0} select(pred[1]{0} %compare.60.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.60.3 = f32[1]{0} convert(s32[1]{0} %select.78.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.825.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.60.3)\n", + " %multiply.82.3 = f32[1]{0} multiply(f32[1]{0} %subtract.39.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.43.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[10:11]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(10,) limit_indices=(11,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.98.3 = f32[1]{0} add(f32[1]{0} %multiply.82.3, f32[1]{0} %slice.43.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.83.3 = f32[1]{0} multiply(f32[1]{0} %convert.60.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.40.3 = f32[1]{0} subtract(f32[1]{0} %add.98.3, f32[1]{0} %multiply.83.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.101.3 = f32[1]{0} add(f32[1]{0} %subtract.40.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.61.3 = pred[1]{0} compare(f32[1]{0} %add.101.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.79.3 = s32[1]{0} select(pred[1]{0} %compare.61.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.61.3 = f32[1]{0} convert(s32[1]{0} %select.79.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.829.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.61.3)\n", + " %multiply.84.3 = f32[1]{0} multiply(f32[1]{0} %subtract.40.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.44.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[11:12]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(11,) limit_indices=(12,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.100.3 = f32[1]{0} add(f32[1]{0} %multiply.84.3, f32[1]{0} %slice.44.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.85.3 = f32[1]{0} multiply(f32[1]{0} %convert.61.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.41.3 = f32[1]{0} subtract(f32[1]{0} %add.100.3, f32[1]{0} %multiply.85.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.103.3 = f32[1]{0} add(f32[1]{0} %subtract.41.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.62.3 = pred[1]{0} compare(f32[1]{0} %add.103.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.80.3 = s32[1]{0} select(pred[1]{0} %compare.62.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.62.3 = f32[1]{0} convert(s32[1]{0} %select.80.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.833.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.62.3)\n", + " %multiply.86.3 = f32[1]{0} multiply(f32[1]{0} %subtract.41.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.45.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[12:13]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(12,) limit_indices=(13,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.102.3 = f32[1]{0} add(f32[1]{0} %multiply.86.3, f32[1]{0} %slice.45.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.87.3 = f32[1]{0} multiply(f32[1]{0} %convert.62.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.42.3 = f32[1]{0} subtract(f32[1]{0} %add.102.3, f32[1]{0} %multiply.87.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.105.3 = f32[1]{0} add(f32[1]{0} %subtract.42.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.63.3 = pred[1]{0} compare(f32[1]{0} %add.105.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.81.3 = s32[1]{0} select(pred[1]{0} %compare.63.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.63.3 = f32[1]{0} convert(s32[1]{0} %select.81.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.837.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.63.3)\n", + " %multiply.88.3 = f32[1]{0} multiply(f32[1]{0} %subtract.42.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.46.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[13:14]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(13,) limit_indices=(14,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.104.3 = f32[1]{0} add(f32[1]{0} %multiply.88.3, f32[1]{0} %slice.46.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.89.3 = f32[1]{0} multiply(f32[1]{0} %convert.63.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.43.3 = f32[1]{0} subtract(f32[1]{0} %add.104.3, f32[1]{0} %multiply.89.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.107.3 = f32[1]{0} add(f32[1]{0} %subtract.43.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.64.3 = pred[1]{0} compare(f32[1]{0} %add.107.3, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.82.3 = s32[1]{0} select(pred[1]{0} %compare.64.3, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.64.3 = f32[1]{0} convert(s32[1]{0} %select.82.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.841.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.64.3)\n", + " %multiply.90.3 = f32[1]{0} multiply(f32[1]{0} %subtract.43.3, f32[1]{0} %constant_61_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.47.3 = f32[1]{0} slice(f32[16]{0} %param_0.166), slice={[14:15]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(14,) limit_indices=(15,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.106.3 = f32[1]{0} add(f32[1]{0} %multiply.90.3, f32[1]{0} %slice.47.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.91.3 = f32[1]{0} multiply(f32[1]{0} %convert.64.3, f32[1]{0} %constant_65_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.44.3 = f32[1]{0} subtract(f32[1]{0} %add.106.3, f32[1]{0} %multiply.91.3), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.109.2 = f32[1]{0} add(f32[1]{0} %subtract.44.3, f32[1]{0} %constant_126_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.66.1 = pred[1]{0} compare(f32[1]{0} %add.109.2, f32[1]{0} %constant_5_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.83.2 = s32[1]{0} select(pred[1]{0} %compare.66.1, s32[1]{0} %constant_127_1, s32[1]{0} %constant_128_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.65.1 = f32[1]{0} convert(s32[1]{0} %select.83.2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %bitcast.845.1 = f32[1,1]{1,0} bitcast(f32[1]{0} %convert.65.1)\n", + " ROOT %concatenate.2.1 = f32[16,1]{1,0} concatenate(f32[1,1]{1,0} %constant_191_1, f32[1,1]{1,0} %convert.50.3, f32[1,1]{1,0} %bitcast.793.1, f32[1,1]{1,0} %bitcast.797.1, f32[1,1]{1,0} %bitcast.801.1, /*index=5*/f32[1,1]{1,0} %bitcast.805.1, f32[1,1]{1,0} %bitcast.809.1, f32[1,1]{1,0} %bitcast.813.1, f32[1,1]{1,0} %bitcast.817.1, f32[1,1]{1,0} %bitcast.821.1, /*index=10*/f32[1,1]{1,0} %bitcast.825.1, f32[1,1]{1,0} %bitcast.829.1, f32[1,1]{1,0} %bitcast.833.1, f32[1,1]{1,0} %bitcast.837.1, f32[1,1]{1,0} %bitcast.841.1, /*index=15*/f32[1,1]{1,0} %bitcast.845.1), dimensions={0}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/concatenate[dimension=0]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + "}\n", + "\n", + "%fused_subtract (param_0.165: f32[16]) -> f32[1] {\n", + " %param_0.165 = f32[16]{0} parameter(0)\n", + " %bitcast.443.43 = f32[1,16]{1,0} bitcast(f32[16]{0} %param_0.165)\n", + " %slice.33.43 = f32[1,1]{1,0} slice(f32[1,16]{1,0} %bitcast.443.43), slice={[0:1], [0:1]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(0,) limit_indices=(1,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %constant_192_1 = f32[1,1]{1,0} constant({ {0.9} })\n", + " %multiply.64.17 = f32[1,1]{1,0} multiply(f32[1,1]{1,0} %slice.33.43, f32[1,1]{1,0} %constant_192_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %bitcast.448.15 = f32[1]{0} bitcast(f32[1,1]{1,0} %multiply.64.17)\n", + " %slice.34.13 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[1:2]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(1,) limit_indices=(2,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.80.15 = f32[1]{0} add(f32[1]{0} %bitcast.448.15, f32[1]{0} %slice.34.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_193_1 = f32[1,1]{1,0} constant({ {-1.1} })\n", + " %add.81.23 = f32[1,1]{1,0} add(f32[1,1]{1,0} %slice.33.43, f32[1,1]{1,0} %constant_193_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %constant_191_2 = f32[1,1]{1,0} constant({ {0} })\n", + " %compare.51.21 = pred[1,1]{1,0} compare(f32[1,1]{1,0} %add.81.23, f32[1,1]{1,0} %constant_191_2), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_195_1 = s32[1,1]{1,0} constant({ {1} })\n", + " %constant_196_1 = s32[1,1]{1,0} constant({ {0} })\n", + " %select.68.19 = s32[1,1]{1,0} select(pred[1,1]{1,0} %compare.51.21, s32[1,1]{1,0} %constant_195_1, s32[1,1]{1,0} %constant_196_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.50.17 = f32[1,1]{1,0} convert(s32[1,1]{1,0} %select.68.19), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_200_1 = f32[1,1]{1,0} constant({ {1.1} })\n", + " %multiply.65.13 = f32[1,1]{1,0} multiply(f32[1,1]{1,0} %convert.50.17, f32[1,1]{1,0} %constant_200_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %bitcast.466.13 = f32[1]{0} bitcast(f32[1,1]{1,0} %multiply.65.13)\n", + " %subtract.31.13 = f32[1]{0} subtract(f32[1]{0} %add.80.15, f32[1]{0} %bitcast.466.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_61_1 = f32[1]{0} constant({0.9})\n", + " %multiply.66.13 = f32[1]{0} multiply(f32[1]{0} %subtract.31.13, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.35.13 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[2:3]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(2,) limit_indices=(3,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.82.13 = f32[1]{0} add(f32[1]{0} %multiply.66.13, f32[1]{0} %slice.35.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %constant_126_1 = f32[1]{0} constant({-1.1}), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %add.83.17 = f32[1]{0} add(f32[1]{0} %subtract.31.13, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %constant_5_1 = f32[1]{0} constant({0})\n", + " %compare.52.17 = pred[1]{0} compare(f32[1]{0} %add.83.17, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_127_2 = s32[1]{0} constant({1})\n", + " %constant_128_2 = s32[1]{0} constant({0})\n", + " %select.69.17 = s32[1]{0} select(pred[1]{0} %compare.52.17, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.51.17 = f32[1]{0} convert(s32[1]{0} %select.69.17), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %constant_65_1 = f32[1]{0} constant({1.1})\n", + " %multiply.67.13 = f32[1]{0} multiply(f32[1]{0} %convert.51.17, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.32.13 = f32[1]{0} subtract(f32[1]{0} %add.82.13, f32[1]{0} %multiply.67.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.68.9 = f32[1]{0} multiply(f32[1]{0} %subtract.32.13, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.36.9 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[3:4]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(3,) limit_indices=(4,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.84.9 = f32[1]{0} add(f32[1]{0} %multiply.68.9, f32[1]{0} %slice.36.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.85.13 = f32[1]{0} add(f32[1]{0} %subtract.32.13, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.53.13 = pred[1]{0} compare(f32[1]{0} %add.85.13, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.70.13 = s32[1]{0} select(pred[1]{0} %compare.53.13, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.53.13 = f32[1]{0} convert(s32[1]{0} %select.70.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.69.9 = f32[1]{0} multiply(f32[1]{0} %convert.53.13, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.33.9 = f32[1]{0} subtract(f32[1]{0} %add.84.9, f32[1]{0} %multiply.69.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.70.9 = f32[1]{0} multiply(f32[1]{0} %subtract.33.9, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.37.9 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[4:5]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(4,) limit_indices=(5,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.86.9 = f32[1]{0} add(f32[1]{0} %multiply.70.9, f32[1]{0} %slice.37.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.87.13 = f32[1]{0} add(f32[1]{0} %subtract.33.9, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.54.13 = pred[1]{0} compare(f32[1]{0} %add.87.13, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.71.13 = s32[1]{0} select(pred[1]{0} %compare.54.13, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.54.13 = f32[1]{0} convert(s32[1]{0} %select.71.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.71.9 = f32[1]{0} multiply(f32[1]{0} %convert.54.13, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.34.9 = f32[1]{0} subtract(f32[1]{0} %add.86.9, f32[1]{0} %multiply.71.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.72.5 = f32[1]{0} multiply(f32[1]{0} %subtract.34.9, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.38.5 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[5:6]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(5,) limit_indices=(6,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.88.5 = f32[1]{0} add(f32[1]{0} %multiply.72.5, f32[1]{0} %slice.38.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.89.9 = f32[1]{0} add(f32[1]{0} %subtract.34.9, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.55.9 = pred[1]{0} compare(f32[1]{0} %add.89.9, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.72.9 = s32[1]{0} select(pred[1]{0} %compare.55.9, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.55.9 = f32[1]{0} convert(s32[1]{0} %select.72.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.73.5 = f32[1]{0} multiply(f32[1]{0} %convert.55.9, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.35.5 = f32[1]{0} subtract(f32[1]{0} %add.88.5, f32[1]{0} %multiply.73.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.74.17 = f32[1]{0} multiply(f32[1]{0} %subtract.35.5, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.39.17 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[6:7]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(6,) limit_indices=(7,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.90.17 = f32[1]{0} add(f32[1]{0} %multiply.74.17, f32[1]{0} %slice.39.17), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.91.21 = f32[1]{0} add(f32[1]{0} %subtract.35.5, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.56.21 = pred[1]{0} compare(f32[1]{0} %add.91.21, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.73.21 = s32[1]{0} select(pred[1]{0} %compare.56.21, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.56.21 = f32[1]{0} convert(s32[1]{0} %select.73.21), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.75.17 = f32[1]{0} multiply(f32[1]{0} %convert.56.21, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.36.17 = f32[1]{0} subtract(f32[1]{0} %add.90.17, f32[1]{0} %multiply.75.17), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.76.13 = f32[1]{0} multiply(f32[1]{0} %subtract.36.17, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.40.13 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[7:8]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(7,) limit_indices=(8,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.92.13 = f32[1]{0} add(f32[1]{0} %multiply.76.13, f32[1]{0} %slice.40.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.93.17 = f32[1]{0} add(f32[1]{0} %subtract.36.17, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.57.17 = pred[1]{0} compare(f32[1]{0} %add.93.17, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.75.17 = s32[1]{0} select(pred[1]{0} %compare.57.17, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.57.17 = f32[1]{0} convert(s32[1]{0} %select.75.17), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.77.13 = f32[1]{0} multiply(f32[1]{0} %convert.57.17, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.37.13 = f32[1]{0} subtract(f32[1]{0} %add.92.13, f32[1]{0} %multiply.77.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.78.9 = f32[1]{0} multiply(f32[1]{0} %subtract.37.13, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.41.9 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[8:9]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(8,) limit_indices=(9,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.94.9 = f32[1]{0} add(f32[1]{0} %multiply.78.9, f32[1]{0} %slice.41.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.95.13 = f32[1]{0} add(f32[1]{0} %subtract.37.13, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.58.13 = pred[1]{0} compare(f32[1]{0} %add.95.13, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.76.13 = s32[1]{0} select(pred[1]{0} %compare.58.13, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.58.13 = f32[1]{0} convert(s32[1]{0} %select.76.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.79.9 = f32[1]{0} multiply(f32[1]{0} %convert.58.13, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.38.9 = f32[1]{0} subtract(f32[1]{0} %add.94.9, f32[1]{0} %multiply.79.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.80.13 = f32[1]{0} multiply(f32[1]{0} %subtract.38.9, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.42.13 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[9:10]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(9,) limit_indices=(10,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.96.13 = f32[1]{0} add(f32[1]{0} %multiply.80.13, f32[1]{0} %slice.42.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.97.17 = f32[1]{0} add(f32[1]{0} %subtract.38.9, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.59.17 = pred[1]{0} compare(f32[1]{0} %add.97.17, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.77.17 = s32[1]{0} select(pred[1]{0} %compare.59.17, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.59.17 = f32[1]{0} convert(s32[1]{0} %select.77.17), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.81.13 = f32[1]{0} multiply(f32[1]{0} %convert.59.17, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.39.13 = f32[1]{0} subtract(f32[1]{0} %add.96.13, f32[1]{0} %multiply.81.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.82.9 = f32[1]{0} multiply(f32[1]{0} %subtract.39.13, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.43.9 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[10:11]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(10,) limit_indices=(11,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.98.9 = f32[1]{0} add(f32[1]{0} %multiply.82.9, f32[1]{0} %slice.43.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.99.13 = f32[1]{0} add(f32[1]{0} %subtract.39.13, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.60.13 = pred[1]{0} compare(f32[1]{0} %add.99.13, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.78.13 = s32[1]{0} select(pred[1]{0} %compare.60.13, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.60.13 = f32[1]{0} convert(s32[1]{0} %select.78.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.83.9 = f32[1]{0} multiply(f32[1]{0} %convert.60.13, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.40.9 = f32[1]{0} subtract(f32[1]{0} %add.98.9, f32[1]{0} %multiply.83.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.84.5 = f32[1]{0} multiply(f32[1]{0} %subtract.40.9, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.44.5 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[11:12]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(11,) limit_indices=(12,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.100.5 = f32[1]{0} add(f32[1]{0} %multiply.84.5, f32[1]{0} %slice.44.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.101.9 = f32[1]{0} add(f32[1]{0} %subtract.40.9, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.61.9 = pred[1]{0} compare(f32[1]{0} %add.101.9, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.79.9 = s32[1]{0} select(pred[1]{0} %compare.61.9, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.61.9 = f32[1]{0} convert(s32[1]{0} %select.79.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.85.5 = f32[1]{0} multiply(f32[1]{0} %convert.61.9, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.41.5 = f32[1]{0} subtract(f32[1]{0} %add.100.5, f32[1]{0} %multiply.85.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.86.13 = f32[1]{0} multiply(f32[1]{0} %subtract.41.5, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.45.13 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[12:13]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(12,) limit_indices=(13,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.102.13 = f32[1]{0} add(f32[1]{0} %multiply.86.13, f32[1]{0} %slice.45.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.103.17 = f32[1]{0} add(f32[1]{0} %subtract.41.5, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.62.17 = pred[1]{0} compare(f32[1]{0} %add.103.17, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.80.17 = s32[1]{0} select(pred[1]{0} %compare.62.17, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.62.17 = f32[1]{0} convert(s32[1]{0} %select.80.17), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.87.13 = f32[1]{0} multiply(f32[1]{0} %convert.62.17, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.42.13 = f32[1]{0} subtract(f32[1]{0} %add.102.13, f32[1]{0} %multiply.87.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.88.9 = f32[1]{0} multiply(f32[1]{0} %subtract.42.13, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.46.9 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[13:14]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(13,) limit_indices=(14,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.104.9 = f32[1]{0} add(f32[1]{0} %multiply.88.9, f32[1]{0} %slice.46.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.105.13 = f32[1]{0} add(f32[1]{0} %subtract.42.13, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.63.13 = pred[1]{0} compare(f32[1]{0} %add.105.13, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.81.13 = s32[1]{0} select(pred[1]{0} %compare.63.13, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.63.13 = f32[1]{0} convert(s32[1]{0} %select.81.13), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.89.9 = f32[1]{0} multiply(f32[1]{0} %convert.63.13, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.43.9 = f32[1]{0} subtract(f32[1]{0} %add.104.9, f32[1]{0} %multiply.89.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.90.5 = f32[1]{0} multiply(f32[1]{0} %subtract.43.9, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.47.5 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[14:15]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(14,) limit_indices=(15,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.106.5 = f32[1]{0} add(f32[1]{0} %multiply.90.5, f32[1]{0} %slice.47.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.107.9 = f32[1]{0} add(f32[1]{0} %subtract.43.9, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.64.9 = pred[1]{0} compare(f32[1]{0} %add.107.9, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.82.9 = s32[1]{0} select(pred[1]{0} %compare.64.9, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.64.9 = f32[1]{0} convert(s32[1]{0} %select.82.9), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.91.5 = f32[1]{0} multiply(f32[1]{0} %convert.64.9, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %subtract.44.5 = f32[1]{0} subtract(f32[1]{0} %add.106.5, f32[1]{0} %multiply.91.5), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %multiply.92.1 = f32[1]{0} multiply(f32[1]{0} %subtract.44.5, f32[1]{0} %constant_61_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %slice.48.1 = f32[1]{0} slice(f32[16]{0} %param_0.165), slice={[15:16]}, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/slice[start_indices=(15,) limit_indices=(16,) strides=(1,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %add.108.1 = f32[1]{0} add(f32[1]{0} %multiply.92.1, f32[1]{0} %slice.48.1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/add\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %add.109.1 = f32[1]{0} add(f32[1]{0} %subtract.44.5, f32[1]{0} %constant_126_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=2}\n", + " %compare.66.2 = pred[1]{0} compare(f32[1]{0} %add.109.1, f32[1]{0} %constant_5_1), direction=GT, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/gt\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %select.83.1 = s32[1]{0} select(pred[1]{0} %compare.66.2, s32[1]{0} %constant_127_2, s32[1]{0} %constant_128_2), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/jit(_where)/select_n\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %convert.65.2 = f32[1]{0} convert(s32[1]{0} %select.83.1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/jit(wrapped_fun)/convert_element_type[new_dtype=float32 weak_type=False]\" source_file=\"/home/legion/.local/lib/python3.10/site-packages/spyx/axn.py\" source_line=6}\n", + " %multiply.93.1 = f32[1]{0} multiply(f32[1]{0} %convert.65.2, f32[1]{0} %constant_65_1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/mul\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " ROOT %subtract.45.1 = f32[1]{0} subtract(f32[1]{0} %add.108.1, f32[1]{0} %multiply.93.1), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + "}\n", + "\n", + "ENTRY %main.414 (Arg_0.1.0: f32[16]) -> (f32[1], f32[16,1]) {\n", + " %Arg_0.1.0 = f32[16]{0} parameter(0), metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/squeeze[dimensions=(0,)]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " %loop_subtract_fusion = f32[1]{0} fusion(f32[16]{0} %Arg_0.1.0), kind=kLoop, calls=%fused_subtract, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/sub\" source_file=\"/tmp/ipykernel_73553/3145203169.py\" source_line=3}\n", + " %loop_concatenate_fusion = f32[16,1]{1,0} fusion(f32[16]{0} %Arg_0.1.0), kind=kLoop, calls=%fused_concatenate, metadata={op_name=\"jit(run_lif_unrolled)/jit(main)/while/body/concatenate[dimension=0]\" source_file=\"/tmp/ipykernel_73553/283343340.py\" source_line=11}\n", + " ROOT %tuple.413.0 = (f32[1]{0}, f32[16,1]{1,0}) tuple(f32[1]{0} %loop_subtract_fusion, f32[16,1]{1,0} %loop_concatenate_fusion)\n", + "}\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(run_lif_unrolled.lower(x_in).compile().as_text())" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "151 µs ± 9.39 µs per loop (mean ± std. dev. of 20 runs, 1000 loops each)\n" + ] + } + ], + "source": [ + "%timeit -r20 run_lif(x_in)[0].block_until_ready()" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "32.9 µs ± 2.03 µs per loop (mean ± std. dev. of 20 runs, 10000 loops each)\n" + ] + } + ], + "source": [ + "%timeit -r20 run_lif_unrolled(x_in)[0].block_until_ready()" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def lif_kernel(V_ref, X_ref, out_X): # carry, x\n", + " \n", + " v, x = V_ref[...], X_ref[...]\n", + " v, x = lif_neuron(v, x)\n", + "\n", + " out_X[...] = x\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def pallas_lif(v, x: jax.Array) -> jax.Array:\n", + " bspec = pl.BlockSpec(block_shape=(1,), index_map=lambda i: i)\n", + " return pl.pallas_call(\n", + " lif_kernel,\n", + " out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype),\n", + " grid=(1,),\n", + " in_specs=[pl.BlockSpec(lambda i: i, (1,)), pl.BlockSpec(lambda i: i, (1,))],\n", + " out_specs=pl.BlockSpec(lambda i: i, (1,))\n", + " )(v0, x_in)" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "ename": "LoweringError", + "evalue": "Exception while lowering eqn:\n a\u001b[35m:f32[1]\u001b[39m = pjit[\n name=wrapped_fun\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; b\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mc\u001b[35m:f32[1]\u001b[39m = custom_vjp_call_jaxpr[\n bwd=. at 0x7c885a4d8310>\n fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; d\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22me\u001b[35m:bool[1]\u001b[39m = gt d 0.0\n f\u001b[35m:i32[1]\u001b[39m = pjit[\n name=_where\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; g\u001b[35m:bool[1]\u001b[39m h\u001b[35m:i32[]\u001b[39m i\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mj\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] h\n k\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] i\n l\u001b[35m:i32[1]\u001b[39m = select_n g k j\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(l,) }\n ] e 1 0\n m\u001b[35m:f32[1]\u001b[39m = convert_element_type[new_dtype=float32 weak_type=False] f\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(m,) }\n fwd_jaxpr_thunk=.memoized at 0x7c885a4d8c10>\n num_consts=0\n out_trees=. at 0x7c885a4d8160>\n symbolic_zeros=False\n ] b\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }\n] n\nWith context:\n LoweringRuleContext(context=ModuleContext(name='lif_kernel', grid_mapping=GridMapping(grid=(1,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(a,) }, indexing_mode=), BlockMapping(block_shape=(1,), index_map_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(a,) }, indexing_mode=), BlockMapping(block_shape=(1,), index_map_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(a,) }, indexing_mode=)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[]), avals_in=[ShapedArray(float32[1])], avals_out=[ShapedArray(float32[1])], block_infos=[None])\nWith inval types=[RankedTensorType(tensor<1xf32>)]\nIn jaxpr:\n{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:MemRef{float32[1]}\u001b[39m b\u001b[35m:MemRef{float32[1]}\u001b[39m c\u001b[35m:MemRef{float32[1]}\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[1]\u001b[39m <- \u001b[32ma[:]\n \u001b[39me\u001b[35m:f32[1]\u001b[39m <- \u001b[32mb[:]\n \u001b[39mf\u001b[35m:f32[1]\u001b[39m = sub d 1.100000023841858\n g\u001b[35m:f32[1]\u001b[39m = pjit[\n name=wrapped_fun\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; h\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mi\u001b[35m:f32[1]\u001b[39m = custom_vjp_call_jaxpr[\n bwd=. at 0x7c885a4d8310>\n fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; j\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mk\u001b[35m:bool[1]\u001b[39m = gt j 0.0\n l\u001b[35m:i32[1]\u001b[39m = pjit[\n name=_where\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; m\u001b[35m:bool[1]\u001b[39m n\u001b[35m:i32[]\u001b[39m o\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mp\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] n\n q\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] o\n r\u001b[35m:i32[1]\u001b[39m = select_n m q p\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(r,) }\n ] k 1 0\n s\u001b[35m:f32[1]\u001b[39m = convert_element_type[\n new_dtype=float32\n weak_type=False\n ] l\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(s,) }\n fwd_jaxpr_thunk=.memoized at 0x7c885a4d8c10>\n num_consts=0\n out_trees=. at 0x7c885a4d8160>\n symbolic_zeros=False\n ] h\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(i,) }\n ] f\n \u001b[32mc[:]\u001b[39m <- g\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m() }", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/lib/python3.10/runpy.py\u001b[0m in \u001b[0;36m_run_module_as_main\u001b[0;34m()\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmod_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0morigin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 196\u001b[0;31m return _run_code(code, main_globals, None,\n\u001b[0m\u001b[1;32m 197\u001b[0m \"__main__\", mod_spec)\n", + "\u001b[0;32m/usr/lib/python3.10/runpy.py\u001b[0m in \u001b[0;36m_run_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 85\u001b[0m __spec__ = mod_spec)\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_globals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrun_globals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel_launcher.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch_new_instance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/traitlets/config/application.py\u001b[0m in \u001b[0;36mlaunch_instance\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 724\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 725\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mio_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 726\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0masyncio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_event_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masyncio_loop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masyncio_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_forever\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py\u001b[0m in \u001b[0;36mrun_forever\u001b[0;34m()\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 603\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 604\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stopping\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py\u001b[0m in \u001b[0;36m_run_once\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1908\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1909\u001b[0;31m \u001b[0mhandle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1910\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# Needed to break cycles when an exception occurs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/events.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m()\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mSystemExit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mdispatch_queue\u001b[0;34m()\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 513\u001b[0;31m \u001b[0;32mawait\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_one\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 514\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mprocess_one\u001b[0;34m()\u001b[0m\n\u001b[1;32m 501\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 502\u001b[0;31m \u001b[0;32mawait\u001b[0m \u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 503\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mdispatch_shell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 408\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misawaitable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 409\u001b[0;31m \u001b[0;32mawait\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 410\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mexecute_request\u001b[0;34m()\u001b[0m\n\u001b[1;32m 728\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minspect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misawaitable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreply_content\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 729\u001b[0;31m \u001b[0mreply_content\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mawait\u001b[0m \u001b[0mreply_content\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 730\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py\u001b[0m in \u001b[0;36mdo_execute\u001b[0;34m()\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msilent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 430\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py\u001b[0m in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_last_traceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 540\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 541\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2913\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2914\u001b[0;31m result = self._run_cell(\n\u001b[0m\u001b[1;32m 2915\u001b[0m raw_cell, store_history, silent, shell_futures)\n", + "\u001b[0;32m/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36m_run_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2959\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2960\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mrunner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoro\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2961\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mBaseException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3/dist-packages/IPython/core/async_helpers.py\u001b[0m in \u001b[0;36m_pseudo_sync_runner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mcoro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell_async\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3184\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3185\u001b[0;31m has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n\u001b[0m\u001b[1;32m 3186\u001b[0m interactivity=interactivity, compiler=compiler, result=result)\n", + "\u001b[0;32m/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_ast_nodes\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3376\u001b[0m \u001b[0masy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3377\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mawait\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0masync_\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0masy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3378\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3456\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3457\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_global_ns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3458\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_73553/852205781.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpallas_lif\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_in\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/tmp/ipykernel_73553/1207246950.py\u001b[0m in \u001b[0;36mpallas_lif\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mbspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBlockSpec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mblock_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex_map\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m return pl.pallas_call(\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mlif_kernel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\u001b[0m in \u001b[0;36mwrapped\u001b[0;34m()\u001b[0m\n\u001b[1;32m 588\u001b[0m for v in flat_out_shapes)\n\u001b[0;32m--> 589\u001b[0;31m grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(\n\u001b[0m\u001b[1;32m 590\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrid_spec\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflat_in_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflat_out_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\u001b[0m in \u001b[0;36m_trace_to_jaxpr\u001b[0;34m()\u001b[0m\n\u001b[1;32m 488\u001b[0m \u001b[0mdebug\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr_in_tree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree_thunk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"pallas_call\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 489\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwrapped_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr_flat_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 490\u001b[0m \u001b[0mjaxpr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_hoist_consts_to_refs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_73553/3887908263.py\u001b[0m in \u001b[0;36mlif_kernel\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mV_ref\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX_ref\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlif_neuron\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/tmp/ipykernel_73553/3145203169.py\u001b[0m in \u001b[0;36mlif_neuron\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mlif_neuron\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# carry, x\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mspikes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mactivation_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mV\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mV\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.9\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mV\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mspikes\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m1.1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m: NotImplementedError: Unimplemented primitive in Pallas GPU lowering: custom_vjp_call_jaxpr. Please file an issue on https://github.com/google/jax/issues.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_triton_ir\u001b[0;34m(ctx, jaxpr, block_infos, *args)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0msource_info_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msource_info\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraceback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 351\u001b[0;31m \u001b[0moutvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrule_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 352\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mLoweringError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36m_pjit_lowering_rule\u001b[0;34m(ctx, jaxpr, *args, **_)\u001b[0m\n\u001b[1;32m 2080\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2081\u001b[0;31m return lower_jaxpr_to_triton_ir(\n\u001b[0m\u001b[1;32m 2082\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblock_infos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_triton_ir\u001b[0;34m(ctx, jaxpr, block_infos, *args)\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprimitive\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtriton_lowering_rules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 337\u001b[0;31m raise NotImplementedError(\n\u001b[0m\u001b[1;32m 338\u001b[0m \u001b[0;34m\"Unimplemented primitive in Pallas GPU lowering: \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNotImplementedError\u001b[0m: Unimplemented primitive in Pallas GPU lowering: custom_vjp_call_jaxpr. Please file an issue on https://github.com/google/jax/issues.", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mLoweringError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_73553/852205781.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpallas_lif\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx_in\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + " \u001b[0;31m[... skipping hidden 23 frame]\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/pallas_call.py\u001b[0m in \u001b[0;36m_pallas_call_lowering\u001b[0;34m(ctx, interpret, *in_nodes, **params)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 531\u001b[0;31m return pallas_call_registration.pallas_call_lowering(\n\u001b[0m\u001b[1;32m 532\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0min_nodes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minterpret\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minterpret\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 533\u001b[0m )\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/pallas_call_registration.py\u001b[0m in \u001b[0;36mpallas_call_lowering\u001b[0;34m(ctx, jaxpr, name, in_shapes, out_shapes, which_linear, interpret, debug, input_output_aliases, grid_mapping, compiler_params, *in_nodes)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0mlowering_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_pallas_call_ptx_lowering\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 306\u001b[0;31m return lowering_fn(\n\u001b[0m\u001b[1;32m 307\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0min_nodes\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/pallas_call_registration.py\u001b[0m in \u001b[0;36m_pallas_call_ttir_lowering\u001b[0;34m(ctx, jaxpr, name, in_shapes, out_shapes, debug, input_output_aliases, grid_mapping, triton_params, num_warps, num_stages, *in_nodes)\u001b[0m\n\u001b[1;32m 203\u001b[0m )\n\u001b[1;32m 204\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m lowering_result = lowering.lower_jaxpr_to_triton_module(\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_shapes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mout_shapes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrid_mapping\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcuda_options\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m )\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_triton_module\u001b[0;34m(jaxpr, in_shapes, grid_mapping, name, cuda_options)\u001b[0m\n\u001b[1;32m 302\u001b[0m )\n\u001b[1;32m 303\u001b[0m ]\n\u001b[0;32m--> 304\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlower_jaxpr_to_triton_ir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock_infos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mentry\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marguments\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 305\u001b[0m \u001b[0mtt_dialect\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mLoweringResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_grid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_triton_ir\u001b[0;34m(ctx, jaxpr, block_infos, *args)\u001b[0m\n\u001b[1;32m 349\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0msource_info_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msource_info\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraceback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 351\u001b[0;31m \u001b[0moutvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrule_ctx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0minvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meqn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 352\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mLoweringError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0;31m# We only add the extra info to the innermost exception.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36m_pjit_lowering_rule\u001b[0;34m(ctx, jaxpr, *args, **_)\u001b[0m\n\u001b[1;32m 2079\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsts\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2080\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2081\u001b[0;31m return lower_jaxpr_to_triton_ir(\n\u001b[0m\u001b[1;32m 2082\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblock_infos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2083\u001b[0m )\n", + "\u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/pallas/triton/lowering.py\u001b[0m in \u001b[0;36mlower_jaxpr_to_triton_ir\u001b[0;34m(ctx, jaxpr, block_infos, *args)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 355\u001b[0m \u001b[0minval_types\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"type\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 356\u001b[0;31m raise LoweringError(\n\u001b[0m\u001b[1;32m 357\u001b[0m \u001b[0;34mf\"Exception while lowering eqn:\\n {eqn}\\nWith context:\\n \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;34mf\" {rule_ctx}\\nWith inval types={inval_types}\\nIn jaxpr:\\n{jaxpr}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mLoweringError\u001b[0m: Exception while lowering eqn:\n a\u001b[35m:f32[1]\u001b[39m = pjit[\n name=wrapped_fun\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; b\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mc\u001b[35m:f32[1]\u001b[39m = custom_vjp_call_jaxpr[\n bwd=. at 0x7c885a4d8310>\n fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; d\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22me\u001b[35m:bool[1]\u001b[39m = gt d 0.0\n f\u001b[35m:i32[1]\u001b[39m = pjit[\n name=_where\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; g\u001b[35m:bool[1]\u001b[39m h\u001b[35m:i32[]\u001b[39m i\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mj\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] h\n k\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] i\n l\u001b[35m:i32[1]\u001b[39m = select_n g k j\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(l,) }\n ] e 1 0\n m\u001b[35m:f32[1]\u001b[39m = convert_element_type[new_dtype=float32 weak_type=False] f\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(m,) }\n fwd_jaxpr_thunk=.memoized at 0x7c885a4d8c10>\n num_consts=0\n out_trees=. at 0x7c885a4d8160>\n symbolic_zeros=False\n ] b\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }\n] n\nWith context:\n LoweringRuleContext(context=ModuleContext(name='lif_kernel', grid_mapping=GridMapping(grid=(1,), block_mappings=(BlockMapping(block_shape=(1,), index_map_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(a,) }, indexing_mode=), BlockMapping(block_shape=(1,), index_map_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(a,) }, indexing_mode=), BlockMapping(block_shape=(1,), index_map_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(a,) }, indexing_mode=)), mapped_dims=(), num_index_operands=0, num_scratch_operands=0), program_ids=[]), avals_in=[ShapedArray(float32[1])], avals_out=[ShapedArray(float32[1])], block_infos=[None])\nWith inval types=[RankedTensorType(tensor<1xf32>)]\nIn jaxpr:\n{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:MemRef{float32[1]}\u001b[39m b\u001b[35m:MemRef{float32[1]}\u001b[39m c\u001b[35m:MemRef{float32[1]}\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22md\u001b[35m:f32[1]\u001b[39m <- \u001b[32ma[:]\n \u001b[39me\u001b[35m:f32[1]\u001b[39m <- \u001b[32mb[:]\n \u001b[39mf\u001b[35m:f32[1]\u001b[39m = sub d 1.100000023841858\n g\u001b[35m:f32[1]\u001b[39m = pjit[\n name=wrapped_fun\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; h\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mi\u001b[35m:f32[1]\u001b[39m = custom_vjp_call_jaxpr[\n bwd=. at 0x7c885a4d8310>\n fun_jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; j\u001b[35m:f32[1]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mk\u001b[35m:bool[1]\u001b[39m = gt j 0.0\n l\u001b[35m:i32[1]\u001b[39m = pjit[\n name=_where\n jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; m\u001b[35m:bool[1]\u001b[39m n\u001b[35m:i32[]\u001b[39m o\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n \u001b[39m\u001b[22m\u001b[22mp\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] n\n q\u001b[35m:i32[1]\u001b[39m = broadcast_in_dim[\n broadcast_dimensions=()\n shape=(1,)\n ] o\n r\u001b[35m:i32[1]\u001b[39m = select_n m q p\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(r,) }\n ] k 1 0\n s\u001b[35m:f32[1]\u001b[39m = convert_element_type[\n new_dtype=float32\n weak_type=False\n ] l\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(s,) }\n fwd_jaxpr_thunk=.memoized at 0x7c885a4d8c10>\n num_consts=0\n out_trees=. at 0x7c885a4d8160>\n symbolic_zeros=False\n ] h\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(i,) }\n ] f\n \u001b[32mc[:]\u001b[39m <- g\n \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m() }" + ] + } + ], + "source": [ + "pallas_lif(v0, x_in)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}