Skip to content

Commit

Permalink
Correct errors
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Jan 25, 2022
1 parent d6e5b22 commit 799958d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 23 deletions.
1 change: 1 addition & 0 deletions flax/linen/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def _apply(self,
name='out')(x)
return out


@dataclass # Keep MyPy happy.
class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention):
"""Multi-head dot-product attention.
Expand Down
35 changes: 27 additions & 8 deletions flax/linen/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
PRNGKey = Any
Shape = Tuple[int, ...]
InexactDType = Type[np.inexact]
NumericDType = Type[np.number]
GenericDType = Type[np.generic]
Array = Any
Initializer = Callable[[PRNGKey, Shape, InexactDType], Array]
Expand Down Expand Up @@ -66,6 +67,22 @@ def _canonicalize_dtypes(
return returned_param_dtype, dtype


def _canonicalize_numeric_dtypes(
input_dtype: NumericDType,
param_dtype: Optional[NumericDType],
computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType,
NumericDType]:
returned_param_dtype = input_dtype if param_dtype is None else param_dtype
dtype = (jnp.result_type(input_dtype, returned_param_dtype)
if computation_dtype is None else computation_dtype)

assert np.issubdtype(input_dtype, np.number)
if np.issubdtype(input_dtype, np.complexfloating):
assert np.issubdtype(returned_param_dtype, np.complexfloating)
assert np.issubdtype(dtype, np.complexfloating)
return returned_param_dtype, dtype


class DenseGeneral(Module):
"""A linear transformation with flexible axes.
Expand Down Expand Up @@ -262,8 +279,8 @@ class Conv(Module):
kernel_dilation: Union[None, int, Sequence[int]] = 1
feature_group_count: int = 1
use_bias: bool = True
dtype: Optional[InexactDType] = None
param_dtype: Optional[InexactDType] = None
dtype: Optional[NumericDType] = None
param_dtype: Optional[NumericDType] = None
precision: Any = None
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = zeros
Expand All @@ -282,8 +299,9 @@ def __call__(self, inputs: Array) -> Array:
Returns:
The convolved data.
"""
param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype,
self.dtype)
param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype,
self.param_dtype,
self.dtype)
inputs = jnp.asarray(inputs, dtype)

if isinstance(self.kernel_size, int):
Expand Down Expand Up @@ -384,8 +402,8 @@ class ConvTranspose(Module):
padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME'
kernel_dilation: Optional[Sequence[int]] = None
use_bias: bool = True
dtype: Optional[InexactDType] = None
param_dtype: Optional[InexactDType] = None
dtype: Optional[NumericDType] = None
param_dtype: Optional[NumericDType] = None
precision: Any = None
kernel_init: Initializer = default_kernel_init
bias_init: Initializer = zeros
Expand All @@ -405,8 +423,9 @@ def __call__(self, inputs: Array) -> Array:
Returns:
The convolved data.
"""
param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype,
self.dtype)
param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype,
self.param_dtype,
self.dtype)
inputs = jnp.asarray(inputs, dtype)

kernel_size: Tuple[int, ...]
Expand Down
4 changes: 2 additions & 2 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,9 @@ def _customized_dataclass_transform(cls):
# Remove 'parent' and 'name' from parents because we always want parent
# and name to show up last in the dataclass args.
if 'parent' in pdf:
pdf.pop('parent')
clz.__dataclass_fields__.pop('parent') # pytype: disable=attribute-error
if 'name' in pdf:
pdf.pop('name')
clz.__dataclass_fields__.pop('name') # pytype: disable=attribute-error

annotations['parent'] = parent_annotation
cls.parent = dataclasses.field(repr=False, default=_unspecified_parent)
Expand Down
12 changes: 6 additions & 6 deletions flax/linen/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ class BatchNorm(Module):
momentum: decay rate for the exponential moving average of
the batch statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
param_dtype: the dtype passed to parameter initializers (default: float32).
dtype: the dtype of the computation (default: None).
param_dtype: the dtype passed to parameter initializers (default: None).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma).
When the next layer is linear (also e.g. nn.relu), this can be disabled
Expand Down Expand Up @@ -256,8 +256,8 @@ class LayerNorm(Module):
Attributes:
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
param_dtype: the dtype passed to parameter initializers (default: float32).
dtype: the dtype of the computation (default: None).
param_dtype: the dtype passed to parameter initializers (default: None).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
Expand Down Expand Up @@ -313,9 +313,9 @@ class GroupNorm(Module):
proposed by the original group normalization paper.
group_size: the number of channels in a group.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the computation (default: float32).
dtype: the dtype of the computation (default: None).
param_dtype: the dtype passed to parameter initializers (default:
float32).
None).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma). When the next layer is
linear (also e.g. nn.relu), this can be disabled since the scaling will
Expand Down
14 changes: 7 additions & 7 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def setup(self):
# NOTE that keys still must be strings. This is to make a possible
# future transition to automatically derived parameter names when assigned
# as a dict easier (like we currently have with submodules).
# See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853
# See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853
str(i): self.param(f'bias_{i}', initializers.ones, self.xshape)
for i in range(4)}
def __call__(self, x):
Expand Down Expand Up @@ -657,8 +657,8 @@ def __call__(self, x):
# attributes
features = 3
use_bias = True
dtype = float32
param_dtype = float32
dtype = None
param_dtype = None
precision = None
kernel_init = init
bias_init = zeros
Expand All @@ -667,8 +667,8 @@ def __call__(self, x):
# attributes
features = 2
use_bias = True
dtype = float32
param_dtype = float32
dtype = None
param_dtype = None
precision = None
kernel_init = init
bias_init = zeros
Expand Down Expand Up @@ -1394,14 +1394,14 @@ def test_rng_reuse_after_rewind(self):
class C(nn.Module):
@nn.compact
def __call__(self):
# Some module that has dropouts in it, in general,
# Some module that has dropouts in it, in general,
# it does more than just dropout!
return self.make_rng('dropout')

class A(nn.Module):
@nn.compact
def __call__(self):
# Some module that has dropouts in it, in general,
# Some module that has dropouts in it, in general,
# it does more than just dropout!
return C()()

Expand Down

0 comments on commit 799958d

Please sign in to comment.