Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions config/image_object_detection.yaml

This file was deleted.

File renamed without changes.
File renamed without changes.
25 changes: 25 additions & 0 deletions image_object_detection/config/image_object_detection.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
image_object_detection_node:
ros__parameters:
model.image_size: 640
model.confidence: 0.25
model.iou_threshold: 0.45
model.weights_file: yolov7-tiny.pt
model.device: '0'

selected_detections: ['person'] # Classes to detect ['person', 'car']

show_image: False
publish_debug_image: True

# Lists of topics to subscribe
camera_topics:
- '/camera/image_raw'
# - '/camera1/image_raw'
# - '/camera2/image_raw'
# - '/camera3/image_raw'

# QoS policy for the image subscriber
subscribers.qos_policy: 'best_effort'

# QoS policy for the image debug publisher
image_debug_publisher.qos_policy: 'best_effort'
3 changes: 3 additions & 0 deletions package.xml → image_object_detection/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
<depend>cv_bridge</depend>
<depend>image_transport</depend>
<depend>vision_msgs</depend>
<depend>std_msgs</depend>
<depend>image_object_detection_msgs</depend>

<export>
<build_type>ament_python</build_type>
</export>
</package>

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import std_srvs.srv
from sensor_msgs.msg import CompressedImage, Image
from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose

import torch
import torch.backends.cudnn as cudnn
Expand All @@ -27,7 +28,9 @@
from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging
from utils.plots import plot_one_box
from utils.torch_utils import select_device
from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose
from vision_msgs.msg import Detection2DArray, Detection2D
from image_object_detection_msgs.srv import SetDetectionClasses

from ament_index_python.packages import get_package_share_directory

PACKAGE_NAME = "image_object_detection"
Expand All @@ -37,7 +40,7 @@ class ImageDetectObjectNode(Node):
def __init__(self):
super().__init__("image_object_detection_node")

# parametros
# Model parameters
self.declare_parameter("model.image_size", 640)
self.model_image_size = (
self.get_parameter("model.image_size").get_parameter_value().integer_value
Expand Down Expand Up @@ -112,6 +115,13 @@ def __init__(self):
callback=self.set_processing_enabled_callback,
)

self.set_classes_service = self.create_service(
SetDetectionClasses,
'set_detection_classes',
self.set_detection_classes_callback
)


if self.subscribers_qos == "best_effort":
self.get_logger().info("Using best effort qos policy for subscribers")
self.qos = QoSProfile(
Expand All @@ -129,60 +139,66 @@ def __init__(self):

self.bridge = cv_bridge.CvBridge()

self.image_sub = self.create_subscription(
msg_type=Image, topic="image", callback=self.image_callback, qos_profile=self.qos
)

self.image_compressed_sub = self.create_subscription(
msg_type=CompressedImage,
topic="image/compressed",
callback=self.image_compressed_callback,
qos_profile=self.qos,
)

self.detection_publisher = self.create_publisher(
msg_type=Detection2DArray, topic="detections", qos_profile=self.qos
)

if self.enable_publish_debug_image:
if self.qos_policy == "best_effort":
self.get_logger().info("Using best effort qos policy for debug image publisher")
self.qos = QoSProfile(
reliability=QoSReliabilityPolicy.BEST_EFFORT,
history=QoSHistoryPolicy.KEEP_LAST,
depth=1,
)
else:
self.get_logger().info("Using reliable qos policy for debug image publisher")
self.qos = QoSProfile(
reliability=QoSReliabilityPolicy.RELIABLE,
history=QoSHistoryPolicy.KEEP_LAST,
depth=1,
# Get the list of camera topics from the config file
self.declare_parameter("camera_topics", ["/cameras/frontleft_fisheye_image/image", "/cameras/frontright_fisheye_image/image", "/cameras/left_fisheye_image/image", "/cameras/right_fisheye_image/image"])
self.camera_topics = self.get_parameter("camera_topics").get_parameter_value().string_array_value

self.get_logger().info(f"Subscribed to topics: {self.camera_topics}")

# Initialize subscribers and publishers for each camera topic
self.subscribers = []
self.detection_publishers = {}
self.debug_image_publishers = {}

for topic in self.camera_topics:
# Create a subscriber for each camera topic
self.subscribers.append(
self.create_subscription(
Image,
topic,
callback=self.image_callback_factory(topic),
qos_profile=self.qos,
)
)

self.debug_image_publisher = self.create_publisher(
msg_type=Image, topic="debug_image", qos_profile=self.qos
# Create a detection publisher for each camera
detection_topic = f"{topic}/detections"
self.detection_publishers[topic] = self.create_publisher(
Detection2DArray, detection_topic, self.qos
)

# Create a debug image publisher for each camera (if enabled)
if self.enable_publish_debug_image:
debug_image_topic = f"{topic}/debug_image"
self.debug_image_publishers[topic] = self.create_publisher(
Image, debug_image_topic, self.qos
)

self.initialize_model()


def set_detection_classes_callback(self, request, response):
self.selected_detections = request.classes
self.get_logger().info(f"Updated selected_detections: {self.selected_detections}")
response.success = True
response.message = f"Successfully updated detection classes to {self.selected_detections}"
return response

def initialize_model(self):
with torch.no_grad():
# Initialize
set_logging()
self.device = select_device(self.device)
self.half = self.device.type != "cpu"

# Load model
self.model = attempt_load(
self.model_weights_file, map_location=self.device
) # load FP32 model
)
self.stride = int(self.model.stride.max())

self.imgsz = check_img_size(self.model_image_size, s=self.stride)

if self.half:
self.model.half() # to FP16
self.model.half()

cudnn.benchmark = True

Expand Down Expand Up @@ -215,17 +231,13 @@ def accomodate_image_to_model(self, img0):
def image_compressed_callback(self, msg):
if not self.processing_enabled:
return

try:
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

detections_msg, debugimg = self.predict(img, self.cv_img)
self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format)
img = self.accomodate_image_to_model(self.cv_img)

self.detection_publisher.publish(detections_msg)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image: {e}")
return
detections_msg, debugimg = self.predict(img, self.cv_img)

self.detection_publisher.publish(detections_msg)

if debugimg is not None:
self.publish_debug_image(debugimg)
Expand All @@ -234,25 +246,43 @@ def image_compressed_callback(self, msg):
cv2.imshow("Compressed Image", debugimg)
cv2.waitKey(1)

def image_callback(self, msg):
if not self.processing_enabled:
return
def image_callback_factory(self, topic):
def callback(msg):
try:
cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
self.image_queue[topic] = cv_img
except CvBridgeError as e:
self.get_logger().error(f"Error converting image from {topic}: {e}")

self.cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
img = self.accomodate_image_to_model(self.cv_img)
return callback

def image_callback_factory(self, topic):
def callback(msg):
if not self.processing_enabled:
return

detections_msg, debugimg = self.predict(img, self.cv_img)
try:
cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8")
img = self.accomodate_image_to_model(cv_img)

self.detection_publisher.publish(detections_msg)
detections_msg, debugimg = self.predict(img, cv_img)

if debugimg is not None:
self.publish_debug_image(debugimg)
# Publish detections for the current camera
self.detection_publishers[topic].publish(detections_msg)

if self.show_image:
cv2.imshow("Detection", debugimg)
cv2.waitKey(1)
# Publish debug image for the current camera (if enabled)
if self.enable_publish_debug_image and topic in self.debug_image_publishers:
self.publish_debug_image(debugimg, topic)

if self.show_image:
cv2.imshow(f"Detection from {topic}", debugimg)
cv2.waitKey(1)
except CvBridgeError as e:
self.get_logger().error(f"Error converting image from {topic}: {e}")

def publish_debug_image(self, debugimg):
return callback

def publish_debug_image(self, debugimg, topic):
if self.debug_image_output_format == "mono8":
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_RGB2GRAY)
elif self.debug_image_output_format == "rgb8":
Expand All @@ -261,11 +291,12 @@ def publish_debug_image(self, debugimg):
debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA)
else:
self.get_logger().error(
"Unsupported debug image output format: {}".format(self.debug_image_output_format)
f"Unsupported debug image output format: {self.debug_image_output_format}"
)
return

self.debug_image_publisher.publish(
# Publish the debug image for the current camera
self.debug_image_publishers[topic].publish(
self.bridge.cv2_to_imgmsg(debugimg, self.debug_image_output_format)
)

Expand Down Expand Up @@ -294,7 +325,6 @@ def predict(self, model_img, original_image):
).round()

for *xyxy, conf, cls in reversed(det):
# clase clases deseadas
if self.names[int(cls)] in self.selected_detections:
detection2D_msg = Detection2D()
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
Expand Down Expand Up @@ -322,7 +352,6 @@ def predict(self, model_img, original_image):

return detections_msg, original_image


def main(args=None):
print(args)
rclpy.init(args=sys.argv)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 21 additions & 0 deletions image_object_detection_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cmake_minimum_required(VERSION 3.8)
project(image_object_detection_msgs)

if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-Wall -Wextra -Wpedantic)
endif()

find_package(ament_cmake REQUIRED)
find_package(rosidl_default_generators REQUIRED)

rosidl_generate_interfaces(${PROJECT_NAME}
"srv/SetDetectionClasses.srv"
)

ament_export_dependencies(rosidl_default_runtime)
install(
DIRECTORY srv
DESTINATION share/${PROJECT_NAME}
)

ament_package()
18 changes: 18 additions & 0 deletions image_object_detection_msgs/package.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0"?>
<?xml-model href="http://download.ros.org/schema/package_format3.xsd" schematypens="http://www.w3.org/2001/XMLSchema"?>
<package format="3">
<name>image_object_detection_msgs</name>
<version>1.0.0</version>
<description>Messages for image object detection</description>
<maintainer email="pablo@ibrobotics.com">Pablo Iñigo Blasco</maintainer>
<license>BSD-3-Clause</license>

<buildtool_depend>ament_cmake</buildtool_depend>
<build_depend>rosidl_default_generators</build_depend>
<exec_depend>rosidl_default_runtime</exec_depend>
<member_of_group>rosidl_interface_packages</member_of_group>

<export>
<build_type>ament_cmake</build_type>
</export>
</package>
4 changes: 4 additions & 0 deletions image_object_detection_msgs/srv/SetDetectionClasses.srv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
string[] classes
---
bool success
string message
Binary file not shown.
Binary file not shown.