diff --git a/flax/linen/attention.py b/flax/linen/attention.py index cfb83cd260..ff92ce7a66 100644 --- a/flax/linen/attention.py +++ b/flax/linen/attention.py @@ -337,6 +337,7 @@ def _apply(self, name='out')(x) return out + @dataclass # Keep MyPy happy. class MultiHeadDotProductAttention(_BaseMultiHeadDotProductAttention): """Multi-head dot-product attention. diff --git a/flax/linen/linear.py b/flax/linen/linear.py index fa11bcd56d..b703c7243d 100644 --- a/flax/linen/linear.py +++ b/flax/linen/linear.py @@ -30,6 +30,7 @@ PRNGKey = Any Shape = Tuple[int, ...] InexactDType = Type[np.inexact] +NumericDType = Type[np.number] GenericDType = Type[np.generic] Array = Any Initializer = Callable[[PRNGKey, Shape, InexactDType], Array] @@ -66,6 +67,22 @@ def _canonicalize_dtypes( return returned_param_dtype, dtype +def _canonicalize_numeric_dtypes( + input_dtype: NumericDType, + param_dtype: Optional[NumericDType], + computation_dtype: Optional[NumericDType]) -> Tuple[NumericDType, + NumericDType]: + returned_param_dtype = input_dtype if param_dtype is None else param_dtype + dtype = (jnp.result_type(input_dtype, returned_param_dtype) + if computation_dtype is None else computation_dtype) + + assert np.issubdtype(input_dtype, np.number) + if np.issubdtype(input_dtype, np.complexfloating): + assert np.issubdtype(returned_param_dtype, np.complexfloating) + assert np.issubdtype(dtype, np.complexfloating) + return returned_param_dtype, dtype + + class DenseGeneral(Module): """A linear transformation with flexible axes. @@ -262,8 +279,8 @@ class Conv(Module): kernel_dilation: Union[None, int, Sequence[int]] = 1 feature_group_count: int = 1 use_bias: bool = True - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[NumericDType] = None + param_dtype: Optional[NumericDType] = None precision: Any = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = zeros @@ -282,8 +299,9 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) inputs = jnp.asarray(inputs, dtype) if isinstance(self.kernel_size, int): @@ -384,8 +402,8 @@ class ConvTranspose(Module): padding: Union[str, Sequence[Tuple[int, int]]] = 'SAME' kernel_dilation: Optional[Sequence[int]] = None use_bias: bool = True - dtype: Optional[InexactDType] = None - param_dtype: Optional[InexactDType] = None + dtype: Optional[NumericDType] = None + param_dtype: Optional[NumericDType] = None precision: Any = None kernel_init: Initializer = default_kernel_init bias_init: Initializer = zeros @@ -405,8 +423,9 @@ def __call__(self, inputs: Array) -> Array: Returns: The convolved data. """ - param_dtype, dtype = _canonicalize_dtypes(inputs.dtype, self.param_dtype, - self.dtype) + param_dtype, dtype = _canonicalize_numeric_dtypes(inputs.dtype, + self.param_dtype, + self.dtype) inputs = jnp.asarray(inputs, dtype) kernel_size: Tuple[int, ...] diff --git a/flax/linen/module.py b/flax/linen/module.py index 2cd5ad1195..6e74f005cb 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -554,9 +554,9 @@ def _customized_dataclass_transform(cls): # Remove 'parent' and 'name' from parents because we always want parent # and name to show up last in the dataclass args. if 'parent' in pdf: - pdf.pop('parent') + clz.__dataclass_fields__.pop('parent') # pytype: disable=attribute-error if 'name' in pdf: - pdf.pop('name') + clz.__dataclass_fields__.pop('name') # pytype: disable=attribute-error annotations['parent'] = parent_annotation cls.parent = dataclasses.field(repr=False, default=_unspecified_parent) diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index a286378e64..fbb99912e9 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -155,8 +155,8 @@ class BatchNorm(Module): momentum: decay rate for the exponential moving average of the batch statistics. epsilon: a small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). use_bias: if True, bias (beta) is added. use_scale: if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled @@ -256,8 +256,8 @@ class LayerNorm(Module): Attributes: epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). - param_dtype: the dtype passed to parameter initializers (default: float32). + dtype: the dtype of the computation (default: None). + param_dtype: the dtype passed to parameter initializers (default: None). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done @@ -313,9 +313,9 @@ class GroupNorm(Module): proposed by the original group normalization paper. group_size: the number of channels in a group. epsilon: A small float added to variance to avoid dividing by zero. - dtype: the dtype of the computation (default: float32). + dtype: the dtype of the computation (default: None). param_dtype: the dtype passed to parameter initializers (default: - float32). + None). use_bias: If True, bias (beta) is added. use_scale: If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index 6f8266f7b9..cd8d9ca56b 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -310,7 +310,7 @@ def setup(self): # NOTE that keys still must be strings. This is to make a possible # future transition to automatically derived parameter names when assigned # as a dict easier (like we currently have with submodules). - # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 + # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 str(i): self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4)} def __call__(self, x): @@ -657,8 +657,8 @@ def __call__(self, x): # attributes features = 3 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros @@ -667,8 +667,8 @@ def __call__(self, x): # attributes features = 2 use_bias = True - dtype = float32 - param_dtype = float32 + dtype = None + param_dtype = None precision = None kernel_init = init bias_init = zeros @@ -1394,14 +1394,14 @@ def test_rng_reuse_after_rewind(self): class C(nn.Module): @nn.compact def __call__(self): - # Some module that has dropouts in it, in general, + # Some module that has dropouts in it, in general, # it does more than just dropout! return self.make_rng('dropout') class A(nn.Module): @nn.compact def __call__(self): - # Some module that has dropouts in it, in general, + # Some module that has dropouts in it, in general, # it does more than just dropout! return C()()