Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Zachary Ravichandran committed Jan 17, 2020
0 parents commit 705afd1
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 0 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cmake_minimum_required(VERSION 2.8.3)
project(semantic_segmentation_ros)

find_package(catkin_simple REQUIRED)

catkin_python_setup()
catkin_simple()

cs_install()
cs_export()
Binary file added docs/tesse-semantic-segmentation.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions launch/semantic_segmentation_tesse.launch
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<?xml version="1.0" encoding="ISO-8859-15"?>
<launch>
<arg name="model" default="TRTModel"/>
<arg name="weight_file" default="$(find semantic_segmentation_ros)/cfg/v5.3/v53-unet-resnet18.onnx"/>

<node name="semantic_segmentation" pkg="semantic_segmentation_ros" type="semantic_segmentation.py" output="screen">
<param name="model" value="$(arg model)"/>
<param name="weight_file" value="$(arg weight_file)"/>
<remap from="/image/image_raw" to="/tesse/left_cam_rgb/image_raw"/>
</node>
</launch>
17 changes: 17 additions & 0 deletions package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<?xml version="1.0"?>
<package>
<name>semantic_segmentation_ros</name>
<version>0.0.1</version>
<description>Semantic Segmentation in ROS</description>

<license>MIT</license>
<maintainer email="[email protected]">Zac Ravichandran</maintainer>
<buildtool_depend>catkin</buildtool_depend>
<buildtool_depend>catkin_simple</buildtool_depend>

<!-- Package build dependencies -->
<build_depend>rospy</build_depend>
<build_depend>cv_bridge</build_depend>
<build_depend>sensor_msgs</build_depend>

</package>
30 changes: 30 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
###################################################################################################
# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
#
# This material is based upon work supported by the Under Secretary of Defense for Research and
# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
# or recommendations expressed in this material are those of the author(s) and do not necessarily
# reflect the views of the Under Secretary of Defense for Research and Engineering.
#

# (c) 2019 Massachusetts Institute of Technology.
#
# The software/firmware is provided to you on an As-Is basis
#
# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
# this work.
###################################################################################################

from setuptools import setup

setup(
name="semantic_segmentation_ros",
version='0.0.1',
description="semantifc segmentation in ROS",
packages=["semantic_segmentation_ros"],
package_dir={'': 'src'}
)
Empty file.
135 changes: 135 additions & 0 deletions src/semantic_segmentation_ros/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
###################################################################################################
# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
#
# This material is based upon work supported by the Under Secretary of Defense for Research and
# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
# or recommendations expressed in this material are those of the author(s) and do not necessarily
# reflect the views of the Under Secretary of Defense for Research and Engineering.
#

# (c) 2019 Massachusetts Institute of Technology.
#
# The software/firmware is provided to you on an As-Is basis
#
# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
# this work.
###################################################################################################

import numpy as np
import tensorrt as trt
import time

import pycuda.autoinit
import pycuda.driver as cuda

from semantic_segmentation_ros.utils import pad_image, unpad_image


def get_model(model, weight_file):
if model == "TRTModel":
return TesseTRTModel(weight_file)
else:
raise ValueError("Currently only TensorRT models are supported")


class TesseTRTModel:
def __init__(self, onnx_file_path):
self.trt_model = TRTModel(onnx_file_path)

def infer(self, image):
image = pad_image(image, 16, 0)
image = image.transpose(2, 0, 1).astype(np.float32)
pred = self.trt_model.infer(image)
pred = unpad_image(pred, 16, 0)
return pred


class TRTModel:
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)

def __init__(self, onnx_file_path):
""" Model for performing inference with TensorRT.
Args:
onnx_file_path (str): Path to ONNX file.
"""
self.input_shape = None
self.output_shape = None
self.engine = self.build_engine(onnx_file_path)
self.context = self.engine.create_execution_context()

self.out_cpu, self.in_gpu, self.out_gpu, = self.allocate_buffers(self.engine)

def build_engine(self, onnx_file_path):
""" Build TensorRT engine from ONNX file.
Args:
onnx_file_path (str): Path to ONNX file.
Returns:
tensorrt.ICudaEngine: TensoRT model as described by
the input ONNX file.
"""
with trt.Builder(
self.TRT_LOGGER
) as builder, builder.create_network() as network, trt.OnnxParser(
network, self.TRT_LOGGER
) as parser:
builder.max_workspace_size = 1 << 25
builder.max_batch_size = 1
with open(onnx_file_path, "rb") as f:
if not parser.parse(f.read()):
raise ValueError("ONNX file parsing failed")
self.input_shape = network.get_input(0).shape
self.output_shape = network.get_output(0).shape
engine = builder.build_cuda_engine(network)
return engine

def infer(self, inputs):
""" Perform inference.
Args:
inputs (np.ndarray): Input array of shape CxHxW
Returns:
np.ndarray: Prediction of shape specified by the network.
"""
assert inputs.shape == self.input_shape, "%s, %s" % (
inputs.shape,
self.input_shape,
)
inputs = inputs.reshape(-1)
cuda.memcpy_htod(self.in_gpu, inputs)
self.context.execute(1, [int(self.in_gpu), int(self.out_gpu)])
cuda.memcpy_dtoh(self.out_cpu, self.out_gpu)
return self.out_cpu.reshape(self.output_shape)

def allocate_buffers(self, engine):
""" Allocate required memory for model inference
Args:
engine (tensorrt.ICudaEngine): TensorRT Engine.
Returns:
Tuple[np.ndarray,
pycuda._driver.DeviceAllocation,
pycuda._driver.DeviceAllocation]
Host output, device input, device output
"""
# host cpu memory
h_in_size = trt.volume(engine.get_binding_shape(0))
h_out_size = trt.volume(engine.get_binding_shape(1))
h_in_dtype = trt.nptype(engine.get_binding_dtype(0))
h_out_dtype = trt.nptype(engine.get_binding_dtype(1))
in_cpu = cuda.pagelocked_empty(h_in_size, h_in_dtype)
out_cpu = cuda.pagelocked_empty(h_out_size, h_out_dtype)

# allocate gpu memory
in_gpu = cuda.mem_alloc(in_cpu.nbytes)
out_gpu = cuda.mem_alloc(out_cpu.nbytes)
# stream = cuda.Stream()
return out_cpu, in_gpu, out_gpu
107 changes: 107 additions & 0 deletions src/semantic_segmentation_ros/semantic_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#! /usr/bin/env python

###################################################################################################
# Distribution authorized to U.S. Government agencies and their contractors. Other requests for
# this document shall be referred to the MIT Lincoln Laboratory Technology Office.
#
# This material is based upon work supported by the Under Secretary of Defense for Research and
# Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions
# or recommendations expressed in this material are those of the author(s) and do not necessarily
# reflect the views of the Under Secretary of Defense for Research and Engineering.
#

# (c) 2019 Massachusetts Institute of Technology.
#
# The software/firmware is provided to you on an As-Is basis
#
# Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013
# or 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work
# are defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other
# than as specifically authorized by the U.S. Government may violate any copyrights that exist in
# this work.
###################################################################################################

import rospy
from sensor_msgs.msg import Image
import numpy as np
from cv_bridge import CvBridge

from semantic_segmentation_ros.models import get_model
from semantic_segmentation_ros.utils import get_debug_image


class SemanticSegmentation:
n_channels = 3

def __init__(self):
model = rospy.get_param("~model", "")
model_weight_path = rospy.get_param("~weight_file", "")
self.publish_debug_image = rospy.get_param("~debug_image", True)

if model is "" or model_weight_path is "":
raise ValueError("Must provide value for param `~model` and `~model_path`")

self.model = get_model(model, model_weight_path)

rospy.loginfo("model initialized")

self.image_subscriber = rospy.Subscriber(
"/image/image_raw", Image, self.image_callback
)
self.image_publisher = rospy.Publisher("/prediction", Image, queue_size=10)

if self.publish_debug_image:
self.debug_image_publisher = rospy.Publisher(
"/prediction_debug", Image, queue_size=10
)

self.cv_bridge = CvBridge()
self.last_image_timestamp = None
self.predict = False

self.spin()

def image_callback(self, image_msg):
self.last_image_timestamp = self.decode_image(image_msg)
self.predict = True

def spin(self):
while not rospy.is_shutdown():
if self.predict:
img, timestamp = self.last_image_timestamp
prediction = self.model.infer(img)
prediction = prediction.astype(np.uint8)

self.image_publisher.publish(
self.get_image_message(prediction, timestamp)
)

if self.publish_debug_image:
self.debug_image_publisher.publish(
self.get_image_message(
get_debug_image(prediction), timestamp, "rgb8"
)
)

self.predict = False

def decode_image(self, image_msg):
height, width = image_msg.height, image_msg.width
img = (
np.frombuffer(image_msg.data, dtype=np.uint8).reshape(
(height, width, self.n_channels)
)
/ 255.0
)
return img, image_msg.header.stamp

def get_image_message(self, image, timestamp, encoding="mono8"):
img_msg = self.cv_bridge.cv2_to_imgmsg(image.astype(np.uint8), encoding)
img_msg.header.stamp = timestamp
return img_msg


if __name__ == "__main__":
rospy.init_node("SemanticSegmentation_node")
node = SemanticSegmentation()
rospy.spin()
Loading

0 comments on commit 705afd1

Please sign in to comment.