diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100755
index 0000000..ad9bc80
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,10 @@
+cmake_minimum_required(VERSION 2.8.3)
+project(semantic_segmentation_ros)
+
+find_package(catkin_simple REQUIRED)
+
+catkin_python_setup()
+catkin_simple()
+
+cs_install()
+cs_export()
diff --git a/docs/tesse-semantic-segmentation.gif b/docs/tesse-semantic-segmentation.gif
new file mode 100644
index 0000000..fe2a70c
Binary files /dev/null and b/docs/tesse-semantic-segmentation.gif differ
diff --git a/launch/semantic_segmentation_tesse.launch b/launch/semantic_segmentation_tesse.launch
new file mode 100644
index 0000000..a91808d
--- /dev/null
+++ b/launch/semantic_segmentation_tesse.launch
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/package.xml b/package.xml
new file mode 100755
index 0000000..dcdca52
--- /dev/null
+++ b/package.xml
@@ -0,0 +1,17 @@
+
+
+ semantic_segmentation_ros
+ 0.0.1
+ Semantic Segmentation in ROS
+
+ MIT
+ Zac Ravichandran
+ catkin
+ catkin_simple
+
+
+ rospy
+ cv_bridge
+ sensor_msgs
+
+
diff --git a/setup.py b/setup.py
new file mode 100755
index 0000000..efb2c5f
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,30 @@
+###################################################################################################
+# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
+# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
+#
+# This material is based upon work supported by the Under Secretary of Defense for Research and
+# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
+# or recommendations expressed in this material are those of the author(s) and do not necessarily
+# reflect the views of the Under Secretary of Defense for Research and Engineering.
+#
+
+# (c) 2019 Massachusetts Institute of Technology.
+#
+# The software/firmware is provided to you on an As-Is basis
+#
+# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
+# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
+# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
+# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
+# this work.
+###################################################################################################
+
+from setuptools import setup
+
+setup(
+ name="semantic_segmentation_ros",
+ version='0.0.1',
+ description="semantifc segmentation in ROS",
+ packages=["semantic_segmentation_ros"],
+ package_dir={'': 'src'}
+)
diff --git a/src/semantic_segmentation_ros/__init__.py b/src/semantic_segmentation_ros/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/semantic_segmentation_ros/models.py b/src/semantic_segmentation_ros/models.py
new file mode 100644
index 0000000..3d0cdaa
--- /dev/null
+++ b/src/semantic_segmentation_ros/models.py
@@ -0,0 +1,135 @@
+###################################################################################################
+# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
+# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
+#
+# This material is based upon work supported by the Under Secretary of Defense for Research and
+# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
+# or recommendations expressed in this material are those of the author(s) and do not necessarily
+# reflect the views of the Under Secretary of Defense for Research and Engineering.
+#
+
+# (c) 2019 Massachusetts Institute of Technology.
+#
+# The software/firmware is provided to you on an As-Is basis
+#
+# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
+# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
+# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
+# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
+# this work.
+###################################################################################################
+
+import numpy as np
+import tensorrt as trt
+import time
+
+import pycuda.autoinit
+import pycuda.driver as cuda
+
+from semantic_segmentation_ros.utils import pad_image, unpad_image
+
+
+def get_model(model, weight_file):
+ if model == "TRTModel":
+ return TesseTRTModel(weight_file)
+ else:
+ raise ValueError("Currently only TensorRT models are supported")
+
+
+class TesseTRTModel:
+ def __init__(self, onnx_file_path):
+ self.trt_model = TRTModel(onnx_file_path)
+
+ def infer(self, image):
+ image = pad_image(image, 16, 0)
+ image = image.transpose(2, 0, 1).astype(np.float32)
+ pred = self.trt_model.infer(image)
+ pred = unpad_image(pred, 16, 0)
+ return pred
+
+
+class TRTModel:
+ TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
+
+ def __init__(self, onnx_file_path):
+ """ Model for performing inference with TensorRT.
+
+ Args:
+ onnx_file_path (str): Path to ONNX file.
+ """
+ self.input_shape = None
+ self.output_shape = None
+ self.engine = self.build_engine(onnx_file_path)
+ self.context = self.engine.create_execution_context()
+
+ self.out_cpu, self.in_gpu, self.out_gpu, = self.allocate_buffers(self.engine)
+
+ def build_engine(self, onnx_file_path):
+ """ Build TensorRT engine from ONNX file.
+
+ Args:
+ onnx_file_path (str): Path to ONNX file.
+
+ Returns:
+ tensorrt.ICudaEngine: TensoRT model as described by
+ the input ONNX file.
+ """
+ with trt.Builder(
+ self.TRT_LOGGER
+ ) as builder, builder.create_network() as network, trt.OnnxParser(
+ network, self.TRT_LOGGER
+ ) as parser:
+ builder.max_workspace_size = 1 << 25
+ builder.max_batch_size = 1
+ with open(onnx_file_path, "rb") as f:
+ if not parser.parse(f.read()):
+ raise ValueError("ONNX file parsing failed")
+ self.input_shape = network.get_input(0).shape
+ self.output_shape = network.get_output(0).shape
+ engine = builder.build_cuda_engine(network)
+ return engine
+
+ def infer(self, inputs):
+ """ Perform inference.
+
+ Args:
+ inputs (np.ndarray): Input array of shape CxHxW
+
+ Returns:
+ np.ndarray: Prediction of shape specified by the network.
+ """
+ assert inputs.shape == self.input_shape, "%s, %s" % (
+ inputs.shape,
+ self.input_shape,
+ )
+ inputs = inputs.reshape(-1)
+ cuda.memcpy_htod(self.in_gpu, inputs)
+ self.context.execute(1, [int(self.in_gpu), int(self.out_gpu)])
+ cuda.memcpy_dtoh(self.out_cpu, self.out_gpu)
+ return self.out_cpu.reshape(self.output_shape)
+
+ def allocate_buffers(self, engine):
+ """ Allocate required memory for model inference
+
+ Args:
+ engine (tensorrt.ICudaEngine): TensorRT Engine.
+
+ Returns:
+ Tuple[np.ndarray,
+ pycuda._driver.DeviceAllocation,
+ pycuda._driver.DeviceAllocation]
+ Host output, device input, device output
+ """
+ # host cpu memory
+ h_in_size = trt.volume(engine.get_binding_shape(0))
+ h_out_size = trt.volume(engine.get_binding_shape(1))
+ h_in_dtype = trt.nptype(engine.get_binding_dtype(0))
+ h_out_dtype = trt.nptype(engine.get_binding_dtype(1))
+ in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype)
+ out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype)
+
+ # allocate gpu memory
+ in_gpu = cuda.mem_alloc(in_cpu.nbytes)
+ out_gpu = cuda.mem_alloc(out_cpu.nbytes)
+ # stream = cuda.Stream()
+ return out_cpu, in_gpu, out_gpu
diff --git a/src/semantic_segmentation_ros/semantic_segmentation.py b/src/semantic_segmentation_ros/semantic_segmentation.py
new file mode 100755
index 0000000..f84c3e7
--- /dev/null
+++ b/src/semantic_segmentation_ros/semantic_segmentation.py
@@ -0,0 +1,107 @@
+#! /usr/bin/env python
+
+###################################################################################################
+# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
+# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
+#
+# This material is based upon work supported by the Under Secretary of Defense for Research and
+# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
+# or recommendations expressed in this material are those of the author(s) and do not necessarily
+# reflect the views of the Under Secretary of Defense for Research and Engineering.
+#
+
+# (c) 2019 Massachusetts Institute of Technology.
+#
+# The software/firmware is provided to you on an As-Is basis
+#
+# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
+# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
+# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
+# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
+# this work.
+###################################################################################################
+
+import rospy
+from sensor_msgs.msg import Image
+import numpy as np
+from cv_bridge import CvBridge
+
+from semantic_segmentation_ros.models import get_model
+from semantic_segmentation_ros.utils import get_debug_image
+
+
+class SemanticSegmentation:
+ n_channels = 3
+
+ def __init__(self):
+ model = rospy.get_param("~model", "")
+ model_weight_path = rospy.get_param("~weight_file", "")
+ self.publish_debug_image = rospy.get_param("~debug_image", True)
+
+ if model is "" or model_weight_path is "":
+ raise ValueError("Must provide value for param `~model` and `~model_path`")
+
+ self.model = get_model(model, model_weight_path)
+
+ rospy.loginfo("model initialized")
+
+ self.image_subscriber = rospy.Subscriber(
+ "/image/image_raw", Image, self.image_callback
+ )
+ self.image_publisher = rospy.Publisher("/prediction", Image, queue_size=10)
+
+ if self.publish_debug_image:
+ self.debug_image_publisher = rospy.Publisher(
+ "/prediction_debug", Image, queue_size=10
+ )
+
+ self.cv_bridge = CvBridge()
+ self.last_image_timestamp = None
+ self.predict = False
+
+ self.spin()
+
+ def image_callback(self, image_msg):
+ self.last_image_timestamp = self.decode_image(image_msg)
+ self.predict = True
+
+ def spin(self):
+ while not rospy.is_shutdown():
+ if self.predict:
+ img, timestamp = self.last_image_timestamp
+ prediction = self.model.infer(img)
+ prediction = prediction.astype(np.uint8)
+
+ self.image_publisher.publish(
+ self.get_image_message(prediction, timestamp)
+ )
+
+ if self.publish_debug_image:
+ self.debug_image_publisher.publish(
+ self.get_image_message(
+ get_debug_image(prediction), timestamp, "rgb8"
+ )
+ )
+
+ self.predict = False
+
+ def decode_image(self, image_msg):
+ height, width = image_msg.height, image_msg.width
+ img = (
+ np.frombuffer(image_msg.data, dtype=np.uint8).reshape(
+ (height, width, self.n_channels)
+ )
+ / 255.0
+ )
+ return img, image_msg.header.stamp
+
+ def get_image_message(self, image, timestamp, encoding="mono8"):
+ img_msg = self.cv_bridge.cv2_to_imgmsg(image.astype(np.uint8), encoding)
+ img_msg.header.stamp = timestamp
+ return img_msg
+
+
+if __name__ == "__main__":
+ rospy.init_node("SemanticSegmentation_node")
+ node = SemanticSegmentation()
+ rospy.spin()
diff --git a/src/semantic_segmentation_ros/utils.py b/src/semantic_segmentation_ros/utils.py
new file mode 100644
index 0000000..d83e7d7
--- /dev/null
+++ b/src/semantic_segmentation_ros/utils.py
@@ -0,0 +1,121 @@
+###################################################################################################
+# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
+# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
+#
+# This material is based upon work supported by the Under Secretary of Defense for Research and
+# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
+# or recommendations expressed in this material are those of the author(s) and do not necessarily
+# reflect the views of the Under Secretary of Defense for Research and Engineering.
+#
+
+# (c) 2019 Massachusetts Institute of Technology.
+#
+# The software/firmware is provided to you on an As-Is basis
+#
+# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
+# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
+# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
+# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
+# this work.
+###################################################################################################
+
+import numpy as np
+
+
+# generic labels from the cityscapes dataset
+SEGMENTATION_COLORS = np.array(
+ [
+ [128, 64, 128],
+ [244, 35, 2320],
+ [70, 70, 70],
+ [102, 102, 156],
+ [190, 153, 153],
+ [153, 153, 153],
+ [250, 170, 30],
+ [220, 220, 0],
+ [107, 142, 35],
+ [152, 251, 152],
+ [70, 130, 180],
+ [220, 20, 60],
+ [255, 0, 0],
+ [0, 0, 142],
+ [0, 0, 70],
+ [0, 60, 100],
+ [0, 80, 100],
+ [0, 0, 230],
+ [119, 11, 32],
+ ]
+)
+
+
+# colors used in TESSE v53
+SEGMENTATION_COLORS_V53 = np.array(
+ [
+ [0, 171, 143],
+ [1, 155, 211],
+ [2, 222, 110],
+ [3, 69, 227],
+ [4, 218, 221],
+ [5, 81, 38],
+ [6, 229, 176],
+ [7, 106, 60],
+ [8, 179, 10],
+ [9, 118, 90],
+ [10, 138, 80],
+ ]
+)
+
+
+def get_debug_image(image):
+ """ Turn color code labels.
+
+ Args:
+ image (np.ndarray): HxW label image.
+
+ Returns:
+ np.ndarray: Color coded RGB image.
+ """
+ labels = np.unique(image)
+ if labels.shape[0] > SEGMENTATION_COLORS.shape[0]:
+ raise ValueError("Need more segmentation colors")
+
+ color_image = np.zeros(image.shape + (3,))
+ for i, label in enumerate(labels):
+ color_image[np.where(image == label)] = SEGMENTATION_COLORS_V53[label]
+ return color_image
+
+
+def pad_image(img, h_pad, w_pad):
+ """Add a padding of (`h_pad`, `w_pad`) to `img`
+
+ Args:
+ img (np.ndarray): Array of shape (h, w, c)
+ h_pad (int): Total width padding
+ w_pad (int): Total width padding
+
+ Returns:
+ np.ndarray: Array of shape (h+h_pad, w+w_pad, c)
+ """
+ assert h_pad % 2 == 0 and w_pad % 2 == 0
+ padding = ((h_pad // 2, h_pad // 2), (w_pad // 2, w_pad // 2))
+
+ if len(img.shape) == 3:
+ padding += ((0, 0),)
+
+ return np.pad(img, padding, mode="constant")
+
+
+def unpad_image(img, h_pad, w_pad):
+ """Remove the edge (`h_pad`, `w_pad`) indicies of `img`
+
+ Args:
+ img (np.ndarray): Array of shape (h+h_pad, w+w_pad, c)
+ h_pad (int): Height padding to remove.
+ w_pad (int): Width padding to remove.
+
+ Returns:
+ np.ndarray: Array of shape (h, w, c)
+ """
+ assert h_pad % 2 == 0 and w_pad % 2 == 0
+ h, w = img.shape[:2]
+ return img[h_pad // 2 : h - (h_pad // 2), w_pad // 2 : w - (w_pad // 2)]