-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 917bec2
Showing
38 changed files
with
4,248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
env/ | ||
wandb/ | ||
**/*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2021 Toyota Research Institute (TRI) | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# **SimNet**: Enabling Robust Unknown Object Manipulation from Pure Synthetic Data via Stereo | ||
[Thomas Kollar](mailto:[email protected]), [Michael Laskey](mailto:[email protected]), [Kevin Stone](mailto:[email protected]), [Brijen Thananjeyan](mailto:[email protected]), [Mark Tjersland](mailto:[email protected]) | ||
<a href="https://www.tri.global/" target="_blank"> | ||
<img align="right" src="/media/tri-logo.png" width="20%"/> | ||
</a> | ||
|
||
[**paper**](https://arxiv.org/abs/2106.16118) / [**project site**](https://sites.google.com/view/simnet-corl-2021) / **blog** | ||
|
||
<img width="90%" src="/media/model.png"/> | ||
|
||
This repo contains the code to train the SimNet architecture on procedurally generated simulation | ||
data from scratch (no transfer learning required). We also provide a small set of in-house | ||
manually labelled validation data containing 3d oriented bounding box labels. | ||
|
||
|
||
## Training the model | ||
|
||
### Requirements | ||
|
||
You will need a Nvidia GPU with at least 12GB of RAM. All code was tested and developed on Ubuntu | ||
20.04. | ||
|
||
All commands are assumed to be run from the root of the `simnet` repo directory (represented by | ||
`$SIMNET_REPO` in commands below). | ||
|
||
### Setup | ||
|
||
#### Python | ||
Create a python 3.8 virtual environment and install requirements: | ||
|
||
```bash | ||
cd $SIMNET_REPO | ||
conda create -y --prefix ./env python=3.8 | ||
./env/bin/python -m pip install --upgrade pip | ||
./env/bin/python -m pip install -r frozen_requirements.txt | ||
``` | ||
|
||
#### Docker | ||
Make sure docker is installed and working without requiring `sudo`. If it is not installed, follow | ||
the [official instructions](https://docs.docker.com/engine/install/) for setting it up. | ||
```bash | ||
docker ps | ||
``` | ||
|
||
#### Wandb | ||
|
||
Launch `wandb` local server for logging training results (*you do not need to do this if you already have a wandb | ||
account setup*). This will launch a local webserver [http://localhost:8080](http://localhost:8080) using docker that you | ||
can use to visualize training progress and validation images. You will have to visit the | ||
[http://localhost:8080/authorize](http://localhost:8080/authorize) page to get the local API access token (this can | ||
take a few minutes the first time). Once you get the key you can paste it into the terminal to continue. | ||
|
||
```bash | ||
cd $SIMNET_REPO | ||
./env/bin/wandb local | ||
``` | ||
|
||
|
||
#### Datasets | ||
|
||
Download and untar train+val datasets | ||
[simnet2021a.tar](https://tri-robotics-public.s3.amazonaws.com/github/simnet/datasets/simnet2021a.tar) | ||
(18GB, md5 checksum:`b8e1d3cb7200b44b1de223e87141f14b`). This file contains all the training and | ||
validation you need to replicate our small objects results. | ||
|
||
```bash | ||
cd $SIMNET_REPO | ||
wget https://tri-robotics-public.s3.amazonaws.com/github/simnet/datasets/simnet2021a.tar -P datasets | ||
tar xf datasets/simnet2021a.tar -C datasets | ||
``` | ||
|
||
### Train and Validate | ||
|
||
Overfit test: | ||
```bash | ||
./runner.sh net_train.py @config/net_config_overfit.txt | ||
``` | ||
|
||
Full training run (requires 12GB GPU memory) | ||
```bash | ||
./runner.sh net_train.py @config/net_config.txt | ||
``` | ||
|
||
#### Results | ||
|
||
Check wandb ([http://localhost:8080](http://localhost:8080)) to see training progress. On a Titan V, it takes about 48 | ||
hours for training to converge, but decent validation results can be seen around 24 hours. | ||
|
||
Example validation image visualization: | ||
<img width="100%" src="/media/wandb_example_val_images.png"/> | ||
|
||
Example 3D oriented bounding box mAP on validation dataset: | ||
<img width="50%" src="/media/wandb_example_3dmap.png"/> | ||
|
||
|
||
## Licenses | ||
|
||
The source code is released under the MIT license. | ||
|
||
The datasets are released under the [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](http://creativecommons.org/licenses/by-nc-sa/4.0/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.ckpt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
--max_steps=400000 | ||
--model_file=models/panoptic_net.py | ||
--model_name=res_fpn | ||
--output=ckpts | ||
--train_path=file://datasets/simnet2021a/train | ||
--train_batch_size=12 | ||
--train_num_workers=4 | ||
--val_path=file://datasets/simnet2021a/val | ||
--val_batch_size=1 | ||
--val_num_workers=10 | ||
--optim_learning_rate=0.0006 | ||
--optim_momentum=0.9 | ||
--optim_weight_decay=1e-4 | ||
--optim_poly_exp=0.9 | ||
--optim_warmup_epochs=1 | ||
--loss_seg_mult=1.0 | ||
--loss_depth_mult=1.0 | ||
--loss_depth_refine_mult=1.0 | ||
--loss_vertex_mult=0.1 | ||
--loss_rotation_mult=0.1 | ||
--loss_heatmap_mult=100.0 | ||
--loss_z_centroid_mult=0.1 | ||
--wandb_name=simnet-github |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
--max_steps=4000 | ||
--model_file=models/panoptic_net.py | ||
--model_name=res_fpn | ||
--output=ckpts | ||
--train_path=file://datasets/simnet2021a/train?samples=400 | ||
--train_batch_size=1 | ||
--train_num_workers=4 | ||
--val_path=file://datasets/simnet2021a/train?samples=1 | ||
--val_batch_size=1 | ||
--val_num_workers=4 | ||
--optim_learning_rate=0.0006 | ||
--optim_momentum=0.9 | ||
--optim_weight_decay=1e-4 | ||
--optim_poly_exp=0.9 | ||
--optim_warmup_epochs=1 | ||
--loss_seg_mult=1.0 | ||
--loss_depth_mult=1.0 | ||
--loss_depth_refine_mult=1.0 | ||
--loss_vertex_mult=0.1 | ||
--loss_rotation_mult=0.1 | ||
--loss_heatmap_mult=100.0 | ||
--loss_z_centroid_mult=0.1 | ||
--wandb_name=simnet-github-overfit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
absl-py==0.13.0 | ||
backcall==0.2.0 | ||
boto3==1.18.1 | ||
botocore==1.21.1 | ||
cachetools==4.2.2 | ||
certifi==2021.5.30 | ||
charset-normalizer==2.0.3 | ||
click==8.0.1 | ||
colour==0.1.5 | ||
configparser==5.0.2 | ||
cycler==0.10.0 | ||
decorator==4.4.2 | ||
docker-pycreds==0.4.0 | ||
future==0.18.2 | ||
gitdb==4.0.7 | ||
GitPython==3.1.18 | ||
google-auth==1.33.0 | ||
google-auth-oauthlib==0.4.4 | ||
greenlet==1.1.0 | ||
grpcio==1.38.1 | ||
idna==3.2 | ||
imageio==2.9.0 | ||
ipython==7.25.0 | ||
ipython-genutils==0.2.0 | ||
jedi==0.18.0 | ||
jmespath==0.10.0 | ||
kiwisolver==1.3.1 | ||
Markdown==3.3.4 | ||
matplotlib==3.4.2 | ||
matplotlib-inline==0.1.2 | ||
msgpack==1.0.2 | ||
networkx==2.5.1 | ||
numpy==1.21.0 | ||
oauthlib==3.1.1 | ||
opencv-python==4.5.3.56 | ||
parso==0.8.2 | ||
pathtools==0.1.2 | ||
pexpect==4.8.0 | ||
pickleshare==0.7.5 | ||
Pillow==8.3.1 | ||
promise==2.3 | ||
prompt-toolkit==3.0.19 | ||
protobuf==3.17.3 | ||
psutil==5.8.0 | ||
ptyprocess==0.7.0 | ||
pyasn1==0.4.8 | ||
pyasn1-modules==0.2.8 | ||
Pygments==2.9.0 | ||
pynvim==0.4.3 | ||
pyparsing==2.4.7 | ||
python-dateutil==2.8.2 | ||
pytorch-lightning==0.7.5 | ||
PyWavelets==1.1.1 | ||
PyYAML==5.4.1 | ||
requests==2.26.0 | ||
requests-oauthlib==1.3.0 | ||
rsa==4.7.2 | ||
s3transfer==0.5.0 | ||
scikit-image==0.18.2 | ||
scipy==1.7.0 | ||
sentry-sdk==1.3.0 | ||
shortuuid==1.0.1 | ||
six==1.16.0 | ||
smmap==4.0.0 | ||
subprocess32==3.5.4 | ||
tensorboard==2.5.0 | ||
tensorboard-data-server==0.6.1 | ||
tensorboard-plugin-wit==1.8.0 | ||
tifffile==2021.7.2 | ||
torch==1.4.0 | ||
torchvision==0.5.0 | ||
tqdm==4.61.2 | ||
traitlets==5.0.5 | ||
urllib3==1.26.6 | ||
wandb==0.11.0 | ||
wcwidth==0.2.5 | ||
Werkzeug==2.0.1 | ||
ydiff==1.2 | ||
zstandard==0.15.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../simnet/lib/datapoint.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../../../simnet/lib/net/post_processing/eval3d.py |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
|
||
os.environ['PYTHONHASHSEED'] = str(1) | ||
|
||
import argparse | ||
from importlib.machinery import SourceFileLoader | ||
import sys | ||
|
||
import random | ||
|
||
random.seed(12345) | ||
import numpy as np | ||
|
||
np.random.seed(12345) | ||
import torch | ||
|
||
torch.manual_seed(12345) | ||
|
||
import wandb | ||
|
||
import pytorch_lightning as pl | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning import loggers | ||
|
||
from simnet.lib.net import common | ||
from simnet.lib import datapoint, camera | ||
from simnet.lib.net.post_processing.eval3d import Eval3d, extract_objects_from_detections | ||
from simnet.lib.net import panoptic_trainer | ||
|
||
_GPU_TO_USE = 0 | ||
|
||
|
||
class EvalMethod(): | ||
|
||
def __init__(self): | ||
|
||
self.eval_3d = Eval3d() | ||
self.camera_model = camera.FMKCamera() | ||
|
||
def process_sample(self, pose_outputs, box_outputs, seg_outputs, detections_gt, scene_name): | ||
detections = pose_outputs.get_detections(self.camera_model) | ||
if scene_name != 'sim': | ||
table_detection, detections_gt, detections = extract_objects_from_detections( | ||
detections_gt, detections | ||
) | ||
self.eval_3d.process_sample(detections, detections_gt, scene_name) | ||
return True | ||
|
||
def process_all_dataset(self, log): | ||
log['all 3Dmap'] = self.eval_3d.process_all_3D_dataset() | ||
|
||
def draw_detections( | ||
self, pose_outputs, box_outputs, seg_outputs, keypoint_outputs, left_image_np, llog, prefix | ||
): | ||
pose_vis = pose_outputs.get_visualization_img( | ||
np.copy(left_image_np), camera_model=self.camera_model | ||
) | ||
llog[f'{prefix}/pose'] = wandb.Image(pose_vis, caption=prefix) | ||
seg_vis = seg_outputs.get_visualization_img(np.copy(left_image_np)) | ||
llog[f'{prefix}/seg'] = wandb.Image(seg_vis, caption=prefix) | ||
|
||
def reset(self): | ||
self.eval_3d = Eval3d() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(fromfile_prefix_chars='@') | ||
common.add_train_args(parser) | ||
hparams = parser.parse_args() | ||
train_ds = datapoint.make_dataset(hparams.train_path) | ||
samples_per_epoch = len(train_ds.list()) | ||
samples_per_step = hparams.train_batch_size | ||
steps = hparams.max_steps | ||
steps_per_epoch = samples_per_epoch // samples_per_step | ||
epochs = int(np.ceil(steps / steps_per_epoch)) | ||
actual_steps = epochs * steps_per_epoch | ||
print('Samples per epoch', samples_per_epoch) | ||
print('Steps per epoch', steps_per_epoch) | ||
print('Target steps:', steps) | ||
print('Actual steps:', actual_steps) | ||
print('Epochs:', epochs) | ||
|
||
model = panoptic_trainer.PanopticModel(hparams, epochs, train_ds, EvalMethod()) | ||
model_checkpoint = ModelCheckpoint(filepath=hparams.output, save_top_k=-1, period=1, mode='max') | ||
wandb_logger = loggers.WandbLogger(name=hparams.wandb_name, project='simnet') | ||
trainer = pl.Trainer( | ||
max_nb_epochs=epochs, | ||
early_stop_callback=None, | ||
gpus=[_GPU_TO_USE], | ||
checkpoint_callback=model_checkpoint, | ||
#val_check_interval=0.7, | ||
check_val_every_n_epoch=1, | ||
logger=wandb_logger, | ||
default_save_path=hparams.output, | ||
use_amp=False, | ||
print_nan_grads=True | ||
) | ||
trainer.fit(model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
torch==1.4.0 | ||
torchvision==0.5.0 | ||
pytorch_lightning==0.7.5 | ||
opencv-python | ||
shortuuid | ||
boto3 | ||
zstandard | ||
scikit-image | ||
colour | ||
ipython | ||
wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/bin/bash | ||
|
||
set -Eeuo pipefail | ||
|
||
SCRIPT_DIR=$(dirname $(readlink -f $0)) | ||
export PYTHONPATH=$(readlink -f "${SCRIPT_DIR}") | ||
export OPENBLAS_NUM_THREADS=1 | ||
|
||
$SCRIPT_DIR/env/bin/python $@ |
Oops, something went wrong.