Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
Update benchmarks.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Mar 19, 2023
1 parent c5d22d2 commit e07b798
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions benchmarks/benchmarks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "3a7f5936",
"metadata": {},
"outputs": [],
Expand All @@ -24,33 +24,24 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "536f66e1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/danieldodd/miniconda3/lib/python3.10/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.\n",
" jax.tree_util.register_keypaths(\n"
]
}
],
"outputs": [],
"source": [
"from mytree import Mytree, param, Softplus\n",
"from mytree import Mytree, param_field, Softplus\n",
"\n",
"class Mytree_SubFoo(Mytree):\n",
" a: Float[Array, \"...\"] = param(Softplus)\n",
" b: Float[Array, \"...\"] = param(Softplus)\n",
" a: Float[Array, \"...\"] = param_field(bijector=Softplus)\n",
" b: Float[Array, \"...\"] = param_field(bijector=Softplus)\n",
"\n",
" def __init__(self, a, b):\n",
" self.a = a\n",
" self.b = b\n",
"\n",
"class Mytree_Foo(Mytree):\n",
" b: list[Mytree_SubFoo]\n",
" a: Float[Array, \"...\"] = param(Softplus)\n",
" a: Float[Array, \"...\"] = param_field(bijector=Softplus)\n",
"\n",
" def __init__(self, a, b):\n",
" self.a = a\n",
Expand All @@ -68,7 +59,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "1ce3a220",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -144,7 +135,7 @@
"source": [
"Run on a M1 Pro CPU.\n",
"\n",
"- **Initialisation**: is slower for mytree, due to it unpacking metadata, and working out what attributes are leaves of the nested pytree structure.\n",
"- **Initialisation**: is faster for mytree, despite it unpacking metadata, and working out what attributes are leaves of the nested pytree structure.\n",
"- **Transformations**: is faster for mytree.\n",
"- **Replacing attributes**: is faster for mytree implimentation.\n",
"\n",
Expand All @@ -153,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "8db39ca8",
"metadata": {},
"outputs": [
Expand All @@ -163,14 +154,14 @@
"text": [
"\n",
" mytree:\n",
"57.6 ms ± 2.55 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"820 ms ± 5.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"1.3 µs ± 31.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"50.7 ms ± 1.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"862 ms ± 13.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"1.34 µs ± 31.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n",
"\n",
" pytree:\n",
"50.8 ms ± 3.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"821 ms ± 4.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"1.66 µs ± 27.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n"
"51.8 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"893 ms ± 24.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"1.7 µs ± 44.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)\n"
]
}
],
Expand Down

0 comments on commit e07b798

Please sign in to comment.