Skip to content

Commit

Permalink
properly mark static fields
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 5, 2022
1 parent e0427cb commit 029476b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion palm_jax/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class PaLM(Module):
embedding: np.ndarray
norm: Module
layers: List[List[Module]]
inv_freq: onp.ndarray
inv_freq: onp.ndarray = static_field()

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions palm_jax/palm_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class PaLM(Module):
embedding: np.ndarray
norm: Module
layers: List[List[Module]]
attn_bias: onp.ndarray
attn_bias: onp.ndarray = static_field()

def __init__(
self,
Expand All @@ -149,7 +149,7 @@ def __init__(
max_seq_len = 2048,
mask_value = -1e10
):
self.embedding = random.normal(key, (num_tokens, dim)) * 0.02
self.embedding = random.normal(key, (num_tokens, dim)) * 0.02

causal_mask = onp.tril(onp.ones((max_seq_len, max_seq_len)))
alibi_bias = calc_alibi_bias(max_seq_len, heads = heads)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'PaLM-jax',
packages = find_packages(exclude=[]),
version = '0.1.1',
version = '0.1.2',
license='MIT',
description = 'PaLM: Scaling Language Modeling with Pathways - Jax',
author = 'Phil Wang',
Expand Down

0 comments on commit 029476b

Please sign in to comment.