Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for arbitrary image resolutions #24

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

LeviVasconcelos
Copy link

@LeviVasconcelos LeviVasconcelos commented Dec 16, 2022

Hi,

This PR adds support for arbitrary image resolution. Here's what I did to make it possible:

  • Rework Image Resizing layers as pointed out here
  • Rewrite Block and Unblock layers to use pure tensorflow: this was necessary because einops does not accept tensors as pattern arguments, which is necessary in some layers of maxim.
  • Realize that dim_u and dim_v could be substituted for the window_size squared.
  • Apply changes to the model itself. (edit: by that I meant substitute the layers for the TF ones, and plug in the necessary arguments for them to work)

I hope this helps, please let me know what you think.

you can test it quickly by: pytest -v maxim/tests/

Best,
Levi.

@@ -6,7 +6,7 @@
from tensorflow.keras import backend as K
from tensorflow.keras import layers

from ..layers import BlockImages, SwapAxes, UnblockImages
from ..layers import SwapAxes, TFBlockImagesByGrid, TFUnblockImages
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there's a separate layer for handling blocking by grids?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockByGrid can be implemented as follows, please see a more detailed explanation here:

def BlockByGrid(image, grid_size):
    block_size = (image_height // grid_size[0], image_width // grid_size[1])
    return BlockImage(image, block_size)

But, while implementing TFBlockImages I used tf.split which expects an int literal as argument for num_or_size_splits.

However, in cases where we only have the grid_size and the block_size has to be computed on the fly (as here), it needs to be a tensor, and we can't use tf.split ins this case. That's why I also wrote BlockByGrid.

Comment on lines -52 to +49
x = BlockImages()(x, patch_size=(fh, fw))
x, ph, pw = TFBlockImagesByGrid()(x, grid_size=(gh, gw))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come these operations are the same?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the original implementation, the authors implement BlockByGrid by computing the block size of a grid cell, and using BlockImages (which block images into patches of block-size).

From the paper, the authors explain the difference between "grid" and "block" like that:
Screenshot from 2022-12-16 12-53-08
Note that we can achieve the same result as the grid split by forwarding a block size of [3,2] instead. This is exactly what the authors do in the original code as highlighted here.

They are equivalent because it does the split based on the grid_size as argument instead of the block_size (called as (fh, fw) in the code) as the authors did.

A more formal test is performed here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining.

Note that we can achieve the same result as the grid split by forwarding a block size of [3,2] instead.

How is the block size of [3, 2] interpreted in that case?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the original code, it is done as explained, note here:

gh, gw = grid_size
fh, fw = h // gh, w // gw
u = BlockImages()(u, patch_size=(fh, fw))

Note that this code is very similar to the pseudo-code written here. grid_size is passed as a parameter, but h and whave to be inferred from the image dimensions (which in case of (None, None, 3)), they are None tensors. Thus can't be used in the einops operations, and the way I found to overcome this was to rewrite the operations in tf.

We can use the block [3,2] to compute the green part of the image (grid blocking with grid_size=[3,2]) this way:

In the example shown in the image, we have that image size is [6,4]. Thus to split it with a grid_size of [2,2], we can do:

gh, gw = (2, 2)
h, w = (6,4) # image dimensions
fh, fw = h // gh, w // gw # Note that fh = 3, and fw = 2
block_image = BlockImages()(image_from_the_piture, patch_size=(fh,fw)) # patch_size=(3,2)

The above code snippet implements the green part of the image, and is very similar to what we described first.

In case with the TFBlockByGrid(), we can simply do:

gh, gw = (2,2)
block_image_using_tfblockByGrid = TFBlockByGrid()(image_from_the_picture, grid_size=(gh,gw))

and block_image should be equivalent to block_image_using_tfblockByGrid, as asserted by this test

I am not sure if this answer what you asked, though. Let me know.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

So, TFBlockByGrid() becomes more idiomatic in that sense. We want to have grid sizes of (2, 2) in the output so, directly pass that as an argument. Correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly!

Copy link
Owner

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your hard work. Left a couple of comments.

@@ -66,7 +63,7 @@ def apply(x):
)(y)
y = layers.Dropout(dropout_rate)(y)
x = x + y
x = UnblockImages()(x, grid_size=(gh, gw), patch_size=(fh, fw))
x = TFUnblockImages()(x, grid_size=(gh, gw), patch_size=(ph, pw))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. You're changing the semanticity of the code. Could you please elaborate why?

Reading this change and also previous x, ph, pw = TFBlockImagesByGrid()(x, grid_size=(gh, gw)) and comparing them to their previous versions -- they don't read the same too.

dim_u = K.int_shape(u)[-3]
ghu, gwu = grid_size
u, phu, pwu = TFBlockImagesByGrid()(u, grid_size=(ghu, gwu))
dim_u = ghu * gwu
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the rationale in the comment.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, advisable not to change the original variable names here and elsewhere.

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable names will be recovered on next push. Did it only for readability (since they get rewritten a few lines below)

If i understood correclty, you are asking why we can substitute K.int_shape(u)[-3] for (gh * gw):

From BlockImages(), we have that the output's shape is "b (gh gw) (fh fw) c". Thus, since:
dim_u = K.int_shape(u)[-3]
dim_u = (gh * gw)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason why fh and fw are getting replaced by gh and gw here?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, those transformations are the same:

def same_operations(random_image, grid_size=(gh,gw)):
    b, h, w, c = random_image.shape
    image_blocked_by_grid = BlockByGrid(random_image, grid_size=(gh, gw))
    image_blocked_by_block = BlockByPatch(random_image, patch_size=(h // gh, w // gw)
    image_blocked_by_grid == image_blocked_by_block # this should be True.

we have this pseudo-code as a test here. Note that BlockImages() used in the test corresponds to the original einops implementation.

maxim/layers.py Outdated
Comment on lines 16 to 28
def call(self, image, patch_size):
bs, h, w, num_channels = (tf.shape(image)[0], tf.shape(image)[1], tf.shape(image)[2], tf.shape(image)[3])
ph, pw = patch_size
gh = h // ph
gw = w // pw
pad = [[0, 0], [0, 0]]
patches = tf.space_to_batch_nd(image, [ph, pw], pad)
patches = tf.split(patches, ph * pw, axis=0)
patches = tf.stack(patches, 3) # (bs, h/p, h/p, p*p, 3)
patches_dim = tf.shape(patches)
patches = tf.reshape(patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1])
patches = tf.reshape(patches, (patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels))
return [patches, gh, gw]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly not sure why we are getting rid of einops. This is significantly more lines of code and also more complex to read.

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, using einops would be in hand. But please, consider this code snippet:

img = tf.random.uniform((1, 4, 4, 1))
block_img = einops.rearrange(img, 'b (gh fh) (gw fw) c -> b (gh gw) (fh fw) c', fh=2, fw=2) # this should work fine
block_img_with_tensors_as_arguments = einops.rearrange(img, 'b (gh fh) (gw fw) c -> b (gh gw) (fh fw) c', fh=tf.constant([2]), fw=tf.constant([2])) # this breaks.

The problem with einops is that it expects int literals as argument to the symbols used in the pattern string. I could not make it work using tensors as shown by the example above. At some stages of the model (here, here, here), the split is computed in online fashion, thus relying on tensors (for the case where the img size is None). Thus it was necessary to rewrite using tensorflow.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it wasn't a problem with the current version of the code. What changed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current version of the code informs the image dimension beforehand, thus when you do:

n, h, w, num_channels = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )

you have the integer literals we need for the einops operations. However, In case when we feed (None, None, 3) as input, h and w cannot be used for computing direct literals for the einops operations, and they have to be represented as tensorflow placeholders (None tensor).

maxim/layers.py Outdated
Comment on lines 17 to 28
bs, h, w, num_channels = (tf.shape(image)[0], tf.shape(image)[1], tf.shape(image)[2], tf.shape(image)[3])
ph, pw = patch_size
gh = h // ph
gw = w // pw
pad = [[0, 0], [0, 0]]
patches = tf.space_to_batch_nd(image, [ph, pw], pad)
patches = tf.split(patches, ph * pw, axis=0)
patches = tf.stack(patches, 3) # (bs, h/p, h/p, p*p, 3)
patches_dim = tf.shape(patches)
patches = tf.reshape(patches, [patches_dim[0], patches_dim[1], patches_dim[2], -1])
patches = tf.reshape(patches, (patches_dim[0], patches_dim[1] * patches_dim[2], ph * pw, num_channels))
return [patches, gh, gw]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the line number you're using for Black formatting? The line-numbers seem long and should be formatted accordingly.

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where I work we use 122, I am reformatting with 88 (black's default, IIRC).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

80 is the default. You can bump it to 90 (which is what I used).

@@ -76,28 +100,60 @@ def get_config(self):


@tf.keras.utils.register_keras_serializable("maxim")
class Resizing(layers.Layer):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the need to segregate this to Up and Down?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it easier to read, but indeed it adds a chunk of code. Reformatting to use a single layer only.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's easier to read, I would consider adding an elaborate comment in the script so that readers are aware.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think now is better (with a single resizing layer).

maxim/layers.py Outdated
return tf.image.resize(
x,
size=(self.height, self.width),
def __call__(self, img):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer call() or is there anything I am missing out on?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, fixing...

maxim/maxim.py Show resolved Hide resolved
maxim/maxim.py Outdated
layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
)
ConvT_up = functools.partial(layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same")
Conv_down = functools.partial(layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same")


def MAXIM(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the main change here needed to facilitate (None, None, 3) input tensors?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the resize layers and piping ratio to them.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resizing layers were there previously too. Do you mean having ratio instead of separate height and width was the key change?

Copy link
Author

@LeviVasconcelos LeviVasconcelos Dec 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To this specific file (maxim.py) yes, for the whole PR no.

The key changes were to rewrite the einops operations in pure TF.

Another key change was to make dim_u and dim_v independent of the input image size. Note that, in this PR, dim_u and dim_v are computed from grid_size and block_size which are passed as parameters, as compared to the current version (links: dim_u, dim_v) which rely on on-the-fly computation based on the input image size.

A last change was to plug in the resizing layers with the correct ratio. On your branch feat/dynamic_shape, there's a little bug: when you pass the ratio for the upsampling layers here. This casting to int() is premature, since some of those values are supposed to be < 0, and casting to int will project them to 0.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This casting to int() is premature, since some of those values are supposed to be < 0, and casting to int will project them to 0.

Would appreciate an example.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider this and this lines of code from the feat/dynamic-shape branch.

If you have j > i on the second one, you get a below zero fraction which will be projected to zero.
Same goes for the same line: whenever (depth - j - 1 - i) is positive, you have a below zero ratio.
If I recall correctly, the second case (depth - j - 1 - i) was yielding 0.25, 0.5, and so on. Thus I changed it.

The way i did to overcome this, was to actually pass the float number, and compute the new desired image size: img_size * ratio. And just after that converting back to int.

Copy link
Owner

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all your explanation. I truly appreciate the hard work here.

A couple more comments.

The immediate next step could be to verify the actual outputs of the reworked models. Let's plan on that.

@LeviVasconcelos
Copy link
Author

Please let me know if anything remains unclear.

Best,

Copy link
Owner

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good.

@sayakpaul
Copy link
Owner

@LeviVasconcelos

Things are looking good. Thanks so much for your hard work. I left a couple more questions.

I would suggest doing the following:

@sayakpaul
Copy link
Owner

@LeviVasconcelos a friendly ping :)

@LeviVasconcelos
Copy link
Author

I was in vacation this last week, thus the late reply.

What do you think of a quick call this or next week? I have a couple questions that I think would be quickly answered in a ~10 mins call.

Let me know what you think...

@sayakpaul
Copy link
Owner

Would prefer chatting via email as I will be busy next week.

@LeviVasconcelos
Copy link
Author

Just a friendly heads up: i will be very busy until jan 20th. Afterward i should start working on it.

Best,

list changes done to achieve arbitrary image shapes.
@danwexler
Copy link

Great to see the update. I'm eager to test out this code. My hope is to get it working in TFJS. I'm guessing that may require the implementation of a few operations, similar to what was done to get it working for arbitrary resolutions? Any tips or suggestions appreciated.

@sayakpaul
Copy link
Owner

Great to know it however I don't about TFJS :(

Maybe reach out to Jason Meyes?

@LeviVasconcelos
Copy link
Author

LeviVasconcelos commented Jul 2, 2023

Hi @sayakpaul ,

sorry for the delay, life got in the way =/.

@LeviVasconcelos

Things are looking good. Thanks so much for your hard work. I left a couple more questions.

I would suggest doing the following:

* Including a detailed summary of changes needed to make this work in the README.

I uploaded the README file explaining in details the changes done.

* Running the [conversion script](https://github.com/sayakpaul/maxim-tf/blob/main/convert_to_tf.py) for each model and verifying their outputs.

I ran convert_to_tf for 5 different models, using as checkpoints the models provided in gs://gresearch/maxim/ckpt . The results of run_eval.py for each model can be found here.

I also modified run_eval.py by removing the dynamic_resize flag.

* Creating PRs to each of the MAXIM model repositories (Google) here: https://huggingface.co/models?search=maxim

Should I create the PRs right away? Or should we merge this first?

@LeviVasconcelos
Copy link
Author

@sayakpaul friendly ping.

@sayakpaul
Copy link
Owner

Hey thanks for your hardwork!

Could you be so kind to remind me about this again in maybe 1 week? A little busy right now.

@LeviVasconcelos
Copy link
Author

@sayakpaul pinging as requested ;)

@rogeriofonteles
Copy link

Hey guys, any update on this?

@LeviVasconcelos
Copy link
Author

Gently pinging @sayakpaul here.

@fcmr
Copy link

fcmr commented Sep 28, 2023

Hi! Any news on this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants