-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RAFT model for optical flow #5022
Conversation
💊 CI failures summary and remediationsAs of commit f077d7c (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are my initial comments @fmassa @datumbox @haooooooqi !
@@ -0,0 +1 @@ | |||
from ._raft.raft import RAFT, raft, raft_small |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should expose the building blocks as well, stuff like ResidualBlock
, FeatureEncoder
, etc.
Should we expose these in a models.optical_flow.raft
namespace?
I remember @datumbox mentioning a few issues when we have both a module and a function with the same name. Ideally, I'd like to keep the raft()
name for the function builder, but I'm happy to get your thoughts on this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking:
See @pmeier's #4867 (comment) around this. We might need to rename this to raft_large
similar to mobilenet_v3_large
and mobilenet_v3_small
.
Question:
Can you talk about the choice to create a second private submodule _raft
? Why is that?
FYI:
We should expose the building blocks as well, stuff like ResidualBlock, FeatureEncoder, etc.
If you check existing models, such as resnet and faster_rcnn you will see we don't expose these on the __all__
or anywhere else. We had discussions about what this means but we didn't have a consensus (public API vs developer API discussion). I would be in favour of not exposing these publicly to be consistent with everywhere else. We should discuss and resolve this once we have time in the roadmap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback
Can you talk about the choice to create a second private submodule _raft? Why is that?
I wrote it like this out of habit but I had no intention of keeping it this way
Instead of renaming raft()
to raft_large()
, would it be OK instead to rename the .py
file to something like raft_core.py
? Or raft_implem.py
?
We would have raft()
, raft_small()
and RAFT
available from torchvision.models.optical_flow
and then we would be able to access the other building blocks like ResidualBlock
in torchvision.models.optical_flow.raft_core
-- I would make sure to exclude them from __all__
.
Would that be OK?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though calling it raft_large
would make it consistent with other small/large models, I can see that on the paper they don't actually call it large. They do use Raft-s
or "small" for the small version, so that's canonical at least. So I understand why you want to avoid naming it like that.
Renaming the file should be OK but @pmeier should confirm. Though I'm not sure what the actual name should be. We typically name the files after their algorithm and we let the model builder have extra info on the variant.
@fmassa thoughts on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I'll use raft.py
and raft_large()
for now to move forward. We can revisit later if needed, thanks for the input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not commenting on the name, because I don't have an opinion on that. We just need to make sure that we don't have an attribute with the same name as a module in the same namespace. Otherwise we can no longer access the module. Thus having raft.py
and def raft()
is problematic, but raft.py
and raft_large()
is fine.
|
||
|
||
class ResidualBlock(nn.Module): | ||
# This is pretty similar to resnet.BasicBlock except for one call to relu, and the bias terms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably merge this with the resnet implementation (same for BottleneckBlock below)... but this might make the API awkward and bloated. For now I'd say it's simpler and safer to keep these implementations separate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: Agreed. One more reason why we shouldn't expose this on the __all__
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep this here as is
self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) | ||
self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to use a ConvNormActivation
while passing norm_layer=None
? No strong opinion from me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: It's up to you. The class was recently updated to support it but I checked and so far all the uses of ConvNormActivation
currently pass a non-None value. So you would be the first to use this idiom. It saves you a few lines of code, but not too much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm ok with this.
flow_head_hidden_size, | ||
# Mask predictor | ||
use_mask_predictor, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs are here to override the RAFT
class parameters, e.g. to override the entire feature_encoder
, or the whole update_block
.
|
||
|
||
class FeatureEncoder(nn.Module): | ||
def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_layer=nn.BatchNorm2d): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I named this layers
to be sort of consistent with the ResNet
class. I'm happy to consider other names though
|
||
|
||
class RecurrentBlock(nn.Module): | ||
def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where possible I tried to write defaults that will correspond to the normal RAFT model (kernel_size, padding). I refrained from doing that for the input and output shapes though. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: I don't know how much "standard" these values are. If you think multiple input_sizes
and hidden_sizes
will keep using the same kernels, then that's OK. From other models I think that's reasonable.
|
||
|
||
class MaskPredictor(nn.Module): | ||
def __init__(self, *, in_channels, hidden_size, multiplier=0.25): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm keeping the 0.25 default for consistency with the original code, but I'm tempted to set it to 1 to encourage users not to use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: Given that TorchVision is supposed to follow very closely the papers, using the 0.25 aligns with this principle. It will also allow you to port weights easier. Personally I'm fine with how you implement and document it. I would also recommend writing a blogpost about your implementation (similar to 1, 2) as this will allow you to discuss such details more thoroughly and help new joiners understand the implementation.
# As in the original paper, the actual output of the context encoder is split in 2 parts: | ||
# - one part is used to initialize the hidden state of the reccurent units of the update block | ||
# - the rest is the "actual" context. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit that I did not see mentioned in the paper. I wish we could separate the initialization of the hidden state from the context encoder, but this would likely prevent us from exactly reproduce the original implementation.
|
||
batch_size, _, h, w = image1.shape | ||
torch._assert((h, w) == image2.shape[-2:], "input images should have the same shape") | ||
torch._assert((h % 8 == 0) and (w % 8 == 0), "input image H and W should be divisible by 8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "8" downscaling factor is hard-coded in different places. We should be able to generalize it in future versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: other models have similar requirements. An example is MobileNets which use the _make_divisible()
method located here. You could create a similar helper method and generalize but that's a NIT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug Great work.
I'm reviewing code practices and API, not ML validity as this is something you checked already with @haooooooqi.
I clearly marked my comments as FYI (which are just for discussion), Questions (where I just want more info), NIT (which are non-blocking and could optionally be addressed on follow up PRs) and Blocking (which I think need to be done here).
My Blocking comments are minimal as you can see below and can be addressed easily. After addressing them, ping me to approve but I would wait for the review of @fmassa prior merging.
@@ -0,0 +1 @@ | |||
from ._raft.raft import RAFT, raft, raft_small |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Blocking:
See @pmeier's #4867 (comment) around this. We might need to rename this to raft_large
similar to mobilenet_v3_large
and mobilenet_v3_small
.
Question:
Can you talk about the choice to create a second private submodule _raft
? Why is that?
FYI:
We should expose the building blocks as well, stuff like ResidualBlock, FeatureEncoder, etc.
If you check existing models, such as resnet and faster_rcnn you will see we don't expose these on the __all__
or anywhere else. We had discussions about what this means but we didn't have a consensus (public API vs developer API discussion). I would be in favour of not exposing these publicly to be consistent with everywhere else. We should discuss and resolve this once we have time in the roadmap.
|
||
|
||
class ResidualBlock(nn.Module): | ||
# This is pretty similar to resnet.BasicBlock except for one call to relu, and the bias terms |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: Agreed. One more reason why we shouldn't expose this on the __all__
.
|
||
|
||
class MaskPredictor(nn.Module): | ||
def __init__(self, *, in_channels, hidden_size, multiplier=0.25): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: Given that TorchVision is supposed to follow very closely the papers, using the 0.25 aligns with this principle. It will also allow you to port weights easier. Personally I'm fine with how you implement and document it. I would also recommend writing a blogpost about your implementation (similar to 1, 2) as this will allow you to discuss such details more thoroughly and help new joiners understand the implementation.
|
||
batch_size, _, h, w = image1.shape | ||
torch._assert((h, w) == image2.shape[-2:], "input images should have the same shape") | ||
torch._assert((h % 8 == 0) and (w % 8 == 0), "input image H and W should be divisible by 8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: other models have similar requirements. An example is MobileNets which use the _make_divisible()
method located here. You could create a similar helper method and generalize but that's a NIT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @NicolasHug !
Approving to unblock, but check @datumbox blocking comments as well before merging.
|
||
|
||
def grid_sample(img, absolute_grid, *args, **kwargs): | ||
"""Same as torch's grid_sample, with absolute pixel coordinates instead of normalized coordinates.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Food for thought: would it be a significant limitation if we were to return the flows in normalized coordinates, and perform all computations in normalized coordinates?
Or is it standard that flow images are in absolute coordinates?
Some of the benefits of keeping it in relative coordinates is that you don't need to multiply by the scaling factor when upsampling an image.
def raft_large(*, pretrained=False, progress=True, **kwargs): | ||
|
||
if pretrained: | ||
raise NotImplementedError("Pretrained weights aren't available yet") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps using NotImplementedError
is indeed the right exception here but I would recommend for now to throw:
raise NotImplementedError("Pretrained weights aren't available yet") | |
raise ValueError(f"No checkpoint is available for model.") |
This is because some tests check for this specific exception:
vision/test/test_prototype_models.py
Lines 45 to 46 in 3d8723d
if "No checkpoint is available" in msg: | |
pytest.skip(msg) |
@@ -818,5 +818,31 @@ def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_load | |||
assert n_trainable_params == _model_tests_values[model_name]["n_trn_params_per_layer"] | |||
|
|||
|
|||
@needs_cuda | |||
@pytest.mark.parametrize("model_builder", (models.optical_flow.raft_large, models.optical_flow.raft_small)) | |||
@pytest.mark.parametrize("scripted", (False, True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm parametrizing over this because testing with _check_jit_scriptable
unfortunately fails on very few entries, e.g.:
Mismatched elements: 153 / 11520 (1.3%)
Greatest absolute difference: 0.0002608299255371094 at index (0, 0, 79, 45) (up to 0.0001 allowed)
Greatest relative difference: 0.021354377198448304 at index (0, 1, 53, 68) (up to 0.0001 allowed)
I could add tol
parameters to the check, but I feel like this current test is just as fine
Test failures are unrelated, merging. Thanks a lot for the reviews!! |
Reviewed By: NicolasHug Differential Revision: D32950937 fbshipit-source-id: 7e024dad4c3d55bc832beadfd1b3ffe867f238f3
Towards #4644
This PR adds the RAFT model, along with its basic building blocks (feature encoder, correlation block, update block, etc.) and two model builder function:
raft_large()
andraft_small()
.The architecture (not the code!) is exactly the same as the original implementation from https://github.com/princeton-vl/RAFT, which will allow us support the original paper's weights if we want to. This architecture differs slightly from what is described in the paper. I have annotated the paper with these differences, hoping this can help the review:
RAFT- Recurrent All-Pairs Field Transforms forOptical Flow (1).pdf
A summary is this:
API
The
RAFT
class accepts torch.nn.Module instances as input and offers a low-level API.The model builder functions
raft_large()
andraft_small()
are higher-level and do not require any parameter. They can however take as input the same parameters as theRAFT
class, so as to override their defaults. E.g.:The building blocks like
FeatureEncoder
,ResidualBlock
,UpdateBlock
, etc. are (sort of publicly) available intorchvision.models.optical_flow.raft
, but are not exposed in__all__
.Still left TODO, here or in follow-up PRs:
raft()
andraft_small()
In follow up PRs I will also submit the training reference and associated transforms.
I will write review comments below to hightlight important bits, or things where I'm not to sure what the best way is.
cc @datumbox