Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend data type support (for bfloat16 in particular) #2656

Open
nenb opened this issue Jan 6, 2025 · 6 comments
Open

Extend data type support (for bfloat16 in particular) #2656

nenb opened this issue Jan 6, 2025 · 6 comments

Comments

@nenb
Copy link

nenb commented Jan 6, 2025

Problem
I would like to read/write numpy dtype extensions (such as bfloat16) with zarr version 2. I am using ml_dtypes from JAX for the dtype extensions.

import numpy as np
import ml_dtypes
import zarr

arr = np.array([ml_dtypes.bfloat16(1)])
zarr.save('example.zarr', arr)  # ValueError: No cast function available. 

I experience a similar issue when trying to read such dtype extensions.

The problem is related to the extensibility (or lack thereof) of the kind codes in numpy. It is well described by the JAX team.

Background
bfloat16 is a very important dtype in the AI/ML community. I would like to use zarr (and specifically the Python implementation) to share models such as LLMs. However, the lack of bfloat16 support is a major blocker.

Questions

  1. Is this something that could be resolved with zarr v2?
  2. Is this something that I could resolve using zarr v3 today?
  3. If the answer to the previous questions was no, what would be required to support it in zarr v3 in the future?

Related issues
#711

cc @jhamman (as suggested by @TomNicholas)

@d-v-b
Copy link
Contributor

d-v-b commented Jan 7, 2025

I think this should be a priority for zarr v3, and I can't see any technical barriers for it. Would you have time / energy to work on a PR that would add this? Happy to give pointers for where to start.

@nenb
Copy link
Author

nenb commented Jan 7, 2025

Would you have time / energy to work on a PR that would add this?

Yes! And thank you. 😄

Happy to give pointers for where to start.

Pointers would be very welcome!

@d-v-b
Copy link
Contributor

d-v-b commented Jan 7, 2025

this is the function that we use for converting user input into a concrete dtype object:

def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
. You would need to test that the strings bfloat16, as well as the concrete bfloat16 dtype emerge on the other side of this function as concrete instances of bfloat16.

because zarr v3 has its own datatype specification that's designed to be decoupled from numpy, we have a separate parsing step for creating a zarr v3 metadata document:

def parse(cls, dtype: DataType | Any | None) -> DataType:
. If you look at the DataType object defined there, you see that its basically an enum that contains some helper functions for mapping variants of that enum to strings that can be serialized to zarr metadata documents, and strings that can be used to make a numpy dtype. I think the v3 spec already reserved the name float16 ( you should check if bfloat16 is the same as the "IEEE 754 half-precision floating point: sign bit, 5 bits exponent, 10 bits mantissa" referenced by the spec -- if so, we can just use the float16 name for the metadata, otherwise we need to make a new name).

Once the metadata stuff is ironed out, you would need to check that zarr-python can correctly make arrays with this dtype -- i think we have some tests that are parametrized over dtype, so drop float16 (or whatever string code you end up using) in there and see what happens

@nenb
Copy link
Author

nenb commented Jan 7, 2025

This is great. I'll tag you when I have a first draft.

Is there a benefit to trying to add support for v2 as well?

@d-v-b
Copy link
Contributor

d-v-b commented Jan 7, 2025

I think v2 support would also be great, provided the result complies with the requirements for dtypes in the v2 spec

@nenb
Copy link
Author

nenb commented Jan 7, 2025

@d-v-b I sketched some quick changes here. It's not a first draft, but contains things I could do with some pointers on if you have time (to stop me going in the wrong direction...). In particular, the comment on numpy kind codes could do with some feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants