-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Zachary Ravichandran
committed
Jan 17, 2020
0 parents
commit 705afd1
Showing
9 changed files
with
431 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
cmake_minimum_required(VERSION 2.8.3) | ||
project(semantic_segmentation_ros) | ||
|
||
find_package(catkin_simple REQUIRED) | ||
|
||
catkin_python_setup() | ||
catkin_simple() | ||
|
||
cs_install() | ||
cs_export() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
<?xml version="1.0" encoding="ISO-8859-15"?> | ||
<launch> | ||
<arg name="model" default="TRTModel"/> | ||
<arg name="weight_file" default="$(find semantic_segmentation_ros)/cfg/v5.3/v53-unet-resnet18.onnx"/> | ||
|
||
<node name="semantic_segmentation" pkg="semantic_segmentation_ros" type="semantic_segmentation.py" output="screen"> | ||
<param name="model" value="$(arg model)"/> | ||
<param name="weight_file" value="$(arg weight_file)"/> | ||
<remap from="/image/image_raw" to="/tesse/left_cam_rgb/image_raw"/> | ||
</node> | ||
</launch> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
<?xml version="1.0"?> | ||
<package> | ||
<name>semantic_segmentation_ros</name> | ||
<version>0.0.1</version> | ||
<description>Semantic Segmentation in ROS</description> | ||
|
||
<license>MIT</license> | ||
<maintainer email="[email protected]">Zac Ravichandran</maintainer> | ||
<buildtool_depend>catkin</buildtool_depend> | ||
<buildtool_depend>catkin_simple</buildtool_depend> | ||
|
||
<!-- Package build dependencies --> | ||
<build_depend>rospy</build_depend> | ||
<build_depend>cv_bridge</build_depend> | ||
<build_depend>sensor_msgs</build_depend> | ||
|
||
</package> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
################################################################################################### | ||
# Distribution authorized to U.S. Government agencies and their contractors. Other requests for | ||
# this document shall be referred to the MIT Lincoln Laboratory Technology Office. | ||
# | ||
# This material is based upon work supported by the Under Secretary of Defense for Research and | ||
# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions | ||
# or recommendations expressed in this material are those of the author(s) and do not necessarily | ||
# reflect the views of the Under Secretary of Defense for Research and Engineering. | ||
# | ||
|
||
# (c) 2019 Massachusetts Institute of Technology. | ||
# | ||
# The software/firmware is provided to you on an As-Is basis | ||
# | ||
# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 | ||
# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work | ||
# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other | ||
# than as specifically authorized by the U.S. Government may violate any copyrights that exist in | ||
# this work. | ||
################################################################################################### | ||
|
||
from setuptools import setup | ||
|
||
setup( | ||
name="semantic_segmentation_ros", | ||
version='0.0.1', | ||
description="semantifc segmentation in ROS", | ||
packages=["semantic_segmentation_ros"], | ||
package_dir={'': 'src'} | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
################################################################################################### | ||
# Distribution authorized to U.S. Government agencies and their contractors. Other requests for | ||
# this document shall be referred to the MIT Lincoln Laboratory Technology Office. | ||
# | ||
# This material is based upon work supported by the Under Secretary of Defense for Research and | ||
# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions | ||
# or recommendations expressed in this material are those of the author(s) and do not necessarily | ||
# reflect the views of the Under Secretary of Defense for Research and Engineering. | ||
# | ||
|
||
# (c) 2019 Massachusetts Institute of Technology. | ||
# | ||
# The software/firmware is provided to you on an As-Is basis | ||
# | ||
# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 | ||
# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work | ||
# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other | ||
# than as specifically authorized by the U.S. Government may violate any copyrights that exist in | ||
# this work. | ||
################################################################################################### | ||
|
||
import numpy as np | ||
import tensorrt as trt | ||
import time | ||
|
||
import pycuda.autoinit | ||
import pycuda.driver as cuda | ||
|
||
from semantic_segmentation_ros.utils import pad_image, unpad_image | ||
|
||
|
||
def get_model(model, weight_file): | ||
if model == "TRTModel": | ||
return TesseTRTModel(weight_file) | ||
else: | ||
raise ValueError("Currently only TensorRT models are supported") | ||
|
||
|
||
class TesseTRTModel: | ||
def __init__(self, onnx_file_path): | ||
self.trt_model = TRTModel(onnx_file_path) | ||
|
||
def infer(self, image): | ||
image = pad_image(image, 16, 0) | ||
image = image.transpose(2, 0, 1).astype(np.float32) | ||
pred = self.trt_model.infer(image) | ||
pred = unpad_image(pred, 16, 0) | ||
return pred | ||
|
||
|
||
class TRTModel: | ||
TRT_LOGGER = trt.Logger(trt.Logger.ERROR) | ||
|
||
def __init__(self, onnx_file_path): | ||
""" Model for performing inference with TensorRT. | ||
Args: | ||
onnx_file_path (str): Path to ONNX file. | ||
""" | ||
self.input_shape = None | ||
self.output_shape = None | ||
self.engine = self.build_engine(onnx_file_path) | ||
self.context = self.engine.create_execution_context() | ||
|
||
self.out_cpu, self.in_gpu, self.out_gpu, = self.allocate_buffers(self.engine) | ||
|
||
def build_engine(self, onnx_file_path): | ||
""" Build TensorRT engine from ONNX file. | ||
Args: | ||
onnx_file_path (str): Path to ONNX file. | ||
Returns: | ||
tensorrt.ICudaEngine: TensoRT model as described by | ||
the input ONNX file. | ||
""" | ||
with trt.Builder( | ||
self.TRT_LOGGER | ||
) as builder, builder.create_network() as network, trt.OnnxParser( | ||
network, self.TRT_LOGGER | ||
) as parser: | ||
builder.max_workspace_size = 1 << 25 | ||
builder.max_batch_size = 1 | ||
with open(onnx_file_path, "rb") as f: | ||
if not parser.parse(f.read()): | ||
raise ValueError("ONNX file parsing failed") | ||
self.input_shape = network.get_input(0).shape | ||
self.output_shape = network.get_output(0).shape | ||
engine = builder.build_cuda_engine(network) | ||
return engine | ||
|
||
def infer(self, inputs): | ||
""" Perform inference. | ||
Args: | ||
inputs (np.ndarray): Input array of shape CxHxW | ||
Returns: | ||
np.ndarray: Prediction of shape specified by the network. | ||
""" | ||
assert inputs.shape == self.input_shape, "%s, %s" % ( | ||
inputs.shape, | ||
self.input_shape, | ||
) | ||
inputs = inputs.reshape(-1) | ||
cuda.memcpy_htod(self.in_gpu, inputs) | ||
self.context.execute(1, [int(self.in_gpu), int(self.out_gpu)]) | ||
cuda.memcpy_dtoh(self.out_cpu, self.out_gpu) | ||
return self.out_cpu.reshape(self.output_shape) | ||
|
||
def allocate_buffers(self, engine): | ||
""" Allocate required memory for model inference | ||
Args: | ||
engine (tensorrt.ICudaEngine): TensorRT Engine. | ||
Returns: | ||
Tuple[np.ndarray, | ||
pycuda._driver.DeviceAllocation, | ||
pycuda._driver.DeviceAllocation] | ||
Host output, device input, device output | ||
""" | ||
# host cpu memory | ||
h_in_size = trt.volume(engine.get_binding_shape(0)) | ||
h_out_size = trt.volume(engine.get_binding_shape(1)) | ||
h_in_dtype = trt.nptype(engine.get_binding_dtype(0)) | ||
h_out_dtype = trt.nptype(engine.get_binding_dtype(1)) | ||
in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype) | ||
out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype) | ||
|
||
# allocate gpu memory | ||
in_gpu = cuda.mem_alloc(in_cpu.nbytes) | ||
out_gpu = cuda.mem_alloc(out_cpu.nbytes) | ||
# stream = cuda.Stream() | ||
return out_cpu, in_gpu, out_gpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#! /usr/bin/env python | ||
|
||
################################################################################################### | ||
# Distribution authorized to U.S. Government agencies and their contractors. Other requests for | ||
# this document shall be referred to the MIT Lincoln Laboratory Technology Office. | ||
# | ||
# This material is based upon work supported by the Under Secretary of Defense for Research and | ||
# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions | ||
# or recommendations expressed in this material are those of the author(s) and do not necessarily | ||
# reflect the views of the Under Secretary of Defense for Research and Engineering. | ||
# | ||
|
||
# (c) 2019 Massachusetts Institute of Technology. | ||
# | ||
# The software/firmware is provided to you on an As-Is basis | ||
# | ||
# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 | ||
# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work | ||
# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other | ||
# than as specifically authorized by the U.S. Government may violate any copyrights that exist in | ||
# this work. | ||
################################################################################################### | ||
|
||
import rospy | ||
from sensor_msgs.msg import Image | ||
import numpy as np | ||
from cv_bridge import CvBridge | ||
|
||
from semantic_segmentation_ros.models import get_model | ||
from semantic_segmentation_ros.utils import get_debug_image | ||
|
||
|
||
class SemanticSegmentation: | ||
n_channels = 3 | ||
|
||
def __init__(self): | ||
model = rospy.get_param("~model", "") | ||
model_weight_path = rospy.get_param("~weight_file", "") | ||
self.publish_debug_image = rospy.get_param("~debug_image", True) | ||
|
||
if model is "" or model_weight_path is "": | ||
raise ValueError("Must provide value for param `~model` and `~model_path`") | ||
|
||
self.model = get_model(model, model_weight_path) | ||
|
||
rospy.loginfo("model initialized") | ||
|
||
self.image_subscriber = rospy.Subscriber( | ||
"/image/image_raw", Image, self.image_callback | ||
) | ||
self.image_publisher = rospy.Publisher("/prediction", Image, queue_size=10) | ||
|
||
if self.publish_debug_image: | ||
self.debug_image_publisher = rospy.Publisher( | ||
"/prediction_debug", Image, queue_size=10 | ||
) | ||
|
||
self.cv_bridge = CvBridge() | ||
self.last_image_timestamp = None | ||
self.predict = False | ||
|
||
self.spin() | ||
|
||
def image_callback(self, image_msg): | ||
self.last_image_timestamp = self.decode_image(image_msg) | ||
self.predict = True | ||
|
||
def spin(self): | ||
while not rospy.is_shutdown(): | ||
if self.predict: | ||
img, timestamp = self.last_image_timestamp | ||
prediction = self.model.infer(img) | ||
prediction = prediction.astype(np.uint8) | ||
|
||
self.image_publisher.publish( | ||
self.get_image_message(prediction, timestamp) | ||
) | ||
|
||
if self.publish_debug_image: | ||
self.debug_image_publisher.publish( | ||
self.get_image_message( | ||
get_debug_image(prediction), timestamp, "rgb8" | ||
) | ||
) | ||
|
||
self.predict = False | ||
|
||
def decode_image(self, image_msg): | ||
height, width = image_msg.height, image_msg.width | ||
img = ( | ||
np.frombuffer(image_msg.data, dtype=np.uint8).reshape( | ||
(height, width, self.n_channels) | ||
) | ||
/ 255.0 | ||
) | ||
return img, image_msg.header.stamp | ||
|
||
def get_image_message(self, image, timestamp, encoding="mono8"): | ||
img_msg = self.cv_bridge.cv2_to_imgmsg(image.astype(np.uint8), encoding) | ||
img_msg.header.stamp = timestamp | ||
return img_msg | ||
|
||
|
||
if __name__ == "__main__": | ||
rospy.init_node("SemanticSegmentation_node") | ||
node = SemanticSegmentation() | ||
rospy.spin() |
Oops, something went wrong.