Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd authored Mar 22, 2023
1 parent d631a07 commit 6166252
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ And we see that `weight`'s parameter is no longer transformed under the `Identit

Applying `stop_gradient` **within** the loss function, prevents the flow of gradients during forward or reverse-mode automatic differentiation.
```python
# Create simulated data.
import jax

# Create simulated data.
n = 100
key = jax.random.PRNGKey(123)
x = jax.random.uniform(key, (n, ))
Expand Down Expand Up @@ -158,7 +158,7 @@ from mytree import meta_map

# Function passed to `meta_map` has its argument as a `(meta, leaf)` tuple!
def if_trainable_then_10(meta_leaf):
meta_leaf
meta, leaf = meta_leaf
if meta.get("trainable", True):
return 10.0
else:
Expand Down

0 comments on commit 6166252

Please sign in to comment.