-
Notifications
You must be signed in to change notification settings - Fork 860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Refinement to Improve High Resolution Image Inpainting #112
Conversation
Wow! That's extremely cool! Can't wait to try! |
Wow folks! that is really impressive! |
Just curious, have you tried your approach with other methods, e.g. MAT or ZITS? Is the improvement for them is like that for LaMa? I'm also amazed by the fact that such an optimization technique does not introduce high-frequency artifacts... How do you think, why is it so? |
Hi, we haven't tried it with other methods due to our limited bandwidth. We started with LaMa, and then during evaluation we found that without refinement, LaMa still generalizes better than the newer methods on high resolution inpainting. So we didn't have a very strong motivation to try this on the other methods. We'd love to see how it works on them though. When applying the refinement, we're basically asking the network to find a high-resolution featuremap that produces an output that, when downscaled, looks like the low-resolution output. We hypothesize that when this featuremap is found, it contains high-level encoded information learned from training about the contents in that region. For example, optimization may adjust a feature that was "kind-of-brick-like" to become "very-red-brick-like." The optimized high level features then gets decoded into low and high frequency brick-like textures. So, it's possible that these latent features of the downscaler control low and high frequency details jointly. We also tried to optimize the latent features of the upscaler, but it didn't work well, and produced cloud-like artifacts. So it probably has something to do with large and overlapping receptive fields of pixels of the featuremap. |
Thank you for the clarification!
Yeah, but in lower resolution these methods are stronger than LaMa - so they might probably benefit more from your method.
Yes, lo-freq and high-freq details do not seem to be disentangled in the features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd kindly suggest simplifying the usage and making default config values more sensible for broader set of environments (e.g. single-gpu). Anyway, this is a great contribution!
batch['mask'] = (batch['mask'] > 0) * 1 | ||
batch = model(batch) | ||
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0).detach().cpu().numpy() | ||
unpad_to_size = batch.get('unpad_to_size', None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L84-87 should be outside if-else - they need to be executed regardless refinement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L84-87 already get addressed inside the refiner. Refiner works on unpadded images (it does the necesssary padding internally and then unpads the output appropriately). We can:
- add an assertion to check
unpad_to_size
is not None - enable refiner to just run on padded image, if
unpad_to_size
isNone
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see. I'd move padding-unpadding from the refiner to predict.py - so both parts of the code are simplified and no logic duplication is introduced. What do you think, is it possible and does it make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can let the refiner get padded input. But refiner still needs some padding in place. Because -
Suppose your input image is a square of size 1000. Then the original image isn't padded because 1000%8==0
, but in the refiner, once we downscale the image, it's size becomes 500, and 500%8!=0
. So we have to pad it to make it 504x504.
So we can't get rid of lines 301 and 302 in refinement.py
, but we can:
- let the padded image to be input to refiner, so that we take L84-87 outside the if-else.
- refiner then doesn't check
unpad_to_size
argument at all. - Padding would happen in the refiner to ensure downscaled image size is divisible by 8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
refiner still needs some padding in place
I see, thank you for the clarification! Let's just leave that piece of the as is - and add a comment about "padding-unpadding is handled within refiner"
image size is divisible by 8.
Padding size depends on depth of the generator and thus needs to be configurable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha, will add the comment. Yeah the padding size of the refiner is not exactly 8, but exactly equal to dataset.pad_out_to_modulo in the predict config. I'll add a comment there in the PR
|
||
refine: False # refiner will only run if this is True | ||
refiner: | ||
gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0," |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest using only 0 by default - or even introduce "None" default (so refiner would rely on the parent device
setting). That would make this work by default in more environments without any modifications by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the refiner needs around 24GB GPU to process 1.8 megapixel images (~1200x1500). Since most people have two 12GB GPUs instead, we decided to split the model onto two GPUs, that's why the default config setting.
Do you suggest to still make it None by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, right, I have not thought about memory consumption. It seems that the most of the consumption comes from storing activations for backward... And you're splitting res-blocks between GPUs to distribute that memory - not to speedup inference - because GPUs are called sequentially.
I have a couple ideas how to overcome that without complex logic or requirement to have two GPUs:
- Set
param.requires_grad_(False)
for all parameters in the generator - that will lead to storing only activations, not gradients for parameters. - Use activation checkpointing - it does something very similar to what you're doing - it splits a nn.Sequential in multiple chunks and runs each chunk with torch.no_grad - so only activations between chunks have to be stored. That will slow the optimization down, but maybe not severely.
- torch.cuda.amp - optimize in fp16 instead of fp32. In case of refinement there is no adversarial training, so there should not be stability issues due to reduced precision (but I'm not sure)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ideas! param.requires_grad_
is already set to False
, since we freeze the model here: https://github.com/geomagical/lama-with-refiner/blob/24a20f804390c6ab969c28abbe999c940f8d6a56/bin/predict.py#L58
I also manually verified the requires_grad for all the params of the model, they were False
.
We were already looking at activation checkpointing, will focus on it more now that you have also mentioned it. Will try the third idea also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.cuda.amp isn't working because pytorch doesn't seem to support Half
dtype for torch.fft.rfftn
. PFA link to the relevant issues in Pytorch repo:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also manually verified the requires_grad for all the params of the model
Great, thank you!
torch.cuda.amp isn't working because pytorch doesn't seem to support Half
Sure, I've forgot that I've already tried half and failed because of that... We could wrap rfftn/irfftn with conversion to and from .float(), but I'm not sure there wouldn't be other issues..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @windj007, sorry for coming back after 2 months! We picked up the experiments, our findings:
- We were able to perform the optimization in mixed precision. I haven't benchmarked it quantitatively, but qualitative results look good. However, for 1024x1024 images, it only reduces the memory from
21-22GB -> 17-18GB
, so it is still not sufficient to fit on a single 12GB GPU - We also tried to play with checkpointing. Performing it naively throws
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
, which we bypass by settinguse_reentrant=False
. However, this setting has some memory leak problem, which causes the GPU consumption to increase at each training loop, eventually leading to OOM error. We plan to raise this issue on the Pytorch repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please share your code for mixed-precision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, it's in amp_float16
branch of our fork:
https://github.com/geomagical/lama-with-refiner/tree/amp_float16
You can get this code by:
git clone [email protected]:geomagical/lama-with-refiner.git
git checkout amp_float16
Also, I've changed the config file of the refiner to run on a single GPU. But yeah feel free to play around with config parameters or anything :)
Link to the config file in the code: https://github.com/geomagical/lama-with-refiner/blob/amp_float16/configs/prediction/default.yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much 🙏
Thanks for your feedback, appreciate it! We are very much interested in addressing your concerns until you're comfortable enough to merge this PR into your code. The usage complication is primarily because we try to fit the refinement on multiple devices. Refinement requires at least 24GB GPU to run on images of sizes like 1024x1024. |
refine: False # refiner will only run if this is True | ||
refiner: | ||
gpu_ids: 0,1 # the GPU ids of the machine to use. If only single GPU, use: "0," | ||
modulo: ${dataset.pad_out_to_modulo} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@windj007 refiner padding is defined here
@senya-ashukha that works too! :) |
Hello, Thanks for your shout out on the main README. Let us know if you would like us to close this PR. |
Sounds good, thanks a lot! Happy to follow up on any further review comments regarding code/formatting/ideas or anything :) |
is there anyway i can make this work on a gpu of 11GB? or run it on CPU? |
@mhashas You can try using our mixed precision branch: #112 (comment) We can probably have a CPU version, but it will be super slow. If you are okay with a slow CPU version, I think replacing the device with torch.device('cpu') in predict.py and refinement.py should work. |
Feature Refinement to Improve High Resolution Image Inpainting
Can I separate out the Feature Refinement to Improve High Resolution Image Inpainting technique, just for a single image? |
Feature Refinement to Improve High Resolution Image Inpainting
We are a team of researchers at Geomagical Labs (geomagical.com), a subsidiary of IKEA. We work on pioneering Mixed Reality apps which allow customers to scan photorealistic models of their indoor spaces and re-imagine them with virtual furniture.
In this PR we propose an additional refinement step for LaMa to improve high-resolution inpainting results. We observed that when inpainting large regions at high resolution, LaMa struggles at structure completion. However, at low resolutions, LaMa can infill the same missing region much better. To address this we added an additional refinement step that uses the structure from low resolution predictions to guide higher resolution predictions.
Our approach can work on any inpainting network, and does not require any additional training or network modification.
How to run refinement
To run refinement, simply pass
refine=True
in the evaluation step as:Evaluation
Here's a few example comparisons, with each triplet showing the masked image, inpainting with LaMa, and inpainting with LaMa using refinement:
Comparison of unrefined and refined images on all test images (kindly shared by you) is available here: https://drive.google.com/drive/folders/15LEa9k_7-dUKb2CPUDuw7e6Zk28KCtzz?usp=sharing
We also performed some numerical evaluation on 1024x1024 size images sampled from [1], using the thin, medium, and thick masks. Results indicate that LaMa+refinement outperforms all the recent inpainting baselines on high resultion inpainting:
Table 1. Performance comparison of various recent inpainting approaches on 1k 1024x1024 size images
Video
We have also created a video to explain the technical details of our approach:
https://www.youtube.com/watch?v=gEukhOheWgE
References
[1]
Unsplash Dataset. https://unsplash.com/data, 2020
[2]
Suvorov, R., Logacheva, E., Mashikhin, A., Remizova, A., Ashukha, A., Silvestrov, A., Kong, N., Goka, H., Park, K. and Lempitsky, V., 2022. Resolution-robust Large Mask Inpainting with Fourier Convolutions. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (pp. 2149-2159)
[3]
Zeng, Y., Fu, J., Chao, H. and Guo, B., 2022. Aggregated contextual transformations for high-resolution image inpainting. IEEE Transactions on Visualization and Computer Graphics.
[4]
Rombach, R., Blattmann, A., Lorenz, D., Esser, P. and Ommer, B., 2021. High-Resolution Image Synthesis with Latent Diffusion Models. arXiv preprint arXiv:2112.10752.
[5]
Dong, Q., Cao, C. and Fu, Y., 2022. Incremental Transformer Structure Enhanced Image Inpainting with Masking Positional Encoding. arXiv preprint arXiv:2203.00867.
[6]
Li, W., Lin, Z., Zhou, K., Qi, L., Wang, Y. and Jia, J., 2022. MAT: Mask-Aware Transformer for Large Hole Image Inpainting. arXiv preprint arXiv:2203.15270.