This is the code for the ICCV 2019 paper CapsuleVOS: Semi-Supervised Video Object Segmentation Using Capsule Routing.
Arxiv Link: https://arxiv.org/abs/1910.00132
The network is implemented using TensorFlow 1.4.1.
Python packages used: numpy, scipy, scikit-video
- caps_layers_cod.py: Contains the functions required to construct capsule layers - (primary, convolutional, and fully-connected, and conditional capsule routing).
- caps_network_train.py: Contains the CapsuleVOS model for training.
- caps_network_test.py: Contains the CapsuleVOS model for testing.
- caps_main.py: Contains the main function, which is called to train the network.
- config.py: Contains several different hyperparameters used for the network, training, or inference.
- inference.py: Contains the inference code.
- load_youtube_data_multi.py: Contains the training data-generator for YoutubeVOS 2018 dataset.
- load_youtubevalid_data.py: Contains the validation data-generator for YoutubeVOS 2018 dataset.
We have supplied the code for training and inference of the model on the YoutubeVOS-2018 dataset. The file load_youtube_data_multi.py
and load_youtubevalid_data.py
creates two DataLoaders - one for training and one for validation. The data_loc
variable at the top of each file should be set to the base directory which contains the frames and annotations.
To run this code, you need to do the following:
- Download the YoutubeVOS dataset
- Perform interpolation for the training frames following the papers' instructions
Once the data is set up you can train (and test) the network by calling python3 caps_main.py
.
The config.py
file contains several hyper-parameters which are useful for training the network.
During training and testing, metrics are printed to stdout as well as an output*.txt file. During training/validation, the losses and accuracies are printed out to the terminal and to an output file.
Pretrained weights for the network are available here. To use them for inference, place them in the network_saves_best
folder.
If you just want to test the trained model with the weights above, run the inference code by calling python3 inference.py
. This code will read in an .mp4 file and a reference segmentation mask, and output the segmented frames of the video to the Output folder.
An example video is available in the Example folder.