diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100755 index 0000000..ad9bc80 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,10 @@ +cmake_minimum_required(VERSION 2.8.3) +project(semantic_segmentation_ros) + +find_package(catkin_simple REQUIRED) + +catkin_python_setup() +catkin_simple() + +cs_install() +cs_export() diff --git a/docs/tesse-semantic-segmentation.gif b/docs/tesse-semantic-segmentation.gif new file mode 100644 index 0000000..fe2a70c Binary files /dev/null and b/docs/tesse-semantic-segmentation.gif differ diff --git a/launch/semantic_segmentation_tesse.launch b/launch/semantic_segmentation_tesse.launch new file mode 100644 index 0000000..a91808d --- /dev/null +++ b/launch/semantic_segmentation_tesse.launch @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/package.xml b/package.xml new file mode 100755 index 0000000..dcdca52 --- /dev/null +++ b/package.xml @@ -0,0 +1,17 @@ + + + semantic_segmentation_ros + 0.0.1 + Semantic Segmentation in ROS + + MIT + Zac Ravichandran + catkin + catkin_simple + + + rospy + cv_bridge + sensor_msgs + + diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..efb2c5f --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +################################################################################################### +# Distribution authorized to U.S. Government agencies and their contractors. Other requests for +# this document shall be referred to the MIT Lincoln Laboratory Technology Office. +# +# This material is based upon work supported by the Under Secretary of Defense for Research and +# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions +# or recommendations expressed in this material are those of the author(s) and do not necessarily +# reflect the views of the Under Secretary of Defense for Research and Engineering. +# + +# (c) 2019 Massachusetts Institute of Technology. +# +# The software/firmware is provided to you on an As-Is basis +# +# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 +# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work +# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other +# than as specifically authorized by the U.S. Government may violate any copyrights that exist in +# this work. +################################################################################################### + +from setuptools import setup + +setup( + name="semantic_segmentation_ros", + version='0.0.1', + description="semantifc segmentation in ROS", + packages=["semantic_segmentation_ros"], + package_dir={'': 'src'} +) diff --git a/src/semantic_segmentation_ros/__init__.py b/src/semantic_segmentation_ros/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/semantic_segmentation_ros/models.py b/src/semantic_segmentation_ros/models.py new file mode 100644 index 0000000..3d0cdaa --- /dev/null +++ b/src/semantic_segmentation_ros/models.py @@ -0,0 +1,135 @@ +################################################################################################### +# Distribution authorized to U.S. Government agencies and their contractors. Other requests for +# this document shall be referred to the MIT Lincoln Laboratory Technology Office. +# +# This material is based upon work supported by the Under Secretary of Defense for Research and +# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions +# or recommendations expressed in this material are those of the author(s) and do not necessarily +# reflect the views of the Under Secretary of Defense for Research and Engineering. +# + +# (c) 2019 Massachusetts Institute of Technology. +# +# The software/firmware is provided to you on an As-Is basis +# +# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 +# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work +# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other +# than as specifically authorized by the U.S. Government may violate any copyrights that exist in +# this work. +################################################################################################### + +import numpy as np +import tensorrt as trt +import time + +import pycuda.autoinit +import pycuda.driver as cuda + +from semantic_segmentation_ros.utils import pad_image, unpad_image + + +def get_model(model, weight_file): + if model == "TRTModel": + return TesseTRTModel(weight_file) + else: + raise ValueError("Currently only TensorRT models are supported") + + +class TesseTRTModel: + def __init__(self, onnx_file_path): + self.trt_model = TRTModel(onnx_file_path) + + def infer(self, image): + image = pad_image(image, 16, 0) + image = image.transpose(2, 0, 1).astype(np.float32) + pred = self.trt_model.infer(image) + pred = unpad_image(pred, 16, 0) + return pred + + +class TRTModel: + TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + def __init__(self, onnx_file_path): + """ Model for performing inference with TensorRT. + + Args: + onnx_file_path (str): Path to ONNX file. + """ + self.input_shape = None + self.output_shape = None + self.engine = self.build_engine(onnx_file_path) + self.context = self.engine.create_execution_context() + + self.out_cpu, self.in_gpu, self.out_gpu, = self.allocate_buffers(self.engine) + + def build_engine(self, onnx_file_path): + """ Build TensorRT engine from ONNX file. + + Args: + onnx_file_path (str): Path to ONNX file. + + Returns: + tensorrt.ICudaEngine: TensoRT model as described by + the input ONNX file. + """ + with trt.Builder( + self.TRT_LOGGER + ) as builder, builder.create_network() as network, trt.OnnxParser( + network, self.TRT_LOGGER + ) as parser: + builder.max_workspace_size = 1 << 25 + builder.max_batch_size = 1 + with open(onnx_file_path, "rb") as f: + if not parser.parse(f.read()): + raise ValueError("ONNX file parsing failed") + self.input_shape = network.get_input(0).shape + self.output_shape = network.get_output(0).shape + engine = builder.build_cuda_engine(network) + return engine + + def infer(self, inputs): + """ Perform inference. + + Args: + inputs (np.ndarray): Input array of shape CxHxW + + Returns: + np.ndarray: Prediction of shape specified by the network. + """ + assert inputs.shape == self.input_shape, "%s, %s" % ( + inputs.shape, + self.input_shape, + ) + inputs = inputs.reshape(-1) + cuda.memcpy_htod(self.in_gpu, inputs) + self.context.execute(1, [int(self.in_gpu), int(self.out_gpu)]) + cuda.memcpy_dtoh(self.out_cpu, self.out_gpu) + return self.out_cpu.reshape(self.output_shape) + + def allocate_buffers(self, engine): + """ Allocate required memory for model inference + + Args: + engine (tensorrt.ICudaEngine): TensorRT Engine. + + Returns: + Tuple[np.ndarray, + pycuda._driver.DeviceAllocation, + pycuda._driver.DeviceAllocation] + Host output, device input, device output + """ + # host cpu memory + h_in_size = trt.volume(engine.get_binding_shape(0)) + h_out_size = trt.volume(engine.get_binding_shape(1)) + h_in_dtype = trt.nptype(engine.get_binding_dtype(0)) + h_out_dtype = trt.nptype(engine.get_binding_dtype(1)) + in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype) + out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype) + + # allocate gpu memory + in_gpu = cuda.mem_alloc(in_cpu.nbytes) + out_gpu = cuda.mem_alloc(out_cpu.nbytes) + # stream = cuda.Stream() + return out_cpu, in_gpu, out_gpu diff --git a/src/semantic_segmentation_ros/semantic_segmentation.py b/src/semantic_segmentation_ros/semantic_segmentation.py new file mode 100755 index 0000000..f84c3e7 --- /dev/null +++ b/src/semantic_segmentation_ros/semantic_segmentation.py @@ -0,0 +1,107 @@ +#! /usr/bin/env python + +################################################################################################### +# Distribution authorized to U.S. Government agencies and their contractors. Other requests for +# this document shall be referred to the MIT Lincoln Laboratory Technology Office. +# +# This material is based upon work supported by the Under Secretary of Defense for Research and +# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions +# or recommendations expressed in this material are those of the author(s) and do not necessarily +# reflect the views of the Under Secretary of Defense for Research and Engineering. +# + +# (c) 2019 Massachusetts Institute of Technology. +# +# The software/firmware is provided to you on an As-Is basis +# +# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 +# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work +# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other +# than as specifically authorized by the U.S. Government may violate any copyrights that exist in +# this work. +################################################################################################### + +import rospy +from sensor_msgs.msg import Image +import numpy as np +from cv_bridge import CvBridge + +from semantic_segmentation_ros.models import get_model +from semantic_segmentation_ros.utils import get_debug_image + + +class SemanticSegmentation: + n_channels = 3 + + def __init__(self): + model = rospy.get_param("~model", "") + model_weight_path = rospy.get_param("~weight_file", "") + self.publish_debug_image = rospy.get_param("~debug_image", True) + + if model is "" or model_weight_path is "": + raise ValueError("Must provide value for param `~model` and `~model_path`") + + self.model = get_model(model, model_weight_path) + + rospy.loginfo("model initialized") + + self.image_subscriber = rospy.Subscriber( + "/image/image_raw", Image, self.image_callback + ) + self.image_publisher = rospy.Publisher("/prediction", Image, queue_size=10) + + if self.publish_debug_image: + self.debug_image_publisher = rospy.Publisher( + "/prediction_debug", Image, queue_size=10 + ) + + self.cv_bridge = CvBridge() + self.last_image_timestamp = None + self.predict = False + + self.spin() + + def image_callback(self, image_msg): + self.last_image_timestamp = self.decode_image(image_msg) + self.predict = True + + def spin(self): + while not rospy.is_shutdown(): + if self.predict: + img, timestamp = self.last_image_timestamp + prediction = self.model.infer(img) + prediction = prediction.astype(np.uint8) + + self.image_publisher.publish( + self.get_image_message(prediction, timestamp) + ) + + if self.publish_debug_image: + self.debug_image_publisher.publish( + self.get_image_message( + get_debug_image(prediction), timestamp, "rgb8" + ) + ) + + self.predict = False + + def decode_image(self, image_msg): + height, width = image_msg.height, image_msg.width + img = ( + np.frombuffer(image_msg.data, dtype=np.uint8).reshape( + (height, width, self.n_channels) + ) + / 255.0 + ) + return img, image_msg.header.stamp + + def get_image_message(self, image, timestamp, encoding="mono8"): + img_msg = self.cv_bridge.cv2_to_imgmsg(image.astype(np.uint8), encoding) + img_msg.header.stamp = timestamp + return img_msg + + +if __name__ == "__main__": + rospy.init_node("SemanticSegmentation_node") + node = SemanticSegmentation() + rospy.spin() diff --git a/src/semantic_segmentation_ros/utils.py b/src/semantic_segmentation_ros/utils.py new file mode 100644 index 0000000..d83e7d7 --- /dev/null +++ b/src/semantic_segmentation_ros/utils.py @@ -0,0 +1,121 @@ +################################################################################################### +# Distribution authorized to U.S. Government agencies and their contractors. Other requests for +# this document shall be referred to the MIT Lincoln Laboratory Technology Office. +# +# This material is based upon work supported by the Under Secretary of Defense for Research and +# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions +# or recommendations expressed in this material are those of the author(s) and do not necessarily +# reflect the views of the Under Secretary of Defense for Research and Engineering. +# + +# (c) 2019 Massachusetts Institute of Technology. +# +# The software/firmware is provided to you on an As-Is basis +# +# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 +# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work +# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other +# than as specifically authorized by the U.S. Government may violate any copyrights that exist in +# this work. +################################################################################################### + +import numpy as np + + +# generic labels from the cityscapes dataset +SEGMENTATION_COLORS = np.array( + [ + [128, 64, 128], + [244, 35, 2320], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [70, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + ] +) + + +# colors used in TESSE v53 +SEGMENTATION_COLORS_V53 = np.array( + [ + [0, 171, 143], + [1, 155, 211], + [2, 222, 110], + [3, 69, 227], + [4, 218, 221], + [5, 81, 38], + [6, 229, 176], + [7, 106, 60], + [8, 179, 10], + [9, 118, 90], + [10, 138, 80], + ] +) + + +def get_debug_image(image): + """ Turn color code labels. + + Args: + image (np.ndarray): HxW label image. + + Returns: + np.ndarray: Color coded RGB image. + """ + labels = np.unique(image) + if labels.shape[0] > SEGMENTATION_COLORS.shape[0]: + raise ValueError("Need more segmentation colors") + + color_image = np.zeros(image.shape + (3,)) + for i, label in enumerate(labels): + color_image[np.where(image == label)] = SEGMENTATION_COLORS_V53[label] + return color_image + + +def pad_image(img, h_pad, w_pad): + """Add a padding of (`h_pad`, `w_pad`) to `img` + + Args: + img (np.ndarray): Array of shape (h, w, c) + h_pad (int): Total width padding + w_pad (int): Total width padding + + Returns: + np.ndarray: Array of shape (h+h_pad, w+w_pad, c) + """ + assert h_pad % 2 == 0 and w_pad % 2 == 0 + padding = ((h_pad // 2, h_pad // 2), (w_pad // 2, w_pad // 2)) + + if len(img.shape) == 3: + padding += ((0, 0),) + + return np.pad(img, padding, mode="constant") + + +def unpad_image(img, h_pad, w_pad): + """Remove the edge (`h_pad`, `w_pad`) indicies of `img` + + Args: + img (np.ndarray): Array of shape (h+h_pad, w+w_pad, c) + h_pad (int): Height padding to remove. + w_pad (int): Width padding to remove. + + Returns: + np.ndarray: Array of shape (h, w, c) + """ + assert h_pad % 2 == 0 and w_pad % 2 == 0 + h, w = img.shape[:2] + return img[h_pad // 2 : h - (h_pad // 2), w_pad // 2 : w - (w_pad // 2)]