-
Notifications
You must be signed in to change notification settings - Fork 650
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove float32 dtype assumption #1803
Conversation
1acac78
to
9de2d91
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
6d87556
to
925e849
Compare
ac1d33a
to
01fdbf7
Compare
Also, if possible, it would be nice to get #1812 in before this pull request to make type checking easier. |
4cf86a3
to
23f094d
Compare
66948f6
to
3a0d8d9
Compare
* Infer dtypes from inputs where possible. * LSTM dtype assumption persists; this is repaired in a separate pull request.
3a0d8d9
to
89a7b4f
Compare
Some internal issues are holding this back? |
Currently this PR is blocked on internal testing. I'm looking into it today though |
flax/linen/linear.py
Outdated
param_dtype: Optional[InexactDType], | ||
computation_dtype: Optional[InexactDType]) -> Tuple[InexactDType, | ||
InexactDType]: | ||
returned_param_dtype = input_dtype if param_dtype is None else param_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
param_dtype should still be float32 by default (see FLIP)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, sorry for missing this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it makes more sense to simply make the param_dtype
and computation_dtype
default to float32
on the modules? That way these helper functions can remain completely type agnostic, and just defer to the input dtype if None
is passed for the other types.
Also, a downside to this change is that making a module anything narrower than float32
now requires:
- changing both
dtype
andparam_dtype
, and - passing the appropriate inputs.
With the original PR, only the input would have to be made narrow, and everything else would be inferred.
flax/linen/linear.py
Outdated
|
||
assert jnp.issubdtype(input_dtype, jnp.number) | ||
if jnp.issubdtype(input_dtype, jnp.complexfloating): | ||
assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex numbers can still be projected by a real transformation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure that differentiation with respect to such parameters will work in that case? Consider:
from jax import grad
from functools import partial
import jax.numpy as jnp
x = jnp.zeros(3, dtype=jnp.complex64)
def f(x, w):
return jnp.sum(w @ x)
grad(partial(f, x), holomorphic=True)(jnp.eye(3, dtype=jnp.complex64)) # Okay.
grad(partial(f, x), holomorphic=True)(jnp.eye(3)) # TypeError: grad with holomorphic=True requires inputs with complex dtype, but got float32.
Anyway, I've removed the assertions. I guess it's the user's problem if she tries to differentiate heterogenous types.
flax/linen/linear.py
Outdated
assert jnp.issubdtype(input_dtype, jnp.number) | ||
if jnp.issubdtype(input_dtype, jnp.complexfloating): | ||
assert jnp.issubdtype(returned_param_dtype, jnp.complexfloating) | ||
assert jnp.issubdtype(dtype, jnp.complexfloating) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casting back to real is currently allowed. We should look into disabling this in JAX instead of doing it inconsistently in this PR I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, good point! For some reason I thought it was a narrowing error.
I agree that we shouldn't try to fix this here. Is there a Jax or numpy issue somewhere that you know of? Maybe we should at least create one?
flax/linen/linear.py
Outdated
return returned_param_dtype, dtype | ||
|
||
|
||
def _canonicalize_numeric_dtypes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to expose and document these APIs publicly so people can implement their own layers according to the spec.
I think you can have a canonicalize and canonicalize_dtype_inexact.
Also I think you should use the true dtypes of the params and not the param_dtype because users can cast params after construction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to expose and document these APIs publicly so people can implement their own layers according to the spec.
Yes, great point.
Did you want them in a new file? I'll expose them here for now, but let me know if you want them somewhere else.
Also I think you should use the true dtypes of the params
Sorry, I'm not sure what you want me to do here? The param_dtype
is an attribute of the module, which is supposed to be used to create the parameters. How can I get the "true dtype of the params" before I create them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would either make a dtypes.py or put it in module.py
Sorry, I'm not sure what you want me to do here?
You can pass the parameter values into jnp.result_type(*inputs, *params)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would either make a dtypes.py or put it in module.py
Sure, I added dtypes.py
since module.py
is getting large.
You can pass the parameter values into jnp.result_type(*inputs, *params)
Sorry, I still don't understand what you mean. The parameters don't exist at the point that canoncalize is called. The parameter creation depends on the output of canonicalize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the canonicalize_dtype should only infer the dtype if it is None. param_dtype defaults to float32 and cannot be inferred (so it's not Optional[Dtype] but Dtype). This way you can init the params and then use the params + inputs to infer the dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the canonicalize_dtype should only infer the dtype if it is None
Right, that's what it does.
This way you can init the params and then use the params + inputs to infer the dtype
Sorry, I'm really trying to understand you here, but I still don't see it. What is the difference between
- passing the parameter dtype to the canonicalize function, versus
- initializing the parameters using the parameter dtype, and passing the parameters to the canonicalize function?
Won't the exact same type inference happen either way?
param_dtype defaults to float32 and cannot be inferred (so it's not Optional[Dtype] but Dtype).
We can do that, but I don't see how that's an advantage. It just makes the canonicalize function less flexible without changing the behaviour when the parameter dtype is provided.
We can make the parameter dtype required on the modules though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The param_dtype argument is not the same as the parameter dtype. In 99% of the cases users will stick to the default float32. For example during eval the two can diverge because a user does something like:
eval_params = jax.tree_map(lambda x: x.asarray(jnp.bfloat16), params)
The reasoning behind this common trick is that small updates during training don't accumulate well in half precision but during eval the weights are static and using half precision is a free gain in most cases.
The core idea behind canonicalize dtype is to reproduce what you normally expect from an equivalent numpy function. So imagine that we wrote nn.Dense as a pure numpy function:
def dense(input, kernel, bias, dtype=None):
dtype = dtype or jnp.result_type(input, kernel, bias)
return input @ kernel + bias
Here the user has to provide kernel and bias so the dtypes of the params are determined by the user and don't depend on input.dtype. The default dtype is just np.result_type(input, kernel, bias).
The reason why we allow this dtype to be inferred is because it preserves the precision of the computation.
When inferring params however you are inferring behavior. e.g.: A complex linear mapping is something different than a linear real mapping even if you have a complex input. Similarly, a learned half precision dense does something very different in practice than a f32 dense layer even if the inputs are in half precision.
A second argument against a None default for param_dtype is that it would cause a big backwards incompatible change because many users really on f32 defaults for half precision inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reasoning behind this common trick is that small updates during training don't accumulate well in half precision but during eval the weights are static and using half precision is a free gain in most cases.
Okay! I see what you're saying now.
I've started this work here, but it gets tricky when there are submodules (how do I take into account the final DenseGeneral that depends on the first set of parameters, and itself has other parameters?). I'm not sure that this approach will work in general even if it looks okay for the simple modules. I'd have to canonicalize multiple times and produce various intermediate dtypes. And it gets even more complicated with the recursive modules that have a carry. This would be a mess.
Honestly, I think recasting the parameters that were created by the modules is a hack. You're essentially bypassing the public interface and accessing the parameters, which are akin to private variables. And it's also bad because it only works when you have the dtype
attribute set to None. If dtype
is set, then changing the parameter dtypes has no effect on the computation.
I think a much nicer solution for people who want to do "evaluation in half precision" is to reconstruct the modules with the dtype
and param_dtype
attributes set to half precision, and then transform the parameters through a public interface.
A complex linear mapping is something different than a linear real mapping even if you have a complex input.
Yes that's true, but (unless very special care is taken) the parameter cotangents will still be complex, so I imagine that in most cases, complex inputs imply complex parameters.
A second argument against a None default for param_dtype is that it would cause a big backwards incompatible change because many users really on f32 defaults for half precision inputs.
Right, that's why I switched all the defaults to float32. This way, this PR just provides the ability to specify computation and parameter dtype. It doesn't change behavior.
However, I still think you should deprecate this default in another PR. As you say: "The core idea behind canonicalize dtype is to reproduce what you normally expect from an equivalent numpy function." And the equivalent numpy function always produces outputs based on its input types. It doesn't silently widen everything to float32.
Also, if someone does want half-precision (or complex, or double precision) computation throughtout their network, they will need to select it for every module. I read the FLIP, but I think things can change in the future. Double precision is already starting to be as performant as single precision on some GPUs. It would be better not to bake in defaults that you can't easily remove.
@NeilGirdhar I finally found the time to look at the changes. Please have a look. |
Just a FYI that Avital is OOO for a few weeks, so I suppose he hasn't been checking his emails. |
Ah okay! No problem, I wasn't sure. I'll try to get this done today then 😄 |
@jheek First pass of the review is done. Please let me know about the above questions when you find more time 😄 |
359e7b3
to
f451757
Compare
I've been thinking about this some more, and I think a big source of our back-and-forth was the confusion that kernel = self.param('kernel', self.kernel_init, kernel_shape, param_dtype) can have a dtype other than While that is true, I think this should be prevented. The parameters and variables are created by module code that the user does not have access to. And as such the user should not be counting on any types or shapes of such parameters. All modules should be free to, in some future version of Flax, create structured parameter types like def kernel_init(rng, shape, dtype) -> KernelParamer: ... where At heart, the problem is that there are a variety of workflows that Flax protects despite them not being crystallized in public interfaces. In our above discussion, it's the ability to change the behavior of a module by recasting the parameters and variables it has created. This workflow might be a useful trick, but I think it should be exposed in a public interface. And are you even testing these workflows? I don't remember seeing a test where you initialize variables with a module M, transform them to have different dtypes, and then push them through apply on that original module M to verify that it still does the right thing. Allowing this workflow multiplies the amount of testing you need to do tremendously because there are all kinds of changes that a user could make to the variables. I propose the following:
T = TypeVar('T')
class Module:
@traceback_util.api_boundary
def transform_variables(self,
old_variables: FrozenVariableDict,
*args,
f: Callable[[T, T], T],
method: Optional[Callable[..., Any]] = None,
mutable: CollectionFilter = DenyList("intermediates"),
**kwargs: Any) -> FrozenVariableDict:
"""Transforms the variables created by a module method and returns modified
variables.
""" This just does an initialize as usual, but after creating a variable The workflow described above would then be my_module = SomeModule(...)
variables = my_module.init(...)
# train variables using my_module...
new_module = SomeModule(...) # with different types
new_variables = new_module.transform_variables(variables, ...)
new_module.apply(new_variables, ...) # guaranteed to work In short, What do you think? |
The variable collections are far from opaque. You make assumptions about its structure when doing transformations, taking gradients, when optimizing, etc. At the minimum they must be PyTree's of JAX arrays. This approach is much more flexible than you would think at first glance though. For example, you can "box" a param in an arbitrary dataclass and add as much metadata as you want to parameters. Still once you do a jax.tree_map another part of the code can transform it and skip over the internal metadata without a problem. |
Sorry, I've been meaning to get back to you on this, but I've had a lot of work to do.
Yes, you're right that you can treat the variables as a PyTree. What I was trying to get at by making it as opaque as possible is to make it so that users of a module X should not make assumptions about the objects that are placed into the variables tree by X. Otherwise, the user code will break when X changes. This ensures the separation of concerns. Users of a module X can configure its functionality by adjusting its dataclass fields. What you're proposing is a second way of changing its functionality by changing the data types of the contents of the variables. I understand that people are already doing this, but it is absolutely horrible. It's only one extra line of code to rebuild an appropriate module with the configuration you want, and then use that. Trying to make all of the modules robust to changes in dataclass types is going to be way too much work for your team. Besides actually making this work everywhere, you will need to test this (and there are no tests). It's also not easy to make canonicalize do this. You suggested that you can just put the array objects from the variables into the type inference. That does work for simple modules like Therefore, I think this is bad design. You should infer the computation dtype based on the module fields alone. If people want to cast the parameters, then they'll also have to reconstruct the module. Anyone who was relying on this undocumented behaviour will have to slightly modify their code, unfortunately. What do you think? |
I hate to bump this up again but... what's the status? |
@PhilipVinc I was still waiting on a reply to my last comment. Unfortunately, I've decided not to spend any more time working on any flax pull requests. You're welcome to lift my code into a pull request of your own if you like. |
@PhilipVinc @NeilGirdhar I'm taking over the implementation of the default dtype FLIP |
Fixes #1777
Checklist