diff --git a/config/image_object_detection.yaml b/config/image_object_detection.yaml deleted file mode 100644 index 7fead2c..0000000 --- a/config/image_object_detection.yaml +++ /dev/null @@ -1,12 +0,0 @@ -image_object_detection_node: - ros__parameters: - model.image_size: 640 - model.confidence: 0.25 - model.iou_threshold: 0.45 - model.weights_file: yolov7-tiny.pt - model.device: '0' - - selected_detections: ['person'] - - show_image: False - publish_debug_image: True diff --git a/LICENSE b/image_object_detection/LICENSE similarity index 100% rename from LICENSE rename to image_object_detection/LICENSE diff --git a/README.md b/image_object_detection/README.md similarity index 100% rename from README.md rename to image_object_detection/README.md diff --git a/image_object_detection/config/image_object_detection.yaml b/image_object_detection/config/image_object_detection.yaml new file mode 100644 index 0000000..81187b4 --- /dev/null +++ b/image_object_detection/config/image_object_detection.yaml @@ -0,0 +1,25 @@ +image_object_detection_node: + ros__parameters: + model.image_size: 640 + model.confidence: 0.25 + model.iou_threshold: 0.45 + model.weights_file: yolov7-tiny.pt + model.device: '0' + + selected_detections: ['person'] # Classes to detect ['person', 'car'] + + show_image: False + publish_debug_image: True + + # Lists of topics to subscribe + camera_topics: + - '/camera/image_raw' + # - '/camera1/image_raw' + # - '/camera2/image_raw' + # - '/camera3/image_raw' + + # QoS policy for the image subscriber + subscribers.qos_policy: 'best_effort' + + # QoS policy for the image debug publisher + image_debug_publisher.qos_policy: 'best_effort' diff --git a/launch/image_object_detection_launch.py b/image_object_detection/launch/image_object_detection_launch.py similarity index 100% rename from launch/image_object_detection_launch.py rename to image_object_detection/launch/image_object_detection_launch.py diff --git a/package.xml b/image_object_detection/package.xml similarity index 91% rename from package.xml rename to image_object_detection/package.xml index 4143738..3c7e08f 100644 --- a/package.xml +++ b/image_object_detection/package.xml @@ -18,8 +18,11 @@ cv_bridge image_transport vision_msgs + std_msgs + image_object_detection_msgs ament_python + diff --git a/requirements.txt b/image_object_detection/requirements.txt similarity index 100% rename from requirements.txt rename to image_object_detection/requirements.txt diff --git a/resource/image_object_detection b/image_object_detection/resource/image_object_detection similarity index 100% rename from resource/image_object_detection rename to image_object_detection/resource/image_object_detection diff --git a/resource/models b/image_object_detection/resource/models similarity index 100% rename from resource/models rename to image_object_detection/resource/models diff --git a/resource/utils b/image_object_detection/resource/utils similarity index 100% rename from resource/utils rename to image_object_detection/resource/utils diff --git a/setup.cfg b/image_object_detection/setup.cfg similarity index 100% rename from setup.cfg rename to image_object_detection/setup.cfg diff --git a/setup.py b/image_object_detection/setup.py similarity index 100% rename from setup.py rename to image_object_detection/setup.py diff --git a/src/image_object_detection/__init__.py b/image_object_detection/src/image_object_detection/__init__.py similarity index 100% rename from src/image_object_detection/__init__.py rename to image_object_detection/src/image_object_detection/__init__.py diff --git a/src/image_object_detection/image_object_detection_node.py b/image_object_detection/src/image_object_detection/image_object_detection_node.py similarity index 71% rename from src/image_object_detection/image_object_detection_node.py rename to image_object_detection/src/image_object_detection/image_object_detection_node.py index 855ddab..14a514b 100644 --- a/src/image_object_detection/image_object_detection_node.py +++ b/image_object_detection/src/image_object_detection/image_object_detection_node.py @@ -19,6 +19,7 @@ import std_srvs.srv from sensor_msgs.msg import CompressedImage, Image +from vision_msgs.msg import Detection2D, ObjectHypothesisWithPose import torch import torch.backends.cudnn as cudnn @@ -27,7 +28,9 @@ from utils.general import check_img_size, non_max_suppression, scale_coords, xyxy2xywh, set_logging from utils.plots import plot_one_box from utils.torch_utils import select_device -from vision_msgs.msg import Detection2D, Detection2DArray, ObjectHypothesisWithPose +from vision_msgs.msg import Detection2DArray, Detection2D +from image_object_detection_msgs.srv import SetDetectionClasses + from ament_index_python.packages import get_package_share_directory PACKAGE_NAME = "image_object_detection" @@ -37,7 +40,7 @@ class ImageDetectObjectNode(Node): def __init__(self): super().__init__("image_object_detection_node") - # parametros + # Model parameters self.declare_parameter("model.image_size", 640) self.model_image_size = ( self.get_parameter("model.image_size").get_parameter_value().integer_value @@ -112,6 +115,13 @@ def __init__(self): callback=self.set_processing_enabled_callback, ) + self.set_classes_service = self.create_service( + SetDetectionClasses, + 'set_detection_classes', + self.set_detection_classes_callback + ) + + if self.subscribers_qos == "best_effort": self.get_logger().info("Using best effort qos policy for subscribers") self.qos = QoSProfile( @@ -129,60 +139,66 @@ def __init__(self): self.bridge = cv_bridge.CvBridge() - self.image_sub = self.create_subscription( - msg_type=Image, topic="image", callback=self.image_callback, qos_profile=self.qos - ) - - self.image_compressed_sub = self.create_subscription( - msg_type=CompressedImage, - topic="image/compressed", - callback=self.image_compressed_callback, - qos_profile=self.qos, - ) - - self.detection_publisher = self.create_publisher( - msg_type=Detection2DArray, topic="detections", qos_profile=self.qos - ) - - if self.enable_publish_debug_image: - if self.qos_policy == "best_effort": - self.get_logger().info("Using best effort qos policy for debug image publisher") - self.qos = QoSProfile( - reliability=QoSReliabilityPolicy.BEST_EFFORT, - history=QoSHistoryPolicy.KEEP_LAST, - depth=1, - ) - else: - self.get_logger().info("Using reliable qos policy for debug image publisher") - self.qos = QoSProfile( - reliability=QoSReliabilityPolicy.RELIABLE, - history=QoSHistoryPolicy.KEEP_LAST, - depth=1, + # Get the list of camera topics from the config file + self.declare_parameter("camera_topics", ["/cameras/frontleft_fisheye_image/image", "/cameras/frontright_fisheye_image/image", "/cameras/left_fisheye_image/image", "/cameras/right_fisheye_image/image"]) + self.camera_topics = self.get_parameter("camera_topics").get_parameter_value().string_array_value + + self.get_logger().info(f"Subscribed to topics: {self.camera_topics}") + + # Initialize subscribers and publishers for each camera topic + self.subscribers = [] + self.detection_publishers = {} + self.debug_image_publishers = {} + + for topic in self.camera_topics: + # Create a subscriber for each camera topic + self.subscribers.append( + self.create_subscription( + Image, + topic, + callback=self.image_callback_factory(topic), + qos_profile=self.qos, ) + ) - self.debug_image_publisher = self.create_publisher( - msg_type=Image, topic="debug_image", qos_profile=self.qos + # Create a detection publisher for each camera + detection_topic = f"{topic}/detections" + self.detection_publishers[topic] = self.create_publisher( + Detection2DArray, detection_topic, self.qos ) + # Create a debug image publisher for each camera (if enabled) + if self.enable_publish_debug_image: + debug_image_topic = f"{topic}/debug_image" + self.debug_image_publishers[topic] = self.create_publisher( + Image, debug_image_topic, self.qos + ) + self.initialize_model() + + def set_detection_classes_callback(self, request, response): + self.selected_detections = request.classes + self.get_logger().info(f"Updated selected_detections: {self.selected_detections}") + response.success = True + response.message = f"Successfully updated detection classes to {self.selected_detections}" + return response + def initialize_model(self): with torch.no_grad(): - # Initialize set_logging() self.device = select_device(self.device) self.half = self.device.type != "cpu" - # Load model self.model = attempt_load( self.model_weights_file, map_location=self.device - ) # load FP32 model + ) self.stride = int(self.model.stride.max()) self.imgsz = check_img_size(self.model_image_size, s=self.stride) if self.half: - self.model.half() # to FP16 + self.model.half() cudnn.benchmark = True @@ -215,17 +231,13 @@ def accomodate_image_to_model(self, img0): def image_compressed_callback(self, msg): if not self.processing_enabled: return - - try: - self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format) - img = self.accomodate_image_to_model(self.cv_img) - detections_msg, debugimg = self.predict(img, self.cv_img) + self.cv_img = self.bridge.compressed_imgmsg_to_cv2(msg, self.debug_image_output_format) + img = self.accomodate_image_to_model(self.cv_img) - self.detection_publisher.publish(detections_msg) - except CvBridgeError as e: - self.get_logger().error(f"Error converting image: {e}") - return + detections_msg, debugimg = self.predict(img, self.cv_img) + + self.detection_publisher.publish(detections_msg) if debugimg is not None: self.publish_debug_image(debugimg) @@ -234,25 +246,43 @@ def image_compressed_callback(self, msg): cv2.imshow("Compressed Image", debugimg) cv2.waitKey(1) - def image_callback(self, msg): - if not self.processing_enabled: - return + def image_callback_factory(self, topic): + def callback(msg): + try: + cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8") + self.image_queue[topic] = cv_img + except CvBridgeError as e: + self.get_logger().error(f"Error converting image from {topic}: {e}") - self.cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8") - img = self.accomodate_image_to_model(self.cv_img) + return callback + + def image_callback_factory(self, topic): + def callback(msg): + if not self.processing_enabled: + return - detections_msg, debugimg = self.predict(img, self.cv_img) + try: + cv_img = self.bridge.imgmsg_to_cv2(msg, "bgr8") + img = self.accomodate_image_to_model(cv_img) - self.detection_publisher.publish(detections_msg) + detections_msg, debugimg = self.predict(img, cv_img) - if debugimg is not None: - self.publish_debug_image(debugimg) + # Publish detections for the current camera + self.detection_publishers[topic].publish(detections_msg) - if self.show_image: - cv2.imshow("Detection", debugimg) - cv2.waitKey(1) + # Publish debug image for the current camera (if enabled) + if self.enable_publish_debug_image and topic in self.debug_image_publishers: + self.publish_debug_image(debugimg, topic) + + if self.show_image: + cv2.imshow(f"Detection from {topic}", debugimg) + cv2.waitKey(1) + except CvBridgeError as e: + self.get_logger().error(f"Error converting image from {topic}: {e}") - def publish_debug_image(self, debugimg): + return callback + + def publish_debug_image(self, debugimg, topic): if self.debug_image_output_format == "mono8": debugimg = cv2.cvtColor(debugimg, cv2.COLOR_RGB2GRAY) elif self.debug_image_output_format == "rgb8": @@ -261,11 +291,12 @@ def publish_debug_image(self, debugimg): debugimg = cv2.cvtColor(debugimg, cv2.COLOR_BGR2RGBA) else: self.get_logger().error( - "Unsupported debug image output format: {}".format(self.debug_image_output_format) + f"Unsupported debug image output format: {self.debug_image_output_format}" ) return - self.debug_image_publisher.publish( + # Publish the debug image for the current camera + self.debug_image_publishers[topic].publish( self.bridge.cv2_to_imgmsg(debugimg, self.debug_image_output_format) ) @@ -294,7 +325,6 @@ def predict(self, model_img, original_image): ).round() for *xyxy, conf, cls in reversed(det): - # clase clases deseadas if self.names[int(cls)] in self.selected_detections: detection2D_msg = Detection2D() xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() @@ -322,7 +352,6 @@ def predict(self, model_img, original_image): return detections_msg, original_image - def main(args=None): print(args) rclpy.init(args=sys.argv) diff --git a/src/image_object_detection/test_publisher.py b/image_object_detection/src/image_object_detection/test_publisher.py similarity index 100% rename from src/image_object_detection/test_publisher.py rename to image_object_detection/src/image_object_detection/test_publisher.py diff --git a/src/models/__init__.py b/image_object_detection/src/models/__init__.py similarity index 100% rename from src/models/__init__.py rename to image_object_detection/src/models/__init__.py diff --git a/src/models/common.py b/image_object_detection/src/models/common.py similarity index 100% rename from src/models/common.py rename to image_object_detection/src/models/common.py diff --git a/src/models/detect.py b/image_object_detection/src/models/detect.py similarity index 100% rename from src/models/detect.py rename to image_object_detection/src/models/detect.py diff --git a/src/models/experimental.py b/image_object_detection/src/models/experimental.py similarity index 100% rename from src/models/experimental.py rename to image_object_detection/src/models/experimental.py diff --git a/src/models/yolo.py b/image_object_detection/src/models/yolo.py similarity index 100% rename from src/models/yolo.py rename to image_object_detection/src/models/yolo.py diff --git a/src/utils/__init__.py b/image_object_detection/src/utils/__init__.py similarity index 100% rename from src/utils/__init__.py rename to image_object_detection/src/utils/__init__.py diff --git a/src/utils/activations.py b/image_object_detection/src/utils/activations.py similarity index 100% rename from src/utils/activations.py rename to image_object_detection/src/utils/activations.py diff --git a/src/utils/add_nms.py b/image_object_detection/src/utils/add_nms.py similarity index 100% rename from src/utils/add_nms.py rename to image_object_detection/src/utils/add_nms.py diff --git a/src/utils/autoanchor.py b/image_object_detection/src/utils/autoanchor.py similarity index 100% rename from src/utils/autoanchor.py rename to image_object_detection/src/utils/autoanchor.py diff --git a/src/utils/aws/__init__.py b/image_object_detection/src/utils/aws/__init__.py similarity index 100% rename from src/utils/aws/__init__.py rename to image_object_detection/src/utils/aws/__init__.py diff --git a/src/utils/aws/mime.sh b/image_object_detection/src/utils/aws/mime.sh similarity index 100% rename from src/utils/aws/mime.sh rename to image_object_detection/src/utils/aws/mime.sh diff --git a/src/utils/aws/resume.py b/image_object_detection/src/utils/aws/resume.py similarity index 100% rename from src/utils/aws/resume.py rename to image_object_detection/src/utils/aws/resume.py diff --git a/src/utils/aws/userdata.sh b/image_object_detection/src/utils/aws/userdata.sh similarity index 100% rename from src/utils/aws/userdata.sh rename to image_object_detection/src/utils/aws/userdata.sh diff --git a/src/utils/datasets.py b/image_object_detection/src/utils/datasets.py similarity index 100% rename from src/utils/datasets.py rename to image_object_detection/src/utils/datasets.py diff --git a/src/utils/general.py b/image_object_detection/src/utils/general.py similarity index 100% rename from src/utils/general.py rename to image_object_detection/src/utils/general.py diff --git a/src/utils/google_app_engine/Dockerfile b/image_object_detection/src/utils/google_app_engine/Dockerfile similarity index 100% rename from src/utils/google_app_engine/Dockerfile rename to image_object_detection/src/utils/google_app_engine/Dockerfile diff --git a/src/utils/google_app_engine/additional_requirements.txt b/image_object_detection/src/utils/google_app_engine/additional_requirements.txt similarity index 100% rename from src/utils/google_app_engine/additional_requirements.txt rename to image_object_detection/src/utils/google_app_engine/additional_requirements.txt diff --git a/src/utils/google_app_engine/app.yaml b/image_object_detection/src/utils/google_app_engine/app.yaml similarity index 100% rename from src/utils/google_app_engine/app.yaml rename to image_object_detection/src/utils/google_app_engine/app.yaml diff --git a/src/utils/google_utils.py b/image_object_detection/src/utils/google_utils.py similarity index 100% rename from src/utils/google_utils.py rename to image_object_detection/src/utils/google_utils.py diff --git a/src/utils/loss.py b/image_object_detection/src/utils/loss.py similarity index 100% rename from src/utils/loss.py rename to image_object_detection/src/utils/loss.py diff --git a/src/utils/metrics.py b/image_object_detection/src/utils/metrics.py similarity index 100% rename from src/utils/metrics.py rename to image_object_detection/src/utils/metrics.py diff --git a/src/utils/plots.py b/image_object_detection/src/utils/plots.py similarity index 100% rename from src/utils/plots.py rename to image_object_detection/src/utils/plots.py diff --git a/src/utils/torch_utils.py b/image_object_detection/src/utils/torch_utils.py similarity index 100% rename from src/utils/torch_utils.py rename to image_object_detection/src/utils/torch_utils.py diff --git a/src/utils/wandb_logging/__init__.py b/image_object_detection/src/utils/wandb_logging/__init__.py similarity index 100% rename from src/utils/wandb_logging/__init__.py rename to image_object_detection/src/utils/wandb_logging/__init__.py diff --git a/src/utils/wandb_logging/log_dataset.py b/image_object_detection/src/utils/wandb_logging/log_dataset.py similarity index 100% rename from src/utils/wandb_logging/log_dataset.py rename to image_object_detection/src/utils/wandb_logging/log_dataset.py diff --git a/src/utils/wandb_logging/wandb_utils.py b/image_object_detection/src/utils/wandb_logging/wandb_utils.py similarity index 100% rename from src/utils/wandb_logging/wandb_utils.py rename to image_object_detection/src/utils/wandb_logging/wandb_utils.py diff --git a/test/test_copyright.py b/image_object_detection/test/test_copyright.py similarity index 100% rename from test/test_copyright.py rename to image_object_detection/test/test_copyright.py diff --git a/test/test_flake8.py b/image_object_detection/test/test_flake8.py similarity index 100% rename from test/test_flake8.py rename to image_object_detection/test/test_flake8.py diff --git a/test/test_pep257.py b/image_object_detection/test/test_pep257.py similarity index 100% rename from test/test_pep257.py rename to image_object_detection/test/test_pep257.py diff --git a/yolov7-tiny.pt b/image_object_detection/yolov7-tiny.pt similarity index 100% rename from yolov7-tiny.pt rename to image_object_detection/yolov7-tiny.pt diff --git a/image_object_detection_msgs/CMakeLists.txt b/image_object_detection_msgs/CMakeLists.txt new file mode 100644 index 0000000..107bf4a --- /dev/null +++ b/image_object_detection_msgs/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.8) +project(image_object_detection_msgs) + +if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_compile_options(-Wall -Wextra -Wpedantic) +endif() + +find_package(ament_cmake REQUIRED) +find_package(rosidl_default_generators REQUIRED) + +rosidl_generate_interfaces(${PROJECT_NAME} + "srv/SetDetectionClasses.srv" +) + +ament_export_dependencies(rosidl_default_runtime) +install( + DIRECTORY srv + DESTINATION share/${PROJECT_NAME} +) + +ament_package() \ No newline at end of file diff --git a/image_object_detection_msgs/package.xml b/image_object_detection_msgs/package.xml new file mode 100644 index 0000000..aed5890 --- /dev/null +++ b/image_object_detection_msgs/package.xml @@ -0,0 +1,18 @@ + + + + image_object_detection_msgs + 1.0.0 + Messages for image object detection + Pablo IƱigo Blasco + BSD-3-Clause + + ament_cmake + rosidl_default_generators + rosidl_default_runtime + rosidl_interface_packages + + + ament_cmake + + diff --git a/image_object_detection_msgs/srv/SetDetectionClasses.srv b/image_object_detection_msgs/srv/SetDetectionClasses.srv new file mode 100644 index 0000000..cc5e818 --- /dev/null +++ b/image_object_detection_msgs/srv/SetDetectionClasses.srv @@ -0,0 +1,4 @@ +string[] classes +--- +bool success +string message diff --git a/src/utils/wandb_logging/__pycache__/__init__.cpython-38.pyc b/src/utils/wandb_logging/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index f5eb56d..0000000 Binary files a/src/utils/wandb_logging/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/utils/wandb_logging/__pycache__/wandb_utils.cpython-38.pyc b/src/utils/wandb_logging/__pycache__/wandb_utils.cpython-38.pyc deleted file mode 100644 index a67c720..0000000 Binary files a/src/utils/wandb_logging/__pycache__/wandb_utils.cpython-38.pyc and /dev/null differ