diff --git a/examples/time_in_zone/requirements.txt b/examples/time_in_zone/requirements.txt
index 9f9446c836..b5ff1911b6 100644
--- a/examples/time_in_zone/requirements.txt
+++ b/examples/time_in_zone/requirements.txt
@@ -1,4 +1,6 @@
supervision
ultralytics
inference
-pytube
+# https://github.com/pytube/pytube/issues/2044
+# pytube
+pytubefix
diff --git a/examples/time_in_zone/scripts/download_from_youtube.py b/examples/time_in_zone/scripts/download_from_youtube.py
index d867363175..b740095573 100644
--- a/examples/time_in_zone/scripts/download_from_youtube.py
+++ b/examples/time_in_zone/scripts/download_from_youtube.py
@@ -3,7 +3,7 @@
import argparse
import os
-from pytube import YouTube
+from pytubefix import YouTube
def main(url: str, output_path: str | None, file_name: str | None) -> None:
diff --git a/mkdocs.yml b/mkdocs.yml
index f25015348a..daf3d09845 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -26,7 +26,7 @@ extra:
extra_css:
- stylesheets/extra.css
- - stylesheets/cookbooks-card.css
+ - stylesheets/cookbooks_card.css
nav:
- Home: index.md
@@ -47,6 +47,7 @@ nav:
- Boxes: detection/utils/boxes.md
- Masks: detection/utils/masks.md
- Polygons: detection/utils/polygons.md
+ - VLMs: detection/utils/vlms.md
- Keypoint Detection:
- Core: keypoint/core.md
- Annotators: keypoint/annotators.md
diff --git a/pyproject.toml b/pyproject.toml
index cae78492ac..e7f86b89ef 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -2,7 +2,7 @@
name = "supervision"
description = "A set of easy-to-use utils that will come in handy in any Computer Vision project"
license = { text = "MIT" }
-version = "0.26.1"
+version = "0.27.0"
readme = "README.md"
requires-python = ">=3.9"
authors = [
@@ -67,11 +67,12 @@ dev = [
"ipywidgets>=8.1.1",
"jupytext>=1.16.1",
"nbconvert>=7.14.2",
- "docutils!=0.21"
+ "docutils!=0.21",
+ "pre-commit>=3.8.0"
]
docs = [
"mkdocs-material[imaging]>=9.5.5",
- "mkdocstrings>=0.25.2,<0.30.0",
+ "mkdocstrings>=0.25.2,<0.31.0",
"mkdocstrings-python>=1.10.9",
"mike>=2.0.0",
"mkdocs-jupyter>=0.24.3",
@@ -81,7 +82,7 @@ docs = [
build = [
"twine>=5.1.1,<7.0.0",
"wheel>=0.40,<0.46",
- "build>=0.10,<1.3"
+ "build>=0.10,<1.4"
]
[tool.bandit]
diff --git a/supervision/__init__.py b/supervision/__init__.py
index 925628be3c..00820076fe 100644
--- a/supervision/__init__.py
+++ b/supervision/__init__.py
@@ -65,6 +65,7 @@
polygon_to_xyxy,
xcycwh_to_xyxy,
xywh_to_xyxy,
+ xyxy_to_mask,
xyxy_to_polygons,
xyxy_to_xcycarh,
xyxy_to_xywh,
@@ -86,12 +87,14 @@
calculate_masks_centroids,
contains_holes,
contains_multiple_segments,
+ filter_segments_by_distance,
move_masks,
)
from supervision.detection.utils.polygons import (
approximate_polygon,
filter_polygons_by_area,
)
+from supervision.detection.utils.vlms import edit_distance, fuzzy_match_index
from supervision.detection.vlm import LMM, VLM
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import (
@@ -107,24 +110,26 @@
)
from supervision.geometry.core import Point, Position, Rect
from supervision.geometry.utils import get_polygon_center
-from supervision.keypoint.annotators import (
+from supervision.key_points.annotators import (
EdgeAnnotator,
VertexAnnotator,
VertexLabelAnnotator,
)
-from supervision.keypoint.core import KeyPoints
+from supervision.key_points.core import KeyPoints
from supervision.metrics.detection import ConfusionMatrix, MeanAveragePrecision
from supervision.tracker.byte_tracker.core import ByteTrack
from supervision.utils.conversion import cv2_to_pillow, pillow_to_cv2
from supervision.utils.file import list_files_with_extensions
from supervision.utils.image import (
ImageSink,
- create_tiles,
crop_image,
+ get_image_resolution_wh,
+ grayscale_image,
letterbox_image,
overlay_image,
resize_image,
scale_image,
+ tint_image,
)
from supervision.utils.notebook import plot_image, plot_images_grid
from supervision.utils.video import (
@@ -205,7 +210,6 @@
"clip_boxes",
"contains_holes",
"contains_multiple_segments",
- "create_tiles",
"crop_image",
"cv2_to_pillow",
"draw_filled_polygon",
@@ -215,10 +219,15 @@
"draw_polygon",
"draw_rectangle",
"draw_text",
+ "edit_distance",
"filter_polygons_by_area",
+ "filter_segments_by_distance",
+ "fuzzy_match_index",
"get_coco_class_index_mapping",
+ "get_image_resolution_wh",
"get_polygon_center",
"get_video_frames_generator",
+ "grayscale_image",
"letterbox_image",
"list_files_with_extensions",
"mask_iou_batch",
@@ -242,8 +251,10 @@
"rle_to_mask",
"scale_boxes",
"scale_image",
+ "tint_image",
"xcycwh_to_xyxy",
"xywh_to_xyxy",
+ "xyxy_to_mask",
"xyxy_to_polygons",
"xyxy_to_xcycarh",
"xyxy_to_xywh",
diff --git a/supervision/annotators/base.py b/supervision/annotators/base.py
index 159ad5567e..9b4bbcbe27 100644
--- a/supervision/annotators/base.py
+++ b/supervision/annotators/base.py
@@ -1,19 +1,7 @@
from abc import ABC, abstractmethod
-from typing import TypeVar
-
-import numpy as np
-from PIL import Image
from supervision.detection.core import Detections
-
-ImageType = TypeVar("ImageType", np.ndarray, Image.Image)
-"""
-An image of type `np.ndarray` or `PIL.Image.Image`.
-
-Unlike a `Union`, ensures the type remains consistent. If a function
-takes an `ImageType` argument and returns an `ImageType`, when you
-pass an `np.ndarray`, you will get an `np.ndarray` back.
-"""
+from supervision.draw.base import ImageType
class BaseAnnotator(ABC):
diff --git a/supervision/annotators/core.py b/supervision/annotators/core.py
index 0b7d4b7632..900d823b2c 100644
--- a/supervision/annotators/core.py
+++ b/supervision/annotators/core.py
@@ -7,8 +7,9 @@
import numpy as np
import numpy.typing as npt
from PIL import Image, ImageDraw, ImageFont
+from scipy.interpolate import splev, splprep
-from supervision.annotators.base import BaseAnnotator, ImageType
+from supervision.annotators.base import BaseAnnotator
from supervision.annotators.utils import (
PENDING_TRACK_ID,
ColorLookup,
@@ -28,12 +29,13 @@
polygon_to_mask,
xyxy_to_polygons,
)
+from supervision.draw.base import ImageType
from supervision.draw.color import Color, ColorPalette
from supervision.draw.utils import draw_polygon, draw_rounded_rectangle, draw_text
from supervision.geometry.core import Point, Position, Rect
from supervision.utils.conversion import (
- ensure_cv2_image_for_annotation,
- ensure_pil_image_for_annotation,
+ ensure_cv2_image_for_class_method,
+ ensure_pil_image_for_class_method,
)
from supervision.utils.image import (
crop_image,
@@ -51,25 +53,28 @@ class _BaseLabelAnnotator(BaseAnnotator):
Attributes:
color (Union[Color, ColorPalette]): The color to use for the label background.
+ color_lookup (ColorLookup): The method used to determine the color of the label.
text_color (Union[Color, ColorPalette]): The color to use for the label text.
text_padding (int): The padding around the label text, in pixels.
text_anchor (Position): The position of the text relative to the detection
- bounding box.
- color_lookup (ColorLookup): The method used to determine the color of the label.
+ bounding box.
+ text_offset (Tuple[int, int]): A tuple of 2D coordinates `(x, y)` to
+ offset the text position from the anchor point, in pixels.
border_radius (int): The radius of the label background corners, in pixels.
smart_position (bool): Whether to intelligently adjust the label position to
- avoid overlapping with other elements.
+ avoid overlapping with other elements.
max_line_length (Optional[int]): Maximum number of characters per line before
- wrapping the text. None means no wrapping.
+ wrapping the text. None means no wrapping.
"""
def __init__(
self,
color: Color | ColorPalette = ColorPalette.DEFAULT,
+ color_lookup: ColorLookup = ColorLookup.CLASS,
text_color: Color | ColorPalette = Color.WHITE,
text_padding: int = 10,
text_position: Position = Position.TOP_LEFT,
- color_lookup: ColorLookup = ColorLookup.CLASS,
+ text_offset: tuple[int, int] = (0, 0),
border_radius: int = 0,
smart_position: bool = False,
max_line_length: int | None = None,
@@ -79,27 +84,29 @@ def __init__(
Args:
color (Union[Color, ColorPalette], optional): The color to use for the label
- background.
+ background.
+ color_lookup (ColorLookup, optional): The method used to determine the color
+ of the label
text_color (Union[Color, ColorPalette], optional): The color to use for the
- label text.
+ label text.
text_padding (int, optional): The padding around the label text, in pixels.
text_position (Position, optional): The position of the text relative to the
- detection bounding box.
- color_lookup (ColorLookup, optional): The method used to determine the color
- of the label
+ detection bounding box.
+ text_offset (Tuple[int, int], optional): A tuple of 2D coordinates
+ `(x, y)` to offset the text position from the anchor point, in pixels.
border_radius (int, optional): The radius of the label background corners,
- in pixels.
+ in pixels.
smart_position (bool, optional): Whether to intelligently adjust the label
- position to avoid overlapping with other elements.
+ position to avoid overlapping with other elements.
max_line_length (Optional[int], optional): Maximum number of characters per
- line before wrapping the text. None means no wrapping.
-
+ line before wrapping the text. None means no wrapping.
"""
self.color: Color | ColorPalette = color
+ self.color_lookup: ColorLookup = color_lookup
self.text_color: Color | ColorPalette = text_color
self.text_padding: int = text_padding
self.text_anchor: Position = text_position
- self.color_lookup: ColorLookup = color_lookup
+ self.text_offset: tuple[int, int] = text_offset
self.border_radius: int = border_radius
self.smart_position = smart_position
self.max_line_length: int | None = max_line_length
@@ -171,7 +178,7 @@ def __init__(
self.thickness: int = thickness
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -254,7 +261,7 @@ def __init__(
self.thickness: int = thickness
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -343,7 +350,7 @@ def __init__(
self.opacity = opacity
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -433,7 +440,7 @@ def __init__(
self.thickness: int = thickness
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -520,7 +527,7 @@ def __init__(
self.color_lookup: ColorLookup = color_lookup
self.opacity = opacity
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -616,7 +623,7 @@ def __init__(
self.color_lookup: ColorLookup = color_lookup
self.kernel_size: int = kernel_size
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -716,7 +723,7 @@ def __init__(
self.end_angle: int = end_angle
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -808,7 +815,7 @@ def __init__(
self.corner_length: int = corner_length
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -897,7 +904,7 @@ def __init__(
self.thickness: int = thickness
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -996,7 +1003,7 @@ def __init__(
self.outline_thickness = outline_thickness
self.outline_color: Color | ColorPalette = outline_color
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -1076,37 +1083,100 @@ class LabelAnnotator(_BaseLabelAnnotator):
def __init__(
self,
color: Color | ColorPalette = ColorPalette.DEFAULT,
+ color_lookup: ColorLookup = ColorLookup.CLASS,
text_color: Color | ColorPalette = Color.WHITE,
text_scale: float = 0.5,
text_thickness: int = 1,
text_padding: int = 10,
text_position: Position = Position.TOP_LEFT,
- color_lookup: ColorLookup = ColorLookup.CLASS,
+ text_offset: tuple[int, int] = (0, 0),
border_radius: int = 0,
smart_position: bool = False,
max_line_length: int | None = None,
):
+ """
+ Args:
+ color (Union[Color, ColorPalette]): The color or color palette to use for
+ annotating the text background.
+ color_lookup (ColorLookup): Strategy for mapping colors to annotations.
+ Options are `INDEX`, `CLASS`, `TRACK`.
+ text_color (Union[Color, ColorPalette]): The color or color palette to use
+ for the text.
+ text_scale (float): Font scale for the text.
+ text_thickness (int): Thickness of the text characters.
+ text_padding (int): Padding around the text within its background box.
+ text_position (Position): Position of the text relative to the detection.
+ Possible values are defined in the `Position` enum.
+ text_offset (Tuple[int, int]): A tuple of 2D coordinates `(x, y)` to
+ offset the text position from the anchor point, in pixels.
+ border_radius (int): The radius to apply round edges. If the selected
+ value is higher than the lower dimension, width or height, is clipped.
+ smart_position (bool): Spread out the labels to avoid overlapping.
+ max_line_length (Optional[int]): Maximum number of characters per line
+ before wrapping the text. None means no wrapping.
+ """
self.text_scale: float = text_scale
self.text_thickness: int = text_thickness
super().__init__(
color=color,
+ color_lookup=color_lookup,
text_color=text_color,
text_padding=text_padding,
text_position=text_position,
- color_lookup=color_lookup,
+ text_offset=text_offset,
border_radius=border_radius,
smart_position=smart_position,
max_line_length=max_line_length,
)
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
- scene: ImageType, # Ensure scene is initially a NumPy array here
+ scene: ImageType,
detections: Detections,
labels: list[str] | None = None,
custom_color_lookup: np.ndarray | None = None,
) -> np.ndarray:
+ """
+ Annotates the given scene with labels based on the provided detections.
+
+ Args:
+ scene (ImageType): The image where labels will be drawn.
+ `ImageType` is a flexible type, accepting either `numpy.ndarray`
+ or `PIL.Image.Image`.
+ detections (Detections): Object detections to annotate.
+ labels (Optional[List[str]]): Custom labels for each detection.
+ custom_color_lookup (Optional[np.ndarray]): Custom color lookup array.
+ Allows to override the default color mapping strategy.
+
+ Returns:
+ The annotated image, matching the type of `scene` (`numpy.ndarray`
+ or `PIL.Image.Image`)
+
+ Example:
+ ```python
+ import supervision as sv
+
+ image = ...
+ detections = sv.Detections(...)
+
+ labels = [
+ f"{class_name} {confidence:.2f}"
+ for class_name, confidence
+ in zip(detections['class_name'], detections.confidence)
+ ]
+
+ label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER)
+ annotated_frame = label_annotator.annotate(
+ scene=image.copy(),
+ detections=detections,
+ labels=labels
+ )
+ ```
+
+ 
+ """
assert isinstance(scene, np.ndarray)
validate_labels(labels, detections)
@@ -1144,7 +1214,12 @@ def _get_label_properties(
anchor=self.text_anchor
).astype(int)
- for label, center_coords in zip(labels, anchors_coordinates):
+ for label, center_coordinates in zip(labels, anchors_coordinates):
+ center_coordinates = (
+ center_coordinates[0] + self.text_offset[0],
+ center_coordinates[1] + self.text_offset[1],
+ )
+
wrapped_lines = wrap_text(label, self.max_line_length)
line_heights = []
line_widths = []
@@ -1170,7 +1245,7 @@ def _get_label_properties(
height_padded = total_height + 2 * self.text_padding
text_background_xyxy = resolve_text_background_xyxy(
- center_coordinates=tuple(center_coords),
+ center_coordinates=center_coordinates,
text_wh=(width_padded, height_padded),
position=self.text_anchor,
)
@@ -1317,31 +1392,54 @@ class RichLabelAnnotator(_BaseLabelAnnotator):
def __init__(
self,
color: Color | ColorPalette = ColorPalette.DEFAULT,
+ color_lookup: ColorLookup = ColorLookup.CLASS,
text_color: Color | ColorPalette = Color.WHITE,
font_path: str | None = None,
font_size: int = 10,
text_padding: int = 10,
text_position: Position = Position.TOP_LEFT,
- color_lookup: ColorLookup = ColorLookup.CLASS,
+ text_offset: tuple[int, int] = (0, 0),
border_radius: int = 0,
smart_position: bool = False,
max_line_length: int | None = None,
):
+ """
+ Args:
+ color (Union[Color, ColorPalette]): The color or color palette to use for
+ annotating the text background.
+ color_lookup (ColorLookup): Strategy for mapping colors to annotations.
+ Options are `INDEX`, `CLASS`, `TRACK`.
+ text_color (Union[Color, ColorPalette]): The color to use for the text.
+ font_path (Optional[str]): Path to the font file (e.g., ".ttf" or ".otf")
+ to use for rendering text. If `None`, the default PIL font will be used.
+ font_size (int): Font size for the text.
+ text_padding (int): Padding around the text within its background box.
+ text_position (Position): Position of the text relative to the detection.
+ Possible values are defined in the `Position` enum.
+ text_offset (Tuple[int, int]): A tuple of 2D coordinates `(x, y)` to
+ offset the text position from the anchor point, in pixels.
+ border_radius (int): The radius to apply round edges. If the selected
+ value is higher than the lower dimension, width or height, is clipped.
+ smart_position (bool): Spread out the labels to avoid overlapping.
+ max_line_length (Optional[int]): Maximum number of characters per line
+ before wrapping the text. None means no wrapping.
+ """
self.font_path = font_path
self.font_size = font_size
self.font = self._load_font(font_size, font_path)
super().__init__(
color=color,
+ color_lookup=color_lookup,
text_color=text_color,
text_padding=text_padding,
text_position=text_position,
- color_lookup=color_lookup,
+ text_offset=text_offset,
border_radius=border_radius,
smart_position=smart_position,
max_line_length=max_line_length,
)
- @ensure_pil_image_for_annotation
+ @ensure_pil_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -1349,6 +1447,44 @@ def annotate(
labels: list[str] | None = None,
custom_color_lookup: np.ndarray | None = None,
) -> ImageType:
+ """
+ Annotates the given scene with labels based on the provided
+ detections, with support for Unicode characters.
+
+ Args:
+ scene (ImageType): The image where labels will be drawn.
+ `ImageType` is a flexible type, accepting either `numpy.ndarray`
+ or `PIL.Image.Image`.
+ detections (Detections): Object detections to annotate.
+ labels (Optional[List[str]]): Custom labels for each detection.
+ custom_color_lookup (Optional[np.ndarray]): Custom color lookup array.
+ Allows to override the default color mapping strategy.
+
+ Returns:
+ The annotated image, matching the type of `scene` (`numpy.ndarray`
+ or `PIL.Image.Image`)
+
+ Example:
+ ```python
+ import supervision as sv
+
+ image = ...
+ detections = sv.Detections(...)
+
+ labels = [
+ f"{class_name} {confidence:.2f}"
+ for class_name, confidence
+ in zip(detections['class_name'], detections.confidence)
+ ]
+
+ rich_label_annotator = sv.RichLabelAnnotator(font_path="path/to/font.ttf")
+ annotated_frame = label_annotator.annotate(
+ scene=image.copy(),
+ detections=detections,
+ labels=labels
+ )
+ ```
+ """
assert isinstance(scene, Image.Image)
validate_labels(labels, detections)
@@ -1386,7 +1522,12 @@ def _get_label_properties(
anchor=self.text_anchor
).astype(int)
- for label, center_coords in zip(labels, anchor_coordinates):
+ for label, center_coordinates in zip(labels, anchor_coordinates):
+ center_coordinates = (
+ center_coordinates[0] + self.text_offset[0],
+ center_coordinates[1] + self.text_offset[1],
+ )
+
wrapped_lines = wrap_text(label, self.max_line_length)
# Calculate the total text height and maximum width
@@ -1409,7 +1550,7 @@ def _get_label_properties(
height_padded = int(total_height + 2 * self.text_padding)
text_background_xyxy = resolve_text_background_xyxy(
- center_coordinates=tuple(center_coords),
+ center_coordinates=center_coordinates,
text_wh=(width_padded, height_padded),
position=self.text_anchor,
)
@@ -1525,7 +1666,7 @@ def __init__(
self.position = icon_position
self.offset_xy = offset_xy
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self, scene: ImageType, detections: Detections, icon_path: str | list[str]
) -> ImageType:
@@ -1614,7 +1755,7 @@ def __init__(self, kernel_size: int = 15):
"""
self.kernel_size: int = kernel_size
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -1681,6 +1822,7 @@ def __init__(
position: Position = Position.CENTER,
trace_length: int = 30,
thickness: int = 2,
+ smooth: bool = False,
color_lookup: ColorLookup = ColorLookup.CLASS,
):
"""
@@ -1692,15 +1834,17 @@ def __init__(
trace_length (int): The maximum length of the trace in terms of historical
points. Defaults to `30`.
thickness (int): The thickness of the trace lines. Defaults to `2`.
+ smooth (bool): Smooth the trace lines.
color_lookup (ColorLookup): Strategy for mapping colors to annotations.
Options are `INDEX`, `CLASS`, `TRACK`.
"""
self.color: Color | ColorPalette = color
self.trace = Trace(max_size=trace_length, anchor=position)
self.thickness = thickness
+ self.smooth = smooth
self.color_lookup: ColorLookup = color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -1769,10 +1913,18 @@ def annotate(
else custom_color_lookup,
)
xy = self.trace.get(tracker_id=tracker_id)
+ spline_points = xy.astype(np.int32)
+
+ if len(xy) > 3 and self.smooth:
+ x, y = xy[:, 0], xy[:, 1]
+ tck, u = splprep([x, y], s=20)
+ x_new, y_new = splev(np.linspace(0, 1, 100), tck)
+ spline_points = np.stack([x_new, y_new], axis=1).astype(np.int32)
+
if len(xy) > 1:
scene = cv2.polylines(
scene,
- [xy.astype(np.int32)],
+ [spline_points],
False,
color=color.as_bgr(),
thickness=self.thickness,
@@ -1814,7 +1966,7 @@ def __init__(
self.low_hue = low_hue
self.heat_mask: npt.NDArray[np.float32] | None = None
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(self, scene: ImageType, detections: Detections) -> ImageType:
"""
Annotates the scene with a heatmap based on the provided detections.
@@ -1896,7 +2048,7 @@ def __init__(self, pixel_size: int = 20):
"""
self.pixel_size: int = pixel_size
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -1993,7 +2145,7 @@ def __init__(
self.outline_thickness: int = outline_thickness
self.outline_color: Color | ColorPalette = outline_color
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -2105,7 +2257,7 @@ def __init__(
raise ValueError("roundness attribute must be float between (0, 1.0]")
self.roundness: float = roundness
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -2245,7 +2397,7 @@ def __init__(
else int(0.15 * self.height)
)
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -2426,7 +2578,7 @@ def __init__(
self.border_thickness: int = border_thickness
self.border_color_lookup: ColorLookup = border_color_lookup
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self,
scene: ImageType,
@@ -2575,7 +2727,7 @@ def __init__(
self.opacity = opacity
self.force_box = force_box
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(self, scene: ImageType, detections: Detections) -> ImageType:
"""
Applies a colored overlay to the scene outside of the detected regions.
@@ -2673,7 +2825,7 @@ def __init__(
self.label_scale = label_scale
self.text_thickness = int(self.label_scale + 1.2)
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(
self, scene: ImageType, detections_1: Detections, detections_2: Detections
) -> ImageType:
diff --git a/supervision/annotators/utils.py b/supervision/annotators/utils.py
index a84e2049d4..42a436c9bb 100644
--- a/supervision/annotators/utils.py
+++ b/supervision/annotators/utils.py
@@ -2,6 +2,7 @@
import textwrap
from enum import Enum
+from typing import Any
import numpy as np
@@ -142,35 +143,45 @@ def resolve_color(
detection_idx=detection_idx,
color_lookup=color_lookup,
)
- if color_lookup == ColorLookup.TRACK and idx == PENDING_TRACK_ID:
+ if (
+ isinstance(color_lookup, ColorLookup)
+ and color_lookup == ColorLookup.TRACK
+ and idx == PENDING_TRACK_ID
+ ):
return PENDING_TRACK_COLOR
return get_color_by_index(color=color, idx=idx)
-def wrap_text(text: str, max_line_length=None) -> list[str]:
+def wrap_text(text: Any, max_line_length=None) -> list[str]:
"""
- Wraps text to the specified maximum line length, respecting existing newlines.
- Uses the textwrap library for robust text wrapping.
+ Wrap `text` to the specified maximum line length, respecting existing
+ newlines. Falls back to str() if `text` is not already a string.
Args:
- text (str): The text to wrap.
+ text (Any): The text (or object) to wrap.
+ max_line_length (int | None): Maximum width for each wrapped line.
Returns:
- List[str]: A list of text lines after wrapping.
+ list[str]: Wrapped lines.
"""
if not text:
return [""]
+ if not isinstance(text, str):
+ text = str(text)
+
if max_line_length is None:
return text.splitlines() or [""]
+ if max_line_length <= 0:
+ raise ValueError("max_line_length must be a positive integer")
+
paragraphs = text.split("\n")
- all_lines = []
+ all_lines: list[str] = []
for paragraph in paragraphs:
- if not paragraph:
- # Keep empty lines
+ if paragraph == "":
all_lines.append("")
continue
@@ -182,12 +193,9 @@ def wrap_text(text: str, max_line_length=None) -> list[str]:
drop_whitespace=True,
)
- if wrapped:
- all_lines.extend(wrapped)
- else:
- all_lines.append("")
+ all_lines.extend(wrapped or [""])
- return all_lines if all_lines else [""]
+ return all_lines or [""]
def validate_labels(labels: list[str] | None, detections: Detections):
diff --git a/supervision/detection/core.py b/supervision/detection/core.py
index ea466b8955..6ca6abc567 100644
--- a/supervision/detection/core.py
+++ b/supervision/detection/core.py
@@ -40,12 +40,14 @@
from supervision.detection.vlm import (
LMM,
VLM,
+ from_deepseek_vl_2,
from_florence_2,
from_google_gemini_2_0,
from_google_gemini_2_5,
from_moondream,
from_paligemma,
from_qwen_2_5_vl,
+ from_qwen_3_vl,
validate_vlm_parameters,
)
from supervision.geometry.core import Position
@@ -295,18 +297,24 @@ def from_ultralytics(cls, ultralytics_results) -> Detections:
class_id=np.arange(len(ultralytics_results)),
)
- class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int)
- class_names = np.array([ultralytics_results.names[i] for i in class_id])
- return cls(
- xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(),
- confidence=ultralytics_results.boxes.conf.cpu().numpy(),
- class_id=class_id,
- mask=extract_ultralytics_masks(ultralytics_results),
- tracker_id=ultralytics_results.boxes.id.int().cpu().numpy()
- if ultralytics_results.boxes.id is not None
- else None,
- data={CLASS_NAME_DATA_FIELD: class_names},
- )
+ if (
+ hasattr(ultralytics_results, "boxes")
+ and ultralytics_results.boxes is not None
+ ):
+ class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int)
+ class_names = np.array([ultralytics_results.names[i] for i in class_id])
+ return cls(
+ xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(),
+ confidence=ultralytics_results.boxes.conf.cpu().numpy(),
+ class_id=class_id,
+ mask=extract_ultralytics_masks(ultralytics_results),
+ tracker_id=ultralytics_results.boxes.id.int().cpu().numpy()
+ if ultralytics_results.boxes.id is not None
+ else None,
+ data={CLASS_NAME_DATA_FIELD: class_names},
+ )
+
+ return cls.empty()
@classmethod
def from_yolo_nas(cls, yolo_nas_results) -> Detections:
@@ -830,6 +838,7 @@ def from_lmm(cls, lmm: LMM | str, result: str | dict, **kwargs: Any) -> Detectio
| Google Gemini 2.0 | `GOOGLE_GEMINI_2_0` | detection | `resolution_wh` | `classes` |
| Google Gemini 2.5 | `GOOGLE_GEMINI_2_5` | detection, segmentation | `resolution_wh` | `classes` |
| Moondream | `MOONDREAM` | detection | `resolution_wh` | |
+ | DeepSeek-VL2 | `DEEPSEEK_VL_2` | detection | `resolution_wh` | `classes` |
Args:
lmm (Union[LMM, str]): The type of LMM (Large Multimodal Model) to use.
@@ -943,6 +952,36 @@ def from_lmm(cls, lmm: LMM | str, result: str | dict, **kwargs: Any) -> Detectio
# array([0, 1])
```
+ !!! example "Qwen3-VL"
+
+ ```python
+ import supervision as sv
+
+ qwen_3_vl_result = \"\"\"```json
+ [
+ {"bbox_2d": [139, 768, 315, 954], "label": "cat"},
+ {"bbox_2d": [366, 679, 536, 849], "label": "dog"}
+ ]
+ ```\"\"\"
+ detections = sv.Detections.from_lmm(
+ sv.LMM.QWEN_3_VL,
+ qwen_3_vl_result,
+ resolution_wh=(1000, 1000),
+ classes=['cat', 'dog'],
+ )
+ detections.xyxy
+ # array([[139., 768., 315., 954.], [366., 679., 536., 849.]])
+
+ detections.class_id
+ # array([0, 1])
+
+ detections.data
+ # {'class_name': array(['cat', 'dog'], dtype='
Detectio
# array([[1752.28, 818.82, 2165.72, 1229.14],
# [1908.01, 1346.67, 2585.99, 2024.11]])
```
+
+ !!! example "DeepSeek-VL2"
+
+
+ ??? tip "Prompt engineering"
+
+ To get the best results from DeepSeek-VL2, use optimized prompts that leverage
+ its object detection and visual grounding capabilities effectively.
+
+ **For general object detection, use the following user prompt:**
+
+ ```
+ \\n<|ref|>The giraffe at the front<|/ref|>
+ ```
+
+ **For visual grounding, use the following user prompt:**
+
+ ```
+ \\n<|grounding|>Detect the giraffes
+ ```
+
+ ```python
+ from PIL import Image
+ import supervision as sv
+
+ deepseek_vl2_result = "<|ref|>The giraffe at the back<|/ref|><|det|>[[580, 270, 999, 904]]<|/det|><|ref|>The giraffe at the front<|/ref|><|det|>[[26, 31, 632, 998]]<|/det|><|endβofβsentence|>"
+
+ detections = sv.Detections.from_vlm(
+ vlm=sv.VLM.DEEPSEEK_VL_2, result=deepseek_vl2_result, resolution_wh=image.size
+ )
+
+ detections.xyxy
+ # array([[ 420, 293, 724, 982],
+ # [ 18, 33, 458, 1084]])
+
+ detections.class_id
+ # array([0, 1])
+
+ detections.data
+ # {'class_name': array(['The giraffe at the back', 'The giraffe at the front'], dtype=' Detectio
LMM.PALIGEMMA: VLM.PALIGEMMA,
LMM.FLORENCE_2: VLM.FLORENCE_2,
LMM.QWEN_2_5_VL: VLM.QWEN_2_5_VL,
+ LMM.DEEPSEEK_VL_2: VLM.DEEPSEEK_VL_2,
LMM.GOOGLE_GEMINI_2_0: VLM.GOOGLE_GEMINI_2_0,
LMM.GOOGLE_GEMINI_2_5: VLM.GOOGLE_GEMINI_2_5,
}
@@ -1161,9 +1242,11 @@ def from_vlm(cls, vlm: VLM | str, result: str | dict, **kwargs: Any) -> Detectio
| PaliGemma | `PALIGEMMA` | detection | `resolution_wh` | `classes` |
| PaliGemma 2 | `PALIGEMMA` | detection | `resolution_wh` | `classes` |
| Qwen2.5-VL | `QWEN_2_5_VL` | detection | `resolution_wh`, `input_wh` | `classes` |
+ | Qwen3-VL | `QWEN_3_VL` | detection | `resolution_wh`, | `classes` |
| Google Gemini 2.0 | `GOOGLE_GEMINI_2_0` | detection | `resolution_wh` | `classes` |
| Google Gemini 2.5 | `GOOGLE_GEMINI_2_5` | detection, segmentation | `resolution_wh` | `classes` |
| Moondream | `MOONDREAM` | detection | `resolution_wh` | |
+ | DeepSeek-VL2 | `DEEPSEEK_VL_2` | detection | `resolution_wh` | `classes` |
Args:
vlm (Union[VLM, str]): The type of VLM (Vision Language Model) to use.
@@ -1277,6 +1360,36 @@ def from_vlm(cls, vlm: VLM | str, result: str | dict, **kwargs: Any) -> Detectio
# array([0, 1])
```
+ !!! example "Qwen3-VL"
+
+ ```python
+ import supervision as sv
+
+ qwen_3_vl_result = \"\"\"```json
+ [
+ {"bbox_2d": [139, 768, 315, 954], "label": "cat"},
+ {"bbox_2d": [366, 679, 536, 849], "label": "dog"}
+ ]
+ ```\"\"\"
+ detections = sv.Detections.from_vlm(
+ sv.VLM.QWEN_3_VL,
+ qwen_3_vl_result,
+ resolution_wh=(1000, 1000),
+ classes=['cat', 'dog'],
+ )
+ detections.xyxy
+ # array([[139., 768., 315., 954.], [366., 679., 536., 849.]])
+
+ detections.class_id
+ # array([0, 1])
+
+ detections.data
+ # {'class_name': array(['cat', 'dog'], dtype=' Detectio
# [1908.01, 1346.67, 2585.99, 2024.11]])
```
+ !!! example "DeepSeek-VL2"
+
+
+ ??? tip "Prompt engineering"
+
+ To get the best results from DeepSeek-VL2, use optimized prompts that leverage
+ its object detection and visual grounding capabilities effectively.
+
+ **For general object detection, use the following user prompt:**
+
+ ```
+ \\n<|ref|>The giraffe at the front<|/ref|>
+ ```
+
+ **For visual grounding, use the following user prompt:**
+
+ ```
+ \\n<|grounding|>Detect the giraffes
+ ```
+
+ ```python
+ from PIL import Image
+ import supervision as sv
+
+ deepseek_vl2_result = "<|ref|>The giraffe at the back<|/ref|><|det|>[[580, 270, 999, 904]]<|/det|><|ref|>The giraffe at the front<|/ref|><|det|>[[26, 31, 632, 998]]<|/det|><|endβofβsentence|>"
+
+ detections = sv.Detections.from_vlm(
+ vlm=sv.VLM.DEEPSEEK_VL_2, result=deepseek_vl2_result, resolution_wh=image.size
+ )
+
+ detections.xyxy
+ # array([[ 420, 293, 724, 982],
+ # [ 18, 33, 458, 1084]])
+
+ detections.class_id
+ # array([0, 1])
+
+ detections.data
+ # {'class_name': array(['The giraffe at the back', 'The giraffe at the front'], dtype=' Detectio
if vlm == VLM.QWEN_2_5_VL:
xyxy, class_id, class_name = from_qwen_2_5_vl(result, **kwargs)
data = {CLASS_NAME_DATA_FIELD: class_name}
+ confidence = np.ones(len(xyxy), dtype=float)
+ return cls(xyxy=xyxy, class_id=class_id, confidence=confidence, data=data)
+
+ if vlm == VLM.QWEN_3_VL:
+ xyxy, class_id, class_name = from_qwen_3_vl(result, **kwargs)
+ data = {CLASS_NAME_DATA_FIELD: class_name}
+ confidence = np.ones(len(xyxy), dtype=float)
+ return cls(xyxy=xyxy, class_id=class_id, confidence=confidence, data=data)
+
+ if vlm == VLM.DEEPSEEK_VL_2:
+ xyxy, class_id, class_name = from_deepseek_vl_2(result, **kwargs)
+ data = {CLASS_NAME_DATA_FIELD: class_name}
return cls(xyxy=xyxy, class_id=class_id, data=data)
if vlm == VLM.FLORENCE_2:
@@ -1922,6 +2088,43 @@ def box_area(self) -> np.ndarray:
"""
return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])
+ @property
+ def box_aspect_ratio(self) -> np.ndarray:
+ """
+ Compute the aspect ratio (width divided by height) for each bounding box.
+
+ Returns:
+ np.ndarray: Array of shape `(N,)` containing aspect ratios, where `N` is the
+ number of boxes (width / height for each box).
+
+ Examples:
+ ```python
+ import numpy as np
+ import supervision as sv
+
+ xyxy = np.array([
+ [10, 10, 50, 50],
+ [60, 10, 180, 50],
+ [10, 60, 50, 180],
+ ])
+
+ detections = sv.Detections(xyxy=xyxy)
+
+ detections.box_aspect_ratio
+ # array([1.0, 3.0, 0.33333333])
+
+ ar = detections.box_aspect_ratio
+ detections[(ar < 2.0) & (ar > 0.5)].xyxy
+ # array([[10., 10., 50., 50.]])
+ ```
+ """
+ widths = self.xyxy[:, 2] - self.xyxy[:, 0]
+ heights = self.xyxy[:, 3] - self.xyxy[:, 1]
+
+ aspect_ratios = np.full_like(widths, np.nan, dtype=np.float64)
+ np.divide(widths, heights, out=aspect_ratios, where=heights != 0)
+ return aspect_ratios
+
def with_nms(
self,
threshold: float = 0.5,
diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py
index aaecccb3dc..4cc05f19cc 100644
--- a/supervision/detection/tools/inference_slicer.py
+++ b/supervision/detection/tools/inference_slicer.py
@@ -11,11 +11,9 @@
from supervision.detection.utils.boxes import move_boxes, move_oriented_boxes
from supervision.detection.utils.iou_and_nms import OverlapFilter, OverlapMetric
from supervision.detection.utils.masks import move_masks
-from supervision.utils.image import crop_image
-from supervision.utils.internal import (
- SupervisionWarnings,
- warn_deprecated,
-)
+from supervision.draw.base import ImageType
+from supervision.utils.image import crop_image, get_image_resolution_wh
+from supervision.utils.internal import SupervisionWarnings
def move_detections(
@@ -53,111 +51,106 @@ def move_detections(
class InferenceSlicer:
"""
- InferenceSlicer performs slicing-based inference for small target detection. This
- method, often referred to as
- [Slicing Adaptive Inference (SAHI)](https://ieeexplore.ieee.org/document/9897990),
- involves dividing a larger image into smaller slices, performing inference on each
- slice, and then merging the detections.
+ Perform tiled inference on large images by slicing them into overlapping patches.
+
+ This class divides an input image into overlapping slices of configurable size
+ and overlap, runs inference on each slice through a user-provided callback, and
+ merges the resulting detections. The slicing process allows efficient processing
+ of large images with limited resources while preserving detection accuracy via
+ configurable overlap and post-processing of overlaps. Uses multi-threading for
+ parallel slice inference.
Args:
- slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The
- tuple should be in the format `(width, height)`.
- overlap_ratio_wh (Optional[Tuple[float, float]]): [β οΈ Deprecated: please set
- to `None` and use `overlap_wh`] A tuple representing the
- desired overlap ratio for width and height between consecutive slices.
- Each value should be in the range [0, 1), where 0 means no overlap and
- a value close to 1 means high overlap.
- overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
- overlap for width and height between consecutive slices measured in pixels.
- Each value should be greater than or equal to 0. Takes precedence over
- `overlap_ratio_wh`.
- overlap_filter (Union[OverlapFilter, str]): Strategy for
- filtering or merging overlapping detections in slices.
- iou_threshold (float): Intersection over Union (IoU) threshold
- used when filtering by overlap.
- overlap_metric (Union[OverlapMetric, str]): Metric used for matching detections
- in slices.
- callback (Callable): A function that performs inference on a given image
- slice and returns detections.
- thread_workers (int): Number of threads for parallel execution.
-
- Note:
- The class ensures that slices do not exceed the boundaries of the original
- image. As a result, the final slices in the row and column dimensions might be
- smaller than the specified slice dimensions if the image's width or height is
- not a multiple of the slice's width or height minus the overlap.
+ callback (Callable[[ImageType], Detections]): Inference function that takes
+ a sliced image and returns a `Detections` object.
+ slice_wh (int or tuple[int, int]): Size of each slice `(width, height)`.
+ If int, both width and height are set to this value.
+ overlap_wh (int or tuple[int, int]): Overlap size `(width, height)` between
+ slices. If int, both width and height are set to this value.
+ overlap_filter (OverlapFilter or str): Strategy to merge overlapping
+ detections (`NON_MAX_SUPPRESSION`, `NON_MAX_MERGE`, or `NONE`).
+ iou_threshold (float): IOU threshold used in merging overlap filtering.
+ overlap_metric (OverlapMetric or str): Metric to compute overlap
+ (`IOU` or `IOS`).
+ thread_workers (int): Number of threads for concurrent slice inference.
+
+ Raises:
+ ValueError: If `slice_wh` or `overlap_wh` are invalid or inconsistent.
+
+ Example:
+ ```python
+ import cv2
+ import supervision as sv
+ from rfdetr import RFDETRMedium
+
+ model = RFDETRMedium()
+
+ def callback(tile):
+ return model.predict(tile)
+
+ slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100)
+
+ image = cv2.imread("example.png")
+ detections = slicer(image)
+ ```
+
+ ```python
+ import supervision as sv
+ from PIL import Image
+ from ultralytics import YOLO
+
+ model = YOLO("yolo11m.pt")
+
+ def callback(tile):
+ results = model(tile)[0]
+ return sv.Detections.from_ultralytics(results)
+
+ slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100)
+
+ image = Image.open("example.png")
+ detections = slicer(image)
+ ```
"""
def __init__(
self,
- callback: Callable[[np.ndarray], Detections],
- slice_wh: tuple[int, int] = (320, 320),
- overlap_ratio_wh: tuple[float, float] | None = (0.2, 0.2),
- overlap_wh: tuple[int, int] | None = None,
+ callback: Callable[[ImageType], Detections],
+ slice_wh: int | tuple[int, int] = 640,
+ overlap_wh: int | tuple[int, int] = 100,
overlap_filter: OverlapFilter | str = OverlapFilter.NON_MAX_SUPPRESSION,
iou_threshold: float = 0.5,
overlap_metric: OverlapMetric | str = OverlapMetric.IOU,
thread_workers: int = 1,
):
- if overlap_ratio_wh is not None:
- warn_deprecated(
- "`overlap_ratio_wh` in `InferenceSlicer.__init__` is deprecated and "
- "will be removed in `supervision-0.27.0`. Please manually set it to "
- "`None` and use `overlap_wh` instead."
- )
+ slice_wh_norm = self._normalize_slice_wh(slice_wh)
+ overlap_wh_norm = self._normalize_overlap_wh(overlap_wh)
- self._validate_overlap(overlap_ratio_wh, overlap_wh)
- self.overlap_ratio_wh = overlap_ratio_wh
- self.overlap_wh = overlap_wh
+ self._validate_overlap(slice_wh=slice_wh_norm, overlap_wh=overlap_wh_norm)
- self.slice_wh = slice_wh
+ self.slice_wh = slice_wh_norm
+ self.overlap_wh = overlap_wh_norm
self.iou_threshold = iou_threshold
self.overlap_metric = OverlapMetric.from_value(overlap_metric)
self.overlap_filter = OverlapFilter.from_value(overlap_filter)
self.callback = callback
self.thread_workers = thread_workers
- def __call__(self, image: np.ndarray) -> Detections:
+ def __call__(self, image: ImageType) -> Detections:
"""
- Performs slicing-based inference on the provided image using the specified
- callback.
+ Perform tiled inference on the full image and return merged detections.
Args:
- image (np.ndarray): The input image on which inference needs to be
- performed. The image should be in the format
- `(height, width, channels)`.
+ image (ImageType): The full image to run inference on.
Returns:
- Detections: A collection of detections for the entire image after merging
- results from all slices and applying NMS.
-
- Example:
- ```python
- import cv2
- import supervision as sv
- from ultralytics import YOLO
-
- image = cv2.imread(SOURCE_IMAGE_PATH)
- model = YOLO(...)
-
- def callback(image_slice: np.ndarray) -> sv.Detections:
- result = model(image_slice)[0]
- return sv.Detections.from_ultralytics(result)
-
- slicer = sv.InferenceSlicer(
- callback=callback,
- overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION,
- )
-
- detections = slicer(image)
- ```
+ Detections: Merged detections across all slices.
"""
- detections_list = []
- resolution_wh = (image.shape[1], image.shape[0])
+ detections_list: list[Detections] = []
+ resolution_wh = get_image_resolution_wh(image)
+
offsets = self._generate_offset(
resolution_wh=resolution_wh,
slice_wh=self.slice_wh,
- overlap_ratio_wh=self.overlap_ratio_wh,
overlap_wh=self.overlap_wh,
)
@@ -171,129 +164,178 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
merged = Detections.merge(detections_list=detections_list)
if self.overlap_filter == OverlapFilter.NONE:
return merged
- elif self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION:
+ if self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION:
return merged.with_nms(
- threshold=self.iou_threshold, overlap_metric=self.overlap_metric
+ threshold=self.iou_threshold,
+ overlap_metric=self.overlap_metric,
)
- elif self.overlap_filter == OverlapFilter.NON_MAX_MERGE:
+ if self.overlap_filter == OverlapFilter.NON_MAX_MERGE:
return merged.with_nmm(
- threshold=self.iou_threshold, overlap_metric=self.overlap_metric
+ threshold=self.iou_threshold,
+ overlap_metric=self.overlap_metric,
)
- else:
- warnings.warn(
- f"Invalid overlap filter strategy: {self.overlap_filter}",
- category=SupervisionWarnings,
- )
- return merged
- def _run_callback(self, image, offset) -> Detections:
+ warnings.warn(
+ f"Invalid overlap filter strategy: {self.overlap_filter}",
+ category=SupervisionWarnings,
+ )
+ return merged
+
+ def _run_callback(self, image: ImageType, offset: np.ndarray) -> Detections:
"""
- Run the provided callback on a slice of an image.
+ Run detection callback on a sliced portion of the image and adjust coordinates.
Args:
- image (np.ndarray): The input image on which inference needs to run
- offset (np.ndarray): An array of shape `(4,)` containing coordinates
- for the slice.
+ image (ImageType): The full image.
+ offset (numpy.ndarray): Coordinates `(x_min, y_min, x_max, y_max)` defining
+ the slice region.
Returns:
- Detections: A collection of detections for the slice.
+ Detections: Detections adjusted to the full image coordinate system.
"""
- image_slice = crop_image(image=image, xyxy=offset)
+ image_slice: ImageType = crop_image(image=image, xyxy=offset)
detections = self.callback(image_slice)
- resolution_wh = (image.shape[1], image.shape[0])
+ resolution_wh = get_image_resolution_wh(image)
+
detections = move_detections(
- detections=detections, offset=offset[:2], resolution_wh=resolution_wh
+ detections=detections,
+ offset=offset[:2],
+ resolution_wh=resolution_wh,
)
-
return detections
+ @staticmethod
+ def _normalize_slice_wh(
+ slice_wh: int | tuple[int, int],
+ ) -> tuple[int, int]:
+ if isinstance(slice_wh, int):
+ if slice_wh <= 0:
+ raise ValueError(
+ f"`slice_wh` must be a positive integer. Received: {slice_wh}"
+ )
+ return slice_wh, slice_wh
+
+ if isinstance(slice_wh, tuple) and len(slice_wh) == 2:
+ width, height = slice_wh
+ if width <= 0 or height <= 0:
+ raise ValueError(
+ f"`slice_wh` values must be positive. Received: {slice_wh}"
+ )
+ return width, height
+
+ raise ValueError(
+ "`slice_wh` must be an int or a tuple of two positive integers "
+ "(slice_w, slice_h). "
+ f"Received: {slice_wh}"
+ )
+
+ @staticmethod
+ def _normalize_overlap_wh(
+ overlap_wh: int | tuple[int, int],
+ ) -> tuple[int, int]:
+ if isinstance(overlap_wh, int):
+ if overlap_wh < 0:
+ raise ValueError(
+ "`overlap_wh` must be a non negative integer. "
+ f"Received: {overlap_wh}"
+ )
+ return overlap_wh, overlap_wh
+
+ if isinstance(overlap_wh, tuple) and len(overlap_wh) == 2:
+ overlap_w, overlap_h = overlap_wh
+ if overlap_w < 0 or overlap_h < 0:
+ raise ValueError(
+ f"`overlap_wh` values must be non negative. Received: {overlap_wh}"
+ )
+ return overlap_w, overlap_h
+
+ raise ValueError(
+ "`overlap_wh` must be an int or a tuple of two non negative integers "
+ "(overlap_w, overlap_h). "
+ f"Received: {overlap_wh}"
+ )
+
@staticmethod
def _generate_offset(
resolution_wh: tuple[int, int],
slice_wh: tuple[int, int],
- overlap_ratio_wh: tuple[float, float] | None,
- overlap_wh: tuple[int, int] | None,
+ overlap_wh: tuple[int, int],
) -> np.ndarray:
"""
- Generate offset coordinates for slicing an image based on the given resolution,
- slice dimensions, and overlap ratios.
+ Generate bounding boxes defining the coordinates of image slices with overlap.
Args:
- resolution_wh (Tuple[int, int]): A tuple representing the width and height
- of the image to be sliced.
- slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The
- tuple should be in the format `(width, height)`.
- overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
- desired overlap ratio for width and height between consecutive slices.
- Each value should be in the range [0, 1), where 0 means no overlap and
- a value close to 1 means high overlap.
- overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
- overlap for width and height between consecutive slices measured in
- pixels. Each value should be greater than or equal to 0.
+ resolution_wh (tuple[int, int]): Image resolution `(width, height)`.
+ slice_wh (tuple[int, int]): Size of each slice `(width, height)`.
+ overlap_wh (tuple[int, int]): Overlap size between slices `(width, height)`.
Returns:
- np.ndarray: An array of shape `(n, 4)` containing coordinates for each
- slice in the format `[xmin, ymin, xmax, ymax]`.
-
- Note:
- The function ensures that slices do not exceed the boundaries of the
- original image. As a result, the final slices in the row and column
- dimensions might be smaller than the specified slice dimensions if the
- image's width or height is not a multiple of the slice's width or
- height minus the overlap.
+ numpy.ndarray: Array of shape `(num_slices, 4)` with each row as
+ `(x_min, y_min, x_max, y_max)` coordinates for a slice.
"""
slice_width, slice_height = slice_wh
image_width, image_height = resolution_wh
- overlap_width = (
- overlap_wh[0]
- if overlap_wh is not None
- else int(overlap_ratio_wh[0] * slice_width)
+ overlap_width, overlap_height = overlap_wh
+
+ stride_x = slice_width - overlap_width
+ stride_y = slice_height - overlap_height
+
+ def _compute_axis_starts(
+ image_size: int,
+ slice_size: int,
+ stride: int,
+ ) -> list[int]:
+ if image_size <= slice_size:
+ return [0]
+
+ if stride == slice_size:
+ return np.arange(0, image_size, stride).tolist()
+
+ last_start = image_size - slice_size
+ starts = np.arange(0, last_start, stride).tolist()
+ if not starts or starts[-1] != last_start:
+ starts.append(last_start)
+ return starts
+
+ x_starts = _compute_axis_starts(
+ image_size=image_width,
+ slice_size=slice_width,
+ stride=stride_x,
)
- overlap_height = (
- overlap_wh[1]
- if overlap_wh is not None
- else int(overlap_ratio_wh[1] * slice_height)
+ y_starts = _compute_axis_starts(
+ image_size=image_height,
+ slice_size=slice_height,
+ stride=stride_y,
)
- width_stride = slice_width - overlap_width
- height_stride = slice_height - overlap_height
+ x_min, y_min = np.meshgrid(x_starts, y_starts)
+ x_max = np.clip(x_min + slice_width, 0, image_width)
+ y_max = np.clip(y_min + slice_height, 0, image_height)
- ws = np.arange(0, image_width, width_stride)
- hs = np.arange(0, image_height, height_stride)
-
- xmin, ymin = np.meshgrid(ws, hs)
- xmax = np.clip(xmin + slice_width, 0, image_width)
- ymax = np.clip(ymin + slice_height, 0, image_height)
-
- offsets = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4)
+ offsets = np.stack(
+ [x_min, y_min, x_max, y_max],
+ axis=-1,
+ ).reshape(-1, 4)
return offsets
@staticmethod
def _validate_overlap(
- overlap_ratio_wh: tuple[float, float] | None,
- overlap_wh: tuple[int, int] | None,
+ slice_wh: tuple[int, int],
+ overlap_wh: tuple[int, int],
) -> None:
- if overlap_ratio_wh is not None and overlap_wh is not None:
+ overlap_w, overlap_h = overlap_wh
+ slice_w, slice_h = slice_wh
+
+ if overlap_w < 0 or overlap_h < 0:
raise ValueError(
- "Both `overlap_ratio_wh` and `overlap_wh` cannot be provided. "
- "Please provide only one of them."
+ "Overlap values must be greater than or equal to 0. "
+ f"Received: {overlap_wh}"
)
- if overlap_ratio_wh is None and overlap_wh is None:
+
+ if overlap_w >= slice_w or overlap_h >= slice_h:
raise ValueError(
- "Either `overlap_ratio_wh` or `overlap_wh` must be provided. "
- "Please provide one of them."
+ "`overlap_wh` must be smaller than `slice_wh` in both dimensions "
+ f"to keep a positive stride. Received overlap_wh={overlap_wh}, "
+ f"slice_wh={slice_wh}."
)
-
- if overlap_ratio_wh is not None:
- if not (0 <= overlap_ratio_wh[0] < 1 and 0 <= overlap_ratio_wh[1] < 1):
- raise ValueError(
- "Overlap ratios must be in the range [0, 1). "
- f"Received: {overlap_ratio_wh}"
- )
- if overlap_wh is not None:
- if not (overlap_wh[0] >= 0 and overlap_wh[1] >= 0):
- raise ValueError(
- "Overlap values must be greater than or equal to 0. "
- f"Received: {overlap_wh}"
- )
diff --git a/supervision/detection/utils/boxes.py b/supervision/detection/utils/boxes.py
index 3b01fcb68b..52e4b69569 100644
--- a/supervision/detection/utils/boxes.py
+++ b/supervision/detection/utils/boxes.py
@@ -14,8 +14,8 @@ def clip_boxes(xyxy: np.ndarray, resolution_wh: tuple[int, int]) -> np.ndarray:
xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each
row corresponds to a bounding box in
the format `(x_min, y_min, x_max, y_max)`.
- resolution_wh (Tuple[int, int]): A tuple of the form `(width, height)`
- representing the resolution of the frame.
+ resolution_wh (Tuple[int, int]): A tuple of the form
+ `(width, height)` representing the resolution of the frame.
Returns:
np.ndarray: A numpy array of shape `(N, 4)` where each row
@@ -95,24 +95,27 @@ def pad_boxes(xyxy: np.ndarray, px: int, py: int | None = None) -> np.ndarray:
def denormalize_boxes(
- normalized_xyxy: np.ndarray,
+ xyxy: np.ndarray,
resolution_wh: tuple[int, int],
normalization_factor: float = 1.0,
) -> np.ndarray:
"""
- Converts normalized bounding box coordinates to absolute pixel values.
+ Convert normalized bounding box coordinates to absolute pixel coordinates.
+
+ Multiplies each bounding box coordinate by image size and divides by
+ `normalization_factor`, mapping values from normalized `[0, normalization_factor]`
+ to absolute pixel values for a given resolution.
Args:
- normalized_xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each row
- contains normalized coordinates in the format `(x_min, y_min, x_max, y_max)`,
- with values between 0 and `normalization_factor`.
- resolution_wh (Tuple[int, int]): A tuple `(width, height)` representing the
- target image resolution.
- normalization_factor (float, optional): The normalization range of the input
- coordinates. Defaults to 1.0.
+ xyxy (`numpy.ndarray`): Normalized bounding boxes of shape `(N, 4)`,
+ where each row is `(x_min, y_min, x_max, y_max)`, values in
+ `[0, normalization_factor]`.
+ resolution_wh (`tuple[int, int]`): Target image resolution as `(width, height)`.
+ normalization_factor (`float`): Maximum value of input coordinate range.
+ Defaults to `1.0`.
Returns:
- np.ndarray: An array of shape `(N, 4)` with absolute coordinates in
+ (`numpy.ndarray`): Array of shape `(N, 4)` with absolute coordinates in
`(x_min, y_min, x_max, y_max)` format.
Examples:
@@ -120,35 +123,39 @@ def denormalize_boxes(
import numpy as np
import supervision as sv
- # Default normalization (0-1)
- normalized_xyxy = np.array([
+ xyxy = np.array([
[0.1, 0.2, 0.5, 0.6],
- [0.3, 0.4, 0.7, 0.8]
+ [0.3, 0.4, 0.7, 0.8],
+ [0.2, 0.1, 0.6, 0.5]
])
- resolution_wh = (100, 200)
- sv.denormalize_boxes(normalized_xyxy, resolution_wh)
+
+ sv.denormalize_boxes(xyxy, (1280, 720))
# array([
- # [ 10., 40., 50., 120.],
- # [ 30., 80., 70., 160.]
+ # [128., 144., 640., 432.],
+ # [384., 288., 896., 576.],
+ # [256., 72., 768., 360.]
# ])
+ ```
- # Custom normalization (0-100)
- normalized_xyxy = np.array([
- [10., 20., 50., 60.],
- [30., 40., 70., 80.]
+ ```
+ import numpy as np
+ import supervision as sv
+
+ xyxy = np.array([
+ [256., 128., 768., 640.]
])
- sv.denormalize_boxes(normalized_xyxy, resolution_wh, normalization_factor=100.0)
+
+ sv.denormalize_boxes(xyxy, (1280, 720), normalization_factor=1024.0)
# array([
- # [ 10., 40., 50., 120.],
- # [ 30., 80., 70., 160.]
+ # [320., 90., 960., 450.]
# ])
```
- """ # noqa E501 // docs
+ """
width, height = resolution_wh
- result = normalized_xyxy.copy()
+ result = xyxy.copy()
- result[[0, 2]] = (result[[0, 2]] * width) / normalization_factor
- result[[1, 3]] = (result[[1, 3]] * height) / normalization_factor
+ result[:, [0, 2]] = (result[:, [0, 2]] * width) / normalization_factor
+ result[:, [1, 3]] = (result[:, [1, 3]] * height) / normalization_factor
return result
diff --git a/supervision/detection/utils/converters.py b/supervision/detection/utils/converters.py
index 9e02783a04..4aef2dc87c 100644
--- a/supervision/detection/utils/converters.py
+++ b/supervision/detection/utils/converters.py
@@ -229,6 +229,70 @@ def mask_to_xyxy(masks: np.ndarray) -> np.ndarray:
return xyxy
+def xyxy_to_mask(boxes: np.ndarray, resolution_wh: tuple[int, int]) -> np.ndarray:
+ """
+ Converts a 2D `np.ndarray` of bounding boxes into a 3D `np.ndarray` of bool masks.
+
+ Parameters:
+ boxes (np.ndarray): A 2D `np.ndarray` of shape `(N, 4)`
+ containing bounding boxes `(x_min, y_min, x_max, y_max)`
+ resolution_wh (Tuple[int, int]): A tuple `(width, height)` specifying
+ the resolution of the output masks
+
+ Returns:
+ np.ndarray: A 3D `np.ndarray` of shape `(N, height, width)`
+ containing 2D bool masks for each bounding box
+
+ Examples:
+ ```python
+ import numpy as np
+ import supervision as sv
+
+ boxes = np.array([[0, 0, 2, 2]])
+
+ sv.xyxy_to_mask(boxes, (5, 5))
+ # array([
+ # [[ True, True, True, False, False],
+ # [ True, True, True, False, False],
+ # [ True, True, True, False, False],
+ # [False, False, False, False, False],
+ # [False, False, False, False, False]]
+ # ])
+
+ boxes = np.array([[0, 0, 1, 1], [3, 3, 4, 4]])
+
+ sv.xyxy_to_mask(boxes, (5, 5))
+ # array([
+ # [[ True, True, False, False, False],
+ # [ True, True, False, False, False],
+ # [False, False, False, False, False],
+ # [False, False, False, False, False],
+ # [False, False, False, False, False]],
+ #
+ # [[False, False, False, False, False],
+ # [False, False, False, False, False],
+ # [False, False, False, False, False],
+ # [False, False, False, True, True],
+ # [False, False, False, True, True]]
+ # ])
+ ```
+ """
+ width, height = resolution_wh
+ n = boxes.shape[0]
+ masks = np.zeros((n, height, width), dtype=bool)
+
+ for i, (x_min, y_min, x_max, y_max) in enumerate(boxes):
+ x_min = max(0, int(x_min))
+ y_min = max(0, int(y_min))
+ x_max = min(width - 1, int(x_max))
+ y_max = min(height - 1, int(y_max))
+
+ if x_max >= x_min and y_max >= y_min:
+ masks[i, y_min : y_max + 1, x_min : x_max + 1] = True
+
+ return masks
+
+
def mask_to_polygons(mask: np.ndarray) -> list[np.ndarray]:
"""
Converts a binary mask to a list of polygons.
diff --git a/supervision/detection/utils/internal.py b/supervision/detection/utils/internal.py
index bc6579a8b9..f5d6dc9fbf 100644
--- a/supervision/detection/utils/internal.py
+++ b/supervision/detection/utils/internal.py
@@ -271,7 +271,6 @@ def merge_metadata(metadata_list: list[dict[str, Any]]) -> dict[str, Any]:
"{type(value)}, {type(other_value)}."
)
else:
- print("hm")
if merged_metadata[key] != value:
raise ValueError(f"Conflicting metadata for key: '{key}'.")
diff --git a/supervision/detection/utils/iou_and_nms.py b/supervision/detection/utils/iou_and_nms.py
index 1a6f80bc58..8a444cf16e 100644
--- a/supervision/detection/utils/iou_and_nms.py
+++ b/supervision/detection/utils/iou_and_nms.py
@@ -64,7 +64,7 @@ class OverlapMetric(Enum):
IOS = "IOS"
@classmethod
- def list(cls):
+ def list(cls) -> list[str]:
return list(map(lambda c: c.value, cls))
@classmethod
@@ -72,7 +72,7 @@ def from_value(cls, value: OverlapMetric | str) -> OverlapMetric:
if isinstance(value, cls):
return value
if isinstance(value, str):
- value = value.lower()
+ value = value.upper()
try:
return cls(value)
except ValueError:
@@ -86,91 +86,107 @@ def from_value(cls, value: OverlapMetric | str) -> OverlapMetric:
def box_iou(
box_true: list[float] | np.ndarray,
box_detection: list[float] | np.ndarray,
+ overlap_metric: OverlapMetric | str = OverlapMetric.IOU,
) -> float:
- r"""
- Compute the Intersection over Union (IoU) between two bounding boxes.
-
- \[
- \text{IoU} = \frac{|\text{box}_{\text{true}} \cap \text{box}_{\text{detection}}|}{|\text{box}_{\text{true}} \cup \text{box}_{\text{detection}}|}
- \]
+ """
+ Compute overlap metric between two bounding boxes.
- Note:
- Use `box_iou` when computing IoU between two individual boxes.
- For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better
- performance.
+ Supports standard IOU (intersection-over-union) and IOS
+ (intersection-over-smaller-area) metrics. Returns the overlap value in range
+ `[0, 1]`.
Args:
- box_true (Union[List[float], np.ndarray]): A single bounding box represented as
- [x_min, y_min, x_max, y_max].
- box_detection (Union[List[float], np.ndarray]):
- A single bounding box represented as [x_min, y_min, x_max, y_max].
+ box_true (`list[float]` or `numpy.array`): Ground truth box in format
+ `(x_min, y_min, x_max, y_max)`.
+ box_detection (`list[float]` or `numpy.array`): Detected box in format
+ `(x_min, y_min, x_max, y_max)`.
+ overlap_metric (`OverlapMetric` or `str`): Overlap type.
+ Use `OverlapMetric.IOU` for IOU or
+ `OverlapMetric.IOS` for IOS. Defaults to `OverlapMetric.IOU`.
Returns:
- IoU (float): IoU score between the two boxes. Ranges from 0.0 (no overlap)
- to 1.0 (perfect overlap).
+ (`float`): Overlap value between boxes in `[0, 1]`.
+
+ Raises:
+ ValueError: If `overlap_metric` is not IOU or IOS.
Examples:
- ```python
- import numpy as np
+ ```
import supervision as sv
- box_true = np.array([100, 100, 200, 200])
- box_detection = np.array([150, 150, 250, 250])
+ box_true = [100, 100, 200, 200]
+ box_detection = [150, 150, 250, 250]
- sv.box_iou(box_true=box_true, box_detection=box_detection)
- # 0.14285814285714285
+ sv.box_iou(box_true, box_detection, overlap_metric=sv.OverlapMetric.IOU)
+ # 0.14285714285714285
+
+ sv.box_iou(box_true, box_detection, overlap_metric=sv.OverlapMetric.IOS)
+ # 0.25
```
- """ # noqa: E501
- box_true = np.array(box_true)
- box_detection = np.array(box_detection)
+ """
+ overlap_metric = OverlapMetric.from_value(overlap_metric)
+ x_min_true, y_min_true, x_max_true, y_max_true = np.array(box_true)
+ x_min_det, y_min_det, x_max_det, y_max_det = np.array(box_detection)
- inter_x1 = max(box_true[0], box_detection[0])
- inter_y1 = max(box_true[1], box_detection[1])
- inter_x2 = min(box_true[2], box_detection[2])
- inter_y2 = min(box_true[3], box_detection[3])
+ x_min_inter = max(x_min_true, x_min_det)
+ y_min_inter = max(y_min_true, y_min_det)
+ x_max_inter = min(x_max_true, x_max_det)
+ y_max_inter = min(y_max_true, y_max_det)
- inter_w = max(0, inter_x2 - inter_x1)
- inter_h = max(0, inter_y2 - inter_y1)
+ inter_w = max(0.0, x_max_inter - x_min_inter)
+ inter_h = max(0.0, y_max_inter - y_min_inter)
- inter_area = inter_w * inter_h
+ area_inter = inter_w * inter_h
- area_true = (box_true[2] - box_true[0]) * (box_true[3] - box_true[1])
- area_detection = (box_detection[2] - box_detection[0]) * (
- box_detection[3] - box_detection[1]
- )
+ area_true = (x_max_true - x_min_true) * (y_max_true - y_min_true)
+ area_det = (x_max_det - x_min_det) * (y_max_det - y_min_det)
- union_area = area_true + area_detection - inter_area
+ if overlap_metric == OverlapMetric.IOU:
+ area_norm = area_true + area_det - area_inter
+ elif overlap_metric == OverlapMetric.IOS:
+ area_norm = min(area_true, area_det)
+ else:
+ raise ValueError(
+ f"overlap_metric {overlap_metric} is not supported, "
+ "only 'IOU' and 'IOS' are supported"
+ )
- return inter_area / union_area + 1e-6
+ if area_norm <= 0.0:
+ return 0.0
+
+ return float(area_inter / area_norm)
def box_iou_batch(
boxes_true: np.ndarray,
boxes_detection: np.ndarray,
- overlap_metric: OverlapMetric = OverlapMetric.IOU,
+ overlap_metric: OverlapMetric | str = OverlapMetric.IOU,
) -> np.ndarray:
"""
- Compute Intersection over Union (IoU) of two sets of bounding boxes -
- `boxes_true` and `boxes_detection`. Both sets
- of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format.
+ Compute pairwise overlap scores between batches of bounding boxes.
- Note:
- Use `box_iou` when computing IoU between two individual boxes.
- For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better
- performance.
+ Supports standard IOU (intersection-over-union) and IOS
+ (intersection-over-smaller-area) metrics for all `boxes_true` and
+ `boxes_detection` pairs. Returns a matrix of overlap values in range
+ `[0, 1]`, matching each box from the first batch to each from the second.
Args:
- boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes.
- `shape = (N, 4)` where `N` is number of true objects.
- boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes.
- `shape = (M, 4)` where `M` is number of detected objects.
- overlap_metric (OverlapMetric): Metric used to compute the degree of overlap
- between pairs of boxes (e.g., IoU, IoS).
+ boxes_true (`numpy.array`): Array of reference boxes in
+ shape `(N, 4)` as `(x_min, y_min, x_max, y_max)`.
+ boxes_detection (`numpy.array`): Array of detected boxes in
+ shape `(M, 4)` as `(x_min, y_min, x_max, y_max)`.
+ overlap_metric (`OverlapMetric` or `str`): Overlap type.
+ Use `OverlapMetric.IOU` for intersection-over-union,
+ `OverlapMetric.IOS` for intersection-over-smaller-area.
+ Defaults to `OverlapMetric.IOU`.
Returns:
- np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`.
- `shape = (N, M)` where `N` is number of true objects and
- `M` is number of detected objects.
+ (`numpy.array`): Overlap matrix of shape `(N, M)`, where entry
+ `[i, j]` is the overlap between `boxes_true[i]` and
+ `boxes_detection[j]`.
+
+ Raises:
+ ValueError: If `overlap_metric` is not IOU or IOS.
Examples:
```python
@@ -186,49 +202,57 @@ def box_iou_batch(
[320, 320, 420, 420]
])
- sv.box_iou_batch(boxes_true=boxes_true, boxes_detection=boxes_detection)
- # array([
- # [0.14285714, 0. ],
- # [0. , 0.47058824]
- # ])
+ sv.box_iou_batch(boxes_true, boxes_detection, overlap_metric=OverlapMetric.IOU)
+ # array([[0.14285715, 0. ],
+ # [0. , 0.47058824]])
+
+ sv.box_iou_batch(boxes_true, boxes_detection, overlap_metric=OverlapMetric.IOS)
+ # array([[0.25, 0. ],
+ # [0. , 0.64]])
```
"""
+ overlap_metric = OverlapMetric.from_value(overlap_metric)
+ x_min_true, y_min_true, x_max_true, y_max_true = boxes_true.T
+ x_min_det, y_min_det, x_max_det, y_max_det = boxes_detection.T
+ count_true, count_det = boxes_true.shape[0], boxes_detection.shape[0]
+
+ if count_true == 0 or count_det == 0:
+ return np.empty((count_true, count_det), dtype=np.float32)
- def box_area(box):
- return (box[2] - box[0]) * (box[3] - box[1])
+ x_min_inter = np.empty((count_true, count_det), dtype=np.float32)
+ x_max_inter = np.empty_like(x_min_inter)
+ y_min_inter = np.empty_like(x_min_inter)
+ y_max_inter = np.empty_like(x_min_inter)
- area_true = box_area(boxes_true.T)
- area_detection = box_area(boxes_detection.T)
+ np.maximum(x_min_true[:, None], x_min_det[None, :], out=x_min_inter)
+ np.minimum(x_max_true[:, None], x_max_det[None, :], out=x_max_inter)
+ np.maximum(y_min_true[:, None], y_min_det[None, :], out=y_min_inter)
+ np.minimum(y_max_true[:, None], y_max_det[None, :], out=y_max_inter)
- top_left = np.maximum(boxes_true[:, None, :2], boxes_detection[:, :2])
- bottom_right = np.minimum(boxes_true[:, None, 2:], boxes_detection[:, 2:])
+ # we reuse x_max_inter and y_max_inter to store inter_w and inter_h
+ np.subtract(x_max_inter, x_min_inter, out=x_max_inter) # inter_w
+ np.subtract(y_max_inter, y_min_inter, out=y_max_inter) # inter_h
+ np.clip(x_max_inter, 0.0, None, out=x_max_inter)
+ np.clip(y_max_inter, 0.0, None, out=y_max_inter)
- area_inter = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None), 2)
+ area_inter = x_max_inter * y_max_inter # inter_w * inter_h
+
+ area_true = (x_max_true - x_min_true) * (y_max_true - y_min_true)
+ area_det = (x_max_det - x_min_det) * (y_max_det - y_min_det)
if overlap_metric == OverlapMetric.IOU:
- union_area = area_true[:, None] + area_detection - area_inter
- ious = np.divide(
- area_inter,
- union_area,
- out=np.zeros_like(area_inter, dtype=float),
- where=union_area != 0,
- )
+ area_norm = area_true[:, None] + area_det[None, :] - area_inter
elif overlap_metric == OverlapMetric.IOS:
- small_area = np.minimum(area_true[:, None], area_detection)
- ious = np.divide(
- area_inter,
- small_area,
- out=np.zeros_like(area_inter, dtype=float),
- where=small_area != 0,
- )
+ area_norm = np.minimum(area_true[:, None], area_det[None, :])
else:
raise ValueError(
f"overlap_metric {overlap_metric} is not supported, "
"only 'IOU' and 'IOS' are supported"
)
- ious = np.nan_to_num(ious)
- return ious
+ out = np.zeros_like(area_inter, dtype=np.float32)
+ np.divide(area_inter, area_norm, out=out, where=area_norm > 0)
+ return out
def _jaccard(box_a: list[float], box_b: list[float], is_crowd: bool) -> float:
diff --git a/supervision/detection/utils/masks.py b/supervision/detection/utils/masks.py
index c5cfee0172..04502dba05 100644
--- a/supervision/detection/utils/masks.py
+++ b/supervision/detection/utils/masks.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from typing import Literal
+
import cv2
import numpy as np
import numpy.typing as npt
@@ -260,3 +262,139 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
resized_masks = masks[:, yv, xv]
return resized_masks.reshape(masks.shape[0], new_height, new_width)
+
+
+def filter_segments_by_distance(
+ mask: npt.NDArray[np.bool_],
+ absolute_distance: float | None = 100.0,
+ relative_distance: float | None = None,
+ connectivity: int = 8,
+ mode: Literal["edge", "centroid"] = "edge",
+) -> npt.NDArray[np.bool_]:
+ """
+ Keep the largest connected component and any other components within a distance
+ threshold.
+
+ Distance can be absolute in pixels or relative to the image diagonal.
+
+ Args:
+ mask: Boolean mask HxW.
+ absolute_distance: Max allowed distance in pixels to the main component.
+ Ignored if `relative_distance` is provided.
+ relative_distance: Fraction of the diagonal. If set, threshold = fraction * sqrt(H^2 + W^2).
+ connectivity: Defines which neighboring pixels are considered connected.
+ - 4-connectedness: Only orthogonal neighbors.
+ ```
+ [ ][X][ ]
+ [X][O][X]
+ [ ][X][ ]
+ ```
+ - 8-connectedness: Includes diagonal neighbors.
+ ```
+ [X][X][X]
+ [X][O][X]
+ [X][X][X]
+ ```
+ Default is 8.
+ mode: Defines how distance between components is measured.
+ - "edge": Uses distance between nearest edges (via distance transform).
+ - "centroid": Uses distance between component centroids.
+
+ Returns:
+ Boolean mask after filtering.
+
+ Examples:
+ ```python
+ import numpy as np
+ import supervision as sv
+
+ mask = np.array([
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ], dtype=bool)
+
+ sv.filter_segments_by_distance(
+ mask,
+ absolute_distance=2,
+ mode="edge",
+ connectivity=8
+ ).astype(int)
+
+ # np.array([
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+ # ], dtype=bool)
+
+ # The nearby 2Γ2 block at columns 6β7 is kept because its edge distance
+ # is within 2 pixels. The distant block at columns 9-10 is removed.
+ ```
+ """ # noqa E501 // docs
+ if mask.dtype != bool:
+ raise TypeError("mask must be boolean")
+
+ height, width = mask.shape
+ if not np.any(mask):
+ return mask.copy()
+
+ image = mask.astype(np.uint8)
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
+ image, connectivity=connectivity
+ )
+
+ if num_labels <= 1:
+ return mask.copy()
+
+ areas = stats[1:, cv2.CC_STAT_AREA]
+ main_label = 1 + int(np.argmax(areas))
+
+ if relative_distance is not None:
+ diagonal = float(np.hypot(height, width))
+ threshold = float(relative_distance) * diagonal
+ else:
+ threshold = float(absolute_distance)
+
+ keep_labels = np.zeros(num_labels, dtype=bool)
+ keep_labels[main_label] = True
+
+ if mode == "centroid":
+ differences = centroids[1:] - centroids[main_label]
+ distances = np.sqrt(np.sum(differences**2, axis=1))
+ nearby = 1 + np.where(distances <= threshold)[0]
+ keep_labels[nearby] = True
+ elif mode == "edge":
+ main_mask = (labels == main_label).astype(np.uint8)
+ inverse = 1 - main_mask
+ distance_transform = cv2.distanceTransform(inverse, cv2.DIST_L2, 3)
+ for label in range(1, num_labels):
+ if label == main_label:
+ continue
+ component = labels == label
+ if not np.any(component):
+ continue
+ min_distance = float(distance_transform[component].min())
+ if min_distance <= threshold:
+ keep_labels[label] = True
+ else:
+ raise ValueError("mode must be 'edge' or 'centroid'")
+
+ return keep_labels[labels]
diff --git a/supervision/detection/utils/vlms.py b/supervision/detection/utils/vlms.py
new file mode 100644
index 0000000000..8c4aca74aa
--- /dev/null
+++ b/supervision/detection/utils/vlms.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+
+def edit_distance(string_1: str, string_2: str, case_sensitive: bool = True) -> int:
+ """
+ Calculates the minimum number of single-character edits required
+ to transform one string into another. Allowed operations are insertion,
+ deletion, and substitution.
+
+ Args:
+ string_1 (str): The source string to be transformed.
+ string_2 (str): The target string to transform into.
+ case_sensitive (bool, optional): Whether comparison should be case-sensitive.
+ Defaults to True.
+
+ Returns:
+ int: The minimum number of edits required to convert `string_1`
+ into `string_2`.
+
+ Examples:
+ ```python
+ import supervision as sv
+
+ sv.edit_distance("hello", "hello")
+ # 0
+
+ sv.edit_distance("Test", "test", case_sensitive=True)
+ # 1
+
+ sv.edit_distance("abc", "xyz")
+ # 3
+
+ sv.edit_distance("hello", "")
+ # 5
+
+ sv.edit_distance("", "")
+ # 0
+
+ sv.edit_distance("hello world", "helloworld")
+ # 1
+ ```
+ """
+ if not case_sensitive:
+ string_1 = string_1.lower()
+ string_2 = string_2.lower()
+
+ if len(string_1) < len(string_2):
+ string_1, string_2 = string_2, string_1
+
+ prev_row = list(range(len(string_2) + 1))
+ curr_row = [0] * (len(string_2) + 1)
+
+ for i in range(1, len(string_1) + 1):
+ curr_row[0] = i
+ for j in range(1, len(string_2) + 1):
+ if string_1[i - 1] == string_2[j - 1]:
+ substitution_cost = 0
+ else:
+ substitution_cost = 1
+ curr_row[j] = min(
+ prev_row[j] + 1,
+ curr_row[j - 1] + 1,
+ prev_row[j - 1] + substitution_cost,
+ )
+ prev_row, curr_row = curr_row, prev_row
+
+ return prev_row[len(string_2)]
+
+
+def fuzzy_match_index(
+ candidates: list[str],
+ query: str,
+ threshold: int,
+ case_sensitive: bool = True,
+) -> int | None:
+ """
+ Searches for the first string in `candidates` whose edit distance
+ to `query` is less than or equal to `threshold`.
+
+ Args:
+ candidates (list[str]): List of strings to search.
+ query (str): String to compare against the candidates.
+ threshold (int): Maximum allowed edit distance for a match.
+ case_sensitive (bool, optional): Whether matching should be case-sensitive.
+
+ Returns:
+ Optional[int]: Index of the first matching string in candidates,
+ or None if no match is found.
+
+ Examples:
+ ```python
+ fuzzy_match_index(["cat", "dog", "rat"], "dat", threshold=1)
+ # 0
+
+ fuzzy_match_index(["alpha", "beta", "gamma"], "bata", threshold=1)
+ # 1
+
+ fuzzy_match_index(["one", "two", "three"], "ten", threshold=2)
+ # None
+ ```
+ """
+ for idx, candidate in enumerate(candidates):
+ if edit_distance(candidate, query, case_sensitive=case_sensitive) <= threshold:
+ return idx
+ return None
diff --git a/supervision/detection/vlm.py b/supervision/detection/vlm.py
index 1a5ad231e6..97988c9f09 100644
--- a/supervision/detection/vlm.py
+++ b/supervision/detection/vlm.py
@@ -1,5 +1,6 @@
from __future__ import annotations
+import ast
import base64
import io
import json
@@ -27,7 +28,8 @@ class LMM(Enum):
Attributes:
PALIGEMMA: Google's PaliGemma vision-language model.
FLORENCE_2: Microsoft's Florence-2 vision-language model.
- QWEN_2_5_VL: Qwen2.5-VL open vision-language model from Alibaba.
+ QWEN_2_5_VL: Qwen2.5-VL open vision-language model from Alibaba.\
+ QWEN_3_VL: Qwen3-VL open vision-language model from Alibaba.
GOOGLE_GEMINI_2_0: Google Gemini 2.0 vision-language model.
GOOGLE_GEMINI_2_5: Google Gemini 2.5 vision-language model.
MOONDREAM: The Moondream vision-language model.
@@ -36,6 +38,8 @@ class LMM(Enum):
PALIGEMMA = "paligemma"
FLORENCE_2 = "florence_2"
QWEN_2_5_VL = "qwen_2_5_vl"
+ QWEN_3_VL = "qwen_3_vl"
+ DEEPSEEK_VL_2 = "deepseek_vl_2"
GOOGLE_GEMINI_2_0 = "gemini_2_0"
GOOGLE_GEMINI_2_5 = "gemini_2_5"
MOONDREAM = "moondream"
@@ -68,6 +72,7 @@ class VLM(Enum):
PALIGEMMA: Google's PaliGemma vision-language model.
FLORENCE_2: Microsoft's Florence-2 vision-language model.
QWEN_2_5_VL: Qwen2.5-VL open vision-language model from Alibaba.
+ QWEN_3_VL: Qwen3-VL open vision-language model from Alibaba.
GOOGLE_GEMINI_2_0: Google Gemini 2.0 vision-language model.
GOOGLE_GEMINI_2_5: Google Gemini 2.5 vision-language model.
MOONDREAM: The Moondream vision-language model.
@@ -76,6 +81,8 @@ class VLM(Enum):
PALIGEMMA = "paligemma"
FLORENCE_2 = "florence_2"
QWEN_2_5_VL = "qwen_2_5_vl"
+ QWEN_3_VL = "qwen_3_vl"
+ DEEPSEEK_VL_2 = "deepseek_vl_2"
GOOGLE_GEMINI_2_0 = "gemini_2_0"
GOOGLE_GEMINI_2_5 = "gemini_2_5"
MOONDREAM = "moondream"
@@ -104,6 +111,8 @@ def from_value(cls, value: VLM | str) -> VLM:
VLM.PALIGEMMA: str,
VLM.FLORENCE_2: dict,
VLM.QWEN_2_5_VL: str,
+ VLM.QWEN_3_VL: str,
+ VLM.DEEPSEEK_VL_2: str,
VLM.GOOGLE_GEMINI_2_0: str,
VLM.GOOGLE_GEMINI_2_5: str,
VLM.MOONDREAM: dict,
@@ -113,6 +122,8 @@ def from_value(cls, value: VLM | str) -> VLM:
VLM.PALIGEMMA: ["resolution_wh"],
VLM.FLORENCE_2: ["resolution_wh"],
VLM.QWEN_2_5_VL: ["input_wh", "resolution_wh"],
+ VLM.QWEN_3_VL: ["resolution_wh"],
+ VLM.DEEPSEEK_VL_2: ["resolution_wh"],
VLM.GOOGLE_GEMINI_2_0: ["resolution_wh"],
VLM.GOOGLE_GEMINI_2_5: ["resolution_wh"],
VLM.MOONDREAM: ["resolution_wh"],
@@ -122,6 +133,8 @@ def from_value(cls, value: VLM | str) -> VLM:
VLM.PALIGEMMA: ["resolution_wh", "classes"],
VLM.FLORENCE_2: ["resolution_wh"],
VLM.QWEN_2_5_VL: ["input_wh", "resolution_wh", "classes"],
+ VLM.QWEN_3_VL: ["resolution_wh", "classes"],
+ VLM.DEEPSEEK_VL_2: ["resolution_wh", "classes"],
VLM.GOOGLE_GEMINI_2_0: ["resolution_wh", "classes"],
VLM.GOOGLE_GEMINI_2_5: ["resolution_wh", "classes"],
VLM.MOONDREAM: ["resolution_wh"],
@@ -230,6 +243,51 @@ def from_paligemma(
return xyxy, class_id, class_name
+def recover_truncated_qwen_2_5_vl_response(text: str) -> Any | None:
+ """
+ Attempt to recover and parse a truncated or malformed JSON snippet from Qwen-2.5-VL
+ output.
+
+ This utility extracts a JSON-like portion from a string that may be truncated or
+ malformed, cleans trailing commas, and attempts to parse it into a Python object.
+
+ Args:
+ text (str): Raw text containing the JSON snippet possibly truncated or
+ incomplete.
+
+ Returns:
+ Parsed Python object (usually list) if recovery and parsing succeed;
+ otherwise `None`.
+ """
+ try:
+ first_bracket = text.find("[")
+ if first_bracket == -1:
+ return None
+ snippet = text[first_bracket:]
+
+ last_brace = snippet.rfind("}")
+ if last_brace == -1:
+ return None
+
+ snippet = snippet[: last_brace + 1]
+
+ prefix_end = snippet.find("[")
+ if prefix_end == -1:
+ return None
+
+ prefix = snippet[: prefix_end + 1]
+ body = snippet[prefix_end + 1 :].rstrip()
+
+ if body.endswith(","):
+ body = body[:-1].rstrip()
+
+ repaired = prefix + body + "]"
+
+ return json.loads(repaired)
+ except Exception:
+ return None
+
+
def from_qwen_2_5_vl(
result: str,
input_wh: tuple[int, int],
@@ -237,7 +295,7 @@ def from_qwen_2_5_vl(
classes: list[str] | None = None,
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]:
"""
- Parse and scale bounding boxes from Qwen-2.5-VL style JSON output.
+ Parse and rescale bounding boxes and class labels from Qwen-2.5-VL JSON output.
The JSON is expected to be enclosed in triple backticks with the format:
```json
@@ -248,38 +306,52 @@ def from_qwen_2_5_vl(
```
Args:
- result: String containing the JSON snippet enclosed by triple backticks.
- input_wh: (input_width, input_height) describing the original bounding box
- scale.
- resolution_wh: (output_width, output_height) to which we rescale the boxes.
- classes: Optional list of valid class names. If provided, returned boxes/labels
- are filtered to only those classes found here.
+ result (str): String containing Qwen-2.5-VL JSON bounding box and label data.
+ input_wh (tuple[int, int]): Width and height of the coordinate space where boxes
+ are normalized.
+ resolution_wh (tuple[int, int]): Target width and height to scale bounding
+ boxes.
+ classes (list[str] or None): Optional list of valid class names to filter
+ results. If provided, only boxes with labels in this list are returned.
Returns:
- xyxy (np.ndarray): An array of shape `(n, 4)` containing
- the bounding boxes coordinates in format `[x1, y1, x2, y2]`
- class_id (Optional[np.ndarray]): An array of shape `(n,)` containing
- the class indices for each bounding box (or None if `classes` is not
- provided)
- class_name (np.ndarray): An array of shape `(n,)` containing
- the class labels for each bounding box
+ xyxy (np.ndarray): Array of shape `(N, 4)` with rescaled bounding boxes in
+ `(x_min, y_min, x_max, y_max)` format.
+ class_id (np.ndarray or None): Array of shape `(N,)` with indices of classes,
+ or `None` if no filtering applied.
+ class_name (np.ndarray): Array of shape `(N,)` with class names as strings.
"""
in_w, in_h = validate_resolution(input_wh)
out_w, out_h = validate_resolution(resolution_wh)
- pattern = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL)
-
- match = pattern.search(result)
- if not match:
- return np.empty((0, 4)), None, np.empty((0,), dtype=str)
+ text = result.strip()
+ text = re.sub(r"^```(json)?", "", text, flags=re.IGNORECASE).strip()
+ text = re.sub(r"```$", "", text).strip()
- json_snippet = match.group(1)
+ start = text.find("[")
+ end = text.rfind("]")
+ if start != -1 and end != -1 and end > start:
+ text = text[start : end + 1].strip()
try:
- data = json.loads(json_snippet)
+ data = json.loads(text)
except json.JSONDecodeError:
- return np.empty((0, 4)), None, np.empty((0,), dtype=str)
+ repaired = recover_truncated_qwen_2_5_vl_response(text)
+ if repaired is not None:
+ data = repaired
+ else:
+ try:
+ data = ast.literal_eval(text)
+ except (ValueError, SyntaxError, TypeError):
+ return (
+ np.empty((0, 4)),
+ np.empty((0,), dtype=int),
+ np.empty((0,), dtype=str),
+ )
+
+ if not isinstance(data, list):
+ return (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty((0,), dtype=str))
boxes_list = []
labels_list = []
@@ -291,7 +363,7 @@ def from_qwen_2_5_vl(
labels_list.append(item["label"])
if not boxes_list:
- return np.empty((0, 4)), None, np.empty((0,), dtype=str)
+ return (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty((0,), dtype=str))
xyxy = np.array(boxes_list, dtype=float)
class_name = np.array(labels_list, dtype=str)
@@ -310,6 +382,109 @@ def from_qwen_2_5_vl(
return xyxy, class_id, class_name
+def from_qwen_3_vl(
+ result: str,
+ resolution_wh: tuple[int, int],
+ classes: list[str] | None = None,
+) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]:
+ """
+ Parse and scale bounding boxes from Qwen-3-VL style JSON output.
+
+ Args:
+ result (str): String containing the Qwen-3-VL JSON output.
+ resolution_wh (tuple[int, int]): Target resolution `(width, height)` to
+ scale bounding boxes.
+ classes (list[str] or None): Optional list of valid classes to filter
+ results.
+
+ Returns:
+ xyxy (np.ndarray): Array of bounding boxes with shape `(N, 4)` in
+ `(x_min, y_min, x_max, y_max)` format scaled to `resolution_wh`.
+ class_id (np.ndarray or None): Array of class indices for each box, or
+ None if no filtering by classes.
+ class_name (np.ndarray): Array of class names as strings.
+ """
+ return from_qwen_2_5_vl(
+ result=result,
+ input_wh=(1000, 1000),
+ resolution_wh=resolution_wh,
+ classes=classes,
+ )
+
+
+def from_deepseek_vl_2(
+ result: str, resolution_wh: tuple[int, int], classes: list[str] | None = None
+) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]:
+ """
+ Parse bounding boxes from deepseek-vl2-formatted text, scale them to the specified
+ resolution, and optionally filter by classes.
+
+ The DeepSeek-VL2 output typically contains pairs of <|ref|> ... <|/ref|> labels
+ and <|det|> ... <|/det|> bounding box definitions. Each <|det|> section may
+ contain one or more bounding boxes in the form [[x1, y1, x2, y2], [x1, y1, x2, y2], ...]
+ (scaled to 0..999). For example:
+
+ ```
+ <|ref|>The giraffe at the back<|/ref|><|det|>[[580, 270, 999, 904]]<|/det|><|ref|>The giraffe at the front<|/ref|><|det|>[[26, 31, 632, 998]]<|/det|><|endβofβsentence|>
+ ```
+
+ Args:
+ result: String containing deepseek-vl2-formatted locations and labels.
+ resolution_wh: Tuple (width, height) to which we scale the box coordinates.
+ classes: Optional list of valid class names. If provided, boxes and labels not
+ in this list are filtered out.
+
+ Returns:
+ xyxy (np.ndarray): An array of shape `(n, 4)` containing
+ the bounding boxes coordinates in format `[x1, y1, x2, y2]`.
+ class_id (Optional[np.ndarray]): An array of shape `(n,)` containing
+ the class indices for each bounding box (or `None` if classes is not
+ provided).
+ class_name (np.ndarray): An array of shape `(n,)` containing
+ the class labels for each bounding box.
+ """ # noqa: E501
+
+ width, height = resolution_wh
+ label_segments = re.findall(r"<\|ref\|>(.*?)<\|/ref\|>", result, flags=re.S)
+ detection_segments = re.findall(r"<\|det\|>(.*?)<\|/det\|>", result, flags=re.S)
+
+ if len(label_segments) != len(detection_segments):
+ raise ValueError(
+ f"Number of ref tags ({len(label_segments)}) "
+ f"and det tags ({len(detection_segments)}) in the result must be equal."
+ )
+
+ xyxy, class_name_list = [], []
+ for label, detection_blob in zip(label_segments, detection_segments):
+ current_class_name = label.strip()
+ for box in re.findall(r"\[(.*?)\]", detection_blob):
+ x1, y1, x2, y2 = map(float, box.strip("[]").split(","))
+ xyxy.append(
+ [
+ (x1 / 999 * width),
+ (y1 / 999 * height),
+ (x2 / 999 * width),
+ (y2 / 999 * height),
+ ]
+ )
+ class_name_list.append(current_class_name)
+
+ xyxy = np.array(xyxy, dtype=np.float32)
+ class_name = np.array(class_name_list)
+
+ if classes is not None:
+ mask = np.array([name in classes for name in class_name], dtype=bool)
+ xyxy = xyxy[mask]
+ class_name = class_name[mask]
+ class_id = np.array([classes.index(name) for name in class_name])
+ else:
+ unique_classes = sorted(list(set(class_name)))
+ class_to_id = {name: i for i, name in enumerate(unique_classes)}
+ class_id = np.array([class_to_id[name] for name in class_name])
+
+ return xyxy, class_id, class_name
+
+
def from_florence_2(
result: dict, resolution_wh: tuple[int, int]
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]:
@@ -460,7 +635,7 @@ def from_google_gemini_2_0(
return np.empty((0, 4)), None, np.empty((0,), dtype=str)
labels = []
- boxes_list = []
+ xyxy = []
for item in data:
if "box_2d" not in item or "label" not in item:
@@ -468,18 +643,16 @@ def from_google_gemini_2_0(
labels.append(item["label"])
box = item["box_2d"]
# Gemini bbox order is [y_min, x_min, y_max, x_max]
- boxes_list.append(
- denormalize_boxes(
- np.array([box[1], box[0], box[3], box[2]]).astype(np.float64),
- resolution_wh=(w, h),
- normalization_factor=1000,
- )
- )
+ xyxy.append([box[1], box[0], box[3], box[2]])
- if not boxes_list:
+ if len(xyxy) == 0:
return np.empty((0, 4)), None, np.empty((0,), dtype=str)
- xyxy = np.array(boxes_list)
+ xyxy = denormalize_boxes(
+ np.array(xyxy, dtype=np.float64),
+ resolution_wh=(w, h),
+ normalization_factor=1000,
+ )
class_name = np.array(labels)
class_id = None
@@ -571,10 +744,10 @@ def from_google_gemini_2_5(
box = item["box_2d"]
# Gemini bbox order is [y_min, x_min, y_max, x_max]
absolute_bbox = denormalize_boxes(
- np.array([box[1], box[0], box[3], box[2]]).astype(np.float64),
+ np.array([[box[1], box[0], box[3], box[2]]]).astype(np.float64),
resolution_wh=(w, h),
normalization_factor=1000,
- )
+ )[0]
boxes_list.append(absolute_bbox)
if "mask" in item:
@@ -657,7 +830,7 @@ def from_google_gemini_2_5(
def from_moondream(
result: dict,
resolution_wh: tuple[int, int],
-) -> tuple[np.ndarray]:
+) -> np.ndarray:
"""
Parse and scale bounding boxes from moondream JSON output.
@@ -695,7 +868,7 @@ def from_moondream(
if "objects" not in result or not isinstance(result["objects"], list):
return np.empty((0, 4), dtype=float)
- denormalize_xyxy = []
+ xyxy = []
for item in result["objects"]:
if not all(k in item for k in ["x_min", "y_min", "x_max", "y_max"]):
@@ -706,14 +879,12 @@ def from_moondream(
x_max = item["x_max"]
y_max = item["y_max"]
- denormalize_xyxy.append(
- denormalize_boxes(
- np.array([x_min, y_min, x_max, y_max]).astype(np.float64),
- resolution_wh=(w, h),
- )
- )
+ xyxy.append([x_min, y_min, x_max, y_max])
- if not denormalize_xyxy:
+ if len(xyxy) == 0:
return np.empty((0, 4))
- return np.array(denormalize_xyxy, dtype=float)
+ return denormalize_boxes(
+ np.array(xyxy).astype(np.float64),
+ resolution_wh=(w, h),
+ )
diff --git a/supervision/draw/base.py b/supervision/draw/base.py
new file mode 100644
index 0000000000..e27c1d3c6b
--- /dev/null
+++ b/supervision/draw/base.py
@@ -0,0 +1,13 @@
+from typing import TypeVar
+
+import numpy as np
+from PIL import Image
+
+ImageType = TypeVar("ImageType", np.ndarray, Image.Image)
+"""
+An image of type `np.ndarray` or `PIL.Image.Image`.
+
+Unlike a `Union`, ensures the type remains consistent. If a function
+takes an `ImageType` argument and returns an `ImageType`, when you
+pass an `np.ndarray`, you will get an `np.ndarray` back.
+"""
diff --git a/supervision/draw/utils.py b/supervision/draw/utils.py
index ed4a903746..0d9ffe1241 100644
--- a/supervision/draw/utils.py
+++ b/supervision/draw/utils.py
@@ -346,28 +346,50 @@ def draw_image(
def calculate_optimal_text_scale(resolution_wh: tuple[int, int]) -> float:
"""
- Calculate font scale based on the resolution of an image.
+ Calculate optimal font scale based on image resolution. Adjusts font scale
+ proportionally to the smallest dimension of the given image resolution for
+ consistent readability.
- Parameters:
- resolution_wh (Tuple[int, int]): A tuple representing the width and height
- of the image.
+ Args:
+ resolution_wh (tuple[int, int]): (width, height) of the image in pixels
Returns:
- float: The calculated font scale factor.
+ float: recommended font scale factor
+
+ Examples:
+ ```python
+ import supervision as sv
+
+ sv.calculate_optimal_text_scale((1920, 1080))
+ # 1.08
+ sv.calculate_optimal_text_scale((640, 480))
+ # 0.48
+ ```
"""
return min(resolution_wh) * 1e-3
def calculate_optimal_line_thickness(resolution_wh: tuple[int, int]) -> int:
"""
- Calculate line thickness based on the resolution of an image.
+ Calculate optimal line thickness based on image resolution. Adjusts the line
+ thickness for readability depending on the smallest dimension of the provided
+ image resolution.
- Parameters:
- resolution_wh (Tuple[int, int]): A tuple representing the width and height
- of the image.
+ Args:
+ resolution_wh (tuple[int, int]): (width, height) of the image in pixels
Returns:
- int: The calculated line thickness in pixels.
+ int: recommended line thickness in pixels
+
+ Examples:
+ ```python
+ import supervision as sv
+
+ sv.calculate_optimal_line_thickness((1920, 1080))
+ # 4
+ sv.calculate_optimal_line_thickness((640, 480))
+ # 2
+ ```
"""
if min(resolution_wh) < 1080:
return 2
diff --git a/supervision/keypoint/__init__.py b/supervision/key_points/__init__.py
similarity index 100%
rename from supervision/keypoint/__init__.py
rename to supervision/key_points/__init__.py
diff --git a/supervision/keypoint/annotators.py b/supervision/key_points/annotators.py
similarity index 97%
rename from supervision/keypoint/annotators.py
rename to supervision/key_points/annotators.py
index d83d7d5e0f..c3f9e984b1 100644
--- a/supervision/keypoint/annotators.py
+++ b/supervision/key_points/annotators.py
@@ -6,14 +6,14 @@
import cv2
import numpy as np
-from supervision.annotators.base import ImageType
from supervision.detection.utils.boxes import pad_boxes, spread_out_boxes
+from supervision.draw.base import ImageType
from supervision.draw.color import Color
from supervision.draw.utils import draw_rounded_rectangle
from supervision.geometry.core import Rect
-from supervision.keypoint.core import KeyPoints
-from supervision.keypoint.skeletons import SKELETONS_BY_VERTEX_COUNT
-from supervision.utils.conversion import ensure_cv2_image_for_annotation
+from supervision.key_points.core import KeyPoints
+from supervision.key_points.skeletons import SKELETONS_BY_VERTEX_COUNT
+from supervision.utils.conversion import ensure_cv2_image_for_class_method
class BaseKeyPointAnnotator(ABC):
@@ -43,7 +43,7 @@ def __init__(
self.color = color
self.radius = radius
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType:
"""
Annotates the given scene with skeleton vertices based on the provided key
@@ -120,7 +120,7 @@ def __init__(
self.thickness = thickness
self.edges = edges
- @ensure_cv2_image_for_annotation
+ @ensure_cv2_image_for_class_method
def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType:
"""
Annotates the given scene by drawing lines between specified key points to form
diff --git a/supervision/keypoint/core.py b/supervision/key_points/core.py
similarity index 81%
rename from supervision/keypoint/core.py
rename to supervision/key_points/core.py
index 0d57c56183..66341f9df5 100644
--- a/supervision/keypoint/core.py
+++ b/supervision/key_points/core.py
@@ -10,7 +10,7 @@
from supervision.config import CLASS_NAME_DATA_FIELD
from supervision.detection.core import Detections
from supervision.detection.utils.internal import get_data_item, is_data_equal
-from supervision.validators import validate_keypoints_fields
+from supervision.validators import validate_key_points_fields
@dataclass
@@ -23,7 +23,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
=== "Ultralytics"
- Use [`sv.KeyPoints.from_ultralytics`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.from_ultralytics)
+ Use [`sv.KeyPoints.from_ultralytics`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_ultralytics)
method, which accepts [YOLOv8-pose](https://docs.ultralytics.com/models/yolov8/), [YOLO11-pose](https://docs.ultralytics.com/models/yolo11/)
[pose](https://docs.ultralytics.com/tasks/pose/) result.
@@ -41,7 +41,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
=== "Inference"
- Use [`sv.KeyPoints.from_inference`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.from_inference)
+ Use [`sv.KeyPoints.from_inference`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_inference)
method, which accepts [Inference](https://inference.roboflow.com/) pose result.
```python
@@ -58,7 +58,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
=== "MediaPipe"
- Use [`sv.KeyPoints.from_mediapipe`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.from_mediapipe)
+ Use [`sv.KeyPoints.from_mediapipe`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_mediapipe)
method, which accepts [MediaPipe](https://github.com/google-ai-edge/mediapipe)
pose result.
@@ -89,10 +89,61 @@ class simplifies data manipulation and filtering, providing a uniform API for
pose_landmarker_result, (image_width, image_height))
```
+ === "Transformers"
+
+ Use [`sv.KeyPoints.from_transformers`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_transformers)
+ method, which accepts [ViTPose](https://huggingface.co/docs/transformers/en/model_doc/vitpose) result.
+
+ ```python
+ from PIL import Image
+ import requests
+ import supervision as sv
+ import torch
+ from transformers import (
+ AutoProcessor,
+ RTDetrForObjectDetection,
+ VitPoseForPoseEstimation,
+ )
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ image = Image.open()
+
+ DETECTION_MODEL_ID = "PekingU/rtdetr_r50vd_coco_o365"
+
+ detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL_ID, use_fast=True)
+ detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=DEVICE)
+
+ inputs = detection_processor(images=frame, return_tensors="pt").to(DEVICE)
+
+ with torch.no_grad():
+ outputs = detection_model(**inputs)
+
+ target_size = torch.tensor([(frame.height, frame.width)])
+ results = detection_processor.post_process_object_detection(
+ outputs, target_sizes=target_size, threshold=0.3)
+
+ detections = sv.Detections.from_transformers(results[0])
+ boxes = sv.xyxy_to_xywh(detections[detections.class_id == 0].xyxy)
+
+ POSE_ESTIMATION_MODEL_ID = "usyd-community/vitpose-base-simple"
+
+ pose_estimation_processor = AutoProcessor.from_pretrained(POSE_ESTIMATION_MODEL_ID)
+ pose_estimation_model = VitPoseForPoseEstimation.from_pretrained(
+ POSE_ESTIMATION_MODEL_ID, device_map=DEVICE)
+
+ inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(DEVICE)
+
+ with torch.no_grad():
+ outputs = pose_estimation_model(**inputs)
+
+ results = pose_estimation_processor.post_process_pose_estimation(outputs, boxes=[boxes])
+ key_point = sv.KeyPoints.from_transformers(results[0])
+ ```
+
Attributes:
xy (np.ndarray): An array of shape `(n, m, 2)` containing
`n` detected objects, each composed of `m` equally-sized
- sets of keypoints, where each point is `[x, y]`.
+ sets of key points, where each point is `[x, y]`.
class_id (Optional[np.ndarray]): An array of shape
`(n,)` containing the class ids of the detected objects.
confidence (Optional[np.ndarray]): An array of shape
@@ -109,7 +160,7 @@ class simplifies data manipulation and filtering, providing a uniform API for
data: dict[str, npt.NDArray[Any] | list] = field(default_factory=dict)
def __post_init__(self):
- validate_keypoints_fields(
+ validate_key_points_fields(
xy=self.xy,
confidence=self.confidence,
class_id=self.class_id,
@@ -514,13 +565,13 @@ def from_detectron2(cls, detectron2_results: Any) -> KeyPoints:
return cls.empty()
@classmethod
- def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
+ def from_transformers(cls, transformers_results: Any) -> KeyPoints:
"""
Create a `sv.KeyPoints` object from the
[Transformers](https://github.com/huggingface/transformers) inference result.
Args:
- transfomers_results (Any): The output of a
+ transformers_results (Any): The output of a
Transformers model containing instances with prediction data.
Returns:
@@ -545,9 +596,9 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
DETECTION_MODEL_ID = "PekingU/rtdetr_r50vd_coco_o365"
detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL_ID, use_fast=True)
- detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=DEVICE)
+ detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=device)
- inputs = detection_processor(images=frame, return_tensors="pt").to(DEVICE)
+ inputs = detection_processor(images=frame, return_tensors="pt").to(device)
with torch.no_grad():
outputs = detection_model(**inputs)
@@ -563,9 +614,9 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
pose_estimation_processor = AutoProcessor.from_pretrained(POSE_ESTIMATION_MODEL_ID)
pose_estimation_model = VitPoseForPoseEstimation.from_pretrained(
- POSE_ESTIMATION_MODEL_ID, device_map=DEVICE)
+ POSE_ESTIMATION_MODEL_ID, device_map=device)
- inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(DEVICE)
+ inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(device)
with torch.no_grad():
outputs = pose_estimation_model(**inputs)
@@ -576,8 +627,8 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
""" # noqa: E501 // docs
- if "keypoints" in transfomers_results[0]:
- if transfomers_results[0]["keypoints"].cpu().numpy().size == 0:
+ if "keypoints" in transformers_results[0]:
+ if transformers_results[0]["keypoints"].cpu().numpy().size == 0:
return cls.empty()
result_data = [
@@ -585,7 +636,7 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
result["keypoints"].cpu().numpy(),
result["scores"].cpu().numpy(),
)
- for result in transfomers_results
+ for result in transformers_results
]
xy, scores = zip(*result_data)
@@ -599,55 +650,68 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints:
return cls.empty()
def __getitem__(
- self, index: int | slice | list[int] | np.ndarray | str
- ) -> KeyPoints | list | np.ndarray | None:
- """
- Get a subset of the `sv.KeyPoints` object or access an item from its data field.
-
- When provided with an integer, slice, list of integers, or a numpy array, this
- method returns a new `sv.KeyPoints` object that represents a subset of the
- original `sv.KeyPoints`. When provided with a string, it accesses the
- corresponding item in the data dictionary.
-
- Args:
- index (Union[int, slice, List[int], np.ndarray, str]): The index, indices,
- or key to access a subset of the `sv.KeyPoints` or an item from the
- data.
-
- Returns:
- A subset of the `sv.KeyPoints` object or an item from the data field.
-
- Examples:
- ```python
- import supervision as sv
-
- key_points = sv.KeyPoints()
-
- # access the first keypoint using an integer index
- key_points[0]
-
- # access the first 10 keypoints using index slice
- key_points[0:10]
-
- # access selected keypoints using a list of indices
- key_points[[0, 2, 4]]
-
- # access keypoints with selected class_id
- key_points[key_points.class_id == 0]
-
- # access keypoints with confidence greater than 0.5
- key_points[key_points.confidence > 0.5]
- ```
- """
+ self, index: int | slice | list[int] | np.ndarray | tuple | str
+ ) -> KeyPoints | np.ndarray | list | None:
if isinstance(index, str):
return self.data.get(index)
- if isinstance(index, int):
- index = [index]
+
+ if not isinstance(index, tuple):
+ index = (index, slice(None))
+
+ i, j = index
+
+ if isinstance(i, int):
+ i = [i]
+
+ if isinstance(i, list) and all(isinstance(x, bool) for x in i):
+ i = np.array(i)
+ if isinstance(j, list) and all(isinstance(x, bool) for x in j):
+ j = np.array(j)
+
+ if isinstance(i, np.ndarray) and i.dtype == bool:
+ i = np.flatnonzero(i)
+ if isinstance(j, np.ndarray) and j.dtype == bool:
+ j = np.flatnonzero(j)
+
+ if (
+ isinstance(i, (list, np.ndarray))
+ and isinstance(j, (list, np.ndarray))
+ and not np.isscalar(i)
+ and not np.isscalar(j)
+ ):
+ i, j = np.ix_(i, j)
+
+ xy_selected = self.xy[i, j]
+
+ conf_selected = self.confidence[i, j] if self.confidence is not None else None
+
+ class_id_selected = self.class_id[i] if self.class_id is not None else None
+
+ data_selected = get_data_item(self.data, i)
+
+ if xy_selected.ndim == 1:
+ xy_selected = xy_selected.reshape(1, 1, 2)
+ if conf_selected is not None:
+ conf_selected = conf_selected.reshape(1, 1)
+ elif xy_selected.ndim == 2:
+ if np.isscalar(index[0]) or (
+ isinstance(index[0], np.ndarray) and index[0].ndim == 0
+ ):
+ xy_selected = xy_selected[np.newaxis, ...]
+ if conf_selected is not None:
+ conf_selected = conf_selected[np.newaxis, ...]
+ elif np.isscalar(index[1]) or (
+ isinstance(index[1], np.ndarray) and index[1].ndim == 0
+ ):
+ xy_selected = xy_selected[:, np.newaxis, :]
+ if conf_selected is not None:
+ conf_selected = conf_selected[:, np.newaxis]
+
return KeyPoints(
- xy=self.xy[index],
- confidence=self.confidence[index] if self.confidence is not None else None,
- class_id=self.class_id[index] if self.class_id is not None else None,
- data=get_data_item(self.data, index),
+ xy=xy_selected,
+ confidence=conf_selected,
+ class_id=class_id_selected,
+ data=data_selected,
)
def __setitem__(self, key: str, value: np.ndarray | list):
@@ -668,12 +732,12 @@ def __setitem__(self, key: str, value: np.ndarray | list):
model = YOLO('yolov8s.pt')
result = model(image)[0]
- keypoints = sv.KeyPoints.from_ultralytics(result)
+ key_points = sv.KeyPoints.from_ultralytics(result)
- keypoints['class_name'] = [
+ key_points['class_name'] = [
model.model.names[class_id]
for class_id
- in keypoints.class_id
+ in key_points.class_id
]
```
"""
@@ -688,7 +752,7 @@ def __setitem__(self, key: str, value: np.ndarray | list):
@classmethod
def empty(cls) -> KeyPoints:
"""
- Create an empty Keypoints object with no keypoints.
+ Create an empty KeyPoints object with no key points.
Returns:
An empty `sv.KeyPoints` object.
@@ -706,9 +770,9 @@ def is_empty(self) -> bool:
"""
Returns `True` if the `KeyPoints` object is considered empty.
"""
- empty_keypoints = KeyPoints.empty()
- empty_keypoints.data = self.data
- return self == empty_keypoints
+ empty_key_points = KeyPoints.empty()
+ empty_key_points.data = self.data
+ return self == empty_key_points
def as_detections(
self, selected_keypoint_indices: Iterable[int] | None = None
@@ -716,21 +780,21 @@ def as_detections(
"""
Convert a KeyPoints object to a Detections object. This
approximates the bounding box of the detected object by
- taking the bounding box that fits all keypoints.
+ taking the bounding box that fits all key points.
Arguments:
selected_keypoint_indices (Optional[Iterable[int]]): The
- indices of the keypoints to include in the bounding box
- calculation. This helps focus on a subset of keypoints,
- e.g. when some are occluded. Captures all keypoints by default.
+ indices of the key points to include in the bounding box
+ calculation. This helps focus on a subset of key points,
+ e.g. when some are occluded. Captures all key points by default.
Returns:
detections (Detections): The converted detections object.
Examples:
```python
- keypoints = sv.KeyPoints.from_inference(...)
- detections = keypoints.as_detections()
+ key_points = sv.KeyPoints.from_inference(...)
+ detections = key_points.as_detections()
```
"""
if self.is_empty():
diff --git a/supervision/keypoint/skeletons.py b/supervision/key_points/skeletons.py
similarity index 100%
rename from supervision/keypoint/skeletons.py
rename to supervision/key_points/skeletons.py
diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py
index 43a3116ac2..6967361702 100644
--- a/supervision/metrics/mean_average_precision.py
+++ b/supervision/metrics/mean_average_precision.py
@@ -102,12 +102,12 @@ def __str__(self) -> str:
f"maxDets=100 ] = {self.map50:.3f}\n"
f"Average Precision (AP) @[ IoU=0.75 | area= all | "
f"maxDets=100 ] = {self.map75:.3f}\n"
- f"Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] "
- f"= {self.small_objects.map50_95:.3f}\n"
- f"Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] \
- = {self.medium_objects.map50_95:.3f}\n"
- f"Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] \
- = {self.large_objects.map50_95:.3f}"
+ f"Average Precision (AP) @[ IoU=0.50:0.95 | area= small | "
+ f"maxDets=100 ] = {self.small_objects.map50_95:.3f}\n"
+ f"Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | "
+ f"maxDets=100 ] = {self.medium_objects.map50_95:.3f}\n"
+ f"Average Precision (AP) @[ IoU=0.50:0.95 | area= large | "
+ f"maxDets=100 ] = {self.large_objects.map50_95:.3f}"
)
def to_pandas(self) -> pd.DataFrame:
diff --git a/supervision/utils/conversion.py b/supervision/utils/conversion.py
index 79ec500300..b1c8f16bfd 100644
--- a/supervision/utils/conversion.py
+++ b/supervision/utils/conversion.py
@@ -4,10 +4,10 @@
import numpy as np
from PIL import Image
-from supervision.annotators.base import ImageType
+from supervision.draw.base import ImageType
-def ensure_cv2_image_for_annotation(annotate_func):
+def ensure_cv2_image_for_class_method(annotate_func):
"""
Decorates `BaseAnnotator.annotate` implementations, converts scene to
an image type used internally by the annotators, converts back when annotation
@@ -32,7 +32,7 @@ def wrapper(self, scene: ImageType, *args, **kwargs):
return wrapper
-def ensure_cv2_image_for_processing(image_processing_fun):
+def ensure_cv2_image_for_standalone_function(image_processing_fun):
"""
Decorates image processing functions that accept np.ndarray, converting `image` to
np.ndarray, converts back when processing is complete.
@@ -55,7 +55,7 @@ def wrapper(image: ImageType, *args, **kwargs):
return wrapper
-def ensure_pil_image_for_annotation(annotate_func):
+def ensure_pil_image_for_class_method(annotate_func):
"""
Decorates image processing functions that accept np.ndarray, converting `image` to
PIL image, converts back when processing is complete.
diff --git a/supervision/utils/image.py b/supervision/utils/image.py
index 6960986768..4d4348d69f 100644
--- a/supervision/utils/image.py
+++ b/supervision/utils/image.py
@@ -1,119 +1,107 @@
from __future__ import annotations
-import itertools
-import math
import os
import shutil
-from collections.abc import Callable
-from functools import partial
-from typing import Literal
import cv2
import numpy as np
import numpy.typing as npt
+from PIL import Image
from supervision.annotators.base import ImageType
from supervision.draw.color import Color, unify_to_bgr
-from supervision.draw.utils import calculate_optimal_text_scale, draw_text
-from supervision.geometry.core import Point
from supervision.utils.conversion import (
- cv2_to_pillow,
- ensure_cv2_image_for_processing,
- images_to_cv2,
+ ensure_cv2_image_for_standalone_function,
)
-from supervision.utils.iterables import create_batches, fill
+from supervision.utils.internal import deprecated
-RelativePosition = Literal["top", "bottom"]
-MAX_COLUMNS_FOR_SINGLE_ROW_GRID = 3
-
-
-@ensure_cv2_image_for_processing
def crop_image(
image: ImageType,
xyxy: npt.NDArray[int] | list[int] | tuple[int, int, int, int],
) -> ImageType:
"""
- Crops the given image based on the given bounding box.
+ Crop image based on bounding box coordinates.
Args:
- image (ImageType): The image to be cropped. `ImageType` is a flexible type,
- accepting either `numpy.ndarray` or `PIL.Image.Image`.
- xyxy (Union[np.ndarray, List[int], Tuple[int, int, int, int]]): A bounding box
- coordinates in the format `(x_min, y_min, x_max, y_max)`, accepted as either
- a `numpy.ndarray`, a `list`, or a `tuple`.
+ image (`numpy.ndarray` or `PIL.Image.Image`): The image to crop.
+ xyxy (`numpy.array`, `list[int]`, or `tuple[int, int, int, int]`):
+ Bounding box coordinates in `(x_min, y_min, x_max, y_max)` format.
Returns:
- (ImageType): The cropped image. The type is determined by the input type and
- may be either a `numpy.ndarray` or `PIL.Image.Image`.
-
- === "OpenCV"
+ (`numpy.ndarray` or `PIL.Image.Image`): Cropped image matching input
+ type.
+ Examples:
```python
import cv2
import supervision as sv
- image = cv2.imread()
+ image = cv2.imread("source.png")
image.shape
# (1080, 1920, 3)
- xyxy = [200, 400, 600, 800]
+ xyxy = (400, 400, 800, 800)
cropped_image = sv.crop_image(image=image, xyxy=xyxy)
cropped_image.shape
# (400, 400, 3)
```
- === "Pillow"
-
```python
from PIL import Image
import supervision as sv
- image = Image.open()
+ image = Image.open("source.png")
image.size
# (1920, 1080)
- xyxy = [200, 400, 600, 800]
+ xyxy = (400, 400, 800, 800)
cropped_image = sv.crop_image(image=image, xyxy=xyxy)
cropped_image.size
# (400, 400)
```
- { align=center width="800" }
+ { align=center width="1000" }
""" # noqa E501 // docs
-
if isinstance(xyxy, (list, tuple)):
xyxy = np.array(xyxy)
+
xyxy = np.round(xyxy).astype(int)
x_min, y_min, x_max, y_max = xyxy.flatten()
- return image[y_min:y_max, x_min:x_max]
+
+ if isinstance(image, np.ndarray):
+ return image[y_min:y_max, x_min:x_max]
+
+ if isinstance(image, Image.Image):
+ return image.crop((x_min, y_min, x_max, y_max))
+
+ raise TypeError(
+ f"`image` must be a numpy.ndarray or PIL.Image.Image. Received {type(image)}"
+ )
-@ensure_cv2_image_for_processing
+@ensure_cv2_image_for_standalone_function
def scale_image(image: ImageType, scale_factor: float) -> ImageType:
"""
- Scales the given image based on the given scale factor.
+ Scale image by given factor. Scale factor > 1.0 zooms in, < 1.0 zooms out.
Args:
- image (ImageType): The image to be scaled. `ImageType` is a flexible type,
- accepting either `numpy.ndarray` or `PIL.Image.Image`.
- scale_factor (float): The factor by which the image will be scaled. Scale
- factor > `1.0` zooms in, < `1.0` zooms out.
+ image (`numpy.ndarray` or `PIL.Image.Image`): The image to scale.
+ scale_factor (`float`): Factor by which to scale the image.
Returns:
- (ImageType): The scaled image. The type is determined by the input type and
- may be either a `numpy.ndarray` or `PIL.Image.Image`.
+ (`numpy.ndarray` or `PIL.Image.Image`): Scaled image matching input
+ type.
Raises:
- ValueError: If the scale factor is non-positive.
-
- === "OpenCV"
+ ValueError: If scale factor is non-positive.
+ Examples:
```python
import cv2
import supervision as sv
- image = cv2.imread()
+ image = cv2.imread("source.png")
image.shape
# (1080, 1920, 3)
@@ -122,13 +110,11 @@ def scale_image(image: ImageType, scale_factor: float) -> ImageType:
# (540, 960, 3)
```
- === "Pillow"
-
```python
from PIL import Image
import supervision as sv
- image = Image.open()
+ image = Image.open("source.png")
image.size
# (1920, 1080)
@@ -136,7 +122,9 @@ def scale_image(image: ImageType, scale_factor: float) -> ImageType:
scaled_image.size
# (960, 540)
```
- """
+
+ { align=center width="1000" }
+ """ # noqa E501 // docs
if scale_factor <= 0:
raise ValueError("Scale factor must be positive.")
@@ -146,35 +134,31 @@ def scale_image(image: ImageType, scale_factor: float) -> ImageType:
return cv2.resize(image, (width_new, height_new), interpolation=cv2.INTER_LINEAR)
-@ensure_cv2_image_for_processing
+@ensure_cv2_image_for_standalone_function
def resize_image(
image: ImageType,
resolution_wh: tuple[int, int],
keep_aspect_ratio: bool = False,
) -> ImageType:
"""
- Resizes the given image to a specified resolution. Can maintain the original aspect
- ratio or resize directly to the desired dimensions.
+ Resize image to specified resolution. Can optionally maintain aspect ratio.
Args:
- image (ImageType): The image to be resized. `ImageType` is a flexible type,
- accepting either `numpy.ndarray` or `PIL.Image.Image`.
- resolution_wh (Tuple[int, int]): The target resolution as
- `(width, height)`.
- keep_aspect_ratio (bool): Flag to maintain the image's original
- aspect ratio. Defaults to `False`.
+ image (`numpy.ndarray` or `PIL.Image.Image`): The image to resize.
+ resolution_wh (`tuple[int, int]`): Target resolution as `(width, height)`.
+ keep_aspect_ratio (`bool`): Flag to maintain original aspect ratio.
+ Defaults to `False`.
Returns:
- (ImageType): The resized image. The type is determined by the input type and
- may be either a `numpy.ndarray` or `PIL.Image.Image`.
-
- === "OpenCV"
+ (`numpy.ndarray` or `PIL.Image.Image`): Resized image matching input
+ type.
+ Examples:
```python
import cv2
import supervision as sv
- image = cv2.imread()
+ image = cv2.imread("source.png")
image.shape
# (1080, 1920, 3)
@@ -185,13 +169,11 @@ def resize_image(
# (562, 1000, 3)
```
- === "Pillow"
-
```python
from PIL import Image
import supervision as sv
- image = Image.open()
+ image = Image.open("source.png")
image.size
# (1920, 1080)
@@ -202,7 +184,7 @@ def resize_image(
# (1000, 562)
```
- { align=center width="800" }
+ { align=center width="1000" }
""" # noqa E501 // docs
if keep_aspect_ratio:
image_ratio = image.shape[1] / image.shape[0]
@@ -219,59 +201,58 @@ def resize_image(
return cv2.resize(image, (width_new, height_new), interpolation=cv2.INTER_LINEAR)
-@ensure_cv2_image_for_processing
+@ensure_cv2_image_for_standalone_function
def letterbox_image(
image: ImageType,
resolution_wh: tuple[int, int],
color: tuple[int, int, int] | Color = Color.BLACK,
) -> ImageType:
"""
- Resizes and pads an image to a specified resolution with a given color, maintaining
- the original aspect ratio.
+ Resize image and pad with color to achieve desired resolution while
+ maintaining aspect ratio.
Args:
- image (ImageType): The image to be resized. `ImageType` is a flexible type,
- accepting either `numpy.ndarray` or `PIL.Image.Image`.
- resolution_wh (Tuple[int, int]): The target resolution as
- `(width, height)`.
- color (Union[Tuple[int, int, int], Color]): The color to pad with. If tuple
- provided it should be in BGR format.
+ image (`numpy.ndarray` or `PIL.Image.Image`): The image to resize and pad.
+ resolution_wh (`tuple[int, int]`): Target resolution as `(width, height)`.
+ color (`tuple[int, int, int]` or `Color`): Padding color. If tuple, should
+ be in BGR format. Defaults to `Color.BLACK`.
Returns:
- (ImageType): The resized image. The type is determined by the input type and
- may be either a `numpy.ndarray` or `PIL.Image.Image`.
-
- === "OpenCV"
+ (`numpy.ndarray` or `PIL.Image.Image`): Letterboxed image matching input
+ type.
+ Examples:
```python
import cv2
import supervision as sv
- image = cv2.imread()
+ image = cv2.imread("source.png")
image.shape
# (1080, 1920, 3)
- letterboxed_image = sv.letterbox_image(image=image, resolution_wh=(1000, 1000))
+ letterboxed_image = sv.letterbox_image(
+ image=image, resolution_wh=(1000, 1000)
+ )
letterboxed_image.shape
# (1000, 1000, 3)
```
- === "Pillow"
-
```python
from PIL import Image
import supervision as sv
- image = Image.open()
+ image = Image.open("source.png")
image.size
# (1920, 1080)
- letterboxed_image = sv.letterbox_image(image=image, resolution_wh=(1000, 1000))
+ letterboxed_image = sv.letterbox_image(
+ image=image, resolution_wh=(1000, 1000)
+ )
letterboxed_image.size
# (1000, 1000)
```
- { align=center width="800" }
+ { align=center width="1000" }
""" # noqa E501 // docs
assert isinstance(image, np.ndarray)
color = unify_to_bgr(color=color)
@@ -302,37 +283,59 @@ def letterbox_image(
return image_with_borders
+@deprecated(
+ "`overlay_image` function is deprecated and will be removed in "
+ "`supervision-0.32.0`. Use `draw_image` instead."
+)
def overlay_image(
image: npt.NDArray[np.uint8],
overlay: npt.NDArray[np.uint8],
anchor: tuple[int, int],
) -> npt.NDArray[np.uint8]:
"""
- Places an image onto a scene at a given anchor point, handling cases where
- the image's position is partially or completely outside the scene's bounds.
+ Overlay image onto scene at specified anchor point. Handles cases where
+ overlay position is partially or completely outside scene bounds.
Args:
- image (np.ndarray): The background scene onto which the image is placed.
- overlay (np.ndarray): The image to be placed onto the scene.
- anchor (Tuple[int, int]): The `(x, y)` coordinates in the scene where the
- top-left corner of the image will be placed.
+ image (`numpy.array`): Background scene with shape `(height, width, 3)`.
+ overlay (`numpy.array`): Image to overlay with shape
+ `(height, width, 3)` or `(height, width, 4)`.
+ anchor (`tuple[int, int]`): Coordinates `(x, y)` where top-left corner
+ of overlay will be placed.
Returns:
- (np.ndarray): The result image with overlay.
+ (`numpy.array`): Scene with overlay applied, shape `(height, width, 3)`.
Examples:
- ```python
+ ```
import cv2
import numpy as np
import supervision as sv
- image = cv2.imread()
+ image = cv2.imread("source.png")
overlay = np.zeros((400, 400, 3), dtype=np.uint8)
- result_image = sv.overlay_image(image=image, overlay=overlay, anchor=(200, 400))
+ overlay[:] = (0, 255, 0) # Green overlay
+
+ result_image = sv.overlay_image(
+ image=image, overlay=overlay, anchor=(200, 400)
+ )
+ cv2.imwrite("target.png", result_image)
```
- { align=center width="800" }
- """ # noqa E501 // docs
+ ```
+ import cv2
+ import numpy as np
+ import supervision as sv
+
+ image = cv2.imread("source.png")
+ overlay = cv2.imread("overlay.png", cv2.IMREAD_UNCHANGED)
+
+ result_image = sv.overlay_image(
+ image=image, overlay=overlay, anchor=(100, 100)
+ )
+ cv2.imwrite("target.png", result_image)
+ ```
+ """
scene_height, scene_width = image.shape[:2]
image_height, image_width = overlay.shape[:2]
anchor_x, anchor_y = anchor
@@ -371,6 +374,158 @@ def overlay_image(
return image
+@ensure_cv2_image_for_standalone_function
+def tint_image(
+ image: ImageType,
+ color: Color = Color.BLACK,
+ opacity: float = 0.5,
+) -> ImageType:
+ """
+ Tint image with solid color overlay at specified opacity.
+
+ Args:
+ image (`numpy.ndarray` or `PIL.Image.Image`): The image to tint.
+ color (`Color`): Overlay tint color. Defaults to `Color.BLACK`.
+ opacity (`float`): Blend ratio between overlay and image (0.0-1.0).
+ Defaults to `0.5`.
+
+ Returns:
+ (`numpy.ndarray` or `PIL.Image.Image`): Tinted image matching input
+ type.
+
+ Raises:
+ ValueError: If opacity is outside range [0.0, 1.0].
+
+ Examples:
+ ```python
+ import cv2
+ import supervision as sv
+
+ image = cv2.imread("source.png")
+ tinted_image = sv.tint_image(
+ image=image, color=sv.Color.ROBOFLOW, opacity=0.5
+ )
+ cv2.imwrite("target.png", tinted_image)
+ ```
+
+ ```python
+ from PIL import Image
+ import supervision as sv
+
+ image = Image.open("source.png")
+ tinted_image = sv.tint_image(
+ image=image, color=sv.Color.ROBOFLOW, opacity=0.5
+ )
+ tinted_image.save("target.png")
+ ```
+
+ { align=center width="1000" }
+ """ # noqa E501 // docs
+ if not 0.0 <= opacity <= 1.0:
+ raise ValueError("opacity must be between 0.0 and 1.0")
+
+ overlay = np.full_like(image, fill_value=color.as_bgr(), dtype=image.dtype)
+ cv2.addWeighted(
+ src1=overlay, alpha=opacity, src2=image, beta=1 - opacity, gamma=0, dst=image
+ )
+ return image
+
+
+@ensure_cv2_image_for_standalone_function
+def grayscale_image(image: ImageType) -> ImageType:
+ """
+ Convert image to 3-channel grayscale. Luminance channel is broadcast to
+ all three channels for compatibility with color-based drawing helpers.
+
+ Args:
+ image (`numpy.ndarray` or `PIL.Image.Image`): The image to convert to
+ grayscale.
+
+ Returns:
+ (`numpy.ndarray` or `PIL.Image.Image`): 3-channel grayscale image
+ matching input type.
+
+ Examples:
+ ```python
+ import cv2
+ import supervision as sv
+
+ image = cv2.imread("source.png")
+ grayscale_image = sv.grayscale_image(image=image)
+ cv2.imwrite("target.png", grayscale_image)
+ ```
+
+ ```python
+ from PIL import Image
+ import supervision as sv
+
+ image = Image.open("source.png")
+ grayscale_image = sv.grayscale_image(image=image)
+ grayscale_image.save("target.png")
+ ```
+
+ { align=center width="1000" }
+ """ # noqa E501 // docs
+ grayscaled = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+ return cv2.cvtColor(grayscaled, cv2.COLOR_GRAY2BGR)
+
+
+def get_image_resolution_wh(image: ImageType) -> tuple[int, int]:
+ """
+ Get image width and height as a tuple `(width, height)` for various image formats.
+
+ Supports both `numpy.ndarray` images (with shape `(H, W, ...)`) and
+ `PIL.Image.Image` inputs.
+
+ Args:
+ image (`numpy.ndarray` or `PIL.Image.Image`): Input image.
+
+ Returns:
+ (`tuple[int, int]`): Image resolution as `(width, height)`.
+
+ Raises:
+ ValueError: If a `numpy.ndarray` image has fewer than 2 dimensions.
+ TypeError: If `image` is not a supported type (`numpy.ndarray` or
+ `PIL.Image.Image`).
+
+ Examples:
+ ```python
+ import cv2
+ import supervision as sv
+
+ image = cv2.imread("example.png")
+ sv.get_image_resolution_wh(image)
+ # (1920, 1080)
+ ```
+
+ ```python
+ from PIL import Image
+ import supervision as sv
+
+ image = Image.open("example.png")
+ sv.get_image_resolution_wh(image)
+ # (1920, 1080)
+ ```
+ """
+ if isinstance(image, np.ndarray):
+ if image.ndim < 2:
+ raise ValueError(
+ "NumPy image must have at least 2 dimensions (H, W, ...). "
+ f"Received shape: {image.shape}"
+ )
+ height, width = image.shape[:2]
+ return int(width), int(height)
+
+ if isinstance(image, Image.Image):
+ width, height = image.size
+ return int(width), int(height)
+
+ raise TypeError(
+ "`image` must be a numpy.ndarray or PIL.Image.Image. "
+ f"Received type: {type(image)}"
+ )
+
+
class ImageSink:
def __init__(
self,
@@ -379,27 +534,64 @@ def __init__(
image_name_pattern: str = "image_{:05d}.png",
):
"""
- Initialize a context manager for saving images.
+ Initialize context manager for saving images to directory.
Args:
- target_dir_path (str): The target directory where images will be saved.
- overwrite (bool): Whether to overwrite the existing directory.
- Defaults to False.
- image_name_pattern (str): The image file name pattern.
- Defaults to "image_{:05d}.png".
+ target_dir_path (`str`): Target directory path where images will be
+ saved.
+ overwrite (`bool`): Whether to overwrite existing directory.
+ Defaults to `False`.
+ image_name_pattern (`str`): File name pattern for saved images.
+ Defaults to `"image_{:05d}.png"`.
Examples:
```python
import supervision as sv
- frames_generator = sv.get_video_frames_generator(, stride=2)
+ frames_generator = sv.get_video_frames_generator(
+ "source.mp4", stride=2
+ )
- with sv.ImageSink(target_dir_path=) as sink:
+ with sv.ImageSink(target_dir_path="output_frames") as sink:
for image in frames_generator:
sink.save_image(image=image)
+
+ # Directory structure:
+ # output_frames/
+ # βββ image_00000.png
+ # βββ image_00001.png
+ # βββ image_00002.png
+ # βββ image_00003.png
```
- """ # noqa E501 // docs
+ ```python
+ import cv2
+ import supervision as sv
+
+ image = cv2.imread("source.png")
+ crop_boxes = [
+ ( 0, 0, 400, 400),
+ (400, 0, 800, 400),
+ ( 0, 400, 400, 800),
+ (400, 400, 800, 800)
+ ]
+
+ with sv.ImageSink(
+ target_dir_path="image_crops",
+ overwrite=True
+ ) as sink:
+ for i, xyxy in enumerate(crop_boxes):
+ crop = sv.crop_image(image=image, xyxy=xyxy)
+ sink.save_image(image=crop, image_name=f"crop_{i}.png")
+
+ # Directory structure:
+ # image_crops/
+ # βββ crop_0.png
+ # βββ crop_1.png
+ # βββ crop_2.png
+ # βββ crop_3.png
+ ```
+ """
self.target_dir_path = target_dir_path
self.overwrite = overwrite
self.image_name_pattern = image_name_pattern
@@ -417,14 +609,14 @@ def __enter__(self):
def save_image(self, image: np.ndarray, image_name: str | None = None):
"""
- Save a given image in the target directory.
+ Save image to target directory with optional custom filename.
Args:
- image (np.ndarray): The image to be saved. The image must be in BGR color
- format.
- image_name (Optional[str]): The name to use for the saved image.
- If not provided, a name will be
- generated using the `image_name_pattern`.
+ image (`numpy.array`): Image to save with shape `(height, width, 3)`
+ in BGR format.
+ image_name (`str` or `None`): Custom filename for saved image. If
+ `None`, generates name using `image_name_pattern`. Defaults to
+ `None`.
"""
if image_name is None:
image_name = self.image_name_pattern.format(self.image_count)
@@ -435,355 +627,3 @@ def save_image(self, image: np.ndarray, image_name: str | None = None):
def __exit__(self, exc_type, exc_value, exc_traceback):
pass
-
-
-def create_tiles(
- images: list[ImageType],
- grid_size: tuple[int | None, int | None] | None = None,
- single_tile_size: tuple[int, int] | None = None,
- tile_scaling: Literal["min", "max", "avg"] = "avg",
- tile_padding_color: tuple[int, int, int] | Color = Color.from_hex("#D9D9D9"),
- tile_margin: int = 10,
- tile_margin_color: tuple[int, int, int] | Color = Color.from_hex("#BFBEBD"),
- return_type: Literal["auto", "cv2", "pillow"] = "auto",
- titles: list[str | None] | None = None,
- titles_anchors: Point | list[Point | None] | None = None,
- titles_color: tuple[int, int, int] | Color = Color.from_hex("#262523"),
- titles_scale: float | None = None,
- titles_thickness: int = 1,
- titles_padding: int = 10,
- titles_text_font: int = cv2.FONT_HERSHEY_SIMPLEX,
- titles_background_color: tuple[int, int, int] | Color = Color.from_hex("#D9D9D9"),
- default_title_placement: RelativePosition = "top",
-) -> ImageType:
- """
- Creates tiles mosaic from input images, automating grid placement and
- converting images to common resolution maintaining aspect ratio. It is
- also possible to render text titles on tiles, using optional set of
- parameters specifying text drawing (see parameters description).
-
- Automated grid placement will try to maintain square shape of grid
- (with size being the nearest integer square root of #images), up to two exceptions:
- * if there are up to 3 images - images will be displayed in single row
- * if square-grid placement causes last row to be empty - number of rows is trimmed
- until last row has at least one image
-
- Args:
- images (List[ImageType]): Images to create tiles. Elements can be either
- np.ndarray or PIL.Image, common representation will be agreed by the
- function.
- grid_size (Optional[Tuple[Optional[int], Optional[int]]]): Expected grid
- size in format (n_rows, n_cols). If not given - automated grid placement
- will be applied. One may also provide only one out of two elements of the
- tuple - then grid will be created with either n_rows or n_cols fixed,
- leaving the other dimension to be adjusted by the number of images
- single_tile_size (Optional[Tuple[int, int]]): sizeof a single tile element
- provided in (width, height) format. If not given - size of tile will be
- automatically calculated based on `tile_scaling` parameter.
- tile_scaling (Literal["min", "max", "avg"]): If `single_tile_size` is not
- given - parameter will be used to calculate tile size - using
- min / max / avg size of image provided in `images` list.
- tile_padding_color (Union[Tuple[int, int, int], sv.Color]): Color to be used in
- images letterbox procedure (while standardising tiles sizes) as a padding.
- If tuple provided - should be BGR.
- tile_margin (int): size of margin between tiles (in pixels)
- tile_margin_color (Union[Tuple[int, int, int], sv.Color]): Color of tile margin.
- If tuple provided - should be BGR.
- return_type (Literal["auto", "cv2", "pillow"]): Parameter dictates the format of
- return image. One may choose specific type ("cv2" or "pillow") to enforce
- conversion. "auto" mode takes a majority vote between types of elements in
- `images` list - resolving draws in favour of OpenCV format. "auto" can be
- safely used when all input images are of the same type.
- titles (Optional[List[Optional[str]]]): Optional titles to be added to tiles.
- Elements of that list may be empty - then specific tile (in order presented
- in `images` parameter) will not be filled with title. It is possible to
- provide list of titles shorter than `images` - then remaining titles will
- be assumed empty.
- titles_anchors (Optional[Union[Point, List[Optional[Point]]]]): Parameter to
- specify anchor points for titles. It is possible to specify anchor either
- globally or for specific tiles (following order of `images`).
- If not given (either globally, or for specific element of the list),
- it will be calculated automatically based on `default_title_placement`.
- titles_color (Union[Tuple[int, int, int], Color]): Color of titles text.
- If tuple provided - should be BGR.
- titles_scale (Optional[float]): Scale of titles. If not provided - value will
- be calculated using `calculate_optimal_text_scale(...)`.
- titles_thickness (int): Thickness of titles text.
- titles_padding (int): Size of titles padding.
- titles_text_font (int): Font to be used to render titles. Must be integer
- constant representing OpenCV font.
- (See docs: https://docs.opencv.org/4.x/d6/d6e/group__imgproc__draw.html)
- titles_background_color (Union[Tuple[int, int, int], Color]): Color of title
- text padding.
- default_title_placement (Literal["top", "bottom"]): Parameter specifies title
- anchor placement in case if explicit anchor is not provided.
-
- Returns:
- ImageType: Image with all input images located in tails grid. The output type is
- determined by `return_type` parameter.
-
- Raises:
- ValueError: In case when input images list is empty, provided `grid_size` is too
- small to fit all images, `tile_scaling` mode is invalid.
- """
- if len(images) == 0:
- raise ValueError("Could not create image tiles from empty list of images.")
- if return_type == "auto":
- return_type = _negotiate_tiles_format(images=images)
- tile_padding_color = unify_to_bgr(color=tile_padding_color)
- tile_margin_color = unify_to_bgr(color=tile_margin_color)
- images = images_to_cv2(images=images)
- if single_tile_size is None:
- single_tile_size = _aggregate_images_shape(images=images, mode=tile_scaling)
- resized_images = [
- letterbox_image(
- image=i, resolution_wh=single_tile_size, color=tile_padding_color
- )
- for i in images
- ]
- grid_size = _establish_grid_size(images=images, grid_size=grid_size)
- if len(images) > grid_size[0] * grid_size[1]:
- raise ValueError(
- f"Could not place {len(images)} in grid with size: {grid_size}."
- )
- if titles is not None:
- titles = fill(sequence=titles, desired_size=len(images), content=None)
- titles_anchors = (
- [titles_anchors]
- if not issubclass(type(titles_anchors), list)
- else titles_anchors
- )
- titles_anchors = fill(
- sequence=titles_anchors, desired_size=len(images), content=None
- )
- titles_color = unify_to_bgr(color=titles_color)
- titles_background_color = unify_to_bgr(color=titles_background_color)
- tiles = _generate_tiles(
- images=resized_images,
- grid_size=grid_size,
- single_tile_size=single_tile_size,
- tile_padding_color=tile_padding_color,
- tile_margin=tile_margin,
- tile_margin_color=tile_margin_color,
- titles=titles,
- titles_anchors=titles_anchors,
- titles_color=titles_color,
- titles_scale=titles_scale,
- titles_thickness=titles_thickness,
- titles_padding=titles_padding,
- titles_text_font=titles_text_font,
- titles_background_color=titles_background_color,
- default_title_placement=default_title_placement,
- )
- if return_type == "pillow":
- tiles = cv2_to_pillow(image=tiles)
- return tiles
-
-
-def _negotiate_tiles_format(images: list[ImageType]) -> Literal["cv2", "pillow"]:
- number_of_np_arrays = sum(issubclass(type(i), np.ndarray) for i in images)
- if number_of_np_arrays >= (len(images) // 2):
- return "cv2"
- return "pillow"
-
-
-def _calculate_aggregated_images_shape(
- images: list[np.ndarray], aggregator: Callable[[list[int]], float]
-) -> tuple[int, int]:
- height = round(aggregator([i.shape[0] for i in images]))
- width = round(aggregator([i.shape[1] for i in images]))
- return width, height
-
-
-SHAPE_AGGREGATION_FUN = {
- "min": partial(_calculate_aggregated_images_shape, aggregator=np.min),
- "max": partial(_calculate_aggregated_images_shape, aggregator=np.max),
- "avg": partial(_calculate_aggregated_images_shape, aggregator=np.average),
-}
-
-
-def _aggregate_images_shape(
- images: list[np.ndarray], mode: Literal["min", "max", "avg"]
-) -> tuple[int, int]:
- if mode not in SHAPE_AGGREGATION_FUN:
- raise ValueError(
- f"Could not aggregate images shape - provided unknown mode: {mode}. "
- f"Supported modes: {list(SHAPE_AGGREGATION_FUN.keys())}."
- )
- return SHAPE_AGGREGATION_FUN[mode](images)
-
-
-def _establish_grid_size(
- images: list[np.ndarray], grid_size: tuple[int | None, int | None] | None
-) -> tuple[int, int]:
- if grid_size is None or all(e is None for e in grid_size):
- return _negotiate_grid_size(images=images)
- if grid_size[0] is None:
- return math.ceil(len(images) / grid_size[1]), grid_size[1]
- if grid_size[1] is None:
- return grid_size[0], math.ceil(len(images) / grid_size[0])
- return grid_size
-
-
-def _negotiate_grid_size(images: list[np.ndarray]) -> tuple[int, int]:
- if len(images) <= MAX_COLUMNS_FOR_SINGLE_ROW_GRID:
- return 1, len(images)
- nearest_sqrt = math.ceil(np.sqrt(len(images)))
- proposed_columns = nearest_sqrt
- proposed_rows = nearest_sqrt
- while proposed_columns * (proposed_rows - 1) >= len(images):
- proposed_rows -= 1
- return proposed_rows, proposed_columns
-
-
-def _generate_tiles(
- images: list[np.ndarray],
- grid_size: tuple[int, int],
- single_tile_size: tuple[int, int],
- tile_padding_color: tuple[int, int, int],
- tile_margin: int,
- tile_margin_color: tuple[int, int, int],
- titles: list[str | None] | None,
- titles_anchors: list[Point | None],
- titles_color: tuple[int, int, int],
- titles_scale: float | None,
- titles_thickness: int,
- titles_padding: int,
- titles_text_font: int,
- titles_background_color: tuple[int, int, int],
- default_title_placement: RelativePosition,
-) -> np.ndarray:
- images = _draw_texts(
- images=images,
- titles=titles,
- titles_anchors=titles_anchors,
- titles_color=titles_color,
- titles_scale=titles_scale,
- titles_thickness=titles_thickness,
- titles_padding=titles_padding,
- titles_text_font=titles_text_font,
- titles_background_color=titles_background_color,
- default_title_placement=default_title_placement,
- )
- rows, columns = grid_size
- tiles_elements = list(create_batches(sequence=images, batch_size=columns))
- while len(tiles_elements[-1]) < columns:
- tiles_elements[-1].append(
- _generate_color_image(shape=single_tile_size, color=tile_padding_color)
- )
- while len(tiles_elements) < rows:
- tiles_elements.append(
- [_generate_color_image(shape=single_tile_size, color=tile_padding_color)]
- * columns
- )
- return _merge_tiles_elements(
- tiles_elements=tiles_elements,
- grid_size=grid_size,
- single_tile_size=single_tile_size,
- tile_margin=tile_margin,
- tile_margin_color=tile_margin_color,
- )
-
-
-def _draw_texts(
- images: list[np.ndarray],
- titles: list[str | None] | None,
- titles_anchors: list[Point | None],
- titles_color: tuple[int, int, int],
- titles_scale: float | None,
- titles_thickness: int,
- titles_padding: int,
- titles_text_font: int,
- titles_background_color: tuple[int, int, int],
- default_title_placement: RelativePosition,
-) -> list[np.ndarray]:
- if titles is None:
- return images
- titles_anchors = _prepare_default_titles_anchors(
- images=images,
- titles_anchors=titles_anchors,
- default_title_placement=default_title_placement,
- )
- if titles_scale is None:
- image_height, image_width = images[0].shape[:2]
- titles_scale = calculate_optimal_text_scale(
- resolution_wh=(image_width, image_height)
- )
- result = []
- for image, text, anchor in zip(images, titles, titles_anchors):
- if text is None:
- result.append(image)
- continue
- processed_image = draw_text(
- scene=image,
- text=text,
- text_anchor=anchor,
- text_color=Color.from_bgr_tuple(titles_color),
- text_scale=titles_scale,
- text_thickness=titles_thickness,
- text_padding=titles_padding,
- text_font=titles_text_font,
- background_color=Color.from_bgr_tuple(titles_background_color),
- )
- result.append(processed_image)
- return result
-
-
-def _prepare_default_titles_anchors(
- images: list[np.ndarray],
- titles_anchors: list[Point | None],
- default_title_placement: RelativePosition,
-) -> list[Point]:
- result = []
- for image, anchor in zip(images, titles_anchors):
- if anchor is not None:
- result.append(anchor)
- continue
- image_height, image_width = image.shape[:2]
- if default_title_placement == "top":
- default_anchor = Point(x=image_width / 2, y=image_height * 0.1)
- else:
- default_anchor = Point(x=image_width / 2, y=image_height * 0.9)
- result.append(default_anchor)
- return result
-
-
-def _merge_tiles_elements(
- tiles_elements: list[list[np.ndarray]],
- grid_size: tuple[int, int],
- single_tile_size: tuple[int, int],
- tile_margin: int,
- tile_margin_color: tuple[int, int, int],
-) -> np.ndarray:
- vertical_padding = (
- np.ones((single_tile_size[1], tile_margin, 3)) * tile_margin_color
- )
- merged_rows = [
- np.concatenate(
- list(
- itertools.chain.from_iterable(
- zip(row, [vertical_padding] * grid_size[1])
- )
- )[:-1],
- axis=1,
- )
- for row in tiles_elements
- ]
- row_width = merged_rows[0].shape[1]
- horizontal_padding = (
- np.ones((tile_margin, row_width, 3), dtype=np.uint8) * tile_margin_color
- )
- rows_with_paddings = []
- for row in merged_rows:
- rows_with_paddings.append(row)
- rows_with_paddings.append(horizontal_padding)
- return np.concatenate(
- rows_with_paddings[:-1],
- axis=0,
- ).astype(np.uint8)
-
-
-def _generate_color_image(
- shape: tuple[int, int], color: tuple[int, int, int]
-) -> np.ndarray:
- return np.ones((*shape[::-1], 3), dtype=np.uint8) * color
diff --git a/supervision/utils/notebook.py b/supervision/utils/notebook.py
index 9262f12bc4..3af09ebbec 100644
--- a/supervision/utils/notebook.py
+++ b/supervision/utils/notebook.py
@@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
from PIL import Image
-from supervision.annotators.base import ImageType
+from supervision.draw.base import ImageType
from supervision.utils.conversion import pillow_to_cv2
diff --git a/supervision/utils/video.py b/supervision/utils/video.py
index 3b281b4e22..0ece0916da 100644
--- a/supervision/utils/video.py
+++ b/supervision/utils/video.py
@@ -1,9 +1,11 @@
from __future__ import annotations
+import threading
import time
from collections import deque
from collections.abc import Callable, Generator
from dataclasses import dataclass
+from queue import Queue
import cv2
import numpy as np
@@ -196,63 +198,126 @@ def process_video(
source_path: str,
target_path: str,
callback: Callable[[np.ndarray, int], np.ndarray],
+ *,
max_frames: int | None = None,
+ prefetch: int = 32,
+ writer_buffer: int = 32,
show_progress: bool = False,
progress_message: str = "Processing video",
) -> None:
"""
- Process a video file by applying a callback function on each frame
- and saving the result to a target video file.
+ Process video frames asynchronously using a threaded pipeline.
+
+ This function orchestrates a three-stage pipeline to optimize video processing
+ throughput:
+
+ 1. Reader thread: Continuously reads frames from the source video file and
+ enqueues them into a bounded queue (`frame_read_queue`). The queue size is
+ limited by the `prefetch` parameter to control memory usage.
+ 2. Main thread (Processor): Dequeues frames from `frame_read_queue`, applies the
+ user-defined `callback` function to process each frame, then enqueues the
+ processed frames into another bounded queue (`frame_write_queue`) for writing.
+ The processing happens in the main thread, simplifying use of stateful objects
+ without synchronization.
+ 3. Writer thread: Dequeues processed frames from `frame_write_queue` and writes
+ them sequentially to the output video file.
Args:
- source_path (str): The path to the source video file.
- target_path (str): The path to the target video file.
- callback (Callable[[np.ndarray, int], np.ndarray]): A function that takes in
- a numpy ndarray representation of a video frame and an
- int index of the frame and returns a processed numpy ndarray
- representation of the frame.
- max_frames (Optional[int]): The maximum number of frames to process.
- show_progress (bool): Whether to show a progress bar.
- progress_message (str): The message to display in the progress bar.
+ source_path (str): Path to the input video file.
+ target_path (str): Path where the processed video will be saved.
+ callback (Callable[[numpy.ndarray, int], numpy.ndarray]): Function called for
+ each frame, accepting the frame as a numpy array and its zero-based index,
+ returning the processed frame.
+ max_frames (int | None): Optional maximum number of frames to process.
+ If None, the entire video is processed (default).
+ prefetch (int): Maximum number of frames buffered by the reader thread.
+ Controls memory use; default is 32.
+ writer_buffer (int): Maximum number of frames buffered before writing.
+ Controls output buffer size; default is 32.
+ show_progress (bool): Whether to display a tqdm progress bar during processing.
+ Default is False.
+ progress_message (str): Description shown in the progress bar.
- Examples:
+ Returns:
+ None
+
+ Example:
```python
+ import cv2
import supervision as sv
+ from rfdetr import RFDETRMedium
- def callback(scene: np.ndarray, index: int) -> np.ndarray:
- ...
+ model = RFDETRMedium()
+
+ def callback(frame, frame_index):
+ return model.predict(frame)
process_video(
- source_path=,
- target_path=,
- callback=callback
+ source_path="source.mp4",
+ target_path="target.mp4",
+ callback=frame_callback,
)
```
"""
- source_video_info = VideoInfo.from_video_path(video_path=source_path)
- video_frames_generator = get_video_frames_generator(
- source_path=source_path, end=max_frames
+ video_info = VideoInfo.from_video_path(video_path=source_path)
+ total_frames = (
+ min(video_info.total_frames, max_frames)
+ if max_frames is not None
+ else video_info.total_frames
)
- with VideoSink(target_path=target_path, video_info=source_video_info) as sink:
- total_frames = (
- min(source_video_info.total_frames, max_frames)
- if max_frames is not None
- else source_video_info.total_frames
+
+ frame_read_queue: Queue[tuple[int, np.ndarray] | None] = Queue(maxsize=prefetch)
+ frame_write_queue: Queue[np.ndarray | None] = Queue(maxsize=writer_buffer)
+
+ def reader_thread() -> None:
+ frame_generator = get_video_frames_generator(
+ source_path=source_path,
+ end=max_frames,
)
- for index, frame in enumerate(
- tqdm(
- video_frames_generator,
- total=total_frames,
- disable=not show_progress,
- desc=progress_message,
- )
- ):
- result_frame = callback(frame, index)
- sink.write_frame(frame=result_frame)
- else:
- for index, frame in enumerate(video_frames_generator):
- result_frame = callback(frame, index)
- sink.write_frame(frame=result_frame)
+ for frame_index, frame in enumerate(frame_generator):
+ frame_read_queue.put((frame_index, frame))
+ frame_read_queue.put(None)
+
+ def writer_thread(video_sink: VideoSink) -> None:
+ while True:
+ frame = frame_write_queue.get()
+ if frame is None:
+ break
+ video_sink.write_frame(frame=frame)
+
+ reader_worker = threading.Thread(target=reader_thread, daemon=True)
+ with VideoSink(target_path=target_path, video_info=video_info) as video_sink:
+ writer_worker = threading.Thread(
+ target=writer_thread,
+ args=(video_sink,),
+ daemon=True,
+ )
+
+ reader_worker.start()
+ writer_worker.start()
+
+ progress_bar = tqdm(
+ total=total_frames,
+ disable=not show_progress,
+ desc=progress_message,
+ )
+
+ try:
+ while True:
+ read_item = frame_read_queue.get()
+ if read_item is None:
+ break
+
+ frame_index, frame = read_item
+ processed_frame = callback(frame, frame_index)
+
+ frame_write_queue.put(processed_frame)
+ progress_bar.update(1)
+ finally:
+ frame_write_queue.put(None)
+ reader_worker.join()
+ writer_worker.join()
+ progress_bar.close()
class FPSMonitor:
diff --git a/supervision/validators/__init__.py b/supervision/validators/__init__.py
index 97fedabdd9..f051d89d7a 100644
--- a/supervision/validators/__init__.py
+++ b/supervision/validators/__init__.py
@@ -53,7 +53,7 @@ def validate_confidence(confidence: Any, n: int) -> None:
)
-def validate_keypoint_confidence(confidence: Any, n: int, m: int) -> None:
+def validate_key_point_confidence(confidence: Any, n: int, m: int) -> None:
expected_shape = f"({n, m})"
actual_shape = str(getattr(confidence, "shape", None))
@@ -126,7 +126,7 @@ def validate_detections_fields(
validate_data(data, n)
-def validate_keypoints_fields(
+def validate_key_points_fields(
xy: Any,
class_id: Any,
confidence: Any,
@@ -136,7 +136,7 @@ def validate_keypoints_fields(
m = len(xy[0]) if len(xy) > 0 else 0
validate_xy(xy, n, m)
validate_class_id(class_id, n)
- validate_keypoint_confidence(confidence, n, m)
+ validate_key_point_confidence(confidence, n, m)
validate_data(data, n)
diff --git a/test/annotators/test_utils.py b/test/annotators/test_utils.py
index 2fdccec116..3ab0f9b902 100644
--- a/test/annotators/test_utils.py
+++ b/test/annotators/test_utils.py
@@ -5,7 +5,7 @@
import numpy as np
import pytest
-from supervision.annotators.utils import ColorLookup, resolve_color_idx
+from supervision.annotators.utils import ColorLookup, resolve_color_idx, wrap_text
from supervision.detection.core import Detections
from test.test_utils import mock_detections
@@ -97,7 +97,7 @@
def test_resolve_color_idx(
detections: Detections,
detection_idx: int,
- color_lookup: ColorLookup,
+ color_lookup: ColorLookup | np.ndarray,
expected_result: int | None,
exception: Exception,
) -> None:
@@ -108,3 +108,67 @@ def test_resolve_color_idx(
color_lookup=color_lookup,
)
assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "text, max_line_length, expected_result, exception",
+ [
+ (None, None, [""], DoesNotRaise()), # text is None
+ ("", None, [""], DoesNotRaise()), # empty string
+ (" \t ", 3, [""], DoesNotRaise()), # whitespace-only (spaces + tab)
+ (12345, None, ["12345"], DoesNotRaise()), # plain integer
+ (-6789, None, ["-6789"], DoesNotRaise()), # negative integer
+ (np.int64(1000), None, ["1000"], DoesNotRaise()), # NumPy int64
+ ([1, 2, 3], None, ["[1, 2, 3]"], DoesNotRaise()), # list to string
+ (
+ "When you play the game of thrones, you win or you die.\nFear cuts deeper than swords.\nA mind needs books as a sword needs a whetstone.", # noqa: E501
+ None,
+ [
+ "When you play the game of thrones, you win or you die.",
+ "Fear cuts deeper than swords.",
+ "A mind needs books as a sword needs a whetstone.",
+ ],
+ DoesNotRaise(),
+ ), # Game-of-Thrones quotes, multiline
+ ("\n", None, [""], DoesNotRaise()), # single newline
+ (
+ "valarmorghulisvalardoharis",
+ 6,
+ ["valarm", "orghul", "isvala", "rdohar", "is"],
+ DoesNotRaise(),
+ ), # long Valyrian phrase, wrapped
+ (
+ "Winter is coming\nFire and blood",
+ 10,
+ [
+ "Winter is",
+ "coming",
+ "Fire and",
+ "blood",
+ ],
+ DoesNotRaise(),
+ ), # mix of short/long with newline
+ (
+ "What is dead may never die",
+ 0,
+ None,
+ pytest.raises(ValueError),
+ ), # width 0 - invalid
+ (
+ "A Lannister always pays his debts",
+ -1,
+ None,
+ pytest.raises(ValueError),
+ ), # width -1 - invalid
+ (None, 10, [""], DoesNotRaise()), # text None, width set
+ ],
+)
+def test_wrap_text(
+ text: object,
+ max_line_length: int | None,
+ expected_result: list[str],
+ exception: Exception,
+) -> None:
+ with exception:
+ result = wrap_text(text=text, max_line_length=max_line_length)
+ assert result == expected_result
diff --git a/test/detection/test_vlm.py b/test/detection/test_vlm.py
index 1b4c2f1894..9a0195f780 100644
--- a/test/detection/test_vlm.py
+++ b/test/detection/test_vlm.py
@@ -6,7 +6,10 @@
import numpy as np
import pytest
+from supervision.config import CLASS_NAME_DATA_FIELD
+from supervision.detection.core import Detections
from supervision.detection.vlm import (
+ VLM,
from_florence_2,
from_google_gemini_2_0,
from_google_gemini_2_5,
@@ -317,6 +320,43 @@ def test_from_paligemma(
np.array(["dog"], dtype=str),
),
), # out-of-bounds box
+ (
+ does_not_raise(),
+ """[
+ {'bbox_2d': [10, 20, 110, 120], 'label': 'cat'}
+ ]""",
+ (640, 640),
+ (1280, 720),
+ None,
+ (
+ np.array([[20.0, 22.5, 220.0, 135.0]]),
+ None,
+ np.array(["cat"], dtype=str),
+ ),
+ ), # python-style list, single quotes, no fences
+ (
+ does_not_raise(),
+ """```json
+ [
+ {"bbox_2d": [0, 0, 64, 64], "label": "dog"},
+ {"bbox_2d": [10, 20, 110, 120], "label": "cat"},
+ {"bbox_2d": [30, 40, 130, 140], "label":
+ """,
+ (640, 640),
+ (640, 640),
+ None,
+ (
+ np.array(
+ [
+ [0.0, 0.0, 64.0, 64.0],
+ [10.0, 20.0, 110.0, 120.0],
+ ],
+ dtype=float,
+ ),
+ None,
+ np.array(["dog", "cat"], dtype=str),
+ ),
+ ), # truncated response, last object unfinished, previous ones recovered
(
pytest.raises(ValueError),
"""```json
@@ -327,8 +367,8 @@ def test_from_paligemma(
(0, 640),
(1280, 720),
None,
- None, # won't be compared because we expect an exception
- ), # zero input width -> ValueError
+ None, # invalid input_wh
+ ),
(
pytest.raises(ValueError),
"""```json
@@ -339,8 +379,8 @@ def test_from_paligemma(
(640, 640),
(1280, -100),
None,
- None,
- ), # negative resolution height -> ValueError
+ None, # invalid resolution_wh
+ ),
],
)
def test_from_qwen_2_5_vl(
@@ -1122,3 +1162,110 @@ def test_from_google_gemini_2_5(
assert masks is not None
assert masks.shape == expected_results[4].shape
assert np.array_equal(masks, expected_results[4])
+
+
+@pytest.mark.parametrize(
+ "exception, result, resolution_wh, classes, expected_detections",
+ [
+ (
+ pytest.raises(ValueError),
+ "",
+ (100, 100),
+ None,
+ None,
+ ), # empty text
+ (
+ pytest.raises(ValueError),
+ "random text",
+ (100, 100),
+ None,
+ None,
+ ), # random text
+ (
+ does_not_raise(),
+ "<|ref|>cat<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>",
+ (1000, 1000),
+ None,
+ Detections(
+ xyxy=np.array([[100.1, 200.2, 300.3, 400.4]]),
+ class_id=np.array([0]),
+ data={CLASS_NAME_DATA_FIELD: np.array(["cat"])},
+ ),
+ ), # single box, no classes
+ (
+ does_not_raise(),
+ "<|ref|>cat<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>",
+ (1000, 1000),
+ ["cat", "dog"],
+ Detections(
+ xyxy=np.array([[100.1, 200.2, 300.3, 400.4]]),
+ class_id=np.array([0]),
+ data={CLASS_NAME_DATA_FIELD: np.array(["cat"])},
+ ),
+ ), # single box, with classes
+ (
+ does_not_raise(),
+ "<|ref|>person<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>",
+ (1000, 1000),
+ ["cat", "dog"],
+ Detections.empty(),
+ ), # single box, wrong class
+ (
+ does_not_raise(),
+ (
+ "<|ref|>cat<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>"
+ "<|ref|>dog<|/ref|><|det|>[[500, 600, 700, 800]]<|/det|>"
+ ),
+ (1000, 1000),
+ ["cat"],
+ Detections(
+ xyxy=np.array([[100.1, 200.2, 300.3, 400.4]]),
+ class_id=np.array([0]),
+ data={CLASS_NAME_DATA_FIELD: np.array(["cat"])},
+ ),
+ ), # multiple boxes, one class correct
+ (
+ pytest.raises(ValueError),
+ "<|ref|>cat<|/ref|>",
+ (100, 100),
+ None,
+ None,
+ ), # only ref
+ (
+ pytest.raises(ValueError),
+ "<|det|>[[100, 200, 300, 400]]<|/det|>",
+ (100, 100),
+ None,
+ None,
+ ), # only det
+ ],
+)
+def test_from_deepseek_vl_2(
+ exception,
+ result: str,
+ resolution_wh: tuple[int, int],
+ classes: list[str] | None,
+ expected_detections: Detections,
+):
+ with exception:
+ detections = Detections.from_vlm(
+ vlm=VLM.DEEPSEEK_VL_2,
+ result=result,
+ resolution_wh=resolution_wh,
+ classes=classes,
+ )
+
+ if expected_detections is None:
+ return
+
+ assert len(detections) == len(expected_detections)
+
+ if len(detections) == 0:
+ return
+
+ assert np.allclose(detections.xyxy, expected_detections.xyxy, atol=1e-1)
+ assert np.array_equal(detections.class_id, expected_detections.class_id)
+ assert np.array_equal(
+ detections.data[CLASS_NAME_DATA_FIELD],
+ expected_detections.data[CLASS_NAME_DATA_FIELD],
+ )
diff --git a/test/detection/tools/test_inference_slicer.py b/test/detection/tools/test_inference_slicer.py
index 2185b77f20..7c313841f3 100644
--- a/test/detection/tools/test_inference_slicer.py
+++ b/test/detection/tools/test_inference_slicer.py
@@ -1,13 +1,10 @@
from __future__ import annotations
-from contextlib import ExitStack as DoesNotRaise
-
import numpy as np
import pytest
from supervision.detection.core import Detections
from supervision.detection.tools.inference_slicer import InferenceSlicer
-from supervision.detection.utils.iou_and_nms import OverlapFilter
@pytest.fixture
@@ -20,54 +17,10 @@ def callback(_: np.ndarray) -> Detections:
return callback
-@pytest.mark.parametrize(
- "slice_wh, overlap_ratio_wh, overlap_wh, expected_overlap, exception",
- [
- # Valid case: explicit overlap_wh in pixels
- ((128, 128), None, (26, 26), (26, 26), DoesNotRaise()),
- # Valid case: overlap_wh in pixels
- ((128, 128), None, (20, 20), (20, 20), DoesNotRaise()),
- # Invalid case: negative overlap_wh, should raise ValueError
- ((128, 128), None, (-10, 20), None, pytest.raises(ValueError)),
- # Invalid case: no overlaps defined
- ((128, 128), None, None, None, pytest.raises(ValueError)),
- # Valid case: overlap_wh = 50 pixels
- ((256, 256), None, (50, 50), (50, 50), DoesNotRaise()),
- # Valid case: overlap_wh = 60 pixels
- ((200, 200), None, (60, 60), (60, 60), DoesNotRaise()),
- # Valid case: small overlap_wh values
- ((100, 100), None, (0.1, 0.1), (0.1, 0.1), DoesNotRaise()),
- # Invalid case: negative overlap_wh values
- ((128, 128), None, (-10, -10), None, pytest.raises(ValueError)),
- # Invalid case: overlap_wh greater than slice size
- ((128, 128), None, (150, 150), (150, 150), DoesNotRaise()),
- # Valid case: zero overlap
- ((128, 128), None, (0, 0), (0, 0), DoesNotRaise()),
- ],
-)
-def test_inference_slicer_overlap(
- mock_callback,
- slice_wh: tuple[int, int],
- overlap_ratio_wh: tuple[float, float] | None,
- overlap_wh: tuple[int, int] | None,
- expected_overlap: tuple[int, int] | None,
- exception: Exception,
-) -> None:
- with exception:
- slicer = InferenceSlicer(
- callback=mock_callback,
- slice_wh=slice_wh,
- overlap_ratio_wh=overlap_ratio_wh,
- overlap_wh=overlap_wh,
- overlap_filter=OverlapFilter.NONE,
- )
- assert slicer.overlap_wh == expected_overlap
-
-
@pytest.mark.parametrize(
"resolution_wh, slice_wh, overlap_wh, expected_offsets",
[
- # Case 1: No overlap, exact slices fit within image dimensions
+ # Case 1: Square image, square slices, no overlap
(
(256, 256),
(128, 128),
@@ -81,7 +34,7 @@ def test_inference_slicer_overlap(
]
),
),
- # Case 2: Overlap of 64 pixels in both directions
+ # Case 2: Square image, square slices, non-zero overlap
(
(256, 256),
(128, 128),
@@ -91,96 +44,154 @@ def test_inference_slicer_overlap(
[0, 0, 128, 128],
[64, 0, 192, 128],
[128, 0, 256, 128],
- [192, 0, 256, 128],
[0, 64, 128, 192],
[64, 64, 192, 192],
[128, 64, 256, 192],
- [192, 64, 256, 192],
[0, 128, 128, 256],
[64, 128, 192, 256],
[128, 128, 256, 256],
- [192, 128, 256, 256],
- [0, 192, 128, 256],
- [64, 192, 192, 256],
- [128, 192, 256, 256],
- [192, 192, 256, 256],
]
),
),
- # Case 3: Image not perfectly divisible by slice size (no overlap)
+ # Case 3: Rectangle image (horizontal), square slices, no overlap
(
- (300, 300),
- (128, 128),
+ (192, 128),
+ (64, 64),
(0, 0),
np.array(
[
- [0, 0, 128, 128],
- [128, 0, 256, 128],
- [256, 0, 300, 128],
- [0, 128, 128, 256],
- [128, 128, 256, 256],
- [256, 128, 300, 256],
- [0, 256, 128, 300],
- [128, 256, 256, 300],
- [256, 256, 300, 300],
+ [0, 0, 64, 64],
+ [64, 0, 128, 64],
+ [128, 0, 192, 64],
+ [0, 64, 64, 128],
+ [64, 64, 128, 128],
+ [128, 64, 192, 128],
]
),
),
- # Case 4: Overlap of 32 pixels, image not perfectly divisible by slice size
+ # Case 4: Rectangle image (horizontal), square slices, non-zero overlap
(
- (300, 300),
- (128, 128),
+ (192, 128),
+ (64, 64),
(32, 32),
np.array(
[
- [0, 0, 128, 128],
- [96, 0, 224, 128],
- [192, 0, 300, 128],
- [288, 0, 300, 128],
- [0, 96, 128, 224],
- [96, 96, 224, 224],
- [192, 96, 300, 224],
- [288, 96, 300, 224],
- [0, 192, 128, 300],
- [96, 192, 224, 300],
- [192, 192, 300, 300],
- [288, 192, 300, 300],
- [0, 288, 128, 300],
- [96, 288, 224, 300],
- [192, 288, 300, 300],
- [288, 288, 300, 300],
+ [0, 0, 64, 64],
+ [32, 0, 96, 64],
+ [64, 0, 128, 64],
+ [96, 0, 160, 64],
+ [128, 0, 192, 64],
+ [0, 32, 64, 96],
+ [32, 32, 96, 96],
+ [64, 32, 128, 96],
+ [96, 32, 160, 96],
+ [128, 32, 192, 96],
+ [0, 64, 64, 128],
+ [32, 64, 96, 128],
+ [64, 64, 128, 128],
+ [96, 64, 160, 128],
+ [128, 64, 192, 128],
]
),
),
- # Case 5: Image smaller than slice size (no overlap)
+ # Case 5: Rectangle image (vertical), square slices, no overlap
(
- (100, 100),
- (128, 128),
+ (128, 192),
+ (64, 64),
(0, 0),
np.array(
[
- [0, 0, 100, 100],
+ [0, 0, 64, 64],
+ [64, 0, 128, 64],
+ [0, 64, 64, 128],
+ [64, 64, 128, 128],
+ [0, 128, 64, 192],
+ [64, 128, 128, 192],
+ ]
+ ),
+ ),
+ # Case 6: Rectangle image (vertical), square slices, non-zero overlap
+ (
+ (128, 192),
+ (64, 64),
+ (32, 32),
+ np.array(
+ [
+ [0, 0, 64, 64],
+ [32, 0, 96, 64],
+ [64, 0, 128, 64],
+ [0, 32, 64, 96],
+ [32, 32, 96, 96],
+ [64, 32, 128, 96],
+ [0, 64, 64, 128],
+ [32, 64, 96, 128],
+ [64, 64, 128, 128],
+ [0, 96, 64, 160],
+ [32, 96, 96, 160],
+ [64, 96, 128, 160],
+ [0, 128, 64, 192],
+ [32, 128, 96, 192],
+ [64, 128, 128, 192],
+ ]
+ ),
+ ),
+ # Case 7: Square image, rectangular slices (horizontal), no overlap
+ (
+ (160, 160),
+ (80, 40),
+ (0, 0),
+ np.array(
+ [
+ [0, 0, 80, 40],
+ [80, 0, 160, 40],
+ [0, 40, 80, 80],
+ [80, 40, 160, 80],
+ [0, 80, 80, 120],
+ [80, 80, 160, 120],
+ [0, 120, 80, 160],
+ [80, 120, 160, 160],
+ ]
+ ),
+ ),
+ # Case 8: Square image, rectangular slices (vertical), non-zero overlap
+ (
+ (160, 160),
+ (40, 80),
+ (10, 20),
+ np.array(
+ [
+ [0, 0, 40, 80],
+ [30, 0, 70, 80],
+ [60, 0, 100, 80],
+ [90, 0, 130, 80],
+ [120, 0, 160, 80],
+ [0, 60, 40, 140],
+ [30, 60, 70, 140],
+ [60, 60, 100, 140],
+ [90, 60, 130, 140],
+ [120, 60, 160, 140],
+ [0, 80, 40, 160],
+ [30, 80, 70, 160],
+ [60, 80, 100, 160],
+ [90, 80, 130, 160],
+ [120, 80, 160, 160],
]
),
),
- # Case 6: Overlap_wh is greater than the slice size
- ((256, 256), (128, 128), (150, 150), np.array([]).reshape(0, 4)),
],
)
def test_generate_offset(
resolution_wh: tuple[int, int],
slice_wh: tuple[int, int],
- overlap_wh: tuple[int, int] | None,
+ overlap_wh: tuple[int, int],
expected_offsets: np.ndarray,
) -> None:
offsets = InferenceSlicer._generate_offset(
resolution_wh=resolution_wh,
slice_wh=slice_wh,
- overlap_ratio_wh=None,
overlap_wh=overlap_wh,
)
- # Verify that the generated offsets match the expected offsets
assert np.array_equal(offsets, expected_offsets), (
f"Expected {expected_offsets}, got {offsets}"
)
diff --git a/test/detection/utils/test_boxes.py b/test/detection/utils/test_boxes.py
index 919989287a..66d0d999c8 100644
--- a/test/detection/utils/test_boxes.py
+++ b/test/detection/utils/test_boxes.py
@@ -5,7 +5,12 @@
import numpy as np
import pytest
-from supervision.detection.utils.boxes import clip_boxes, move_boxes, scale_boxes
+from supervision.detection.utils.boxes import (
+ clip_boxes,
+ denormalize_boxes,
+ move_boxes,
+ scale_boxes,
+)
@pytest.mark.parametrize(
@@ -142,3 +147,88 @@ def test_scale_boxes(
with exception:
result = scale_boxes(xyxy=xyxy, factor=factor)
assert np.array_equal(result, expected_result)
+
+
+@pytest.mark.parametrize(
+ "xyxy, resolution_wh, normalization_factor, expected_result, exception",
+ [
+ (
+ np.empty(shape=(0, 4)),
+ (1280, 720),
+ 1.0,
+ np.empty(shape=(0, 4)),
+ DoesNotRaise(),
+ ), # empty array
+ (
+ np.array([[0.1, 0.2, 0.5, 0.6]]),
+ (1280, 720),
+ 1.0,
+ np.array([[128.0, 144.0, 640.0, 432.0]]),
+ DoesNotRaise(),
+ ), # single box with default normalization
+ (
+ np.array([[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8]]),
+ (1280, 720),
+ 1.0,
+ np.array([[128.0, 144.0, 640.0, 432.0], [384.0, 288.0, 896.0, 576.0]]),
+ DoesNotRaise(),
+ ), # two boxes with default normalization
+ (
+ np.array(
+ [[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8], [0.2, 0.1, 0.6, 0.5]]
+ ),
+ (1280, 720),
+ 1.0,
+ np.array(
+ [
+ [128.0, 144.0, 640.0, 432.0],
+ [384.0, 288.0, 896.0, 576.0],
+ [256.0, 72.0, 768.0, 360.0],
+ ]
+ ),
+ DoesNotRaise(),
+ ), # three boxes - regression test for issue #1959
+ (
+ np.array([[10.0, 20.0, 50.0, 60.0]]),
+ (100, 200),
+ 100.0,
+ np.array([[10.0, 40.0, 50.0, 120.0]]),
+ DoesNotRaise(),
+ ), # single box with custom normalization factor
+ (
+ np.array([[10.0, 20.0, 50.0, 60.0], [30.0, 40.0, 70.0, 80.0]]),
+ (100, 200),
+ 100.0,
+ np.array([[10.0, 40.0, 50.0, 120.0], [30.0, 80.0, 70.0, 160.0]]),
+ DoesNotRaise(),
+ ), # two boxes with custom normalization factor
+ (
+ np.array([[0.0, 0.0, 1.0, 1.0]]),
+ (1920, 1080),
+ 1.0,
+ np.array([[0.0, 0.0, 1920.0, 1080.0]]),
+ DoesNotRaise(),
+ ), # full frame box
+ (
+ np.array([[0.5, 0.5, 0.5, 0.5]]),
+ (640, 480),
+ 1.0,
+ np.array([[320.0, 240.0, 320.0, 240.0]]),
+ DoesNotRaise(),
+ ), # zero-area box (point)
+ ],
+)
+def test_denormalize_boxes(
+ xyxy: np.ndarray,
+ resolution_wh: tuple[int, int],
+ normalization_factor: float,
+ expected_result: np.ndarray,
+ exception: Exception,
+) -> None:
+ with exception:
+ result = denormalize_boxes(
+ xyxy=xyxy,
+ resolution_wh=resolution_wh,
+ normalization_factor=normalization_factor,
+ )
+ assert np.allclose(result, expected_result)
diff --git a/test/detection/utils/test_converters.py b/test/detection/utils/test_converters.py
index e13b150042..52a3b52004 100644
--- a/test/detection/utils/test_converters.py
+++ b/test/detection/utils/test_converters.py
@@ -6,6 +6,7 @@
from supervision.detection.utils.converters import (
xcycwh_to_xyxy,
xywh_to_xyxy,
+ xyxy_to_mask,
xyxy_to_xcycarh,
xyxy_to_xywh,
)
@@ -129,3 +130,174 @@ def test_xyxy_to_xcycarh(xyxy: np.ndarray, expected_result: np.ndarray) -> None:
def test_xcycwh_to_xyxy(xcycwh: np.ndarray, expected_result: np.ndarray) -> None:
result = xcycwh_to_xyxy(xcycwh)
np.testing.assert_array_equal(result, expected_result)
+
+
+@pytest.mark.parametrize(
+ "boxes,resolution_wh,expected",
+ [
+ # 0) Empty input
+ (
+ np.array([], dtype=float).reshape(0, 4),
+ (5, 4),
+ np.array([], dtype=bool).reshape(0, 4, 5),
+ ),
+ # 1) Single pixel box
+ (
+ np.array([[2, 1, 2, 1]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, False, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 2) Horizontal line, inclusive bounds
+ (
+ np.array([[1, 2, 3, 2]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ [False, True, True, True, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 3) Vertical line, inclusive bounds
+ (
+ np.array([[3, 0, 3, 2]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [False, False, False, True, False],
+ [False, False, False, True, False],
+ [False, False, False, True, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 4) Proper rectangle fill
+ (
+ np.array([[1, 1, 3, 2]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 5) Negative coordinates clipped to [0, 0]
+ (
+ np.array([[-2, -1, 1, 1]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [True, True, False, False, False],
+ [True, True, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 6) Overflow coordinates clipped to width-1 and height-1
+ (
+ np.array([[3, 2, 10, 10]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, True, True],
+ [False, False, False, True, True],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 7) Invalid box where max < min after ints, mask stays empty
+ (
+ np.array([[3, 2, 1, 4]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 8) Fractional coordinates are floored by int conversion
+ # (0.2,0.2)-(2.8,1.9) -> (0,0)-(2,1)
+ (
+ np.array([[0.2, 0.2, 2.8, 1.9]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ [
+ [True, True, True, False, False],
+ [True, True, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ]
+ ],
+ dtype=bool,
+ ),
+ ),
+ # 9) Multiple boxes, separate masks
+ (
+ np.array([[0, 0, 1, 0], [2, 1, 4, 3]], dtype=float),
+ (5, 4),
+ np.array(
+ [
+ # Box 0: row 0, cols 0..1
+ [
+ [True, True, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ # Box 1: rows 1..3, cols 2..4
+ [
+ [False, False, False, False, False],
+ [False, False, True, True, True],
+ [False, False, True, True, True],
+ [False, False, True, True, True],
+ ],
+ ],
+ dtype=bool,
+ ),
+ ),
+ ],
+)
+def test_xyxy_to_mask(boxes: np.ndarray, resolution_wh, expected: np.ndarray) -> None:
+ result = xyxy_to_mask(boxes, resolution_wh)
+ assert result.dtype == np.bool_
+ assert result.shape == expected.shape
+ np.testing.assert_array_equal(result, expected)
diff --git a/test/detection/utils/test_iou_and_nms.py b/test/detection/utils/test_iou_and_nms.py
index 8039bf2425..ab7586483c 100644
--- a/test/detection/utils/test_iou_and_nms.py
+++ b/test/detection/utils/test_iou_and_nms.py
@@ -6,11 +6,15 @@
import pytest
from supervision.detection.utils.iou_and_nms import (
+ OverlapMetric,
_group_overlapping_boxes,
+ box_iou,
+ box_iou_batch,
box_non_max_suppression,
mask_non_max_merge,
mask_non_max_suppression,
)
+from test.test_utils import random_boxes
@pytest.mark.parametrize(
@@ -631,3 +635,497 @@ def test_mask_non_max_merge(
sorted_result = sorted([sorted(group) for group in result])
sorted_expected_result = sorted([sorted(group) for group in expected_result])
assert sorted_result == sorted_expected_result
+
+
+@pytest.mark.parametrize(
+ "box_true, box_detection, overlap_metric, expected_overlap, exception",
+ [
+ (
+ [100.0, 100.0, 200.0, 200.0],
+ [150.0, 150.0, 250.0, 250.0],
+ OverlapMetric.IOU,
+ 0.14285714285714285,
+ DoesNotRaise(),
+ ), # partial overlap, IOU
+ (
+ [100.0, 100.0, 200.0, 200.0],
+ [150.0, 150.0, 250.0, 250.0],
+ OverlapMetric.IOS,
+ 0.25,
+ DoesNotRaise(),
+ ), # partial overlap, IOS
+ (
+ np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32),
+ np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32),
+ OverlapMetric.IOU,
+ 1.0,
+ DoesNotRaise(),
+ ), # identical boxes, both boxes are arrays, IOU
+ (
+ np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32),
+ np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32),
+ OverlapMetric.IOS,
+ 1.0,
+ DoesNotRaise(),
+ ), # identical boxes, both boxes are arrays, IOS
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [0.0, 0.0, 10.0, 10.0],
+ "iou",
+ 1.0,
+ DoesNotRaise(),
+ ), # identical boxes, both boxes are arrays, IOU as lowercase string
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [0.0, 0.0, 10.0, 10.0],
+ "ios",
+ 1.0,
+ DoesNotRaise(),
+ ), # identical boxes, both boxes are arrays, IOS as lowercase string
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [0.0, 0.0, 10.0, 10.0],
+ "IOU",
+ 1.0,
+ DoesNotRaise(),
+ ), # identical boxes, both boxes are arrays, IOU as uppercase string
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [0.0, 0.0, 10.0, 10.0],
+ "IOU",
+ 1.0,
+ DoesNotRaise(),
+ ), # identical boxes, both boxes are arrays, IOS as uppercase string
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [20.0, 20.0, 30.0, 30.0],
+ OverlapMetric.IOU,
+ 0.0,
+ DoesNotRaise(),
+ ), # no overlap, IOU
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [20.0, 20.0, 30.0, 30.0],
+ OverlapMetric.IOS,
+ 0.0,
+ DoesNotRaise(),
+ ), # no overlap, IOS
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [10.0, 0.0, 20.0, 10.0],
+ OverlapMetric.IOU,
+ 0.0,
+ DoesNotRaise(),
+ ), # boxes touch at edge, zero intersection, IOU
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [10.0, 0.0, 20.0, 10.0],
+ OverlapMetric.IOS,
+ 0.0,
+ DoesNotRaise(),
+ ), # boxes touch at edge, zero intersection, IOU
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [2.0, 2.0, 8.0, 8.0],
+ OverlapMetric.IOU,
+ 0.36,
+ DoesNotRaise(),
+ ), # one box inside another, IOU
+ (
+ [0.0, 0.0, 10.0, 10.0],
+ [2.0, 2.0, 8.0, 8.0],
+ OverlapMetric.IOS,
+ 1.0,
+ DoesNotRaise(),
+ ), # one box inside another, IOS
+ (
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 10.0, 10.0],
+ OverlapMetric.IOU,
+ 0.0,
+ DoesNotRaise(),
+ ), # degenerate true box with zero area, IOU
+ (
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 10.0, 10.0],
+ OverlapMetric.IOS,
+ 0.0,
+ DoesNotRaise(),
+ ), # degenerate true box with zero area, IOS
+ (
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ OverlapMetric.IOU,
+ 0.0,
+ DoesNotRaise(),
+ ), # both boxes fully degenerate, IOU
+ (
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ OverlapMetric.IOS,
+ 0.0,
+ DoesNotRaise(),
+ ), # both boxes fully degenerate, IOS
+ (
+ [-5.0, 0.0, 5.0, 10.0],
+ [0.0, 0.0, 10.0, 10.0],
+ OverlapMetric.IOU,
+ 1.0 / 3.0,
+ DoesNotRaise(),
+ ), # negative x_min, overlapping boxes, IOU is 1/3
+ (
+ [-5.0, 0.0, 5.0, 10.0],
+ [0.0, 0.0, 10.0, 10.0],
+ OverlapMetric.IOS,
+ 0.5,
+ DoesNotRaise(),
+ ), # negative x_min, overlapping boxes, IOS is 0.5
+ (
+ [0.0, 0.0, 1.0, 1.0],
+ [0.5, 0.5, 1.5, 1.5],
+ OverlapMetric.IOU,
+ 0.14285714285714285,
+ DoesNotRaise(),
+ ), # partial overlap with fractional coordinates, IOU
+ (
+ [0.0, 0.0, 1.0, 1.0],
+ [0.5, 0.5, 1.5, 1.5],
+ OverlapMetric.IOS,
+ 0.25,
+ DoesNotRaise(),
+ ), # partial overlap with fractional coordinates, IOS
+ ],
+)
+def test_box_iou(
+ box_true: list[float] | np.ndarray,
+ box_detection: list[float] | np.ndarray,
+ overlap_metric: str | OverlapMetric,
+ expected_overlap: float,
+ exception: Exception,
+) -> None:
+ with exception:
+ result = box_iou(
+ box_true=box_true,
+ box_detection=box_detection,
+ overlap_metric=overlap_metric,
+ )
+ assert result == pytest.approx(expected_overlap, rel=1e-6, abs=1e-12)
+
+
+@pytest.mark.parametrize(
+ "boxes_true, boxes_detection, overlap_metric, expected_overlap, exception",
+ [
+ # both inputs empty
+ (
+ np.empty((0, 4), dtype=np.float32),
+ np.empty((0, 4), dtype=np.float32),
+ OverlapMetric.IOU,
+ np.empty((0, 0), dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # one true box, no detections
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.empty((0, 4), dtype=np.float32),
+ OverlapMetric.IOU,
+ np.empty((1, 0), dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # no true boxes, one detection
+ (
+ np.empty((0, 4), dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.empty((0, 1), dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 partial overlap, IOU
+ (
+ np.array([[100.0, 100.0, 200.0, 200.0]], dtype=np.float32),
+ np.array([[150.0, 150.0, 250.0, 250.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.14285715]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 partial overlap, IOS
+ (
+ np.array([[100.0, 100.0, 200.0, 200.0]], dtype=np.float32),
+ np.array([[150.0, 150.0, 250.0, 250.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.25]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 identical boxes, IOU as lowercase string
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ "iou",
+ np.array([[1.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 identical boxes, IOS as lowercase string
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ "ios",
+ np.array([[1.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 identical boxes, IOU as uppercase string
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ "IOU",
+ np.array([[1.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 identical boxes, IOS as uppercase string
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ "IOS",
+ np.array([[1.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 no overlap, IOU
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[20.0, 20.0, 30.0, 30.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 no overlap, IOS
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[20.0, 20.0, 30.0, 30.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 touching at edge, zero intersection, IOU
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[10.0, 0.0, 20.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 touching at edge, zero intersection, IOS
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[10.0, 0.0, 20.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 box inside another, IOU
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[2.0, 2.0, 8.0, 8.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.36]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 box inside another, IOS
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[2.0, 2.0, 8.0, 8.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[1.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 degenerate true box, IOU
+ (
+ np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 degenerate true box, IOS
+ (
+ np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 both boxes degenerate, IOU
+ (
+ np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 both boxes degenerate, IOS
+ (
+ np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 negative coordinate, partial overlap, IOU
+ (
+ np.array([[-5.0, 0.0, 5.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[1.0 / 3.0]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 negative coordinate, partial overlap, IOS
+ (
+ np.array([[-5.0, 0.0, 5.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.5]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 fractional coordinates, partial overlap, IOU
+ (
+ np.array([[0.0, 0.0, 1.0, 1.0]], dtype=np.float32),
+ np.array([[0.5, 0.5, 1.5, 1.5]], dtype=np.float32),
+ OverlapMetric.IOU,
+ np.array([[0.14285715]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # 1x1 fractional coordinates, partial overlap, IOS
+ (
+ np.array([[0.0, 0.0, 1.0, 1.0]], dtype=np.float32),
+ np.array([[0.5, 0.5, 1.5, 1.5]], dtype=np.float32),
+ OverlapMetric.IOS,
+ np.array([[0.25]], dtype=np.float32),
+ DoesNotRaise(),
+ ),
+ # true batch case, 2x2, IOU
+ (
+ np.array(
+ [
+ [0.0, 0.0, 10.0, 10.0],
+ [10.0, 10.0, 20.0, 20.0],
+ ],
+ dtype=np.float32,
+ ),
+ np.array(
+ [
+ [0.0, 0.0, 10.0, 10.0],
+ [5.0, 5.0, 15.0, 15.0],
+ ],
+ dtype=np.float32,
+ ),
+ OverlapMetric.IOU,
+ np.array(
+ [
+ [1.0, 0.14285715],
+ [0.0, 0.14285715],
+ ],
+ dtype=np.float32,
+ ),
+ DoesNotRaise(),
+ ),
+ # true batch case, 2x2, IOS
+ (
+ np.array(
+ [
+ [0.0, 0.0, 10.0, 10.0],
+ [10.0, 10.0, 20.0, 20.0],
+ ],
+ dtype=np.float32,
+ ),
+ np.array(
+ [
+ [0.0, 0.0, 10.0, 10.0],
+ [5.0, 5.0, 15.0, 15.0],
+ ],
+ dtype=np.float32,
+ ),
+ OverlapMetric.IOS,
+ np.array(
+ [
+ [1.0, 0.25],
+ [0.0, 0.25],
+ ],
+ dtype=np.float32,
+ ),
+ DoesNotRaise(),
+ ),
+ # invalid overlap_metric
+ (
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32),
+ "invalid",
+ None,
+ pytest.raises(ValueError),
+ ),
+ ],
+)
+def test_box_iou_batch(
+ boxes_true: np.ndarray,
+ boxes_detection: np.ndarray,
+ overlap_metric: str | OverlapMetric,
+ expected_overlap: np.ndarray | None,
+ exception: Exception,
+) -> None:
+ with exception:
+ result = box_iou_batch(
+ boxes_true=boxes_true,
+ boxes_detection=boxes_detection,
+ overlap_metric=overlap_metric,
+ )
+
+ assert isinstance(result, np.ndarray)
+ assert result.shape == expected_overlap.shape
+ assert np.allclose(
+ result,
+ expected_overlap,
+ rtol=1e-6,
+ atol=1e-12,
+ )
+
+
+@pytest.mark.parametrize(
+ "num_true, num_det",
+ [
+ (5, 5),
+ (5, 10),
+ (10, 5),
+ (10, 10),
+ (20, 30),
+ (30, 20),
+ (50, 50),
+ (100, 100),
+ ],
+)
+@pytest.mark.parametrize(
+ "overlap_metric",
+ [OverlapMetric.IOU, OverlapMetric.IOS],
+)
+def test_box_iou_batch_symmetric_large(
+ num_true: int,
+ num_det: int,
+ overlap_metric: OverlapMetric,
+) -> None:
+ boxes_true = random_boxes(num_true)
+ boxes_det = random_boxes(num_det)
+
+ result_ab = box_iou_batch(
+ boxes_true=boxes_true,
+ boxes_detection=boxes_det,
+ overlap_metric=overlap_metric,
+ )
+ result_ba = box_iou_batch(
+ boxes_true=boxes_det,
+ boxes_detection=boxes_true,
+ overlap_metric=overlap_metric,
+ )
+
+ assert result_ab.shape == (num_true, num_det)
+ assert result_ba.shape == (num_det, num_true)
+ assert np.allclose(
+ result_ab,
+ result_ba.T,
+ rtol=1e-6,
+ atol=1e-12,
+ )
diff --git a/test/detection/utils/test_masks.py b/test/detection/utils/test_masks.py
index 2097f6082c..b41f208edb 100644
--- a/test/detection/utils/test_masks.py
+++ b/test/detection/utils/test_masks.py
@@ -10,6 +10,7 @@
calculate_masks_centroids,
contains_holes,
contains_multiple_segments,
+ filter_segments_by_distance,
move_masks,
)
@@ -500,3 +501,228 @@ def test_contains_multiple_segments(
with exception:
result = contains_multiple_segments(mask=mask, connectivity=connectivity)
assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "mask, connectivity, mode, absolute_distance, relative_distance, expected_result, exception", # noqa: E501
+ [
+ # single component, unchanged
+ (
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0],
+ [0, 1, 1, 1, 0, 0],
+ [0, 1, 1, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ 8,
+ "edge",
+ 2.0,
+ None,
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0],
+ [0, 1, 1, 1, 0, 0],
+ [0, 1, 1, 1, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ DoesNotRaise(),
+ ),
+ # two components, edge distance 2, kept with abs=1
+ (
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 1],
+ [0, 1, 1, 1, 0, 1],
+ [0, 1, 1, 1, 0, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ 8,
+ "edge",
+ 2.0,
+ None,
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 1],
+ [0, 1, 1, 1, 0, 1],
+ [0, 1, 1, 1, 0, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ DoesNotRaise(),
+ ),
+ # centroid mode, far centroids, dropped with small relative threshold
+ (
+ np.array(
+ [
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ ],
+ dtype=bool,
+ ),
+ 8,
+ "centroid",
+ None,
+ 0.3, # diagonal ~8.49, threshold ~2.55, centroid gap ~4.24
+ np.array(
+ [
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ DoesNotRaise(),
+ ),
+ # centroid mode, larger relative threshold, kept
+ (
+ np.array(
+ [
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ ],
+ dtype=bool,
+ ),
+ 8,
+ "centroid",
+ None,
+ 0.6, # diagonal ~8.49, threshold ~5.09, centroid gap ~4.24
+ np.array(
+ [
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ ],
+ dtype=bool,
+ ),
+ DoesNotRaise(),
+ ),
+ # empty mask
+ (
+ np.zeros((4, 4), dtype=bool),
+ 4,
+ "edge",
+ 2.0,
+ None,
+ np.zeros((4, 4), dtype=bool),
+ DoesNotRaise(),
+ ),
+ # full mask
+ (
+ np.ones((4, 4), dtype=bool),
+ 8,
+ "centroid",
+ None,
+ 0.2,
+ np.ones((4, 4), dtype=bool),
+ DoesNotRaise(),
+ ),
+ # two components, pixel distance = 2, kept with abs=2
+ (
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 1, 1, 1],
+ [0, 1, 1, 1, 0, 1, 1, 1],
+ [0, 1, 1, 1, 0, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ 8,
+ "edge",
+ 2.0, # was 1.0
+ None,
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 1, 1, 1],
+ [0, 1, 1, 1, 0, 1, 1, 1],
+ [0, 1, 1, 1, 0, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ DoesNotRaise(),
+ ),
+ # two components, pixel distance = 3, dropped with abs=2
+ (
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 0, 1, 1],
+ [0, 1, 1, 1, 0, 0, 0, 1, 1],
+ [0, 1, 1, 1, 0, 0, 0, 1, 1],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ 8,
+ "edge",
+ 2.0, # keep threshold below 3 so the right blob is removed
+ None,
+ np.array(
+ [
+ [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 0, 0, 0],
+ [0, 1, 1, 1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0, 0],
+ ],
+ dtype=bool,
+ ),
+ DoesNotRaise(),
+ ),
+ ],
+)
+def test_filter_segments_by_distance_sweep(
+ mask: npt.NDArray,
+ connectivity: int,
+ mode: str,
+ absolute_distance: float | None,
+ relative_distance: float | None,
+ expected_result: npt.NDArray | None,
+ exception: Exception,
+) -> None:
+ with exception:
+ result = filter_segments_by_distance(
+ mask=mask,
+ connectivity=connectivity,
+ mode=mode, # type: ignore[arg-type]
+ absolute_distance=absolute_distance,
+ relative_distance=relative_distance,
+ )
+ assert np.array_equal(result, expected_result)
diff --git a/test/detection/utils/test_vlms.py b/test/detection/utils/test_vlms.py
new file mode 100644
index 0000000000..a6fe649b73
--- /dev/null
+++ b/test/detection/utils/test_vlms.py
@@ -0,0 +1,117 @@
+import pytest
+
+from supervision.detection.utils.vlms import edit_distance, fuzzy_match_index
+
+
+@pytest.mark.parametrize(
+ "string_1, string_2, case_sensitive, expected_result",
+ [
+ # identical strings, various cases
+ ("hello", "hello", True, 0),
+ ("hello", "hello", False, 0),
+ # case sensitive vs insensitive
+ ("Test", "test", True, 1),
+ ("Test", "test", False, 0),
+ ("CASE", "case", True, 4),
+ ("CASE", "case", False, 0),
+ # completely different
+ ("abc", "xyz", True, 3),
+ ("abc", "xyz", False, 3),
+ # one string empty
+ ("hello", "", True, 5),
+ ("", "world", True, 5),
+ # single character cases
+ ("a", "b", True, 1),
+ ("A", "a", True, 1),
+ ("A", "a", False, 0),
+ # whitespaces
+ ("hello world", "helloworld", True, 1),
+ ("test", " test", True, 1),
+ # unicode and emoji
+ ("π", "π", True, 0),
+ ("π", "π’", True, 1),
+ # long string vs empty
+ ("a" * 100, "", True, 100),
+ ("", "b" * 100, True, 100),
+ # prefix/suffix
+ ("prefix", "prefixes", True, 2),
+ ("suffix", "asuffix", True, 1),
+ # leading/trailing whitespace
+ (" hello", "hello", True, 1),
+ ("hello", "hello ", True, 1),
+ # long almost-equal string
+ (
+ "The quick brown fox jumps over the lazy dog",
+ "The quick brown fox jumps over the lazy cog",
+ True,
+ 1,
+ ),
+ (
+ "The quick brown fox jumps over the lazy dog",
+ "The quick brown fox jumps over the lazy cog",
+ False,
+ 1,
+ ),
+ # both empty
+ ("", "", True, 0),
+ ("", "", False, 0),
+ # mixed case with symbols
+ ("123ABC!", "123abc!", True, 3),
+ ("123ABC!", "123abc!", False, 0),
+ ],
+)
+def test_edit_distance(string_1, string_2, case_sensitive, expected_result):
+ assert (
+ edit_distance(string_1, string_2, case_sensitive=case_sensitive)
+ == expected_result
+ )
+
+
+@pytest.mark.parametrize(
+ "candidates, query, threshold, case_sensitive, expected_result",
+ [
+ # exact match at index 0
+ (["cat", "dog", "rat"], "cat", 0, True, 0),
+ # match at index 2 within threshold
+ (["cat", "dog", "rat"], "dat", 1, True, 0),
+ # no match due to high threshold
+ (["cat", "dog", "rat"], "bat", 0, True, None),
+ # multiple possible matches, returns first
+ (["apple", "apply", "appla"], "apple", 1, True, 0),
+ # case-insensitive match
+ (["Alpha", "beta", "Gamma"], "alpha", 0, False, 0),
+ # case-sensitive: no match
+ (["Alpha", "beta", "Gamma"], "alpha", 0, True, None),
+ # threshold boundary
+ (["alpha", "beta", "gamma"], "bata", 1, True, 1),
+ # no match (all distances too high)
+ (["one", "two", "three"], "ten", 1, True, None),
+ # unicode/emoji match
+ (["π", "π’", "π"], "π", 1, True, 0),
+ (["π", "π’", "π"], "π", 0, True, 0),
+ # empty candidates
+ ([], "any", 2, True, None),
+ # empty query, non-empty candidates
+ (["", "abc"], "", 0, True, 0),
+ (["", "abc"], "", 1, True, 0),
+ (["a", "b", "c"], "", 1, True, 0),
+ # non-empty query, empty candidate
+ (["", ""], "a", 1, True, 0),
+ # all candidates require higher edit than threshold
+ (["short", "words", "only"], "longerword", 2, True, None),
+ # repeated candidates
+ (["a", "a", "a"], "b", 1, True, 0),
+ ],
+)
+def test_fuzzy_match_index(
+ candidates, query, threshold, case_sensitive, expected_result
+):
+ assert (
+ fuzzy_match_index(
+ candidates=candidates,
+ query=query,
+ threshold=threshold,
+ case_sensitive=case_sensitive,
+ )
+ == expected_result
+ )
diff --git a/test/key_points/__init__.py b/test/key_points/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/test/key_points/test_core.py b/test/key_points/test_core.py
new file mode 100644
index 0000000000..5a8244cd87
--- /dev/null
+++ b/test/key_points/test_core.py
@@ -0,0 +1,268 @@
+from contextlib import nullcontext as DoesNotRaise
+
+import numpy as np
+import pytest
+
+from supervision.key_points.core import KeyPoints
+from test.test_utils import mock_key_points
+
+KEY_POINTS = mock_key_points(
+ xy=[
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
+ [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]],
+ [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]],
+ ],
+ confidence=[
+ [0.8, 0.2, 0.6, 0.1, 0.5],
+ [0.7, 0.9, 0.3, 0.4, 0.0],
+ [0.1, 0.6, 0.8, 0.2, 0.7],
+ ],
+ class_id=[0, 1, 2],
+)
+
+
+@pytest.mark.parametrize(
+ "key_points, index, expected_result, exception",
+ [
+ (
+ KeyPoints.empty(),
+ slice(None),
+ KeyPoints.empty(),
+ DoesNotRaise(),
+ ), # slice all key points when key points object empty
+ (
+ KEY_POINTS,
+ slice(None),
+ KEY_POINTS,
+ DoesNotRaise(),
+ ), # slice all key points when key points object nonempty
+ (
+ KEY_POINTS,
+ slice(0, 1),
+ mock_key_points(
+ xy=[[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]],
+ confidence=[[0.8, 0.2, 0.6, 0.1, 0.5]],
+ class_id=[0],
+ ),
+ DoesNotRaise(),
+ ), # select the first skeleton by slice
+ (
+ KEY_POINTS,
+ slice(0, 2),
+ mock_key_points(
+ xy=[
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
+ [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]],
+ ],
+ confidence=[
+ [0.8, 0.2, 0.6, 0.1, 0.5],
+ [0.7, 0.9, 0.3, 0.4, 0.0],
+ ],
+ class_id=[0, 1],
+ ),
+ DoesNotRaise(),
+ ), # select the first skeleton by slice
+ (
+ KEY_POINTS,
+ 0,
+ mock_key_points(
+ xy=[[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]],
+ confidence=[[0.8, 0.2, 0.6, 0.1, 0.5]],
+ class_id=[0],
+ ),
+ DoesNotRaise(),
+ ), # select the first skeleton by index
+ (
+ KEY_POINTS,
+ -1,
+ mock_key_points(
+ xy=[[[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]]],
+ confidence=[[0.1, 0.6, 0.8, 0.2, 0.7]],
+ class_id=[2],
+ ),
+ DoesNotRaise(),
+ ), # select the last skeleton by index
+ (
+ KEY_POINTS,
+ [0, 1],
+ mock_key_points(
+ xy=[
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
+ [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]],
+ ],
+ confidence=[
+ [0.8, 0.2, 0.6, 0.1, 0.5],
+ [0.7, 0.9, 0.3, 0.4, 0.0],
+ ],
+ class_id=[0, 1],
+ ),
+ DoesNotRaise(),
+ ), # select the first two skeletons by index; list
+ (
+ KEY_POINTS,
+ np.array([0, 1]),
+ mock_key_points(
+ xy=[
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
+ [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]],
+ ],
+ confidence=[
+ [0.8, 0.2, 0.6, 0.1, 0.5],
+ [0.7, 0.9, 0.3, 0.4, 0.0],
+ ],
+ class_id=[0, 1],
+ ),
+ DoesNotRaise(),
+ ), # select the first two skeletons by index; np.array
+ (
+ KEY_POINTS,
+ [True, True, False],
+ mock_key_points(
+ xy=[
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
+ [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]],
+ ],
+ confidence=[
+ [0.8, 0.2, 0.6, 0.1, 0.5],
+ [0.7, 0.9, 0.3, 0.4, 0.0],
+ ],
+ class_id=[0, 1],
+ ),
+ DoesNotRaise(),
+ ), # select only skeletons associated with positive filter; list
+ (
+ KEY_POINTS,
+ np.array([True, True, False]),
+ mock_key_points(
+ xy=[
+ [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]],
+ [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]],
+ ],
+ confidence=[
+ [0.8, 0.2, 0.6, 0.1, 0.5],
+ [0.7, 0.9, 0.3, 0.4, 0.0],
+ ],
+ class_id=[0, 1],
+ ),
+ DoesNotRaise(),
+ ), # select only skeletons associated with positive filter; list
+ (
+ KEY_POINTS,
+ (slice(None), slice(None)),
+ KEY_POINTS,
+ DoesNotRaise(),
+ ), # slice all anchors from all skeletons
+ (
+ KEY_POINTS,
+ (slice(None), slice(0, 1)),
+ mock_key_points(
+ xy=[[[0, 1]], [[10, 11]], [[20, 21]]],
+ confidence=[[0.8], [0.7], [0.1]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # slice the first anchor from every skeleton
+ (
+ KEY_POINTS,
+ (slice(None), slice(0, 2)),
+ mock_key_points(
+ xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]],
+ confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # slice the first anchor two anchors from every skeleton
+ (
+ KEY_POINTS,
+ (slice(None), 0),
+ mock_key_points(
+ xy=[[[0, 1]], [[10, 11]], [[20, 21]]],
+ confidence=[[0.8], [0.7], [0.1]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # select the first anchor from every skeleton by index
+ (
+ KEY_POINTS,
+ (slice(None), -1),
+ mock_key_points(
+ xy=[[[8, 9]], [[18, 19]], [[28, 29]]],
+ confidence=[[0.5], [0.0], [0.7]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # select the last anchor from every skeleton by index
+ (
+ KEY_POINTS,
+ (slice(None), [0, 1]),
+ mock_key_points(
+ xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]],
+ confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # select the first two anchors from every skeleton by index; list
+ (
+ KEY_POINTS,
+ (slice(None), np.array([0, 1])),
+ mock_key_points(
+ xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]],
+ confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # select the first two anchors from every skeleton by index; np.array
+ (
+ KEY_POINTS,
+ (slice(None), [True, True, False, False, False]),
+ mock_key_points(
+ xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]],
+ confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # select only anchors associated with positive filter; list
+ (
+ KEY_POINTS,
+ (slice(None), np.array([True, True, False, False, False])),
+ mock_key_points(
+ xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]],
+ confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]],
+ class_id=[0, 1, 2],
+ ),
+ DoesNotRaise(),
+ ), # select only anchors associated with positive filter; np.array
+ (
+ KEY_POINTS,
+ (0, 0),
+ mock_key_points(
+ xy=[
+ [[0, 1]],
+ ],
+ confidence=[
+ [0.8],
+ ],
+ class_id=[0],
+ ),
+ DoesNotRaise(),
+ ), # select the first anchor from the first skeleton by index
+ (
+ KEY_POINTS,
+ (0, -1),
+ mock_key_points(
+ xy=[
+ [[8, 9]],
+ ],
+ confidence=[
+ [0.5],
+ ],
+ class_id=[0],
+ ),
+ DoesNotRaise(),
+ ), # select the last anchor from the first skeleton by index
+ ],
+)
+def test_key_points_getitem(key_points, index, expected_result, exception):
+ with exception:
+ result = key_points[index]
+ assert result == expected_result
diff --git a/test/test_utils.py b/test/test_utils.py
index 19fffad5bb..961f8cee1b 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -1,16 +1,16 @@
from __future__ import annotations
+import random
from typing import Any
import numpy as np
-import numpy.typing as npt
from supervision.detection.core import Detections
-from supervision.keypoint.core import KeyPoints
+from supervision.key_points.core import KeyPoints
def mock_detections(
- xyxy: npt.NDArray[np.float32],
+ xyxy: list[list[float]],
mask: list[np.ndarray] | None = None,
confidence: list[float] | None = None,
class_id: list[int] | None = None,
@@ -34,9 +34,9 @@ def convert_data(data: dict[str, list[Any]]):
)
-def mock_keypoints(
- xy: npt.NDArray[np.float32],
- confidence: list[float] | None = None,
+def mock_key_points(
+ xy: list[list[list[float]]],
+ confidence: list[list[float]] | None = None,
class_id: list[int] | None = None,
data: dict[str, list[Any]] | None = None,
) -> KeyPoints:
@@ -53,5 +53,49 @@ def convert_data(data: dict[str, list[Any]]):
)
+def random_boxes(
+ count: int,
+ image_size: tuple[int, int] = (1920, 1080),
+ min_box_size: int = 20,
+ max_box_size: int = 200,
+ seed: int | None = None,
+) -> np.ndarray:
+ """
+ Generate random bounding boxes within given image dimensions and size constraints.
+
+ Creates `count` bounding boxes randomly positioned and sized, ensuring each
+ stays within image bounds and has width and height in the specified range.
+
+ Args:
+ count (`int`): Number of random bounding boxes to generate.
+ image_size (`tuple[int, int]`): Image size as `(width, height)`.
+ min_box_size (`int`): Minimum side length (pixels) for generated boxes.
+ max_box_size (`int`): Maximum side length (pixels) for generated boxes.
+ seed (`int` or `None`): Optional random seed for reproducibility.
+
+ Returns:
+ (`numpy.ndarray`): Array of shape `(count, 4)` with bounding boxes as
+ `(x_min, y_min, x_max, y_max)`.
+ """
+ if seed is not None:
+ random.seed(seed)
+
+ img_w, img_h = image_size
+ out = np.zeros((count, 4), dtype=np.float32)
+
+ for i in range(count):
+ w = random.uniform(min_box_size, max_box_size)
+ h = random.uniform(min_box_size, max_box_size)
+
+ x_min = random.uniform(0, img_w - w)
+ y_min = random.uniform(0, img_h - h)
+ x_max = x_min + w
+ y_max = y_min + h
+
+ out[i] = (x_min, y_min, x_max, y_max)
+
+ return out
+
+
def assert_almost_equal(actual, expected, tolerance=1e-5):
assert abs(actual - expected) < tolerance, f"Expected {expected}, but got {actual}."
diff --git a/test/utils/test_conversion.py b/test/utils/test_conversion.py
index 65cbd8a1ca..e9fabb0d81 100644
--- a/test/utils/test_conversion.py
+++ b/test/utils/test_conversion.py
@@ -3,7 +3,7 @@
from supervision.utils.conversion import (
cv2_to_pillow,
- ensure_cv2_image_for_processing,
+ ensure_cv2_image_for_standalone_function,
images_to_cv2,
pillow_to_cv2,
)
@@ -16,7 +16,7 @@ def test_ensure_cv2_image_for_processing_when_pillow_image_submitted(
param_a_value = 3
param_b_value = "some"
- @ensure_cv2_image_for_processing
+ @ensure_cv2_image_for_standalone_function
def my_custom_processing_function(
image: np.ndarray,
param_a: int,
@@ -55,7 +55,7 @@ def test_ensure_cv2_image_for_processing_when_cv2_image_submitted(
param_a_value = 3
param_b_value = "some"
- @ensure_cv2_image_for_processing
+ @ensure_cv2_image_for_standalone_function
def my_custom_processing_function(
image: np.ndarray,
param_a: int,
diff --git a/test/utils/test_image.py b/test/utils/test_image.py
index 39640330e7..688f938b70 100644
--- a/test/utils/test_image.py
+++ b/test/utils/test_image.py
@@ -2,8 +2,12 @@
import pytest
from PIL import Image, ImageChops
-from supervision import Color, Point
-from supervision.utils.image import create_tiles, letterbox_image, resize_image
+from supervision.utils.image import (
+ crop_image,
+ get_image_resolution_wh,
+ letterbox_image,
+ resize_image,
+)
def test_resize_image_for_opencv_image() -> None:
@@ -98,145 +102,59 @@ def test_letterbox_image_for_pillow_image() -> None:
)
-def test_create_tiles_with_one_image(
- one_image: np.ndarray, single_image_tile: np.ndarray
-) -> None:
- # when
- result = create_tiles(images=[one_image], single_tile_size=(240, 240))
-
- # # then
- assert np.allclose(result, single_image_tile, atol=5.0)
-
-
-def test_create_tiles_with_one_image_and_enforced_grid(
- one_image: np.ndarray, single_image_tile_enforced_grid: np.ndarray
-) -> None:
- # when
- result = create_tiles(
- images=[one_image],
- grid_size=(None, 3),
- single_tile_size=(240, 240),
- )
-
- # then
- assert np.allclose(result, single_image_tile_enforced_grid, atol=5.0)
-
-
-def test_create_tiles_with_two_images(
- two_images: list[np.ndarray], two_images_tile: np.ndarray
-) -> None:
- # when
- result = create_tiles(images=two_images, single_tile_size=(240, 240))
-
- # then
- assert np.allclose(result, two_images_tile, atol=5.0)
-
-
-def test_create_tiles_with_three_images(
- three_images: list[np.ndarray], three_images_tile: np.ndarray
-) -> None:
- # when
- result = create_tiles(images=three_images, single_tile_size=(240, 240))
-
- # then
- assert np.allclose(result, three_images_tile, atol=5.0)
-
-
-def test_create_tiles_with_four_images(
- four_images: list[np.ndarray],
- four_images_tile: np.ndarray,
-) -> None:
- # when
- result = create_tiles(images=four_images, single_tile_size=(240, 240))
-
- # then
- assert np.allclose(result, four_images_tile, atol=5.0)
-
-
-def test_create_tiles_with_all_images(
- all_images: list[np.ndarray],
- all_images_tile: np.ndarray,
-) -> None:
- # when
- result = create_tiles(images=all_images, single_tile_size=(240, 240))
-
- # then
- assert np.allclose(result, all_images_tile, atol=5.0)
-
-
-def test_create_tiles_with_all_images_and_custom_grid(
- all_images: list[np.ndarray], all_images_tile_and_custom_grid: np.ndarray
-) -> None:
- # when
- result = create_tiles(
- images=all_images,
- grid_size=(3, 3),
- single_tile_size=(240, 240),
- )
-
- # then
- assert np.allclose(result, all_images_tile_and_custom_grid, atol=5.0)
-
-
-def test_create_tiles_with_all_images_and_custom_colors(
- all_images: list[np.ndarray], all_images_tile_and_custom_colors: np.ndarray
-) -> None:
- # when
- result = create_tiles(
- images=all_images,
- tile_margin_color=(127, 127, 127),
- tile_padding_color=(224, 224, 224),
- single_tile_size=(240, 240),
- )
-
- # then
- assert np.allclose(result, all_images_tile_and_custom_colors, atol=5.0)
-
-
-def test_create_tiles_with_all_images_and_titles(
- all_images: list[np.ndarray],
- all_images_tile_and_custom_colors_and_titles: np.ndarray,
-) -> None:
- # when
- result = create_tiles(
- images=all_images,
- titles=["Image 1", None, "Image 3", "Image 4"],
- single_tile_size=(240, 240),
- )
-
- # then
- assert np.allclose(result, all_images_tile_and_custom_colors_and_titles, atol=5.0)
-
-
-def test_create_tiles_with_all_images_and_titles_with_custom_configs(
- all_images: list[np.ndarray],
- all_images_tile_and_titles_with_custom_configs: np.ndarray,
-) -> None:
- # when
- result = create_tiles(
- images=all_images,
- titles=["Image 1", None, "Image 3", "Image 4"],
- single_tile_size=(240, 240),
- titles_anchors=[
- Point(x=200, y=300),
- Point(x=300, y=400),
- None,
- Point(x=300, y=400),
- ],
- titles_color=Color.RED,
- titles_scale=1.5,
- titles_thickness=3,
- titles_padding=20,
- titles_background_color=Color.BLACK,
- default_title_placement="bottom",
- )
-
- # then
- assert np.allclose(result, all_images_tile_and_titles_with_custom_configs, atol=5.0)
-
-
-def test_create_tiles_with_all_images_and_custom_grid_to_small_to_fit_images(
- all_images: list[np.ndarray],
-) -> None:
- with pytest.raises(ValueError):
- _ = create_tiles(images=all_images, grid_size=(2, 2))
+@pytest.mark.parametrize(
+ "image, xyxy, expected_size",
+ [
+ # NumPy RGB
+ (
+ np.zeros((4, 6, 3), dtype=np.uint8),
+ (2, 1, 5, 3),
+ (3, 2), # width = 5-2, height = 3-1
+ ),
+ # NumPy grayscale
+ (
+ np.zeros((5, 5), dtype=np.uint8),
+ (1, 1, 4, 4),
+ (3, 3),
+ ),
+ # Pillow RGB
+ (
+ Image.new("RGB", (6, 4), color=0),
+ (2, 1, 5, 3),
+ (3, 2),
+ ),
+ # Pillow grayscale
+ (
+ Image.new("L", (5, 5), color=0),
+ (1, 1, 4, 4),
+ (3, 3),
+ ),
+ ],
+)
+def test_crop_image(image, xyxy, expected_size):
+ cropped = crop_image(image=image, xyxy=xyxy)
+ if isinstance(image, np.ndarray):
+ assert isinstance(cropped, np.ndarray)
+ assert cropped.shape[1] == expected_size[0] # width
+ assert cropped.shape[0] == expected_size[1] # height
+ else:
+ assert isinstance(cropped, Image.Image)
+ assert cropped.size == expected_size
+
+
+@pytest.mark.parametrize(
+ "image, expected",
+ [
+ # NumPy RGB
+ (np.zeros((4, 6, 3), dtype=np.uint8), (6, 4)),
+ # NumPy grayscale
+ (np.zeros((10, 20), dtype=np.uint8), (20, 10)),
+ # Pillow RGB
+ (Image.new("RGB", (6, 4), color=0), (6, 4)),
+ # Pillow grayscale
+ (Image.new("L", (20, 10), color=0), (20, 10)),
+ ],
+)
+def test_get_image_resolution_wh(image, expected):
+ resolution = get_image_resolution_wh(image)
+ assert resolution == expected
diff --git a/test/utils/test_internal.py b/test/utils/test_internal.py
index 07674f39cb..749d4be3a8 100644
--- a/test/utils/test_internal.py
+++ b/test/utils/test_internal.py
@@ -145,6 +145,7 @@ def __private_property(self):
"metadata",
"area",
"box_area",
+ "box_aspect_ratio",
},
DoesNotRaise(),
),