Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace ndim with len(shape) #193

Conversation

dziulek
Copy link

@dziulek dziulek commented Mar 11, 2024

Hi!

I use tensorflow==2.9.1 and I came across a situation where tf.Tensor could not have ndim attribute at some point of execution. Moreover, what I found out is that the problem does not occur in tensorflow==2.16.
Here you have a simple code to reproduce the error.

import tensorflow as tf
import jaxtyping as jax
import beartype

@tf.function()
@jax.jaxtyped(typechecker=beartype.beartype)
def map_function(tensor: jax.Float[tf.Tensor, "h w c"]) -> jax.Float[tf.Tensor, "h w c"]:
    return 1 - tensor

def main():

    tf.config.run_functions_eagerly(True)
    tf.data.experimental.enable_debug_mode()

    dataset = tf.data.Dataset.from_tensor_slices(tensors=tf.random.uniform((100,30,30,3)))
    dataset = dataset.map(map_function)

if __name__ == "__main__":

    main()

This is what i get
AttributeError: 'Tensor' object has no attribute 'ndim'

Unfortunately I'm pinned to the 2.9.1 version of tensorflow. Let me know what do you think.

@patrick-kidger patrick-kidger merged commit 07e58de into patrick-kidger:main Mar 11, 2024
1 check passed
@patrick-kidger
Copy link
Owner

LGTM! No idea why they don't have an ndim attribute, but happy to support this use-case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants