diff --git a/.github/workflows/clear-cache.yml b/.github/workflows/clear-cache.yml index a9cf4544fe..4696895ecd 100644 --- a/.github/workflows/clear-cache.yml +++ b/.github/workflows/clear-cache.yml @@ -16,7 +16,7 @@ jobs: timeout-minutes: 10 steps: - name: Clear cache - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: script: | console.log("Starting cache cleanup...") diff --git a/.github/workflows/publish-docs.yml b/.github/workflows/publish-docs.yml index 3b05352ef8..e3b3f038a0 100644 --- a/.github/workflows/publish-docs.yml +++ b/.github/workflows/publish-docs.yml @@ -34,7 +34,7 @@ jobs: fetch-depth: 0 - name: 🐍 Install uv and set Python ${{ matrix.python-version }} - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/publish-pre-release.yml b/.github/workflows/publish-pre-release.yml index f6e410eeb0..104ecfa545 100644 --- a/.github/workflows/publish-pre-release.yml +++ b/.github/workflows/publish-pre-release.yml @@ -29,7 +29,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true @@ -42,6 +42,6 @@ jobs: uv run twine check --strict dist/* - name: πŸš€ Publish to PyPi - uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: attestations: true diff --git a/.github/workflows/publish-release.yml b/.github/workflows/publish-release.yml index 2ba7536f9f..45d7a2a196 100644 --- a/.github/workflows/publish-release.yml +++ b/.github/workflows/publish-release.yml @@ -27,7 +27,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true @@ -40,6 +40,6 @@ jobs: uv run twine check --strict dist/* - name: πŸš€ Publish to PyPi - uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: attestations: true diff --git a/.github/workflows/publish-testpypi.yml b/.github/workflows/publish-testpypi.yml index 69dd746d13..a6d14c5a97 100644 --- a/.github/workflows/publish-testpypi.yml +++ b/.github/workflows/publish-testpypi.yml @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true @@ -37,7 +37,7 @@ jobs: uv run twine check --strict dist/* - name: πŸš€ Publish to Test-PyPi - uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: repository-url: https://test.pypi.org/legacy/ attestations: true diff --git a/.github/workflows/test-doc.yml b/.github/workflows/test-doc.yml index 01930690be..cd41018d0b 100644 --- a/.github/workflows/test-doc.yml +++ b/.github/workflows/test-doc.yml @@ -24,7 +24,7 @@ jobs: fetch-depth: 0 - name: 🐍 Install uv and set Python ${{ matrix.python-version }} - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/uv-test.yml b/.github/workflows/uv-test.yml index 1aa2882ec8..65be9c17e0 100644 --- a/.github/workflows/uv-test.yml +++ b/.github/workflows/uv-test.yml @@ -19,7 +19,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: 🐍 Install uv and set Python version ${{ matrix.python-version }} - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 24e432cd54..87528b7e40 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: additional_dependencies: ["bandit[toml]"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.3 + rev: v0.12.5 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/docs/changelog.md b/docs/changelog.md index 4ebec50021..1974680c66 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,73 @@ # Changelog +### 0.27.0 Nov 16, 2025 + +- Added [#2008](https://github.com/roboflow/supervision/pull/2008): [`sv.filter_segments_by_distance`](https://supervision.roboflow.com/0.27.0/detection/utils/masks/#supervision.detection.utils.masks.filter_segments_by_distance) to keep the largest connected component and nearby components within an absolute or relative distance threshold. Useful for cleaning segmentation predictions from models such as SAM, SAM2, YOLO segmentation, and RF-DETR segmentation. + +- Added [#2006](https://github.com/roboflow/supervision/pull/2006): [`sv.xyxy_to_mask`](https://supervision.roboflow.com/0.27.0/detection/utils/converters/#supervision.detection.utils.converters.xyxy_to_mask) to convert bounding boxes into 2D boolean masks, where each mask corresponds to a single box. + +- Added [#1943](https://github.com/roboflow/supervision/pull/1943): [`sv.tint_image`](https://supervision.roboflow.com/0.27.0/utils/image/#supervision.utils.image.tint_image) to apply a solid color overlay to an image at a given opacity. Works with both NumPy and PIL inputs. + +- Added [#1943](https://github.com/roboflow/supervision/pull/1943): [`sv.grayscale_image`](https://supervision.roboflow.com/0.27.0/utils/image/#supervision.utils.image.tint_image) to convert an image to 3 channel grayscale for compatibility with color based drawing utilities. + +- Added [#2014](https://github.com/roboflow/supervision/pull/2014): [`sv.get_image_resolution_wh`](https://supervision.roboflow.com/0.27.0/utils/image/#supervision.utils.image.get_image_resolution_wh) as a unified way to read image width and height from NumPy and PIL inputs. + +- Added [#1912](https://github.com/roboflow/supervision/pull/1912): [`sv.edit_distance`](https://supervision.roboflow.com/0.27.0/detection/utils/vlms/#supervision.detection.utils.vlms.edit_distance) for Levenshtein distance between two strings. Supports insert, delete, and substitute operations. + +- Added [#1912](https://github.com/roboflow/supervision/pull/1912): [`sv.fuzzy_match_index`](https://supervision.roboflow.com/0.27.0/detection/utils/vlms/#supervision.detection.utils.vlms.fuzzy_match_index) to find the first close match in a list using edit distance. + +- Changed [#2015](https://github.com/roboflow/supervision/pull/2015): [`sv.Detections.from_vlm`](https://supervision.roboflow.com/0.27.0/detection/core/#supervision.detection.core.Detections.from_vlm) and legacy `from_lmm` now support Qwen3 VL via `vlm=sv.VLM.QWEN_3_VL`. + +- Changed [#1884](https://github.com/roboflow/supervision/pull/1884): [`sv.Detections.from_vlm`](https://supervision.roboflow.com/0.27.0/detection/core/#supervision.detection.core.Detections.from_vlm) and legacy `from_lmm` now support DeepSeek VL 2 via `vlm=sv.VLM.DEEPSEEK_VL_2`. + +- Changed [#2015](https://github.com/roboflow/supervision/pull/2015): [`sv.Detections.from_vlm`](https://supervision.roboflow.com/0.27.0/detection/core/#supervision.detection.core.Detections.from_vlm) now parses Qwen 2.5 VL outputs more robustly and handles incomplete or truncated JSON responses. + +- Changed [#2014](https://github.com/roboflow/supervision/pull/2014): [`sv.InferenceSlicer`](https://supervision.roboflow.com/0.27.0/detection/tools/inference_slicer/#supervision.detection.tools.inference_slicer.InferenceSlicer) now uses a new offset generation logic that removes redundant tiles and aligns borders cleanly. This reduces the number of processed tiles and shortens inference time without hurting detection quality. + +- Changed [#2016](https://github.com/roboflow/supervision/pull/2016): [`sv.Detections`](https://supervision.roboflow.com/0.27.0/detection/core/#supervision.detection.core.Detections) now includes a `box_aspect_ratio` property for vectorized aspect ratio computation, useful for filtering detections based on box shape. + +- Changed [#2001](https://github.com/roboflow/supervision/pull/2001): Significantly improved the performance of [`sv.box_iou_batch`](https://supervision.roboflow.com/0.27.0/detection/utils/iou_and_nms/#supervision.detection.utils.iou_and_nms.box_iou_batch). On internal benchmarks, processing runs approximately 2x to 5x faster. + +- Changed [#1997](https://github.com/roboflow/supervision/pull/1997): [`sv.process_video`](https://supervision.roboflow.com/0.27.0/utils/video/#supervision.utils.video.process_video) now uses a threaded reader, processor, and writer pipeline. This removes I/O stalls and improves throughput while keeping the callback single threaded and safe for stateful models. + +- Changed: [`sv.denormalize_boxes`](https://supervision.roboflow.com/0.27.0/detection/utils/boxes/#supervision.detection.utils.boxes.denormalize_boxes) now supports batch conversion of bounding boxes. The function accepts arrays of shape `(N, 4)` and returns a batch of absolute pixel coordinates. + +- Changed [#1917](https://github.com/roboflow/supervision/pull/1917): [`sv.LabelAnnotator`](https://supervision.roboflow.com/develop/0.27.0/annotators/#supervision.annotators.core.LabelAnnotator) and [`sv.RichLabelAnnotator`](https://supervision.roboflow.com/develop/0.27.0/annotators/#supervision.annotators.core.LabelAnnotator) now accept `text_offset=(x, y)` to shift the label relative to `text_position`. Works with smart label position and line wrapping. + +!!! failure "Removed" + Removed the deprecated `overlap_ratio_wh` argument from `sv.InferenceSlicer`. Use the pixel based `overlap_wh` argument to control slice overlap. + +!!! info "Tip" + Convert your old ratio based overlap to pixel based overlap by multiplying each ratio by the slice dimensions. + + ```python + # before + + slice_wh = (640, 640) + overlap_ratio_wh = (0.25, 0.25) + + slicer = sv.InferenceSlicer( + callback=callback, + slice_wh=slice_wh, + overlap_ratio_wh=overlap_ratio_wh, + overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION, + ) + + # after + + overlap_wh = ( + int(overlap_ratio_wh[0] * slice_wh[0]), + int(overlap_ratio_wh[1] * slice_wh[1]), + ) + + slicer = sv.InferenceSlicer( + callback=callback, + slice_wh=slice_wh, + overlap_wh=overlap_wh, + overlap_filter=sv.OverlapFilter.NON_MAX_SUPPRESSION, + ) + ``` + ### 0.26.1 Jul 22, 2025 - Fixed [1894](https://github.com/roboflow/supervision/pull/1894): Error in [`sv.MeanAveragePrecision`](https://supervision.roboflow.com/0.26.1/metrics/mean_average_precision/#supervision.metrics.mean_average_precision.MeanAveragePrecision) where the area used for size-specific evaluation (small / medium / large) was always zero unless explicitly provided in `sv.Detections.data`. diff --git a/docs/detection/annotators.md b/docs/detection/annotators.md index 938c49ff48..a0341eddf1 100644 --- a/docs/detection/annotators.md +++ b/docs/detection/annotators.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Annotators diff --git a/docs/detection/core.md b/docs/detection/core.md index 475cdae1da..35225cec51 100644 --- a/docs/detection/core.md +++ b/docs/detection/core.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Detections diff --git a/docs/detection/utils/boxes.md b/docs/detection/utils/boxes.md index 63a3231755..020cc8f99a 100644 --- a/docs/detection/utils/boxes.md +++ b/docs/detection/utils/boxes.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Boxes Utils diff --git a/docs/detection/utils/converters.md b/docs/detection/utils/converters.md index 48bec65fe4..b6b1e2af6c 100644 --- a/docs/detection/utils/converters.md +++ b/docs/detection/utils/converters.md @@ -58,3 +58,9 @@ status: new :::supervision.detection.utils.converters.polygon_to_xyxy + +
+

xyxy_to_mask

+
+ +:::supervision.detection.utils.converters.xyxy_to_mask diff --git a/docs/detection/utils/iou_and_nms.md b/docs/detection/utils/iou_and_nms.md index 2b4e4fc334..7191656b7e 100644 --- a/docs/detection/utils/iou_and_nms.md +++ b/docs/detection/utils/iou_and_nms.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # IoU and NMS Utils diff --git a/docs/detection/utils/masks.md b/docs/detection/utils/masks.md index 9e53a6baa1..99097bef6f 100644 --- a/docs/detection/utils/masks.md +++ b/docs/detection/utils/masks.md @@ -22,3 +22,9 @@ status: new :::supervision.detection.utils.masks.contains_multiple_segments + +
+

filter_segments_by_distance

+
+ +:::supervision.detection.utils.masks.filter_segments_by_distance diff --git a/docs/detection/utils/polygons.md b/docs/detection/utils/polygons.md index cd9525345a..8a7cf1e1ce 100644 --- a/docs/detection/utils/polygons.md +++ b/docs/detection/utils/polygons.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Polygons Utils diff --git a/docs/detection/utils/vlms.md b/docs/detection/utils/vlms.md new file mode 100644 index 0000000000..2706d2a837 --- /dev/null +++ b/docs/detection/utils/vlms.md @@ -0,0 +1,18 @@ +--- +comments: true +status: new +--- + +# VLMs Utils + +
+

edit_distance

+
+ +:::supervision.detection.utils.vlms.edit_distance + +
+

fuzzy_match_index

+
+ +:::supervision.detection.utils.vlms.fuzzy_match_index diff --git a/docs/how_to/benchmark_a_model.md b/docs/how_to/benchmark_a_model.md index bf23ee0890..aa707fa734 100644 --- a/docs/how_to/benchmark_a_model.md +++ b/docs/how_to/benchmark_a_model.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- ![Corgi Example](https://media.roboflow.com/supervision/image-examples/how-to/benchmark-models/corgi-sorted-2.png) diff --git a/docs/how_to/process_datasets.md b/docs/how_to/process_datasets.md index 36c122df41..acfd941c47 100644 --- a/docs/how_to/process_datasets.md +++ b/docs/how_to/process_datasets.md @@ -331,12 +331,9 @@ for i in range(16): annotated_image = label_annotator.annotate(annotated_image, annotations, labels) annotated_images.append(annotated_image) -grid = sv.create_tiles( +sv.plot_images_grid( annotated_images, grid_size=(4, 4), - single_tile_size=(400, 400), - tile_padding_color=sv.Color.WHITE, - tile_margin_color=sv.Color.WHITE ) ``` diff --git a/docs/how_to/save_detections.md b/docs/how_to/save_detections.md index e9e1484942..c998dee7e3 100644 --- a/docs/how_to/save_detections.md +++ b/docs/how_to/save_detections.md @@ -234,7 +234,7 @@ with model = get_model(model_id="yolov8n-640") frames_generator = sv.get_video_frames_generator() - with sv.JSONSink() as sink: + with sv.JSONSink() as sink: for frame_index, frame in enumerate(frames_generator): results = model.infer(image)[0] @@ -250,7 +250,7 @@ with model = YOLO("yolov8n.pt") frames_generator = sv.get_video_frames_generator() - with sv.JSONSink() as sink: + with sv.JSONSink() as sink: for frame_index, frame in enumerate(frames_generator): results = model(frame)[0] @@ -268,7 +268,7 @@ with model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") frames_generator = sv.get_video_frames_generator() - with sv.JSONSink() as sink: + with sv.JSONSink() as sink: for frame_index, frame in enumerate(frames_generator): frame = sv.cv2_to_pillow(frame) diff --git a/docs/how_to/track_objects.md b/docs/how_to/track_objects.md index c2be3cef6a..2acad7740d 100644 --- a/docs/how_to/track_objects.md +++ b/docs/how_to/track_objects.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Track Objects @@ -345,7 +344,7 @@ Supervision is versatile and compatible with various models. Check this [link](/ We will define a `callback` function, which will process each frame of the video by obtaining model predictions and then annotating the frame based on these predictions. -Let's immediately visualize the results with our [`EdgeAnnotator`](/latest/keypoint/annotators/#supervision.keypoint.annotators.EdgeAnnotator) and [`VertexAnnotator`](https://supervision.roboflow.com/latest/keypoint/annotators/#supervision.keypoint.annotators.VertexAnnotator). +Let's immediately visualize the results with our [`EdgeAnnotator`](/latest/keypoint/annotators/#supervision.key_points.annotators.EdgeAnnotator) and [`VertexAnnotator`](https://supervision.roboflow.com/latest/keypoint/annotators/#supervision.key_points.annotators.VertexAnnotator). === "Ultralytics" @@ -408,7 +407,7 @@ Let's immediately visualize the results with our [`EdgeAnnotator`](/latest/keypo ### Convert to Detections -Keypoint tracking is currently supported via the conversion of `KeyPoints` to `Detections`. This is achieved with the [`KeyPoints.as_detections()`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.as_detections) function. +Keypoint tracking is currently supported via the conversion of `KeyPoints` to `Detections`. This is achieved with the [`KeyPoints.as_detections()`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.as_detections) function. Let's convert to detections and visualize the results with our [`BoxAnnotator`](/latest/detection/annotators/#supervision.annotators.core.BoxAnnotator). diff --git a/docs/keypoint/annotators.md b/docs/keypoint/annotators.md index 32f30626bb..92c7cebaf5 100644 --- a/docs/keypoint/annotators.md +++ b/docs/keypoint/annotators.md @@ -78,19 +78,19 @@ comments: true -:::supervision.keypoint.annotators.VertexAnnotator +:::supervision.key_points.annotators.VertexAnnotator -:::supervision.keypoint.annotators.EdgeAnnotator +:::supervision.key_points.annotators.EdgeAnnotator -:::supervision.keypoint.annotators.VertexLabelAnnotator +:::supervision.key_points.annotators.VertexLabelAnnotator diff --git a/docs/keypoint/core.md b/docs/keypoint/core.md index 7354babab0..e683ae873a 100644 --- a/docs/keypoint/core.md +++ b/docs/keypoint/core.md @@ -1,8 +1,7 @@ --- comments: true -status: new --- # Keypoint Detection -:::supervision.keypoint.core.KeyPoints +:::supervision.key_points.core.KeyPoints diff --git a/docs/metrics/mean_average_precision.md b/docs/metrics/mean_average_precision.md index ce3e06a411..10f7a97771 100644 --- a/docs/metrics/mean_average_precision.md +++ b/docs/metrics/mean_average_precision.md @@ -1,6 +1,5 @@ --- comments: true -status: new --- # Mean Average Precision diff --git a/docs/stylesheets/cookbooks-card.css b/docs/stylesheets/cookbooks_card.css similarity index 100% rename from docs/stylesheets/cookbooks-card.css rename to docs/stylesheets/cookbooks_card.css diff --git a/docs/utils/image.md b/docs/utils/image.md index 8e39136a8b..9d1c1895ca 100644 --- a/docs/utils/image.md +++ b/docs/utils/image.md @@ -1,5 +1,6 @@ --- comments: true +status: new --- # Image Utils @@ -29,10 +30,22 @@ comments: true :::supervision.utils.image.letterbox_image -:::supervision.utils.image.overlay_image +:::supervision.utils.image.tint_image + + + +:::supervision.utils.image.grayscale_image + + + +:::supervision.utils.image.get_image_resolution_wh

ImageSink

diff --git a/examples/time_in_zone/requirements.txt b/examples/time_in_zone/requirements.txt index 9f9446c836..b5ff1911b6 100644 --- a/examples/time_in_zone/requirements.txt +++ b/examples/time_in_zone/requirements.txt @@ -1,4 +1,6 @@ supervision ultralytics inference -pytube +# https://github.com/pytube/pytube/issues/2044 +# pytube +pytubefix diff --git a/examples/time_in_zone/scripts/download_from_youtube.py b/examples/time_in_zone/scripts/download_from_youtube.py index d867363175..b740095573 100644 --- a/examples/time_in_zone/scripts/download_from_youtube.py +++ b/examples/time_in_zone/scripts/download_from_youtube.py @@ -3,7 +3,7 @@ import argparse import os -from pytube import YouTube +from pytubefix import YouTube def main(url: str, output_path: str | None, file_name: str | None) -> None: diff --git a/mkdocs.yml b/mkdocs.yml index f25015348a..daf3d09845 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,7 +26,7 @@ extra: extra_css: - stylesheets/extra.css - - stylesheets/cookbooks-card.css + - stylesheets/cookbooks_card.css nav: - Home: index.md @@ -47,6 +47,7 @@ nav: - Boxes: detection/utils/boxes.md - Masks: detection/utils/masks.md - Polygons: detection/utils/polygons.md + - VLMs: detection/utils/vlms.md - Keypoint Detection: - Core: keypoint/core.md - Annotators: keypoint/annotators.md diff --git a/pyproject.toml b/pyproject.toml index cae78492ac..e7f86b89ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "supervision" description = "A set of easy-to-use utils that will come in handy in any Computer Vision project" license = { text = "MIT" } -version = "0.26.1" +version = "0.27.0" readme = "README.md" requires-python = ">=3.9" authors = [ @@ -67,11 +67,12 @@ dev = [ "ipywidgets>=8.1.1", "jupytext>=1.16.1", "nbconvert>=7.14.2", - "docutils!=0.21" + "docutils!=0.21", + "pre-commit>=3.8.0" ] docs = [ "mkdocs-material[imaging]>=9.5.5", - "mkdocstrings>=0.25.2,<0.30.0", + "mkdocstrings>=0.25.2,<0.31.0", "mkdocstrings-python>=1.10.9", "mike>=2.0.0", "mkdocs-jupyter>=0.24.3", @@ -81,7 +82,7 @@ docs = [ build = [ "twine>=5.1.1,<7.0.0", "wheel>=0.40,<0.46", - "build>=0.10,<1.3" + "build>=0.10,<1.4" ] [tool.bandit] diff --git a/supervision/__init__.py b/supervision/__init__.py index 925628be3c..00820076fe 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -65,6 +65,7 @@ polygon_to_xyxy, xcycwh_to_xyxy, xywh_to_xyxy, + xyxy_to_mask, xyxy_to_polygons, xyxy_to_xcycarh, xyxy_to_xywh, @@ -86,12 +87,14 @@ calculate_masks_centroids, contains_holes, contains_multiple_segments, + filter_segments_by_distance, move_masks, ) from supervision.detection.utils.polygons import ( approximate_polygon, filter_polygons_by_area, ) +from supervision.detection.utils.vlms import edit_distance, fuzzy_match_index from supervision.detection.vlm import LMM, VLM from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import ( @@ -107,24 +110,26 @@ ) from supervision.geometry.core import Point, Position, Rect from supervision.geometry.utils import get_polygon_center -from supervision.keypoint.annotators import ( +from supervision.key_points.annotators import ( EdgeAnnotator, VertexAnnotator, VertexLabelAnnotator, ) -from supervision.keypoint.core import KeyPoints +from supervision.key_points.core import KeyPoints from supervision.metrics.detection import ConfusionMatrix, MeanAveragePrecision from supervision.tracker.byte_tracker.core import ByteTrack from supervision.utils.conversion import cv2_to_pillow, pillow_to_cv2 from supervision.utils.file import list_files_with_extensions from supervision.utils.image import ( ImageSink, - create_tiles, crop_image, + get_image_resolution_wh, + grayscale_image, letterbox_image, overlay_image, resize_image, scale_image, + tint_image, ) from supervision.utils.notebook import plot_image, plot_images_grid from supervision.utils.video import ( @@ -205,7 +210,6 @@ "clip_boxes", "contains_holes", "contains_multiple_segments", - "create_tiles", "crop_image", "cv2_to_pillow", "draw_filled_polygon", @@ -215,10 +219,15 @@ "draw_polygon", "draw_rectangle", "draw_text", + "edit_distance", "filter_polygons_by_area", + "filter_segments_by_distance", + "fuzzy_match_index", "get_coco_class_index_mapping", + "get_image_resolution_wh", "get_polygon_center", "get_video_frames_generator", + "grayscale_image", "letterbox_image", "list_files_with_extensions", "mask_iou_batch", @@ -242,8 +251,10 @@ "rle_to_mask", "scale_boxes", "scale_image", + "tint_image", "xcycwh_to_xyxy", "xywh_to_xyxy", + "xyxy_to_mask", "xyxy_to_polygons", "xyxy_to_xcycarh", "xyxy_to_xywh", diff --git a/supervision/annotators/base.py b/supervision/annotators/base.py index 159ad5567e..9b4bbcbe27 100644 --- a/supervision/annotators/base.py +++ b/supervision/annotators/base.py @@ -1,19 +1,7 @@ from abc import ABC, abstractmethod -from typing import TypeVar - -import numpy as np -from PIL import Image from supervision.detection.core import Detections - -ImageType = TypeVar("ImageType", np.ndarray, Image.Image) -""" -An image of type `np.ndarray` or `PIL.Image.Image`. - -Unlike a `Union`, ensures the type remains consistent. If a function -takes an `ImageType` argument and returns an `ImageType`, when you -pass an `np.ndarray`, you will get an `np.ndarray` back. -""" +from supervision.draw.base import ImageType class BaseAnnotator(ABC): diff --git a/supervision/annotators/core.py b/supervision/annotators/core.py index 0b7d4b7632..900d823b2c 100644 --- a/supervision/annotators/core.py +++ b/supervision/annotators/core.py @@ -7,8 +7,9 @@ import numpy as np import numpy.typing as npt from PIL import Image, ImageDraw, ImageFont +from scipy.interpolate import splev, splprep -from supervision.annotators.base import BaseAnnotator, ImageType +from supervision.annotators.base import BaseAnnotator from supervision.annotators.utils import ( PENDING_TRACK_ID, ColorLookup, @@ -28,12 +29,13 @@ polygon_to_mask, xyxy_to_polygons, ) +from supervision.draw.base import ImageType from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import draw_polygon, draw_rounded_rectangle, draw_text from supervision.geometry.core import Point, Position, Rect from supervision.utils.conversion import ( - ensure_cv2_image_for_annotation, - ensure_pil_image_for_annotation, + ensure_cv2_image_for_class_method, + ensure_pil_image_for_class_method, ) from supervision.utils.image import ( crop_image, @@ -51,25 +53,28 @@ class _BaseLabelAnnotator(BaseAnnotator): Attributes: color (Union[Color, ColorPalette]): The color to use for the label background. + color_lookup (ColorLookup): The method used to determine the color of the label. text_color (Union[Color, ColorPalette]): The color to use for the label text. text_padding (int): The padding around the label text, in pixels. text_anchor (Position): The position of the text relative to the detection - bounding box. - color_lookup (ColorLookup): The method used to determine the color of the label. + bounding box. + text_offset (Tuple[int, int]): A tuple of 2D coordinates `(x, y)` to + offset the text position from the anchor point, in pixels. border_radius (int): The radius of the label background corners, in pixels. smart_position (bool): Whether to intelligently adjust the label position to - avoid overlapping with other elements. + avoid overlapping with other elements. max_line_length (Optional[int]): Maximum number of characters per line before - wrapping the text. None means no wrapping. + wrapping the text. None means no wrapping. """ def __init__( self, color: Color | ColorPalette = ColorPalette.DEFAULT, + color_lookup: ColorLookup = ColorLookup.CLASS, text_color: Color | ColorPalette = Color.WHITE, text_padding: int = 10, text_position: Position = Position.TOP_LEFT, - color_lookup: ColorLookup = ColorLookup.CLASS, + text_offset: tuple[int, int] = (0, 0), border_radius: int = 0, smart_position: bool = False, max_line_length: int | None = None, @@ -79,27 +84,29 @@ def __init__( Args: color (Union[Color, ColorPalette], optional): The color to use for the label - background. + background. + color_lookup (ColorLookup, optional): The method used to determine the color + of the label text_color (Union[Color, ColorPalette], optional): The color to use for the - label text. + label text. text_padding (int, optional): The padding around the label text, in pixels. text_position (Position, optional): The position of the text relative to the - detection bounding box. - color_lookup (ColorLookup, optional): The method used to determine the color - of the label + detection bounding box. + text_offset (Tuple[int, int], optional): A tuple of 2D coordinates + `(x, y)` to offset the text position from the anchor point, in pixels. border_radius (int, optional): The radius of the label background corners, - in pixels. + in pixels. smart_position (bool, optional): Whether to intelligently adjust the label - position to avoid overlapping with other elements. + position to avoid overlapping with other elements. max_line_length (Optional[int], optional): Maximum number of characters per - line before wrapping the text. None means no wrapping. - + line before wrapping the text. None means no wrapping. """ self.color: Color | ColorPalette = color + self.color_lookup: ColorLookup = color_lookup self.text_color: Color | ColorPalette = text_color self.text_padding: int = text_padding self.text_anchor: Position = text_position - self.color_lookup: ColorLookup = color_lookup + self.text_offset: tuple[int, int] = text_offset self.border_radius: int = border_radius self.smart_position = smart_position self.max_line_length: int | None = max_line_length @@ -171,7 +178,7 @@ def __init__( self.thickness: int = thickness self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -254,7 +261,7 @@ def __init__( self.thickness: int = thickness self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -343,7 +350,7 @@ def __init__( self.opacity = opacity self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -433,7 +440,7 @@ def __init__( self.thickness: int = thickness self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -520,7 +527,7 @@ def __init__( self.color_lookup: ColorLookup = color_lookup self.opacity = opacity - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -616,7 +623,7 @@ def __init__( self.color_lookup: ColorLookup = color_lookup self.kernel_size: int = kernel_size - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -716,7 +723,7 @@ def __init__( self.end_angle: int = end_angle self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -808,7 +815,7 @@ def __init__( self.corner_length: int = corner_length self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -897,7 +904,7 @@ def __init__( self.thickness: int = thickness self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -996,7 +1003,7 @@ def __init__( self.outline_thickness = outline_thickness self.outline_color: Color | ColorPalette = outline_color - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -1076,37 +1083,100 @@ class LabelAnnotator(_BaseLabelAnnotator): def __init__( self, color: Color | ColorPalette = ColorPalette.DEFAULT, + color_lookup: ColorLookup = ColorLookup.CLASS, text_color: Color | ColorPalette = Color.WHITE, text_scale: float = 0.5, text_thickness: int = 1, text_padding: int = 10, text_position: Position = Position.TOP_LEFT, - color_lookup: ColorLookup = ColorLookup.CLASS, + text_offset: tuple[int, int] = (0, 0), border_radius: int = 0, smart_position: bool = False, max_line_length: int | None = None, ): + """ + Args: + color (Union[Color, ColorPalette]): The color or color palette to use for + annotating the text background. + color_lookup (ColorLookup): Strategy for mapping colors to annotations. + Options are `INDEX`, `CLASS`, `TRACK`. + text_color (Union[Color, ColorPalette]): The color or color palette to use + for the text. + text_scale (float): Font scale for the text. + text_thickness (int): Thickness of the text characters. + text_padding (int): Padding around the text within its background box. + text_position (Position): Position of the text relative to the detection. + Possible values are defined in the `Position` enum. + text_offset (Tuple[int, int]): A tuple of 2D coordinates `(x, y)` to + offset the text position from the anchor point, in pixels. + border_radius (int): The radius to apply round edges. If the selected + value is higher than the lower dimension, width or height, is clipped. + smart_position (bool): Spread out the labels to avoid overlapping. + max_line_length (Optional[int]): Maximum number of characters per line + before wrapping the text. None means no wrapping. + """ self.text_scale: float = text_scale self.text_thickness: int = text_thickness super().__init__( color=color, + color_lookup=color_lookup, text_color=text_color, text_padding=text_padding, text_position=text_position, - color_lookup=color_lookup, + text_offset=text_offset, border_radius=border_radius, smart_position=smart_position, max_line_length=max_line_length, ) - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, - scene: ImageType, # Ensure scene is initially a NumPy array here + scene: ImageType, detections: Detections, labels: list[str] | None = None, custom_color_lookup: np.ndarray | None = None, ) -> np.ndarray: + """ + Annotates the given scene with labels based on the provided detections. + + Args: + scene (ImageType): The image where labels will be drawn. + `ImageType` is a flexible type, accepting either `numpy.ndarray` + or `PIL.Image.Image`. + detections (Detections): Object detections to annotate. + labels (Optional[List[str]]): Custom labels for each detection. + custom_color_lookup (Optional[np.ndarray]): Custom color lookup array. + Allows to override the default color mapping strategy. + + Returns: + The annotated image, matching the type of `scene` (`numpy.ndarray` + or `PIL.Image.Image`) + + Example: + ```python + import supervision as sv + + image = ... + detections = sv.Detections(...) + + labels = [ + f"{class_name} {confidence:.2f}" + for class_name, confidence + in zip(detections['class_name'], detections.confidence) + ] + + label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) + annotated_frame = label_annotator.annotate( + scene=image.copy(), + detections=detections, + labels=labels + ) + ``` + + ![label-annotator-example](https://media.roboflow.com/ + supervision-annotator-examples/label-annotator-example-purple.png) + """ assert isinstance(scene, np.ndarray) validate_labels(labels, detections) @@ -1144,7 +1214,12 @@ def _get_label_properties( anchor=self.text_anchor ).astype(int) - for label, center_coords in zip(labels, anchors_coordinates): + for label, center_coordinates in zip(labels, anchors_coordinates): + center_coordinates = ( + center_coordinates[0] + self.text_offset[0], + center_coordinates[1] + self.text_offset[1], + ) + wrapped_lines = wrap_text(label, self.max_line_length) line_heights = [] line_widths = [] @@ -1170,7 +1245,7 @@ def _get_label_properties( height_padded = total_height + 2 * self.text_padding text_background_xyxy = resolve_text_background_xyxy( - center_coordinates=tuple(center_coords), + center_coordinates=center_coordinates, text_wh=(width_padded, height_padded), position=self.text_anchor, ) @@ -1317,31 +1392,54 @@ class RichLabelAnnotator(_BaseLabelAnnotator): def __init__( self, color: Color | ColorPalette = ColorPalette.DEFAULT, + color_lookup: ColorLookup = ColorLookup.CLASS, text_color: Color | ColorPalette = Color.WHITE, font_path: str | None = None, font_size: int = 10, text_padding: int = 10, text_position: Position = Position.TOP_LEFT, - color_lookup: ColorLookup = ColorLookup.CLASS, + text_offset: tuple[int, int] = (0, 0), border_radius: int = 0, smart_position: bool = False, max_line_length: int | None = None, ): + """ + Args: + color (Union[Color, ColorPalette]): The color or color palette to use for + annotating the text background. + color_lookup (ColorLookup): Strategy for mapping colors to annotations. + Options are `INDEX`, `CLASS`, `TRACK`. + text_color (Union[Color, ColorPalette]): The color to use for the text. + font_path (Optional[str]): Path to the font file (e.g., ".ttf" or ".otf") + to use for rendering text. If `None`, the default PIL font will be used. + font_size (int): Font size for the text. + text_padding (int): Padding around the text within its background box. + text_position (Position): Position of the text relative to the detection. + Possible values are defined in the `Position` enum. + text_offset (Tuple[int, int]): A tuple of 2D coordinates `(x, y)` to + offset the text position from the anchor point, in pixels. + border_radius (int): The radius to apply round edges. If the selected + value is higher than the lower dimension, width or height, is clipped. + smart_position (bool): Spread out the labels to avoid overlapping. + max_line_length (Optional[int]): Maximum number of characters per line + before wrapping the text. None means no wrapping. + """ self.font_path = font_path self.font_size = font_size self.font = self._load_font(font_size, font_path) super().__init__( color=color, + color_lookup=color_lookup, text_color=text_color, text_padding=text_padding, text_position=text_position, - color_lookup=color_lookup, + text_offset=text_offset, border_radius=border_radius, smart_position=smart_position, max_line_length=max_line_length, ) - @ensure_pil_image_for_annotation + @ensure_pil_image_for_class_method def annotate( self, scene: ImageType, @@ -1349,6 +1447,44 @@ def annotate( labels: list[str] | None = None, custom_color_lookup: np.ndarray | None = None, ) -> ImageType: + """ + Annotates the given scene with labels based on the provided + detections, with support for Unicode characters. + + Args: + scene (ImageType): The image where labels will be drawn. + `ImageType` is a flexible type, accepting either `numpy.ndarray` + or `PIL.Image.Image`. + detections (Detections): Object detections to annotate. + labels (Optional[List[str]]): Custom labels for each detection. + custom_color_lookup (Optional[np.ndarray]): Custom color lookup array. + Allows to override the default color mapping strategy. + + Returns: + The annotated image, matching the type of `scene` (`numpy.ndarray` + or `PIL.Image.Image`) + + Example: + ```python + import supervision as sv + + image = ... + detections = sv.Detections(...) + + labels = [ + f"{class_name} {confidence:.2f}" + for class_name, confidence + in zip(detections['class_name'], detections.confidence) + ] + + rich_label_annotator = sv.RichLabelAnnotator(font_path="path/to/font.ttf") + annotated_frame = label_annotator.annotate( + scene=image.copy(), + detections=detections, + labels=labels + ) + ``` + """ assert isinstance(scene, Image.Image) validate_labels(labels, detections) @@ -1386,7 +1522,12 @@ def _get_label_properties( anchor=self.text_anchor ).astype(int) - for label, center_coords in zip(labels, anchor_coordinates): + for label, center_coordinates in zip(labels, anchor_coordinates): + center_coordinates = ( + center_coordinates[0] + self.text_offset[0], + center_coordinates[1] + self.text_offset[1], + ) + wrapped_lines = wrap_text(label, self.max_line_length) # Calculate the total text height and maximum width @@ -1409,7 +1550,7 @@ def _get_label_properties( height_padded = int(total_height + 2 * self.text_padding) text_background_xyxy = resolve_text_background_xyxy( - center_coordinates=tuple(center_coords), + center_coordinates=center_coordinates, text_wh=(width_padded, height_padded), position=self.text_anchor, ) @@ -1525,7 +1666,7 @@ def __init__( self.position = icon_position self.offset_xy = offset_xy - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, detections: Detections, icon_path: str | list[str] ) -> ImageType: @@ -1614,7 +1755,7 @@ def __init__(self, kernel_size: int = 15): """ self.kernel_size: int = kernel_size - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -1681,6 +1822,7 @@ def __init__( position: Position = Position.CENTER, trace_length: int = 30, thickness: int = 2, + smooth: bool = False, color_lookup: ColorLookup = ColorLookup.CLASS, ): """ @@ -1692,15 +1834,17 @@ def __init__( trace_length (int): The maximum length of the trace in terms of historical points. Defaults to `30`. thickness (int): The thickness of the trace lines. Defaults to `2`. + smooth (bool): Smooth the trace lines. color_lookup (ColorLookup): Strategy for mapping colors to annotations. Options are `INDEX`, `CLASS`, `TRACK`. """ self.color: Color | ColorPalette = color self.trace = Trace(max_size=trace_length, anchor=position) self.thickness = thickness + self.smooth = smooth self.color_lookup: ColorLookup = color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -1769,10 +1913,18 @@ def annotate( else custom_color_lookup, ) xy = self.trace.get(tracker_id=tracker_id) + spline_points = xy.astype(np.int32) + + if len(xy) > 3 and self.smooth: + x, y = xy[:, 0], xy[:, 1] + tck, u = splprep([x, y], s=20) + x_new, y_new = splev(np.linspace(0, 1, 100), tck) + spline_points = np.stack([x_new, y_new], axis=1).astype(np.int32) + if len(xy) > 1: scene = cv2.polylines( scene, - [xy.astype(np.int32)], + [spline_points], False, color=color.as_bgr(), thickness=self.thickness, @@ -1814,7 +1966,7 @@ def __init__( self.low_hue = low_hue self.heat_mask: npt.NDArray[np.float32] | None = None - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate(self, scene: ImageType, detections: Detections) -> ImageType: """ Annotates the scene with a heatmap based on the provided detections. @@ -1896,7 +2048,7 @@ def __init__(self, pixel_size: int = 20): """ self.pixel_size: int = pixel_size - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -1993,7 +2145,7 @@ def __init__( self.outline_thickness: int = outline_thickness self.outline_color: Color | ColorPalette = outline_color - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -2105,7 +2257,7 @@ def __init__( raise ValueError("roundness attribute must be float between (0, 1.0]") self.roundness: float = roundness - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -2245,7 +2397,7 @@ def __init__( else int(0.15 * self.height) ) - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -2426,7 +2578,7 @@ def __init__( self.border_thickness: int = border_thickness self.border_color_lookup: ColorLookup = border_color_lookup - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, @@ -2575,7 +2727,7 @@ def __init__( self.opacity = opacity self.force_box = force_box - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate(self, scene: ImageType, detections: Detections) -> ImageType: """ Applies a colored overlay to the scene outside of the detected regions. @@ -2673,7 +2825,7 @@ def __init__( self.label_scale = label_scale self.text_thickness = int(self.label_scale + 1.2) - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate( self, scene: ImageType, detections_1: Detections, detections_2: Detections ) -> ImageType: diff --git a/supervision/annotators/utils.py b/supervision/annotators/utils.py index a84e2049d4..42a436c9bb 100644 --- a/supervision/annotators/utils.py +++ b/supervision/annotators/utils.py @@ -2,6 +2,7 @@ import textwrap from enum import Enum +from typing import Any import numpy as np @@ -142,35 +143,45 @@ def resolve_color( detection_idx=detection_idx, color_lookup=color_lookup, ) - if color_lookup == ColorLookup.TRACK and idx == PENDING_TRACK_ID: + if ( + isinstance(color_lookup, ColorLookup) + and color_lookup == ColorLookup.TRACK + and idx == PENDING_TRACK_ID + ): return PENDING_TRACK_COLOR return get_color_by_index(color=color, idx=idx) -def wrap_text(text: str, max_line_length=None) -> list[str]: +def wrap_text(text: Any, max_line_length=None) -> list[str]: """ - Wraps text to the specified maximum line length, respecting existing newlines. - Uses the textwrap library for robust text wrapping. + Wrap `text` to the specified maximum line length, respecting existing + newlines. Falls back to str() if `text` is not already a string. Args: - text (str): The text to wrap. + text (Any): The text (or object) to wrap. + max_line_length (int | None): Maximum width for each wrapped line. Returns: - List[str]: A list of text lines after wrapping. + list[str]: Wrapped lines. """ if not text: return [""] + if not isinstance(text, str): + text = str(text) + if max_line_length is None: return text.splitlines() or [""] + if max_line_length <= 0: + raise ValueError("max_line_length must be a positive integer") + paragraphs = text.split("\n") - all_lines = [] + all_lines: list[str] = [] for paragraph in paragraphs: - if not paragraph: - # Keep empty lines + if paragraph == "": all_lines.append("") continue @@ -182,12 +193,9 @@ def wrap_text(text: str, max_line_length=None) -> list[str]: drop_whitespace=True, ) - if wrapped: - all_lines.extend(wrapped) - else: - all_lines.append("") + all_lines.extend(wrapped or [""]) - return all_lines if all_lines else [""] + return all_lines or [""] def validate_labels(labels: list[str] | None, detections: Detections): diff --git a/supervision/detection/core.py b/supervision/detection/core.py index ea466b8955..6ca6abc567 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -40,12 +40,14 @@ from supervision.detection.vlm import ( LMM, VLM, + from_deepseek_vl_2, from_florence_2, from_google_gemini_2_0, from_google_gemini_2_5, from_moondream, from_paligemma, from_qwen_2_5_vl, + from_qwen_3_vl, validate_vlm_parameters, ) from supervision.geometry.core import Position @@ -295,18 +297,24 @@ def from_ultralytics(cls, ultralytics_results) -> Detections: class_id=np.arange(len(ultralytics_results)), ) - class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int) - class_names = np.array([ultralytics_results.names[i] for i in class_id]) - return cls( - xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(), - confidence=ultralytics_results.boxes.conf.cpu().numpy(), - class_id=class_id, - mask=extract_ultralytics_masks(ultralytics_results), - tracker_id=ultralytics_results.boxes.id.int().cpu().numpy() - if ultralytics_results.boxes.id is not None - else None, - data={CLASS_NAME_DATA_FIELD: class_names}, - ) + if ( + hasattr(ultralytics_results, "boxes") + and ultralytics_results.boxes is not None + ): + class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int) + class_names = np.array([ultralytics_results.names[i] for i in class_id]) + return cls( + xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(), + confidence=ultralytics_results.boxes.conf.cpu().numpy(), + class_id=class_id, + mask=extract_ultralytics_masks(ultralytics_results), + tracker_id=ultralytics_results.boxes.id.int().cpu().numpy() + if ultralytics_results.boxes.id is not None + else None, + data={CLASS_NAME_DATA_FIELD: class_names}, + ) + + return cls.empty() @classmethod def from_yolo_nas(cls, yolo_nas_results) -> Detections: @@ -830,6 +838,7 @@ def from_lmm(cls, lmm: LMM | str, result: str | dict, **kwargs: Any) -> Detectio | Google Gemini 2.0 | `GOOGLE_GEMINI_2_0` | detection | `resolution_wh` | `classes` | | Google Gemini 2.5 | `GOOGLE_GEMINI_2_5` | detection, segmentation | `resolution_wh` | `classes` | | Moondream | `MOONDREAM` | detection | `resolution_wh` | | + | DeepSeek-VL2 | `DEEPSEEK_VL_2` | detection | `resolution_wh` | `classes` | Args: lmm (Union[LMM, str]): The type of LMM (Large Multimodal Model) to use. @@ -943,6 +952,36 @@ def from_lmm(cls, lmm: LMM | str, result: str | dict, **kwargs: Any) -> Detectio # array([0, 1]) ``` + !!! example "Qwen3-VL" + + ```python + import supervision as sv + + qwen_3_vl_result = \"\"\"```json + [ + {"bbox_2d": [139, 768, 315, 954], "label": "cat"}, + {"bbox_2d": [366, 679, 536, 849], "label": "dog"} + ] + ```\"\"\" + detections = sv.Detections.from_lmm( + sv.LMM.QWEN_3_VL, + qwen_3_vl_result, + resolution_wh=(1000, 1000), + classes=['cat', 'dog'], + ) + detections.xyxy + # array([[139., 768., 315., 954.], [366., 679., 536., 849.]]) + + detections.class_id + # array([0, 1]) + + detections.data + # {'class_name': array(['cat', 'dog'], dtype=' Detectio # array([[1752.28, 818.82, 2165.72, 1229.14], # [1908.01, 1346.67, 2585.99, 2024.11]]) ``` + + !!! example "DeepSeek-VL2" + + + ??? tip "Prompt engineering" + + To get the best results from DeepSeek-VL2, use optimized prompts that leverage + its object detection and visual grounding capabilities effectively. + + **For general object detection, use the following user prompt:** + + ``` + \\n<|ref|>The giraffe at the front<|/ref|> + ``` + + **For visual grounding, use the following user prompt:** + + ``` + \\n<|grounding|>Detect the giraffes + ``` + + ```python + from PIL import Image + import supervision as sv + + deepseek_vl2_result = "<|ref|>The giraffe at the back<|/ref|><|det|>[[580, 270, 999, 904]]<|/det|><|ref|>The giraffe at the front<|/ref|><|det|>[[26, 31, 632, 998]]<|/det|><|end▁of▁sentence|>" + + detections = sv.Detections.from_vlm( + vlm=sv.VLM.DEEPSEEK_VL_2, result=deepseek_vl2_result, resolution_wh=image.size + ) + + detections.xyxy + # array([[ 420, 293, 724, 982], + # [ 18, 33, 458, 1084]]) + + detections.class_id + # array([0, 1]) + + detections.data + # {'class_name': array(['The giraffe at the back', 'The giraffe at the front'], dtype=' Detectio LMM.PALIGEMMA: VLM.PALIGEMMA, LMM.FLORENCE_2: VLM.FLORENCE_2, LMM.QWEN_2_5_VL: VLM.QWEN_2_5_VL, + LMM.DEEPSEEK_VL_2: VLM.DEEPSEEK_VL_2, LMM.GOOGLE_GEMINI_2_0: VLM.GOOGLE_GEMINI_2_0, LMM.GOOGLE_GEMINI_2_5: VLM.GOOGLE_GEMINI_2_5, } @@ -1161,9 +1242,11 @@ def from_vlm(cls, vlm: VLM | str, result: str | dict, **kwargs: Any) -> Detectio | PaliGemma | `PALIGEMMA` | detection | `resolution_wh` | `classes` | | PaliGemma 2 | `PALIGEMMA` | detection | `resolution_wh` | `classes` | | Qwen2.5-VL | `QWEN_2_5_VL` | detection | `resolution_wh`, `input_wh` | `classes` | + | Qwen3-VL | `QWEN_3_VL` | detection | `resolution_wh`, | `classes` | | Google Gemini 2.0 | `GOOGLE_GEMINI_2_0` | detection | `resolution_wh` | `classes` | | Google Gemini 2.5 | `GOOGLE_GEMINI_2_5` | detection, segmentation | `resolution_wh` | `classes` | | Moondream | `MOONDREAM` | detection | `resolution_wh` | | + | DeepSeek-VL2 | `DEEPSEEK_VL_2` | detection | `resolution_wh` | `classes` | Args: vlm (Union[VLM, str]): The type of VLM (Vision Language Model) to use. @@ -1277,6 +1360,36 @@ def from_vlm(cls, vlm: VLM | str, result: str | dict, **kwargs: Any) -> Detectio # array([0, 1]) ``` + !!! example "Qwen3-VL" + + ```python + import supervision as sv + + qwen_3_vl_result = \"\"\"```json + [ + {"bbox_2d": [139, 768, 315, 954], "label": "cat"}, + {"bbox_2d": [366, 679, 536, 849], "label": "dog"} + ] + ```\"\"\" + detections = sv.Detections.from_vlm( + sv.VLM.QWEN_3_VL, + qwen_3_vl_result, + resolution_wh=(1000, 1000), + classes=['cat', 'dog'], + ) + detections.xyxy + # array([[139., 768., 315., 954.], [366., 679., 536., 849.]]) + + detections.class_id + # array([0, 1]) + + detections.data + # {'class_name': array(['cat', 'dog'], dtype=' Detectio # [1908.01, 1346.67, 2585.99, 2024.11]]) ``` + !!! example "DeepSeek-VL2" + + + ??? tip "Prompt engineering" + + To get the best results from DeepSeek-VL2, use optimized prompts that leverage + its object detection and visual grounding capabilities effectively. + + **For general object detection, use the following user prompt:** + + ``` + \\n<|ref|>The giraffe at the front<|/ref|> + ``` + + **For visual grounding, use the following user prompt:** + + ``` + \\n<|grounding|>Detect the giraffes + ``` + + ```python + from PIL import Image + import supervision as sv + + deepseek_vl2_result = "<|ref|>The giraffe at the back<|/ref|><|det|>[[580, 270, 999, 904]]<|/det|><|ref|>The giraffe at the front<|/ref|><|det|>[[26, 31, 632, 998]]<|/det|><|end▁of▁sentence|>" + + detections = sv.Detections.from_vlm( + vlm=sv.VLM.DEEPSEEK_VL_2, result=deepseek_vl2_result, resolution_wh=image.size + ) + + detections.xyxy + # array([[ 420, 293, 724, 982], + # [ 18, 33, 458, 1084]]) + + detections.class_id + # array([0, 1]) + + detections.data + # {'class_name': array(['The giraffe at the back', 'The giraffe at the front'], dtype=' Detectio if vlm == VLM.QWEN_2_5_VL: xyxy, class_id, class_name = from_qwen_2_5_vl(result, **kwargs) data = {CLASS_NAME_DATA_FIELD: class_name} + confidence = np.ones(len(xyxy), dtype=float) + return cls(xyxy=xyxy, class_id=class_id, confidence=confidence, data=data) + + if vlm == VLM.QWEN_3_VL: + xyxy, class_id, class_name = from_qwen_3_vl(result, **kwargs) + data = {CLASS_NAME_DATA_FIELD: class_name} + confidence = np.ones(len(xyxy), dtype=float) + return cls(xyxy=xyxy, class_id=class_id, confidence=confidence, data=data) + + if vlm == VLM.DEEPSEEK_VL_2: + xyxy, class_id, class_name = from_deepseek_vl_2(result, **kwargs) + data = {CLASS_NAME_DATA_FIELD: class_name} return cls(xyxy=xyxy, class_id=class_id, data=data) if vlm == VLM.FLORENCE_2: @@ -1922,6 +2088,43 @@ def box_area(self) -> np.ndarray: """ return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0]) + @property + def box_aspect_ratio(self) -> np.ndarray: + """ + Compute the aspect ratio (width divided by height) for each bounding box. + + Returns: + np.ndarray: Array of shape `(N,)` containing aspect ratios, where `N` is the + number of boxes (width / height for each box). + + Examples: + ```python + import numpy as np + import supervision as sv + + xyxy = np.array([ + [10, 10, 50, 50], + [60, 10, 180, 50], + [10, 60, 50, 180], + ]) + + detections = sv.Detections(xyxy=xyxy) + + detections.box_aspect_ratio + # array([1.0, 3.0, 0.33333333]) + + ar = detections.box_aspect_ratio + detections[(ar < 2.0) & (ar > 0.5)].xyxy + # array([[10., 10., 50., 50.]]) + ``` + """ + widths = self.xyxy[:, 2] - self.xyxy[:, 0] + heights = self.xyxy[:, 3] - self.xyxy[:, 1] + + aspect_ratios = np.full_like(widths, np.nan, dtype=np.float64) + np.divide(widths, heights, out=aspect_ratios, where=heights != 0) + return aspect_ratios + def with_nms( self, threshold: float = 0.5, diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py index aaecccb3dc..4cc05f19cc 100644 --- a/supervision/detection/tools/inference_slicer.py +++ b/supervision/detection/tools/inference_slicer.py @@ -11,11 +11,9 @@ from supervision.detection.utils.boxes import move_boxes, move_oriented_boxes from supervision.detection.utils.iou_and_nms import OverlapFilter, OverlapMetric from supervision.detection.utils.masks import move_masks -from supervision.utils.image import crop_image -from supervision.utils.internal import ( - SupervisionWarnings, - warn_deprecated, -) +from supervision.draw.base import ImageType +from supervision.utils.image import crop_image, get_image_resolution_wh +from supervision.utils.internal import SupervisionWarnings def move_detections( @@ -53,111 +51,106 @@ def move_detections( class InferenceSlicer: """ - InferenceSlicer performs slicing-based inference for small target detection. This - method, often referred to as - [Slicing Adaptive Inference (SAHI)](https://ieeexplore.ieee.org/document/9897990), - involves dividing a larger image into smaller slices, performing inference on each - slice, and then merging the detections. + Perform tiled inference on large images by slicing them into overlapping patches. + + This class divides an input image into overlapping slices of configurable size + and overlap, runs inference on each slice through a user-provided callback, and + merges the resulting detections. The slicing process allows efficient processing + of large images with limited resources while preserving detection accuracy via + configurable overlap and post-processing of overlaps. Uses multi-threading for + parallel slice inference. Args: - slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The - tuple should be in the format `(width, height)`. - overlap_ratio_wh (Optional[Tuple[float, float]]): [⚠️ Deprecated: please set - to `None` and use `overlap_wh`] A tuple representing the - desired overlap ratio for width and height between consecutive slices. - Each value should be in the range [0, 1), where 0 means no overlap and - a value close to 1 means high overlap. - overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired - overlap for width and height between consecutive slices measured in pixels. - Each value should be greater than or equal to 0. Takes precedence over - `overlap_ratio_wh`. - overlap_filter (Union[OverlapFilter, str]): Strategy for - filtering or merging overlapping detections in slices. - iou_threshold (float): Intersection over Union (IoU) threshold - used when filtering by overlap. - overlap_metric (Union[OverlapMetric, str]): Metric used for matching detections - in slices. - callback (Callable): A function that performs inference on a given image - slice and returns detections. - thread_workers (int): Number of threads for parallel execution. - - Note: - The class ensures that slices do not exceed the boundaries of the original - image. As a result, the final slices in the row and column dimensions might be - smaller than the specified slice dimensions if the image's width or height is - not a multiple of the slice's width or height minus the overlap. + callback (Callable[[ImageType], Detections]): Inference function that takes + a sliced image and returns a `Detections` object. + slice_wh (int or tuple[int, int]): Size of each slice `(width, height)`. + If int, both width and height are set to this value. + overlap_wh (int or tuple[int, int]): Overlap size `(width, height)` between + slices. If int, both width and height are set to this value. + overlap_filter (OverlapFilter or str): Strategy to merge overlapping + detections (`NON_MAX_SUPPRESSION`, `NON_MAX_MERGE`, or `NONE`). + iou_threshold (float): IOU threshold used in merging overlap filtering. + overlap_metric (OverlapMetric or str): Metric to compute overlap + (`IOU` or `IOS`). + thread_workers (int): Number of threads for concurrent slice inference. + + Raises: + ValueError: If `slice_wh` or `overlap_wh` are invalid or inconsistent. + + Example: + ```python + import cv2 + import supervision as sv + from rfdetr import RFDETRMedium + + model = RFDETRMedium() + + def callback(tile): + return model.predict(tile) + + slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100) + + image = cv2.imread("example.png") + detections = slicer(image) + ``` + + ```python + import supervision as sv + from PIL import Image + from ultralytics import YOLO + + model = YOLO("yolo11m.pt") + + def callback(tile): + results = model(tile)[0] + return sv.Detections.from_ultralytics(results) + + slicer = sv.InferenceSlicer(callback, slice_wh=640, overlap_wh=100) + + image = Image.open("example.png") + detections = slicer(image) + ``` """ def __init__( self, - callback: Callable[[np.ndarray], Detections], - slice_wh: tuple[int, int] = (320, 320), - overlap_ratio_wh: tuple[float, float] | None = (0.2, 0.2), - overlap_wh: tuple[int, int] | None = None, + callback: Callable[[ImageType], Detections], + slice_wh: int | tuple[int, int] = 640, + overlap_wh: int | tuple[int, int] = 100, overlap_filter: OverlapFilter | str = OverlapFilter.NON_MAX_SUPPRESSION, iou_threshold: float = 0.5, overlap_metric: OverlapMetric | str = OverlapMetric.IOU, thread_workers: int = 1, ): - if overlap_ratio_wh is not None: - warn_deprecated( - "`overlap_ratio_wh` in `InferenceSlicer.__init__` is deprecated and " - "will be removed in `supervision-0.27.0`. Please manually set it to " - "`None` and use `overlap_wh` instead." - ) + slice_wh_norm = self._normalize_slice_wh(slice_wh) + overlap_wh_norm = self._normalize_overlap_wh(overlap_wh) - self._validate_overlap(overlap_ratio_wh, overlap_wh) - self.overlap_ratio_wh = overlap_ratio_wh - self.overlap_wh = overlap_wh + self._validate_overlap(slice_wh=slice_wh_norm, overlap_wh=overlap_wh_norm) - self.slice_wh = slice_wh + self.slice_wh = slice_wh_norm + self.overlap_wh = overlap_wh_norm self.iou_threshold = iou_threshold self.overlap_metric = OverlapMetric.from_value(overlap_metric) self.overlap_filter = OverlapFilter.from_value(overlap_filter) self.callback = callback self.thread_workers = thread_workers - def __call__(self, image: np.ndarray) -> Detections: + def __call__(self, image: ImageType) -> Detections: """ - Performs slicing-based inference on the provided image using the specified - callback. + Perform tiled inference on the full image and return merged detections. Args: - image (np.ndarray): The input image on which inference needs to be - performed. The image should be in the format - `(height, width, channels)`. + image (ImageType): The full image to run inference on. Returns: - Detections: A collection of detections for the entire image after merging - results from all slices and applying NMS. - - Example: - ```python - import cv2 - import supervision as sv - from ultralytics import YOLO - - image = cv2.imread(SOURCE_IMAGE_PATH) - model = YOLO(...) - - def callback(image_slice: np.ndarray) -> sv.Detections: - result = model(image_slice)[0] - return sv.Detections.from_ultralytics(result) - - slicer = sv.InferenceSlicer( - callback=callback, - overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION, - ) - - detections = slicer(image) - ``` + Detections: Merged detections across all slices. """ - detections_list = [] - resolution_wh = (image.shape[1], image.shape[0]) + detections_list: list[Detections] = [] + resolution_wh = get_image_resolution_wh(image) + offsets = self._generate_offset( resolution_wh=resolution_wh, slice_wh=self.slice_wh, - overlap_ratio_wh=self.overlap_ratio_wh, overlap_wh=self.overlap_wh, ) @@ -171,129 +164,178 @@ def callback(image_slice: np.ndarray) -> sv.Detections: merged = Detections.merge(detections_list=detections_list) if self.overlap_filter == OverlapFilter.NONE: return merged - elif self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION: + if self.overlap_filter == OverlapFilter.NON_MAX_SUPPRESSION: return merged.with_nms( - threshold=self.iou_threshold, overlap_metric=self.overlap_metric + threshold=self.iou_threshold, + overlap_metric=self.overlap_metric, ) - elif self.overlap_filter == OverlapFilter.NON_MAX_MERGE: + if self.overlap_filter == OverlapFilter.NON_MAX_MERGE: return merged.with_nmm( - threshold=self.iou_threshold, overlap_metric=self.overlap_metric + threshold=self.iou_threshold, + overlap_metric=self.overlap_metric, ) - else: - warnings.warn( - f"Invalid overlap filter strategy: {self.overlap_filter}", - category=SupervisionWarnings, - ) - return merged - def _run_callback(self, image, offset) -> Detections: + warnings.warn( + f"Invalid overlap filter strategy: {self.overlap_filter}", + category=SupervisionWarnings, + ) + return merged + + def _run_callback(self, image: ImageType, offset: np.ndarray) -> Detections: """ - Run the provided callback on a slice of an image. + Run detection callback on a sliced portion of the image and adjust coordinates. Args: - image (np.ndarray): The input image on which inference needs to run - offset (np.ndarray): An array of shape `(4,)` containing coordinates - for the slice. + image (ImageType): The full image. + offset (numpy.ndarray): Coordinates `(x_min, y_min, x_max, y_max)` defining + the slice region. Returns: - Detections: A collection of detections for the slice. + Detections: Detections adjusted to the full image coordinate system. """ - image_slice = crop_image(image=image, xyxy=offset) + image_slice: ImageType = crop_image(image=image, xyxy=offset) detections = self.callback(image_slice) - resolution_wh = (image.shape[1], image.shape[0]) + resolution_wh = get_image_resolution_wh(image) + detections = move_detections( - detections=detections, offset=offset[:2], resolution_wh=resolution_wh + detections=detections, + offset=offset[:2], + resolution_wh=resolution_wh, ) - return detections + @staticmethod + def _normalize_slice_wh( + slice_wh: int | tuple[int, int], + ) -> tuple[int, int]: + if isinstance(slice_wh, int): + if slice_wh <= 0: + raise ValueError( + f"`slice_wh` must be a positive integer. Received: {slice_wh}" + ) + return slice_wh, slice_wh + + if isinstance(slice_wh, tuple) and len(slice_wh) == 2: + width, height = slice_wh + if width <= 0 or height <= 0: + raise ValueError( + f"`slice_wh` values must be positive. Received: {slice_wh}" + ) + return width, height + + raise ValueError( + "`slice_wh` must be an int or a tuple of two positive integers " + "(slice_w, slice_h). " + f"Received: {slice_wh}" + ) + + @staticmethod + def _normalize_overlap_wh( + overlap_wh: int | tuple[int, int], + ) -> tuple[int, int]: + if isinstance(overlap_wh, int): + if overlap_wh < 0: + raise ValueError( + "`overlap_wh` must be a non negative integer. " + f"Received: {overlap_wh}" + ) + return overlap_wh, overlap_wh + + if isinstance(overlap_wh, tuple) and len(overlap_wh) == 2: + overlap_w, overlap_h = overlap_wh + if overlap_w < 0 or overlap_h < 0: + raise ValueError( + f"`overlap_wh` values must be non negative. Received: {overlap_wh}" + ) + return overlap_w, overlap_h + + raise ValueError( + "`overlap_wh` must be an int or a tuple of two non negative integers " + "(overlap_w, overlap_h). " + f"Received: {overlap_wh}" + ) + @staticmethod def _generate_offset( resolution_wh: tuple[int, int], slice_wh: tuple[int, int], - overlap_ratio_wh: tuple[float, float] | None, - overlap_wh: tuple[int, int] | None, + overlap_wh: tuple[int, int], ) -> np.ndarray: """ - Generate offset coordinates for slicing an image based on the given resolution, - slice dimensions, and overlap ratios. + Generate bounding boxes defining the coordinates of image slices with overlap. Args: - resolution_wh (Tuple[int, int]): A tuple representing the width and height - of the image to be sliced. - slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The - tuple should be in the format `(width, height)`. - overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the - desired overlap ratio for width and height between consecutive slices. - Each value should be in the range [0, 1), where 0 means no overlap and - a value close to 1 means high overlap. - overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired - overlap for width and height between consecutive slices measured in - pixels. Each value should be greater than or equal to 0. + resolution_wh (tuple[int, int]): Image resolution `(width, height)`. + slice_wh (tuple[int, int]): Size of each slice `(width, height)`. + overlap_wh (tuple[int, int]): Overlap size between slices `(width, height)`. Returns: - np.ndarray: An array of shape `(n, 4)` containing coordinates for each - slice in the format `[xmin, ymin, xmax, ymax]`. - - Note: - The function ensures that slices do not exceed the boundaries of the - original image. As a result, the final slices in the row and column - dimensions might be smaller than the specified slice dimensions if the - image's width or height is not a multiple of the slice's width or - height minus the overlap. + numpy.ndarray: Array of shape `(num_slices, 4)` with each row as + `(x_min, y_min, x_max, y_max)` coordinates for a slice. """ slice_width, slice_height = slice_wh image_width, image_height = resolution_wh - overlap_width = ( - overlap_wh[0] - if overlap_wh is not None - else int(overlap_ratio_wh[0] * slice_width) + overlap_width, overlap_height = overlap_wh + + stride_x = slice_width - overlap_width + stride_y = slice_height - overlap_height + + def _compute_axis_starts( + image_size: int, + slice_size: int, + stride: int, + ) -> list[int]: + if image_size <= slice_size: + return [0] + + if stride == slice_size: + return np.arange(0, image_size, stride).tolist() + + last_start = image_size - slice_size + starts = np.arange(0, last_start, stride).tolist() + if not starts or starts[-1] != last_start: + starts.append(last_start) + return starts + + x_starts = _compute_axis_starts( + image_size=image_width, + slice_size=slice_width, + stride=stride_x, ) - overlap_height = ( - overlap_wh[1] - if overlap_wh is not None - else int(overlap_ratio_wh[1] * slice_height) + y_starts = _compute_axis_starts( + image_size=image_height, + slice_size=slice_height, + stride=stride_y, ) - width_stride = slice_width - overlap_width - height_stride = slice_height - overlap_height + x_min, y_min = np.meshgrid(x_starts, y_starts) + x_max = np.clip(x_min + slice_width, 0, image_width) + y_max = np.clip(y_min + slice_height, 0, image_height) - ws = np.arange(0, image_width, width_stride) - hs = np.arange(0, image_height, height_stride) - - xmin, ymin = np.meshgrid(ws, hs) - xmax = np.clip(xmin + slice_width, 0, image_width) - ymax = np.clip(ymin + slice_height, 0, image_height) - - offsets = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4) + offsets = np.stack( + [x_min, y_min, x_max, y_max], + axis=-1, + ).reshape(-1, 4) return offsets @staticmethod def _validate_overlap( - overlap_ratio_wh: tuple[float, float] | None, - overlap_wh: tuple[int, int] | None, + slice_wh: tuple[int, int], + overlap_wh: tuple[int, int], ) -> None: - if overlap_ratio_wh is not None and overlap_wh is not None: + overlap_w, overlap_h = overlap_wh + slice_w, slice_h = slice_wh + + if overlap_w < 0 or overlap_h < 0: raise ValueError( - "Both `overlap_ratio_wh` and `overlap_wh` cannot be provided. " - "Please provide only one of them." + "Overlap values must be greater than or equal to 0. " + f"Received: {overlap_wh}" ) - if overlap_ratio_wh is None and overlap_wh is None: + + if overlap_w >= slice_w or overlap_h >= slice_h: raise ValueError( - "Either `overlap_ratio_wh` or `overlap_wh` must be provided. " - "Please provide one of them." + "`overlap_wh` must be smaller than `slice_wh` in both dimensions " + f"to keep a positive stride. Received overlap_wh={overlap_wh}, " + f"slice_wh={slice_wh}." ) - - if overlap_ratio_wh is not None: - if not (0 <= overlap_ratio_wh[0] < 1 and 0 <= overlap_ratio_wh[1] < 1): - raise ValueError( - "Overlap ratios must be in the range [0, 1). " - f"Received: {overlap_ratio_wh}" - ) - if overlap_wh is not None: - if not (overlap_wh[0] >= 0 and overlap_wh[1] >= 0): - raise ValueError( - "Overlap values must be greater than or equal to 0. " - f"Received: {overlap_wh}" - ) diff --git a/supervision/detection/utils/boxes.py b/supervision/detection/utils/boxes.py index 3b01fcb68b..52e4b69569 100644 --- a/supervision/detection/utils/boxes.py +++ b/supervision/detection/utils/boxes.py @@ -14,8 +14,8 @@ def clip_boxes(xyxy: np.ndarray, resolution_wh: tuple[int, int]) -> np.ndarray: xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each row corresponds to a bounding box in the format `(x_min, y_min, x_max, y_max)`. - resolution_wh (Tuple[int, int]): A tuple of the form `(width, height)` - representing the resolution of the frame. + resolution_wh (Tuple[int, int]): A tuple of the form + `(width, height)` representing the resolution of the frame. Returns: np.ndarray: A numpy array of shape `(N, 4)` where each row @@ -95,24 +95,27 @@ def pad_boxes(xyxy: np.ndarray, px: int, py: int | None = None) -> np.ndarray: def denormalize_boxes( - normalized_xyxy: np.ndarray, + xyxy: np.ndarray, resolution_wh: tuple[int, int], normalization_factor: float = 1.0, ) -> np.ndarray: """ - Converts normalized bounding box coordinates to absolute pixel values. + Convert normalized bounding box coordinates to absolute pixel coordinates. + + Multiplies each bounding box coordinate by image size and divides by + `normalization_factor`, mapping values from normalized `[0, normalization_factor]` + to absolute pixel values for a given resolution. Args: - normalized_xyxy (np.ndarray): A numpy array of shape `(N, 4)` where each row - contains normalized coordinates in the format `(x_min, y_min, x_max, y_max)`, - with values between 0 and `normalization_factor`. - resolution_wh (Tuple[int, int]): A tuple `(width, height)` representing the - target image resolution. - normalization_factor (float, optional): The normalization range of the input - coordinates. Defaults to 1.0. + xyxy (`numpy.ndarray`): Normalized bounding boxes of shape `(N, 4)`, + where each row is `(x_min, y_min, x_max, y_max)`, values in + `[0, normalization_factor]`. + resolution_wh (`tuple[int, int]`): Target image resolution as `(width, height)`. + normalization_factor (`float`): Maximum value of input coordinate range. + Defaults to `1.0`. Returns: - np.ndarray: An array of shape `(N, 4)` with absolute coordinates in + (`numpy.ndarray`): Array of shape `(N, 4)` with absolute coordinates in `(x_min, y_min, x_max, y_max)` format. Examples: @@ -120,35 +123,39 @@ def denormalize_boxes( import numpy as np import supervision as sv - # Default normalization (0-1) - normalized_xyxy = np.array([ + xyxy = np.array([ [0.1, 0.2, 0.5, 0.6], - [0.3, 0.4, 0.7, 0.8] + [0.3, 0.4, 0.7, 0.8], + [0.2, 0.1, 0.6, 0.5] ]) - resolution_wh = (100, 200) - sv.denormalize_boxes(normalized_xyxy, resolution_wh) + + sv.denormalize_boxes(xyxy, (1280, 720)) # array([ - # [ 10., 40., 50., 120.], - # [ 30., 80., 70., 160.] + # [128., 144., 640., 432.], + # [384., 288., 896., 576.], + # [256., 72., 768., 360.] # ]) + ``` - # Custom normalization (0-100) - normalized_xyxy = np.array([ - [10., 20., 50., 60.], - [30., 40., 70., 80.] + ``` + import numpy as np + import supervision as sv + + xyxy = np.array([ + [256., 128., 768., 640.] ]) - sv.denormalize_boxes(normalized_xyxy, resolution_wh, normalization_factor=100.0) + + sv.denormalize_boxes(xyxy, (1280, 720), normalization_factor=1024.0) # array([ - # [ 10., 40., 50., 120.], - # [ 30., 80., 70., 160.] + # [320., 90., 960., 450.] # ]) ``` - """ # noqa E501 // docs + """ width, height = resolution_wh - result = normalized_xyxy.copy() + result = xyxy.copy() - result[[0, 2]] = (result[[0, 2]] * width) / normalization_factor - result[[1, 3]] = (result[[1, 3]] * height) / normalization_factor + result[:, [0, 2]] = (result[:, [0, 2]] * width) / normalization_factor + result[:, [1, 3]] = (result[:, [1, 3]] * height) / normalization_factor return result diff --git a/supervision/detection/utils/converters.py b/supervision/detection/utils/converters.py index 9e02783a04..4aef2dc87c 100644 --- a/supervision/detection/utils/converters.py +++ b/supervision/detection/utils/converters.py @@ -229,6 +229,70 @@ def mask_to_xyxy(masks: np.ndarray) -> np.ndarray: return xyxy +def xyxy_to_mask(boxes: np.ndarray, resolution_wh: tuple[int, int]) -> np.ndarray: + """ + Converts a 2D `np.ndarray` of bounding boxes into a 3D `np.ndarray` of bool masks. + + Parameters: + boxes (np.ndarray): A 2D `np.ndarray` of shape `(N, 4)` + containing bounding boxes `(x_min, y_min, x_max, y_max)` + resolution_wh (Tuple[int, int]): A tuple `(width, height)` specifying + the resolution of the output masks + + Returns: + np.ndarray: A 3D `np.ndarray` of shape `(N, height, width)` + containing 2D bool masks for each bounding box + + Examples: + ```python + import numpy as np + import supervision as sv + + boxes = np.array([[0, 0, 2, 2]]) + + sv.xyxy_to_mask(boxes, (5, 5)) + # array([ + # [[ True, True, True, False, False], + # [ True, True, True, False, False], + # [ True, True, True, False, False], + # [False, False, False, False, False], + # [False, False, False, False, False]] + # ]) + + boxes = np.array([[0, 0, 1, 1], [3, 3, 4, 4]]) + + sv.xyxy_to_mask(boxes, (5, 5)) + # array([ + # [[ True, True, False, False, False], + # [ True, True, False, False, False], + # [False, False, False, False, False], + # [False, False, False, False, False], + # [False, False, False, False, False]], + # + # [[False, False, False, False, False], + # [False, False, False, False, False], + # [False, False, False, False, False], + # [False, False, False, True, True], + # [False, False, False, True, True]] + # ]) + ``` + """ + width, height = resolution_wh + n = boxes.shape[0] + masks = np.zeros((n, height, width), dtype=bool) + + for i, (x_min, y_min, x_max, y_max) in enumerate(boxes): + x_min = max(0, int(x_min)) + y_min = max(0, int(y_min)) + x_max = min(width - 1, int(x_max)) + y_max = min(height - 1, int(y_max)) + + if x_max >= x_min and y_max >= y_min: + masks[i, y_min : y_max + 1, x_min : x_max + 1] = True + + return masks + + def mask_to_polygons(mask: np.ndarray) -> list[np.ndarray]: """ Converts a binary mask to a list of polygons. diff --git a/supervision/detection/utils/internal.py b/supervision/detection/utils/internal.py index bc6579a8b9..f5d6dc9fbf 100644 --- a/supervision/detection/utils/internal.py +++ b/supervision/detection/utils/internal.py @@ -271,7 +271,6 @@ def merge_metadata(metadata_list: list[dict[str, Any]]) -> dict[str, Any]: "{type(value)}, {type(other_value)}." ) else: - print("hm") if merged_metadata[key] != value: raise ValueError(f"Conflicting metadata for key: '{key}'.") diff --git a/supervision/detection/utils/iou_and_nms.py b/supervision/detection/utils/iou_and_nms.py index 1a6f80bc58..8a444cf16e 100644 --- a/supervision/detection/utils/iou_and_nms.py +++ b/supervision/detection/utils/iou_and_nms.py @@ -64,7 +64,7 @@ class OverlapMetric(Enum): IOS = "IOS" @classmethod - def list(cls): + def list(cls) -> list[str]: return list(map(lambda c: c.value, cls)) @classmethod @@ -72,7 +72,7 @@ def from_value(cls, value: OverlapMetric | str) -> OverlapMetric: if isinstance(value, cls): return value if isinstance(value, str): - value = value.lower() + value = value.upper() try: return cls(value) except ValueError: @@ -86,91 +86,107 @@ def from_value(cls, value: OverlapMetric | str) -> OverlapMetric: def box_iou( box_true: list[float] | np.ndarray, box_detection: list[float] | np.ndarray, + overlap_metric: OverlapMetric | str = OverlapMetric.IOU, ) -> float: - r""" - Compute the Intersection over Union (IoU) between two bounding boxes. - - \[ - \text{IoU} = \frac{|\text{box}_{\text{true}} \cap \text{box}_{\text{detection}}|}{|\text{box}_{\text{true}} \cup \text{box}_{\text{detection}}|} - \] + """ + Compute overlap metric between two bounding boxes. - Note: - Use `box_iou` when computing IoU between two individual boxes. - For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better - performance. + Supports standard IOU (intersection-over-union) and IOS + (intersection-over-smaller-area) metrics. Returns the overlap value in range + `[0, 1]`. Args: - box_true (Union[List[float], np.ndarray]): A single bounding box represented as - [x_min, y_min, x_max, y_max]. - box_detection (Union[List[float], np.ndarray]): - A single bounding box represented as [x_min, y_min, x_max, y_max]. + box_true (`list[float]` or `numpy.array`): Ground truth box in format + `(x_min, y_min, x_max, y_max)`. + box_detection (`list[float]` or `numpy.array`): Detected box in format + `(x_min, y_min, x_max, y_max)`. + overlap_metric (`OverlapMetric` or `str`): Overlap type. + Use `OverlapMetric.IOU` for IOU or + `OverlapMetric.IOS` for IOS. Defaults to `OverlapMetric.IOU`. Returns: - IoU (float): IoU score between the two boxes. Ranges from 0.0 (no overlap) - to 1.0 (perfect overlap). + (`float`): Overlap value between boxes in `[0, 1]`. + + Raises: + ValueError: If `overlap_metric` is not IOU or IOS. Examples: - ```python - import numpy as np + ``` import supervision as sv - box_true = np.array([100, 100, 200, 200]) - box_detection = np.array([150, 150, 250, 250]) + box_true = [100, 100, 200, 200] + box_detection = [150, 150, 250, 250] - sv.box_iou(box_true=box_true, box_detection=box_detection) - # 0.14285814285714285 + sv.box_iou(box_true, box_detection, overlap_metric=sv.OverlapMetric.IOU) + # 0.14285714285714285 + + sv.box_iou(box_true, box_detection, overlap_metric=sv.OverlapMetric.IOS) + # 0.25 ``` - """ # noqa: E501 - box_true = np.array(box_true) - box_detection = np.array(box_detection) + """ + overlap_metric = OverlapMetric.from_value(overlap_metric) + x_min_true, y_min_true, x_max_true, y_max_true = np.array(box_true) + x_min_det, y_min_det, x_max_det, y_max_det = np.array(box_detection) - inter_x1 = max(box_true[0], box_detection[0]) - inter_y1 = max(box_true[1], box_detection[1]) - inter_x2 = min(box_true[2], box_detection[2]) - inter_y2 = min(box_true[3], box_detection[3]) + x_min_inter = max(x_min_true, x_min_det) + y_min_inter = max(y_min_true, y_min_det) + x_max_inter = min(x_max_true, x_max_det) + y_max_inter = min(y_max_true, y_max_det) - inter_w = max(0, inter_x2 - inter_x1) - inter_h = max(0, inter_y2 - inter_y1) + inter_w = max(0.0, x_max_inter - x_min_inter) + inter_h = max(0.0, y_max_inter - y_min_inter) - inter_area = inter_w * inter_h + area_inter = inter_w * inter_h - area_true = (box_true[2] - box_true[0]) * (box_true[3] - box_true[1]) - area_detection = (box_detection[2] - box_detection[0]) * ( - box_detection[3] - box_detection[1] - ) + area_true = (x_max_true - x_min_true) * (y_max_true - y_min_true) + area_det = (x_max_det - x_min_det) * (y_max_det - y_min_det) - union_area = area_true + area_detection - inter_area + if overlap_metric == OverlapMetric.IOU: + area_norm = area_true + area_det - area_inter + elif overlap_metric == OverlapMetric.IOS: + area_norm = min(area_true, area_det) + else: + raise ValueError( + f"overlap_metric {overlap_metric} is not supported, " + "only 'IOU' and 'IOS' are supported" + ) - return inter_area / union_area + 1e-6 + if area_norm <= 0.0: + return 0.0 + + return float(area_inter / area_norm) def box_iou_batch( boxes_true: np.ndarray, boxes_detection: np.ndarray, - overlap_metric: OverlapMetric = OverlapMetric.IOU, + overlap_metric: OverlapMetric | str = OverlapMetric.IOU, ) -> np.ndarray: """ - Compute Intersection over Union (IoU) of two sets of bounding boxes - - `boxes_true` and `boxes_detection`. Both sets - of boxes are expected to be in `(x_min, y_min, x_max, y_max)` format. + Compute pairwise overlap scores between batches of bounding boxes. - Note: - Use `box_iou` when computing IoU between two individual boxes. - For comparing multiple boxes (arrays of boxes), use `box_iou_batch` for better - performance. + Supports standard IOU (intersection-over-union) and IOS + (intersection-over-smaller-area) metrics for all `boxes_true` and + `boxes_detection` pairs. Returns a matrix of overlap values in range + `[0, 1]`, matching each box from the first batch to each from the second. Args: - boxes_true (np.ndarray): 2D `np.ndarray` representing ground-truth boxes. - `shape = (N, 4)` where `N` is number of true objects. - boxes_detection (np.ndarray): 2D `np.ndarray` representing detection boxes. - `shape = (M, 4)` where `M` is number of detected objects. - overlap_metric (OverlapMetric): Metric used to compute the degree of overlap - between pairs of boxes (e.g., IoU, IoS). + boxes_true (`numpy.array`): Array of reference boxes in + shape `(N, 4)` as `(x_min, y_min, x_max, y_max)`. + boxes_detection (`numpy.array`): Array of detected boxes in + shape `(M, 4)` as `(x_min, y_min, x_max, y_max)`. + overlap_metric (`OverlapMetric` or `str`): Overlap type. + Use `OverlapMetric.IOU` for intersection-over-union, + `OverlapMetric.IOS` for intersection-over-smaller-area. + Defaults to `OverlapMetric.IOU`. Returns: - np.ndarray: Pairwise IoU of boxes from `boxes_true` and `boxes_detection`. - `shape = (N, M)` where `N` is number of true objects and - `M` is number of detected objects. + (`numpy.array`): Overlap matrix of shape `(N, M)`, where entry + `[i, j]` is the overlap between `boxes_true[i]` and + `boxes_detection[j]`. + + Raises: + ValueError: If `overlap_metric` is not IOU or IOS. Examples: ```python @@ -186,49 +202,57 @@ def box_iou_batch( [320, 320, 420, 420] ]) - sv.box_iou_batch(boxes_true=boxes_true, boxes_detection=boxes_detection) - # array([ - # [0.14285714, 0. ], - # [0. , 0.47058824] - # ]) + sv.box_iou_batch(boxes_true, boxes_detection, overlap_metric=OverlapMetric.IOU) + # array([[0.14285715, 0. ], + # [0. , 0.47058824]]) + + sv.box_iou_batch(boxes_true, boxes_detection, overlap_metric=OverlapMetric.IOS) + # array([[0.25, 0. ], + # [0. , 0.64]]) ``` """ + overlap_metric = OverlapMetric.from_value(overlap_metric) + x_min_true, y_min_true, x_max_true, y_max_true = boxes_true.T + x_min_det, y_min_det, x_max_det, y_max_det = boxes_detection.T + count_true, count_det = boxes_true.shape[0], boxes_detection.shape[0] + + if count_true == 0 or count_det == 0: + return np.empty((count_true, count_det), dtype=np.float32) - def box_area(box): - return (box[2] - box[0]) * (box[3] - box[1]) + x_min_inter = np.empty((count_true, count_det), dtype=np.float32) + x_max_inter = np.empty_like(x_min_inter) + y_min_inter = np.empty_like(x_min_inter) + y_max_inter = np.empty_like(x_min_inter) - area_true = box_area(boxes_true.T) - area_detection = box_area(boxes_detection.T) + np.maximum(x_min_true[:, None], x_min_det[None, :], out=x_min_inter) + np.minimum(x_max_true[:, None], x_max_det[None, :], out=x_max_inter) + np.maximum(y_min_true[:, None], y_min_det[None, :], out=y_min_inter) + np.minimum(y_max_true[:, None], y_max_det[None, :], out=y_max_inter) - top_left = np.maximum(boxes_true[:, None, :2], boxes_detection[:, :2]) - bottom_right = np.minimum(boxes_true[:, None, 2:], boxes_detection[:, 2:]) + # we reuse x_max_inter and y_max_inter to store inter_w and inter_h + np.subtract(x_max_inter, x_min_inter, out=x_max_inter) # inter_w + np.subtract(y_max_inter, y_min_inter, out=y_max_inter) # inter_h + np.clip(x_max_inter, 0.0, None, out=x_max_inter) + np.clip(y_max_inter, 0.0, None, out=y_max_inter) - area_inter = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None), 2) + area_inter = x_max_inter * y_max_inter # inter_w * inter_h + + area_true = (x_max_true - x_min_true) * (y_max_true - y_min_true) + area_det = (x_max_det - x_min_det) * (y_max_det - y_min_det) if overlap_metric == OverlapMetric.IOU: - union_area = area_true[:, None] + area_detection - area_inter - ious = np.divide( - area_inter, - union_area, - out=np.zeros_like(area_inter, dtype=float), - where=union_area != 0, - ) + area_norm = area_true[:, None] + area_det[None, :] - area_inter elif overlap_metric == OverlapMetric.IOS: - small_area = np.minimum(area_true[:, None], area_detection) - ious = np.divide( - area_inter, - small_area, - out=np.zeros_like(area_inter, dtype=float), - where=small_area != 0, - ) + area_norm = np.minimum(area_true[:, None], area_det[None, :]) else: raise ValueError( f"overlap_metric {overlap_metric} is not supported, " "only 'IOU' and 'IOS' are supported" ) - ious = np.nan_to_num(ious) - return ious + out = np.zeros_like(area_inter, dtype=np.float32) + np.divide(area_inter, area_norm, out=out, where=area_norm > 0) + return out def _jaccard(box_a: list[float], box_b: list[float], is_crowd: bool) -> float: diff --git a/supervision/detection/utils/masks.py b/supervision/detection/utils/masks.py index c5cfee0172..04502dba05 100644 --- a/supervision/detection/utils/masks.py +++ b/supervision/detection/utils/masks.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Literal + import cv2 import numpy as np import numpy.typing as npt @@ -260,3 +262,139 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: resized_masks = masks[:, yv, xv] return resized_masks.reshape(masks.shape[0], new_height, new_width) + + +def filter_segments_by_distance( + mask: npt.NDArray[np.bool_], + absolute_distance: float | None = 100.0, + relative_distance: float | None = None, + connectivity: int = 8, + mode: Literal["edge", "centroid"] = "edge", +) -> npt.NDArray[np.bool_]: + """ + Keep the largest connected component and any other components within a distance + threshold. + + Distance can be absolute in pixels or relative to the image diagonal. + + Args: + mask: Boolean mask HxW. + absolute_distance: Max allowed distance in pixels to the main component. + Ignored if `relative_distance` is provided. + relative_distance: Fraction of the diagonal. If set, threshold = fraction * sqrt(H^2 + W^2). + connectivity: Defines which neighboring pixels are considered connected. + - 4-connectedness: Only orthogonal neighbors. + ``` + [ ][X][ ] + [X][O][X] + [ ][X][ ] + ``` + - 8-connectedness: Includes diagonal neighbors. + ``` + [X][X][X] + [X][O][X] + [X][X][X] + ``` + Default is 8. + mode: Defines how distance between components is measured. + - "edge": Uses distance between nearest edges (via distance transform). + - "centroid": Uses distance between component centroids. + + Returns: + Boolean mask after filtering. + + Examples: + ```python + import numpy as np + import supervision as sv + + mask = np.array([ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], dtype=bool) + + sv.filter_segments_by_distance( + mask, + absolute_distance=2, + mode="edge", + connectivity=8 + ).astype(int) + + # np.array([ + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + # ], dtype=bool) + + # The nearby 2Γ—2 block at columns 6–7 is kept because its edge distance + # is within 2 pixels. The distant block at columns 9-10 is removed. + ``` + """ # noqa E501 // docs + if mask.dtype != bool: + raise TypeError("mask must be boolean") + + height, width = mask.shape + if not np.any(mask): + return mask.copy() + + image = mask.astype(np.uint8) + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + image, connectivity=connectivity + ) + + if num_labels <= 1: + return mask.copy() + + areas = stats[1:, cv2.CC_STAT_AREA] + main_label = 1 + int(np.argmax(areas)) + + if relative_distance is not None: + diagonal = float(np.hypot(height, width)) + threshold = float(relative_distance) * diagonal + else: + threshold = float(absolute_distance) + + keep_labels = np.zeros(num_labels, dtype=bool) + keep_labels[main_label] = True + + if mode == "centroid": + differences = centroids[1:] - centroids[main_label] + distances = np.sqrt(np.sum(differences**2, axis=1)) + nearby = 1 + np.where(distances <= threshold)[0] + keep_labels[nearby] = True + elif mode == "edge": + main_mask = (labels == main_label).astype(np.uint8) + inverse = 1 - main_mask + distance_transform = cv2.distanceTransform(inverse, cv2.DIST_L2, 3) + for label in range(1, num_labels): + if label == main_label: + continue + component = labels == label + if not np.any(component): + continue + min_distance = float(distance_transform[component].min()) + if min_distance <= threshold: + keep_labels[label] = True + else: + raise ValueError("mode must be 'edge' or 'centroid'") + + return keep_labels[labels] diff --git a/supervision/detection/utils/vlms.py b/supervision/detection/utils/vlms.py new file mode 100644 index 0000000000..8c4aca74aa --- /dev/null +++ b/supervision/detection/utils/vlms.py @@ -0,0 +1,105 @@ +from __future__ import annotations + + +def edit_distance(string_1: str, string_2: str, case_sensitive: bool = True) -> int: + """ + Calculates the minimum number of single-character edits required + to transform one string into another. Allowed operations are insertion, + deletion, and substitution. + + Args: + string_1 (str): The source string to be transformed. + string_2 (str): The target string to transform into. + case_sensitive (bool, optional): Whether comparison should be case-sensitive. + Defaults to True. + + Returns: + int: The minimum number of edits required to convert `string_1` + into `string_2`. + + Examples: + ```python + import supervision as sv + + sv.edit_distance("hello", "hello") + # 0 + + sv.edit_distance("Test", "test", case_sensitive=True) + # 1 + + sv.edit_distance("abc", "xyz") + # 3 + + sv.edit_distance("hello", "") + # 5 + + sv.edit_distance("", "") + # 0 + + sv.edit_distance("hello world", "helloworld") + # 1 + ``` + """ + if not case_sensitive: + string_1 = string_1.lower() + string_2 = string_2.lower() + + if len(string_1) < len(string_2): + string_1, string_2 = string_2, string_1 + + prev_row = list(range(len(string_2) + 1)) + curr_row = [0] * (len(string_2) + 1) + + for i in range(1, len(string_1) + 1): + curr_row[0] = i + for j in range(1, len(string_2) + 1): + if string_1[i - 1] == string_2[j - 1]: + substitution_cost = 0 + else: + substitution_cost = 1 + curr_row[j] = min( + prev_row[j] + 1, + curr_row[j - 1] + 1, + prev_row[j - 1] + substitution_cost, + ) + prev_row, curr_row = curr_row, prev_row + + return prev_row[len(string_2)] + + +def fuzzy_match_index( + candidates: list[str], + query: str, + threshold: int, + case_sensitive: bool = True, +) -> int | None: + """ + Searches for the first string in `candidates` whose edit distance + to `query` is less than or equal to `threshold`. + + Args: + candidates (list[str]): List of strings to search. + query (str): String to compare against the candidates. + threshold (int): Maximum allowed edit distance for a match. + case_sensitive (bool, optional): Whether matching should be case-sensitive. + + Returns: + Optional[int]: Index of the first matching string in candidates, + or None if no match is found. + + Examples: + ```python + fuzzy_match_index(["cat", "dog", "rat"], "dat", threshold=1) + # 0 + + fuzzy_match_index(["alpha", "beta", "gamma"], "bata", threshold=1) + # 1 + + fuzzy_match_index(["one", "two", "three"], "ten", threshold=2) + # None + ``` + """ + for idx, candidate in enumerate(candidates): + if edit_distance(candidate, query, case_sensitive=case_sensitive) <= threshold: + return idx + return None diff --git a/supervision/detection/vlm.py b/supervision/detection/vlm.py index 1a5ad231e6..97988c9f09 100644 --- a/supervision/detection/vlm.py +++ b/supervision/detection/vlm.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import base64 import io import json @@ -27,7 +28,8 @@ class LMM(Enum): Attributes: PALIGEMMA: Google's PaliGemma vision-language model. FLORENCE_2: Microsoft's Florence-2 vision-language model. - QWEN_2_5_VL: Qwen2.5-VL open vision-language model from Alibaba. + QWEN_2_5_VL: Qwen2.5-VL open vision-language model from Alibaba.\ + QWEN_3_VL: Qwen3-VL open vision-language model from Alibaba. GOOGLE_GEMINI_2_0: Google Gemini 2.0 vision-language model. GOOGLE_GEMINI_2_5: Google Gemini 2.5 vision-language model. MOONDREAM: The Moondream vision-language model. @@ -36,6 +38,8 @@ class LMM(Enum): PALIGEMMA = "paligemma" FLORENCE_2 = "florence_2" QWEN_2_5_VL = "qwen_2_5_vl" + QWEN_3_VL = "qwen_3_vl" + DEEPSEEK_VL_2 = "deepseek_vl_2" GOOGLE_GEMINI_2_0 = "gemini_2_0" GOOGLE_GEMINI_2_5 = "gemini_2_5" MOONDREAM = "moondream" @@ -68,6 +72,7 @@ class VLM(Enum): PALIGEMMA: Google's PaliGemma vision-language model. FLORENCE_2: Microsoft's Florence-2 vision-language model. QWEN_2_5_VL: Qwen2.5-VL open vision-language model from Alibaba. + QWEN_3_VL: Qwen3-VL open vision-language model from Alibaba. GOOGLE_GEMINI_2_0: Google Gemini 2.0 vision-language model. GOOGLE_GEMINI_2_5: Google Gemini 2.5 vision-language model. MOONDREAM: The Moondream vision-language model. @@ -76,6 +81,8 @@ class VLM(Enum): PALIGEMMA = "paligemma" FLORENCE_2 = "florence_2" QWEN_2_5_VL = "qwen_2_5_vl" + QWEN_3_VL = "qwen_3_vl" + DEEPSEEK_VL_2 = "deepseek_vl_2" GOOGLE_GEMINI_2_0 = "gemini_2_0" GOOGLE_GEMINI_2_5 = "gemini_2_5" MOONDREAM = "moondream" @@ -104,6 +111,8 @@ def from_value(cls, value: VLM | str) -> VLM: VLM.PALIGEMMA: str, VLM.FLORENCE_2: dict, VLM.QWEN_2_5_VL: str, + VLM.QWEN_3_VL: str, + VLM.DEEPSEEK_VL_2: str, VLM.GOOGLE_GEMINI_2_0: str, VLM.GOOGLE_GEMINI_2_5: str, VLM.MOONDREAM: dict, @@ -113,6 +122,8 @@ def from_value(cls, value: VLM | str) -> VLM: VLM.PALIGEMMA: ["resolution_wh"], VLM.FLORENCE_2: ["resolution_wh"], VLM.QWEN_2_5_VL: ["input_wh", "resolution_wh"], + VLM.QWEN_3_VL: ["resolution_wh"], + VLM.DEEPSEEK_VL_2: ["resolution_wh"], VLM.GOOGLE_GEMINI_2_0: ["resolution_wh"], VLM.GOOGLE_GEMINI_2_5: ["resolution_wh"], VLM.MOONDREAM: ["resolution_wh"], @@ -122,6 +133,8 @@ def from_value(cls, value: VLM | str) -> VLM: VLM.PALIGEMMA: ["resolution_wh", "classes"], VLM.FLORENCE_2: ["resolution_wh"], VLM.QWEN_2_5_VL: ["input_wh", "resolution_wh", "classes"], + VLM.QWEN_3_VL: ["resolution_wh", "classes"], + VLM.DEEPSEEK_VL_2: ["resolution_wh", "classes"], VLM.GOOGLE_GEMINI_2_0: ["resolution_wh", "classes"], VLM.GOOGLE_GEMINI_2_5: ["resolution_wh", "classes"], VLM.MOONDREAM: ["resolution_wh"], @@ -230,6 +243,51 @@ def from_paligemma( return xyxy, class_id, class_name +def recover_truncated_qwen_2_5_vl_response(text: str) -> Any | None: + """ + Attempt to recover and parse a truncated or malformed JSON snippet from Qwen-2.5-VL + output. + + This utility extracts a JSON-like portion from a string that may be truncated or + malformed, cleans trailing commas, and attempts to parse it into a Python object. + + Args: + text (str): Raw text containing the JSON snippet possibly truncated or + incomplete. + + Returns: + Parsed Python object (usually list) if recovery and parsing succeed; + otherwise `None`. + """ + try: + first_bracket = text.find("[") + if first_bracket == -1: + return None + snippet = text[first_bracket:] + + last_brace = snippet.rfind("}") + if last_brace == -1: + return None + + snippet = snippet[: last_brace + 1] + + prefix_end = snippet.find("[") + if prefix_end == -1: + return None + + prefix = snippet[: prefix_end + 1] + body = snippet[prefix_end + 1 :].rstrip() + + if body.endswith(","): + body = body[:-1].rstrip() + + repaired = prefix + body + "]" + + return json.loads(repaired) + except Exception: + return None + + def from_qwen_2_5_vl( result: str, input_wh: tuple[int, int], @@ -237,7 +295,7 @@ def from_qwen_2_5_vl( classes: list[str] | None = None, ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: """ - Parse and scale bounding boxes from Qwen-2.5-VL style JSON output. + Parse and rescale bounding boxes and class labels from Qwen-2.5-VL JSON output. The JSON is expected to be enclosed in triple backticks with the format: ```json @@ -248,38 +306,52 @@ def from_qwen_2_5_vl( ``` Args: - result: String containing the JSON snippet enclosed by triple backticks. - input_wh: (input_width, input_height) describing the original bounding box - scale. - resolution_wh: (output_width, output_height) to which we rescale the boxes. - classes: Optional list of valid class names. If provided, returned boxes/labels - are filtered to only those classes found here. + result (str): String containing Qwen-2.5-VL JSON bounding box and label data. + input_wh (tuple[int, int]): Width and height of the coordinate space where boxes + are normalized. + resolution_wh (tuple[int, int]): Target width and height to scale bounding + boxes. + classes (list[str] or None): Optional list of valid class names to filter + results. If provided, only boxes with labels in this list are returned. Returns: - xyxy (np.ndarray): An array of shape `(n, 4)` containing - the bounding boxes coordinates in format `[x1, y1, x2, y2]` - class_id (Optional[np.ndarray]): An array of shape `(n,)` containing - the class indices for each bounding box (or None if `classes` is not - provided) - class_name (np.ndarray): An array of shape `(n,)` containing - the class labels for each bounding box + xyxy (np.ndarray): Array of shape `(N, 4)` with rescaled bounding boxes in + `(x_min, y_min, x_max, y_max)` format. + class_id (np.ndarray or None): Array of shape `(N,)` with indices of classes, + or `None` if no filtering applied. + class_name (np.ndarray): Array of shape `(N,)` with class names as strings. """ in_w, in_h = validate_resolution(input_wh) out_w, out_h = validate_resolution(resolution_wh) - pattern = re.compile(r"```json\s*(.*?)\s*```", re.DOTALL) - - match = pattern.search(result) - if not match: - return np.empty((0, 4)), None, np.empty((0,), dtype=str) + text = result.strip() + text = re.sub(r"^```(json)?", "", text, flags=re.IGNORECASE).strip() + text = re.sub(r"```$", "", text).strip() - json_snippet = match.group(1) + start = text.find("[") + end = text.rfind("]") + if start != -1 and end != -1 and end > start: + text = text[start : end + 1].strip() try: - data = json.loads(json_snippet) + data = json.loads(text) except json.JSONDecodeError: - return np.empty((0, 4)), None, np.empty((0,), dtype=str) + repaired = recover_truncated_qwen_2_5_vl_response(text) + if repaired is not None: + data = repaired + else: + try: + data = ast.literal_eval(text) + except (ValueError, SyntaxError, TypeError): + return ( + np.empty((0, 4)), + np.empty((0,), dtype=int), + np.empty((0,), dtype=str), + ) + + if not isinstance(data, list): + return (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty((0,), dtype=str)) boxes_list = [] labels_list = [] @@ -291,7 +363,7 @@ def from_qwen_2_5_vl( labels_list.append(item["label"]) if not boxes_list: - return np.empty((0, 4)), None, np.empty((0,), dtype=str) + return (np.empty((0, 4)), np.empty((0,), dtype=int), np.empty((0,), dtype=str)) xyxy = np.array(boxes_list, dtype=float) class_name = np.array(labels_list, dtype=str) @@ -310,6 +382,109 @@ def from_qwen_2_5_vl( return xyxy, class_id, class_name +def from_qwen_3_vl( + result: str, + resolution_wh: tuple[int, int], + classes: list[str] | None = None, +) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: + """ + Parse and scale bounding boxes from Qwen-3-VL style JSON output. + + Args: + result (str): String containing the Qwen-3-VL JSON output. + resolution_wh (tuple[int, int]): Target resolution `(width, height)` to + scale bounding boxes. + classes (list[str] or None): Optional list of valid classes to filter + results. + + Returns: + xyxy (np.ndarray): Array of bounding boxes with shape `(N, 4)` in + `(x_min, y_min, x_max, y_max)` format scaled to `resolution_wh`. + class_id (np.ndarray or None): Array of class indices for each box, or + None if no filtering by classes. + class_name (np.ndarray): Array of class names as strings. + """ + return from_qwen_2_5_vl( + result=result, + input_wh=(1000, 1000), + resolution_wh=resolution_wh, + classes=classes, + ) + + +def from_deepseek_vl_2( + result: str, resolution_wh: tuple[int, int], classes: list[str] | None = None +) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: + """ + Parse bounding boxes from deepseek-vl2-formatted text, scale them to the specified + resolution, and optionally filter by classes. + + The DeepSeek-VL2 output typically contains pairs of <|ref|> ... <|/ref|> labels + and <|det|> ... <|/det|> bounding box definitions. Each <|det|> section may + contain one or more bounding boxes in the form [[x1, y1, x2, y2], [x1, y1, x2, y2], ...] + (scaled to 0..999). For example: + + ``` + <|ref|>The giraffe at the back<|/ref|><|det|>[[580, 270, 999, 904]]<|/det|><|ref|>The giraffe at the front<|/ref|><|det|>[[26, 31, 632, 998]]<|/det|><|end▁of▁sentence|> + ``` + + Args: + result: String containing deepseek-vl2-formatted locations and labels. + resolution_wh: Tuple (width, height) to which we scale the box coordinates. + classes: Optional list of valid class names. If provided, boxes and labels not + in this list are filtered out. + + Returns: + xyxy (np.ndarray): An array of shape `(n, 4)` containing + the bounding boxes coordinates in format `[x1, y1, x2, y2]`. + class_id (Optional[np.ndarray]): An array of shape `(n,)` containing + the class indices for each bounding box (or `None` if classes is not + provided). + class_name (np.ndarray): An array of shape `(n,)` containing + the class labels for each bounding box. + """ # noqa: E501 + + width, height = resolution_wh + label_segments = re.findall(r"<\|ref\|>(.*?)<\|/ref\|>", result, flags=re.S) + detection_segments = re.findall(r"<\|det\|>(.*?)<\|/det\|>", result, flags=re.S) + + if len(label_segments) != len(detection_segments): + raise ValueError( + f"Number of ref tags ({len(label_segments)}) " + f"and det tags ({len(detection_segments)}) in the result must be equal." + ) + + xyxy, class_name_list = [], [] + for label, detection_blob in zip(label_segments, detection_segments): + current_class_name = label.strip() + for box in re.findall(r"\[(.*?)\]", detection_blob): + x1, y1, x2, y2 = map(float, box.strip("[]").split(",")) + xyxy.append( + [ + (x1 / 999 * width), + (y1 / 999 * height), + (x2 / 999 * width), + (y2 / 999 * height), + ] + ) + class_name_list.append(current_class_name) + + xyxy = np.array(xyxy, dtype=np.float32) + class_name = np.array(class_name_list) + + if classes is not None: + mask = np.array([name in classes for name in class_name], dtype=bool) + xyxy = xyxy[mask] + class_name = class_name[mask] + class_id = np.array([classes.index(name) for name in class_name]) + else: + unique_classes = sorted(list(set(class_name))) + class_to_id = {name: i for i, name in enumerate(unique_classes)} + class_id = np.array([class_to_id[name] for name in class_name]) + + return xyxy, class_id, class_name + + def from_florence_2( result: dict, resolution_wh: tuple[int, int] ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None, np.ndarray | None]: @@ -460,7 +635,7 @@ def from_google_gemini_2_0( return np.empty((0, 4)), None, np.empty((0,), dtype=str) labels = [] - boxes_list = [] + xyxy = [] for item in data: if "box_2d" not in item or "label" not in item: @@ -468,18 +643,16 @@ def from_google_gemini_2_0( labels.append(item["label"]) box = item["box_2d"] # Gemini bbox order is [y_min, x_min, y_max, x_max] - boxes_list.append( - denormalize_boxes( - np.array([box[1], box[0], box[3], box[2]]).astype(np.float64), - resolution_wh=(w, h), - normalization_factor=1000, - ) - ) + xyxy.append([box[1], box[0], box[3], box[2]]) - if not boxes_list: + if len(xyxy) == 0: return np.empty((0, 4)), None, np.empty((0,), dtype=str) - xyxy = np.array(boxes_list) + xyxy = denormalize_boxes( + np.array(xyxy, dtype=np.float64), + resolution_wh=(w, h), + normalization_factor=1000, + ) class_name = np.array(labels) class_id = None @@ -571,10 +744,10 @@ def from_google_gemini_2_5( box = item["box_2d"] # Gemini bbox order is [y_min, x_min, y_max, x_max] absolute_bbox = denormalize_boxes( - np.array([box[1], box[0], box[3], box[2]]).astype(np.float64), + np.array([[box[1], box[0], box[3], box[2]]]).astype(np.float64), resolution_wh=(w, h), normalization_factor=1000, - ) + )[0] boxes_list.append(absolute_bbox) if "mask" in item: @@ -657,7 +830,7 @@ def from_google_gemini_2_5( def from_moondream( result: dict, resolution_wh: tuple[int, int], -) -> tuple[np.ndarray]: +) -> np.ndarray: """ Parse and scale bounding boxes from moondream JSON output. @@ -695,7 +868,7 @@ def from_moondream( if "objects" not in result or not isinstance(result["objects"], list): return np.empty((0, 4), dtype=float) - denormalize_xyxy = [] + xyxy = [] for item in result["objects"]: if not all(k in item for k in ["x_min", "y_min", "x_max", "y_max"]): @@ -706,14 +879,12 @@ def from_moondream( x_max = item["x_max"] y_max = item["y_max"] - denormalize_xyxy.append( - denormalize_boxes( - np.array([x_min, y_min, x_max, y_max]).astype(np.float64), - resolution_wh=(w, h), - ) - ) + xyxy.append([x_min, y_min, x_max, y_max]) - if not denormalize_xyxy: + if len(xyxy) == 0: return np.empty((0, 4)) - return np.array(denormalize_xyxy, dtype=float) + return denormalize_boxes( + np.array(xyxy).astype(np.float64), + resolution_wh=(w, h), + ) diff --git a/supervision/draw/base.py b/supervision/draw/base.py new file mode 100644 index 0000000000..e27c1d3c6b --- /dev/null +++ b/supervision/draw/base.py @@ -0,0 +1,13 @@ +from typing import TypeVar + +import numpy as np +from PIL import Image + +ImageType = TypeVar("ImageType", np.ndarray, Image.Image) +""" +An image of type `np.ndarray` or `PIL.Image.Image`. + +Unlike a `Union`, ensures the type remains consistent. If a function +takes an `ImageType` argument and returns an `ImageType`, when you +pass an `np.ndarray`, you will get an `np.ndarray` back. +""" diff --git a/supervision/draw/utils.py b/supervision/draw/utils.py index ed4a903746..0d9ffe1241 100644 --- a/supervision/draw/utils.py +++ b/supervision/draw/utils.py @@ -346,28 +346,50 @@ def draw_image( def calculate_optimal_text_scale(resolution_wh: tuple[int, int]) -> float: """ - Calculate font scale based on the resolution of an image. + Calculate optimal font scale based on image resolution. Adjusts font scale + proportionally to the smallest dimension of the given image resolution for + consistent readability. - Parameters: - resolution_wh (Tuple[int, int]): A tuple representing the width and height - of the image. + Args: + resolution_wh (tuple[int, int]): (width, height) of the image in pixels Returns: - float: The calculated font scale factor. + float: recommended font scale factor + + Examples: + ```python + import supervision as sv + + sv.calculate_optimal_text_scale((1920, 1080)) + # 1.08 + sv.calculate_optimal_text_scale((640, 480)) + # 0.48 + ``` """ return min(resolution_wh) * 1e-3 def calculate_optimal_line_thickness(resolution_wh: tuple[int, int]) -> int: """ - Calculate line thickness based on the resolution of an image. + Calculate optimal line thickness based on image resolution. Adjusts the line + thickness for readability depending on the smallest dimension of the provided + image resolution. - Parameters: - resolution_wh (Tuple[int, int]): A tuple representing the width and height - of the image. + Args: + resolution_wh (tuple[int, int]): (width, height) of the image in pixels Returns: - int: The calculated line thickness in pixels. + int: recommended line thickness in pixels + + Examples: + ```python + import supervision as sv + + sv.calculate_optimal_line_thickness((1920, 1080)) + # 4 + sv.calculate_optimal_line_thickness((640, 480)) + # 2 + ``` """ if min(resolution_wh) < 1080: return 2 diff --git a/supervision/keypoint/__init__.py b/supervision/key_points/__init__.py similarity index 100% rename from supervision/keypoint/__init__.py rename to supervision/key_points/__init__.py diff --git a/supervision/keypoint/annotators.py b/supervision/key_points/annotators.py similarity index 97% rename from supervision/keypoint/annotators.py rename to supervision/key_points/annotators.py index d83d7d5e0f..c3f9e984b1 100644 --- a/supervision/keypoint/annotators.py +++ b/supervision/key_points/annotators.py @@ -6,14 +6,14 @@ import cv2 import numpy as np -from supervision.annotators.base import ImageType from supervision.detection.utils.boxes import pad_boxes, spread_out_boxes +from supervision.draw.base import ImageType from supervision.draw.color import Color from supervision.draw.utils import draw_rounded_rectangle from supervision.geometry.core import Rect -from supervision.keypoint.core import KeyPoints -from supervision.keypoint.skeletons import SKELETONS_BY_VERTEX_COUNT -from supervision.utils.conversion import ensure_cv2_image_for_annotation +from supervision.key_points.core import KeyPoints +from supervision.key_points.skeletons import SKELETONS_BY_VERTEX_COUNT +from supervision.utils.conversion import ensure_cv2_image_for_class_method class BaseKeyPointAnnotator(ABC): @@ -43,7 +43,7 @@ def __init__( self.color = color self.radius = radius - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: """ Annotates the given scene with skeleton vertices based on the provided key @@ -120,7 +120,7 @@ def __init__( self.thickness = thickness self.edges = edges - @ensure_cv2_image_for_annotation + @ensure_cv2_image_for_class_method def annotate(self, scene: ImageType, key_points: KeyPoints) -> ImageType: """ Annotates the given scene by drawing lines between specified key points to form diff --git a/supervision/keypoint/core.py b/supervision/key_points/core.py similarity index 81% rename from supervision/keypoint/core.py rename to supervision/key_points/core.py index 0d57c56183..66341f9df5 100644 --- a/supervision/keypoint/core.py +++ b/supervision/key_points/core.py @@ -10,7 +10,7 @@ from supervision.config import CLASS_NAME_DATA_FIELD from supervision.detection.core import Detections from supervision.detection.utils.internal import get_data_item, is_data_equal -from supervision.validators import validate_keypoints_fields +from supervision.validators import validate_key_points_fields @dataclass @@ -23,7 +23,7 @@ class simplifies data manipulation and filtering, providing a uniform API for === "Ultralytics" - Use [`sv.KeyPoints.from_ultralytics`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.from_ultralytics) + Use [`sv.KeyPoints.from_ultralytics`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_ultralytics) method, which accepts [YOLOv8-pose](https://docs.ultralytics.com/models/yolov8/), [YOLO11-pose](https://docs.ultralytics.com/models/yolo11/) [pose](https://docs.ultralytics.com/tasks/pose/) result. @@ -41,7 +41,7 @@ class simplifies data manipulation and filtering, providing a uniform API for === "Inference" - Use [`sv.KeyPoints.from_inference`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.from_inference) + Use [`sv.KeyPoints.from_inference`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_inference) method, which accepts [Inference](https://inference.roboflow.com/) pose result. ```python @@ -58,7 +58,7 @@ class simplifies data manipulation and filtering, providing a uniform API for === "MediaPipe" - Use [`sv.KeyPoints.from_mediapipe`](/latest/keypoint/core/#supervision.keypoint.core.KeyPoints.from_mediapipe) + Use [`sv.KeyPoints.from_mediapipe`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_mediapipe) method, which accepts [MediaPipe](https://github.com/google-ai-edge/mediapipe) pose result. @@ -89,10 +89,61 @@ class simplifies data manipulation and filtering, providing a uniform API for pose_landmarker_result, (image_width, image_height)) ``` + === "Transformers" + + Use [`sv.KeyPoints.from_transformers`](/latest/keypoint/core/#supervision.key_points.core.KeyPoints.from_transformers) + method, which accepts [ViTPose](https://huggingface.co/docs/transformers/en/model_doc/vitpose) result. + + ```python + from PIL import Image + import requests + import supervision as sv + import torch + from transformers import ( + AutoProcessor, + RTDetrForObjectDetection, + VitPoseForPoseEstimation, + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + image = Image.open() + + DETECTION_MODEL_ID = "PekingU/rtdetr_r50vd_coco_o365" + + detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL_ID, use_fast=True) + detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=DEVICE) + + inputs = detection_processor(images=frame, return_tensors="pt").to(DEVICE) + + with torch.no_grad(): + outputs = detection_model(**inputs) + + target_size = torch.tensor([(frame.height, frame.width)]) + results = detection_processor.post_process_object_detection( + outputs, target_sizes=target_size, threshold=0.3) + + detections = sv.Detections.from_transformers(results[0]) + boxes = sv.xyxy_to_xywh(detections[detections.class_id == 0].xyxy) + + POSE_ESTIMATION_MODEL_ID = "usyd-community/vitpose-base-simple" + + pose_estimation_processor = AutoProcessor.from_pretrained(POSE_ESTIMATION_MODEL_ID) + pose_estimation_model = VitPoseForPoseEstimation.from_pretrained( + POSE_ESTIMATION_MODEL_ID, device_map=DEVICE) + + inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(DEVICE) + + with torch.no_grad(): + outputs = pose_estimation_model(**inputs) + + results = pose_estimation_processor.post_process_pose_estimation(outputs, boxes=[boxes]) + key_point = sv.KeyPoints.from_transformers(results[0]) + ``` + Attributes: xy (np.ndarray): An array of shape `(n, m, 2)` containing `n` detected objects, each composed of `m` equally-sized - sets of keypoints, where each point is `[x, y]`. + sets of key points, where each point is `[x, y]`. class_id (Optional[np.ndarray]): An array of shape `(n,)` containing the class ids of the detected objects. confidence (Optional[np.ndarray]): An array of shape @@ -109,7 +160,7 @@ class simplifies data manipulation and filtering, providing a uniform API for data: dict[str, npt.NDArray[Any] | list] = field(default_factory=dict) def __post_init__(self): - validate_keypoints_fields( + validate_key_points_fields( xy=self.xy, confidence=self.confidence, class_id=self.class_id, @@ -514,13 +565,13 @@ def from_detectron2(cls, detectron2_results: Any) -> KeyPoints: return cls.empty() @classmethod - def from_transformers(cls, transfomers_results: Any) -> KeyPoints: + def from_transformers(cls, transformers_results: Any) -> KeyPoints: """ Create a `sv.KeyPoints` object from the [Transformers](https://github.com/huggingface/transformers) inference result. Args: - transfomers_results (Any): The output of a + transformers_results (Any): The output of a Transformers model containing instances with prediction data. Returns: @@ -545,9 +596,9 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints: DETECTION_MODEL_ID = "PekingU/rtdetr_r50vd_coco_o365" detection_processor = AutoProcessor.from_pretrained(DETECTION_MODEL_ID, use_fast=True) - detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=DEVICE) + detection_model = RTDetrForObjectDetection.from_pretrained(DETECTION_MODEL_ID, device_map=device) - inputs = detection_processor(images=frame, return_tensors="pt").to(DEVICE) + inputs = detection_processor(images=frame, return_tensors="pt").to(device) with torch.no_grad(): outputs = detection_model(**inputs) @@ -563,9 +614,9 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints: pose_estimation_processor = AutoProcessor.from_pretrained(POSE_ESTIMATION_MODEL_ID) pose_estimation_model = VitPoseForPoseEstimation.from_pretrained( - POSE_ESTIMATION_MODEL_ID, device_map=DEVICE) + POSE_ESTIMATION_MODEL_ID, device_map=device) - inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(DEVICE) + inputs = pose_estimation_processor(frame, boxes=[boxes], return_tensors="pt").to(device) with torch.no_grad(): outputs = pose_estimation_model(**inputs) @@ -576,8 +627,8 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints: """ # noqa: E501 // docs - if "keypoints" in transfomers_results[0]: - if transfomers_results[0]["keypoints"].cpu().numpy().size == 0: + if "keypoints" in transformers_results[0]: + if transformers_results[0]["keypoints"].cpu().numpy().size == 0: return cls.empty() result_data = [ @@ -585,7 +636,7 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints: result["keypoints"].cpu().numpy(), result["scores"].cpu().numpy(), ) - for result in transfomers_results + for result in transformers_results ] xy, scores = zip(*result_data) @@ -599,55 +650,68 @@ def from_transformers(cls, transfomers_results: Any) -> KeyPoints: return cls.empty() def __getitem__( - self, index: int | slice | list[int] | np.ndarray | str - ) -> KeyPoints | list | np.ndarray | None: - """ - Get a subset of the `sv.KeyPoints` object or access an item from its data field. - - When provided with an integer, slice, list of integers, or a numpy array, this - method returns a new `sv.KeyPoints` object that represents a subset of the - original `sv.KeyPoints`. When provided with a string, it accesses the - corresponding item in the data dictionary. - - Args: - index (Union[int, slice, List[int], np.ndarray, str]): The index, indices, - or key to access a subset of the `sv.KeyPoints` or an item from the - data. - - Returns: - A subset of the `sv.KeyPoints` object or an item from the data field. - - Examples: - ```python - import supervision as sv - - key_points = sv.KeyPoints() - - # access the first keypoint using an integer index - key_points[0] - - # access the first 10 keypoints using index slice - key_points[0:10] - - # access selected keypoints using a list of indices - key_points[[0, 2, 4]] - - # access keypoints with selected class_id - key_points[key_points.class_id == 0] - - # access keypoints with confidence greater than 0.5 - key_points[key_points.confidence > 0.5] - ``` - """ + self, index: int | slice | list[int] | np.ndarray | tuple | str + ) -> KeyPoints | np.ndarray | list | None: if isinstance(index, str): return self.data.get(index) - if isinstance(index, int): - index = [index] + + if not isinstance(index, tuple): + index = (index, slice(None)) + + i, j = index + + if isinstance(i, int): + i = [i] + + if isinstance(i, list) and all(isinstance(x, bool) for x in i): + i = np.array(i) + if isinstance(j, list) and all(isinstance(x, bool) for x in j): + j = np.array(j) + + if isinstance(i, np.ndarray) and i.dtype == bool: + i = np.flatnonzero(i) + if isinstance(j, np.ndarray) and j.dtype == bool: + j = np.flatnonzero(j) + + if ( + isinstance(i, (list, np.ndarray)) + and isinstance(j, (list, np.ndarray)) + and not np.isscalar(i) + and not np.isscalar(j) + ): + i, j = np.ix_(i, j) + + xy_selected = self.xy[i, j] + + conf_selected = self.confidence[i, j] if self.confidence is not None else None + + class_id_selected = self.class_id[i] if self.class_id is not None else None + + data_selected = get_data_item(self.data, i) + + if xy_selected.ndim == 1: + xy_selected = xy_selected.reshape(1, 1, 2) + if conf_selected is not None: + conf_selected = conf_selected.reshape(1, 1) + elif xy_selected.ndim == 2: + if np.isscalar(index[0]) or ( + isinstance(index[0], np.ndarray) and index[0].ndim == 0 + ): + xy_selected = xy_selected[np.newaxis, ...] + if conf_selected is not None: + conf_selected = conf_selected[np.newaxis, ...] + elif np.isscalar(index[1]) or ( + isinstance(index[1], np.ndarray) and index[1].ndim == 0 + ): + xy_selected = xy_selected[:, np.newaxis, :] + if conf_selected is not None: + conf_selected = conf_selected[:, np.newaxis] + return KeyPoints( - xy=self.xy[index], - confidence=self.confidence[index] if self.confidence is not None else None, - class_id=self.class_id[index] if self.class_id is not None else None, - data=get_data_item(self.data, index), + xy=xy_selected, + confidence=conf_selected, + class_id=class_id_selected, + data=data_selected, ) def __setitem__(self, key: str, value: np.ndarray | list): @@ -668,12 +732,12 @@ def __setitem__(self, key: str, value: np.ndarray | list): model = YOLO('yolov8s.pt') result = model(image)[0] - keypoints = sv.KeyPoints.from_ultralytics(result) + key_points = sv.KeyPoints.from_ultralytics(result) - keypoints['class_name'] = [ + key_points['class_name'] = [ model.model.names[class_id] for class_id - in keypoints.class_id + in key_points.class_id ] ``` """ @@ -688,7 +752,7 @@ def __setitem__(self, key: str, value: np.ndarray | list): @classmethod def empty(cls) -> KeyPoints: """ - Create an empty Keypoints object with no keypoints. + Create an empty KeyPoints object with no key points. Returns: An empty `sv.KeyPoints` object. @@ -706,9 +770,9 @@ def is_empty(self) -> bool: """ Returns `True` if the `KeyPoints` object is considered empty. """ - empty_keypoints = KeyPoints.empty() - empty_keypoints.data = self.data - return self == empty_keypoints + empty_key_points = KeyPoints.empty() + empty_key_points.data = self.data + return self == empty_key_points def as_detections( self, selected_keypoint_indices: Iterable[int] | None = None @@ -716,21 +780,21 @@ def as_detections( """ Convert a KeyPoints object to a Detections object. This approximates the bounding box of the detected object by - taking the bounding box that fits all keypoints. + taking the bounding box that fits all key points. Arguments: selected_keypoint_indices (Optional[Iterable[int]]): The - indices of the keypoints to include in the bounding box - calculation. This helps focus on a subset of keypoints, - e.g. when some are occluded. Captures all keypoints by default. + indices of the key points to include in the bounding box + calculation. This helps focus on a subset of key points, + e.g. when some are occluded. Captures all key points by default. Returns: detections (Detections): The converted detections object. Examples: ```python - keypoints = sv.KeyPoints.from_inference(...) - detections = keypoints.as_detections() + key_points = sv.KeyPoints.from_inference(...) + detections = key_points.as_detections() ``` """ if self.is_empty(): diff --git a/supervision/keypoint/skeletons.py b/supervision/key_points/skeletons.py similarity index 100% rename from supervision/keypoint/skeletons.py rename to supervision/key_points/skeletons.py diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index 43a3116ac2..6967361702 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -102,12 +102,12 @@ def __str__(self) -> str: f"maxDets=100 ] = {self.map50:.3f}\n" f"Average Precision (AP) @[ IoU=0.75 | area= all | " f"maxDets=100 ] = {self.map75:.3f}\n" - f"Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] " - f"= {self.small_objects.map50_95:.3f}\n" - f"Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] \ - = {self.medium_objects.map50_95:.3f}\n" - f"Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] \ - = {self.large_objects.map50_95:.3f}" + f"Average Precision (AP) @[ IoU=0.50:0.95 | area= small | " + f"maxDets=100 ] = {self.small_objects.map50_95:.3f}\n" + f"Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | " + f"maxDets=100 ] = {self.medium_objects.map50_95:.3f}\n" + f"Average Precision (AP) @[ IoU=0.50:0.95 | area= large | " + f"maxDets=100 ] = {self.large_objects.map50_95:.3f}" ) def to_pandas(self) -> pd.DataFrame: diff --git a/supervision/utils/conversion.py b/supervision/utils/conversion.py index 79ec500300..b1c8f16bfd 100644 --- a/supervision/utils/conversion.py +++ b/supervision/utils/conversion.py @@ -4,10 +4,10 @@ import numpy as np from PIL import Image -from supervision.annotators.base import ImageType +from supervision.draw.base import ImageType -def ensure_cv2_image_for_annotation(annotate_func): +def ensure_cv2_image_for_class_method(annotate_func): """ Decorates `BaseAnnotator.annotate` implementations, converts scene to an image type used internally by the annotators, converts back when annotation @@ -32,7 +32,7 @@ def wrapper(self, scene: ImageType, *args, **kwargs): return wrapper -def ensure_cv2_image_for_processing(image_processing_fun): +def ensure_cv2_image_for_standalone_function(image_processing_fun): """ Decorates image processing functions that accept np.ndarray, converting `image` to np.ndarray, converts back when processing is complete. @@ -55,7 +55,7 @@ def wrapper(image: ImageType, *args, **kwargs): return wrapper -def ensure_pil_image_for_annotation(annotate_func): +def ensure_pil_image_for_class_method(annotate_func): """ Decorates image processing functions that accept np.ndarray, converting `image` to PIL image, converts back when processing is complete. diff --git a/supervision/utils/image.py b/supervision/utils/image.py index 6960986768..4d4348d69f 100644 --- a/supervision/utils/image.py +++ b/supervision/utils/image.py @@ -1,119 +1,107 @@ from __future__ import annotations -import itertools -import math import os import shutil -from collections.abc import Callable -from functools import partial -from typing import Literal import cv2 import numpy as np import numpy.typing as npt +from PIL import Image from supervision.annotators.base import ImageType from supervision.draw.color import Color, unify_to_bgr -from supervision.draw.utils import calculate_optimal_text_scale, draw_text -from supervision.geometry.core import Point from supervision.utils.conversion import ( - cv2_to_pillow, - ensure_cv2_image_for_processing, - images_to_cv2, + ensure_cv2_image_for_standalone_function, ) -from supervision.utils.iterables import create_batches, fill +from supervision.utils.internal import deprecated -RelativePosition = Literal["top", "bottom"] -MAX_COLUMNS_FOR_SINGLE_ROW_GRID = 3 - - -@ensure_cv2_image_for_processing def crop_image( image: ImageType, xyxy: npt.NDArray[int] | list[int] | tuple[int, int, int, int], ) -> ImageType: """ - Crops the given image based on the given bounding box. + Crop image based on bounding box coordinates. Args: - image (ImageType): The image to be cropped. `ImageType` is a flexible type, - accepting either `numpy.ndarray` or `PIL.Image.Image`. - xyxy (Union[np.ndarray, List[int], Tuple[int, int, int, int]]): A bounding box - coordinates in the format `(x_min, y_min, x_max, y_max)`, accepted as either - a `numpy.ndarray`, a `list`, or a `tuple`. + image (`numpy.ndarray` or `PIL.Image.Image`): The image to crop. + xyxy (`numpy.array`, `list[int]`, or `tuple[int, int, int, int]`): + Bounding box coordinates in `(x_min, y_min, x_max, y_max)` format. Returns: - (ImageType): The cropped image. The type is determined by the input type and - may be either a `numpy.ndarray` or `PIL.Image.Image`. - - === "OpenCV" + (`numpy.ndarray` or `PIL.Image.Image`): Cropped image matching input + type. + Examples: ```python import cv2 import supervision as sv - image = cv2.imread() + image = cv2.imread("source.png") image.shape # (1080, 1920, 3) - xyxy = [200, 400, 600, 800] + xyxy = (400, 400, 800, 800) cropped_image = sv.crop_image(image=image, xyxy=xyxy) cropped_image.shape # (400, 400, 3) ``` - === "Pillow" - ```python from PIL import Image import supervision as sv - image = Image.open() + image = Image.open("source.png") image.size # (1920, 1080) - xyxy = [200, 400, 600, 800] + xyxy = (400, 400, 800, 800) cropped_image = sv.crop_image(image=image, xyxy=xyxy) cropped_image.size # (400, 400) ``` - ![crop_image](https://media.roboflow.com/supervision-docs/crop-image.png){ align=center width="800" } + ![crop-image](https://media.roboflow.com/supervision-docs/supervision-docs-crop-image-2.png){ align=center width="1000" } """ # noqa E501 // docs - if isinstance(xyxy, (list, tuple)): xyxy = np.array(xyxy) + xyxy = np.round(xyxy).astype(int) x_min, y_min, x_max, y_max = xyxy.flatten() - return image[y_min:y_max, x_min:x_max] + + if isinstance(image, np.ndarray): + return image[y_min:y_max, x_min:x_max] + + if isinstance(image, Image.Image): + return image.crop((x_min, y_min, x_max, y_max)) + + raise TypeError( + f"`image` must be a numpy.ndarray or PIL.Image.Image. Received {type(image)}" + ) -@ensure_cv2_image_for_processing +@ensure_cv2_image_for_standalone_function def scale_image(image: ImageType, scale_factor: float) -> ImageType: """ - Scales the given image based on the given scale factor. + Scale image by given factor. Scale factor > 1.0 zooms in, < 1.0 zooms out. Args: - image (ImageType): The image to be scaled. `ImageType` is a flexible type, - accepting either `numpy.ndarray` or `PIL.Image.Image`. - scale_factor (float): The factor by which the image will be scaled. Scale - factor > `1.0` zooms in, < `1.0` zooms out. + image (`numpy.ndarray` or `PIL.Image.Image`): The image to scale. + scale_factor (`float`): Factor by which to scale the image. Returns: - (ImageType): The scaled image. The type is determined by the input type and - may be either a `numpy.ndarray` or `PIL.Image.Image`. + (`numpy.ndarray` or `PIL.Image.Image`): Scaled image matching input + type. Raises: - ValueError: If the scale factor is non-positive. - - === "OpenCV" + ValueError: If scale factor is non-positive. + Examples: ```python import cv2 import supervision as sv - image = cv2.imread() + image = cv2.imread("source.png") image.shape # (1080, 1920, 3) @@ -122,13 +110,11 @@ def scale_image(image: ImageType, scale_factor: float) -> ImageType: # (540, 960, 3) ``` - === "Pillow" - ```python from PIL import Image import supervision as sv - image = Image.open() + image = Image.open("source.png") image.size # (1920, 1080) @@ -136,7 +122,9 @@ def scale_image(image: ImageType, scale_factor: float) -> ImageType: scaled_image.size # (960, 540) ``` - """ + + ![scale-image](https://media.roboflow.com/supervision-docs/supervision-docs-scale-image-2.png){ align=center width="1000" } + """ # noqa E501 // docs if scale_factor <= 0: raise ValueError("Scale factor must be positive.") @@ -146,35 +134,31 @@ def scale_image(image: ImageType, scale_factor: float) -> ImageType: return cv2.resize(image, (width_new, height_new), interpolation=cv2.INTER_LINEAR) -@ensure_cv2_image_for_processing +@ensure_cv2_image_for_standalone_function def resize_image( image: ImageType, resolution_wh: tuple[int, int], keep_aspect_ratio: bool = False, ) -> ImageType: """ - Resizes the given image to a specified resolution. Can maintain the original aspect - ratio or resize directly to the desired dimensions. + Resize image to specified resolution. Can optionally maintain aspect ratio. Args: - image (ImageType): The image to be resized. `ImageType` is a flexible type, - accepting either `numpy.ndarray` or `PIL.Image.Image`. - resolution_wh (Tuple[int, int]): The target resolution as - `(width, height)`. - keep_aspect_ratio (bool): Flag to maintain the image's original - aspect ratio. Defaults to `False`. + image (`numpy.ndarray` or `PIL.Image.Image`): The image to resize. + resolution_wh (`tuple[int, int]`): Target resolution as `(width, height)`. + keep_aspect_ratio (`bool`): Flag to maintain original aspect ratio. + Defaults to `False`. Returns: - (ImageType): The resized image. The type is determined by the input type and - may be either a `numpy.ndarray` or `PIL.Image.Image`. - - === "OpenCV" + (`numpy.ndarray` or `PIL.Image.Image`): Resized image matching input + type. + Examples: ```python import cv2 import supervision as sv - image = cv2.imread() + image = cv2.imread("source.png") image.shape # (1080, 1920, 3) @@ -185,13 +169,11 @@ def resize_image( # (562, 1000, 3) ``` - === "Pillow" - ```python from PIL import Image import supervision as sv - image = Image.open() + image = Image.open("source.png") image.size # (1920, 1080) @@ -202,7 +184,7 @@ def resize_image( # (1000, 562) ``` - ![resize_image](https://media.roboflow.com/supervision-docs/resize-image.png){ align=center width="800" } + ![resize-image](https://media.roboflow.com/supervision-docs/supervision-docs-resize-image-2.png){ align=center width="1000" } """ # noqa E501 // docs if keep_aspect_ratio: image_ratio = image.shape[1] / image.shape[0] @@ -219,59 +201,58 @@ def resize_image( return cv2.resize(image, (width_new, height_new), interpolation=cv2.INTER_LINEAR) -@ensure_cv2_image_for_processing +@ensure_cv2_image_for_standalone_function def letterbox_image( image: ImageType, resolution_wh: tuple[int, int], color: tuple[int, int, int] | Color = Color.BLACK, ) -> ImageType: """ - Resizes and pads an image to a specified resolution with a given color, maintaining - the original aspect ratio. + Resize image and pad with color to achieve desired resolution while + maintaining aspect ratio. Args: - image (ImageType): The image to be resized. `ImageType` is a flexible type, - accepting either `numpy.ndarray` or `PIL.Image.Image`. - resolution_wh (Tuple[int, int]): The target resolution as - `(width, height)`. - color (Union[Tuple[int, int, int], Color]): The color to pad with. If tuple - provided it should be in BGR format. + image (`numpy.ndarray` or `PIL.Image.Image`): The image to resize and pad. + resolution_wh (`tuple[int, int]`): Target resolution as `(width, height)`. + color (`tuple[int, int, int]` or `Color`): Padding color. If tuple, should + be in BGR format. Defaults to `Color.BLACK`. Returns: - (ImageType): The resized image. The type is determined by the input type and - may be either a `numpy.ndarray` or `PIL.Image.Image`. - - === "OpenCV" + (`numpy.ndarray` or `PIL.Image.Image`): Letterboxed image matching input + type. + Examples: ```python import cv2 import supervision as sv - image = cv2.imread() + image = cv2.imread("source.png") image.shape # (1080, 1920, 3) - letterboxed_image = sv.letterbox_image(image=image, resolution_wh=(1000, 1000)) + letterboxed_image = sv.letterbox_image( + image=image, resolution_wh=(1000, 1000) + ) letterboxed_image.shape # (1000, 1000, 3) ``` - === "Pillow" - ```python from PIL import Image import supervision as sv - image = Image.open() + image = Image.open("source.png") image.size # (1920, 1080) - letterboxed_image = sv.letterbox_image(image=image, resolution_wh=(1000, 1000)) + letterboxed_image = sv.letterbox_image( + image=image, resolution_wh=(1000, 1000) + ) letterboxed_image.size # (1000, 1000) ``` - ![letterbox_image](https://media.roboflow.com/supervision-docs/letterbox-image.png){ align=center width="800" } + ![letterbox-image](https://media.roboflow.com/supervision-docs/supervision-docs-letterbox-image-2.png){ align=center width="1000" } """ # noqa E501 // docs assert isinstance(image, np.ndarray) color = unify_to_bgr(color=color) @@ -302,37 +283,59 @@ def letterbox_image( return image_with_borders +@deprecated( + "`overlay_image` function is deprecated and will be removed in " + "`supervision-0.32.0`. Use `draw_image` instead." +) def overlay_image( image: npt.NDArray[np.uint8], overlay: npt.NDArray[np.uint8], anchor: tuple[int, int], ) -> npt.NDArray[np.uint8]: """ - Places an image onto a scene at a given anchor point, handling cases where - the image's position is partially or completely outside the scene's bounds. + Overlay image onto scene at specified anchor point. Handles cases where + overlay position is partially or completely outside scene bounds. Args: - image (np.ndarray): The background scene onto which the image is placed. - overlay (np.ndarray): The image to be placed onto the scene. - anchor (Tuple[int, int]): The `(x, y)` coordinates in the scene where the - top-left corner of the image will be placed. + image (`numpy.array`): Background scene with shape `(height, width, 3)`. + overlay (`numpy.array`): Image to overlay with shape + `(height, width, 3)` or `(height, width, 4)`. + anchor (`tuple[int, int]`): Coordinates `(x, y)` where top-left corner + of overlay will be placed. Returns: - (np.ndarray): The result image with overlay. + (`numpy.array`): Scene with overlay applied, shape `(height, width, 3)`. Examples: - ```python + ``` import cv2 import numpy as np import supervision as sv - image = cv2.imread() + image = cv2.imread("source.png") overlay = np.zeros((400, 400, 3), dtype=np.uint8) - result_image = sv.overlay_image(image=image, overlay=overlay, anchor=(200, 400)) + overlay[:] = (0, 255, 0) # Green overlay + + result_image = sv.overlay_image( + image=image, overlay=overlay, anchor=(200, 400) + ) + cv2.imwrite("target.png", result_image) ``` - ![overlay_image](https://media.roboflow.com/supervision-docs/overlay-image.png){ align=center width="800" } - """ # noqa E501 // docs + ``` + import cv2 + import numpy as np + import supervision as sv + + image = cv2.imread("source.png") + overlay = cv2.imread("overlay.png", cv2.IMREAD_UNCHANGED) + + result_image = sv.overlay_image( + image=image, overlay=overlay, anchor=(100, 100) + ) + cv2.imwrite("target.png", result_image) + ``` + """ scene_height, scene_width = image.shape[:2] image_height, image_width = overlay.shape[:2] anchor_x, anchor_y = anchor @@ -371,6 +374,158 @@ def overlay_image( return image +@ensure_cv2_image_for_standalone_function +def tint_image( + image: ImageType, + color: Color = Color.BLACK, + opacity: float = 0.5, +) -> ImageType: + """ + Tint image with solid color overlay at specified opacity. + + Args: + image (`numpy.ndarray` or `PIL.Image.Image`): The image to tint. + color (`Color`): Overlay tint color. Defaults to `Color.BLACK`. + opacity (`float`): Blend ratio between overlay and image (0.0-1.0). + Defaults to `0.5`. + + Returns: + (`numpy.ndarray` or `PIL.Image.Image`): Tinted image matching input + type. + + Raises: + ValueError: If opacity is outside range [0.0, 1.0]. + + Examples: + ```python + import cv2 + import supervision as sv + + image = cv2.imread("source.png") + tinted_image = sv.tint_image( + image=image, color=sv.Color.ROBOFLOW, opacity=0.5 + ) + cv2.imwrite("target.png", tinted_image) + ``` + + ```python + from PIL import Image + import supervision as sv + + image = Image.open("source.png") + tinted_image = sv.tint_image( + image=image, color=sv.Color.ROBOFLOW, opacity=0.5 + ) + tinted_image.save("target.png") + ``` + + ![tint-image](https://media.roboflow.com/supervision-docs/supervision-docs-tint-image-2.png){ align=center width="1000" } + """ # noqa E501 // docs + if not 0.0 <= opacity <= 1.0: + raise ValueError("opacity must be between 0.0 and 1.0") + + overlay = np.full_like(image, fill_value=color.as_bgr(), dtype=image.dtype) + cv2.addWeighted( + src1=overlay, alpha=opacity, src2=image, beta=1 - opacity, gamma=0, dst=image + ) + return image + + +@ensure_cv2_image_for_standalone_function +def grayscale_image(image: ImageType) -> ImageType: + """ + Convert image to 3-channel grayscale. Luminance channel is broadcast to + all three channels for compatibility with color-based drawing helpers. + + Args: + image (`numpy.ndarray` or `PIL.Image.Image`): The image to convert to + grayscale. + + Returns: + (`numpy.ndarray` or `PIL.Image.Image`): 3-channel grayscale image + matching input type. + + Examples: + ```python + import cv2 + import supervision as sv + + image = cv2.imread("source.png") + grayscale_image = sv.grayscale_image(image=image) + cv2.imwrite("target.png", grayscale_image) + ``` + + ```python + from PIL import Image + import supervision as sv + + image = Image.open("source.png") + grayscale_image = sv.grayscale_image(image=image) + grayscale_image.save("target.png") + ``` + + ![grayscale-image](https://media.roboflow.com/supervision-docs/supervision-docs-grayscale-image-2.png){ align=center width="1000" } + """ # noqa E501 // docs + grayscaled = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + return cv2.cvtColor(grayscaled, cv2.COLOR_GRAY2BGR) + + +def get_image_resolution_wh(image: ImageType) -> tuple[int, int]: + """ + Get image width and height as a tuple `(width, height)` for various image formats. + + Supports both `numpy.ndarray` images (with shape `(H, W, ...)`) and + `PIL.Image.Image` inputs. + + Args: + image (`numpy.ndarray` or `PIL.Image.Image`): Input image. + + Returns: + (`tuple[int, int]`): Image resolution as `(width, height)`. + + Raises: + ValueError: If a `numpy.ndarray` image has fewer than 2 dimensions. + TypeError: If `image` is not a supported type (`numpy.ndarray` or + `PIL.Image.Image`). + + Examples: + ```python + import cv2 + import supervision as sv + + image = cv2.imread("example.png") + sv.get_image_resolution_wh(image) + # (1920, 1080) + ``` + + ```python + from PIL import Image + import supervision as sv + + image = Image.open("example.png") + sv.get_image_resolution_wh(image) + # (1920, 1080) + ``` + """ + if isinstance(image, np.ndarray): + if image.ndim < 2: + raise ValueError( + "NumPy image must have at least 2 dimensions (H, W, ...). " + f"Received shape: {image.shape}" + ) + height, width = image.shape[:2] + return int(width), int(height) + + if isinstance(image, Image.Image): + width, height = image.size + return int(width), int(height) + + raise TypeError( + "`image` must be a numpy.ndarray or PIL.Image.Image. " + f"Received type: {type(image)}" + ) + + class ImageSink: def __init__( self, @@ -379,27 +534,64 @@ def __init__( image_name_pattern: str = "image_{:05d}.png", ): """ - Initialize a context manager for saving images. + Initialize context manager for saving images to directory. Args: - target_dir_path (str): The target directory where images will be saved. - overwrite (bool): Whether to overwrite the existing directory. - Defaults to False. - image_name_pattern (str): The image file name pattern. - Defaults to "image_{:05d}.png". + target_dir_path (`str`): Target directory path where images will be + saved. + overwrite (`bool`): Whether to overwrite existing directory. + Defaults to `False`. + image_name_pattern (`str`): File name pattern for saved images. + Defaults to `"image_{:05d}.png"`. Examples: ```python import supervision as sv - frames_generator = sv.get_video_frames_generator(, stride=2) + frames_generator = sv.get_video_frames_generator( + "source.mp4", stride=2 + ) - with sv.ImageSink(target_dir_path=) as sink: + with sv.ImageSink(target_dir_path="output_frames") as sink: for image in frames_generator: sink.save_image(image=image) + + # Directory structure: + # output_frames/ + # β”œβ”€β”€ image_00000.png + # β”œβ”€β”€ image_00001.png + # β”œβ”€β”€ image_00002.png + # └── image_00003.png ``` - """ # noqa E501 // docs + ```python + import cv2 + import supervision as sv + + image = cv2.imread("source.png") + crop_boxes = [ + ( 0, 0, 400, 400), + (400, 0, 800, 400), + ( 0, 400, 400, 800), + (400, 400, 800, 800) + ] + + with sv.ImageSink( + target_dir_path="image_crops", + overwrite=True + ) as sink: + for i, xyxy in enumerate(crop_boxes): + crop = sv.crop_image(image=image, xyxy=xyxy) + sink.save_image(image=crop, image_name=f"crop_{i}.png") + + # Directory structure: + # image_crops/ + # β”œβ”€β”€ crop_0.png + # β”œβ”€β”€ crop_1.png + # β”œβ”€β”€ crop_2.png + # └── crop_3.png + ``` + """ self.target_dir_path = target_dir_path self.overwrite = overwrite self.image_name_pattern = image_name_pattern @@ -417,14 +609,14 @@ def __enter__(self): def save_image(self, image: np.ndarray, image_name: str | None = None): """ - Save a given image in the target directory. + Save image to target directory with optional custom filename. Args: - image (np.ndarray): The image to be saved. The image must be in BGR color - format. - image_name (Optional[str]): The name to use for the saved image. - If not provided, a name will be - generated using the `image_name_pattern`. + image (`numpy.array`): Image to save with shape `(height, width, 3)` + in BGR format. + image_name (`str` or `None`): Custom filename for saved image. If + `None`, generates name using `image_name_pattern`. Defaults to + `None`. """ if image_name is None: image_name = self.image_name_pattern.format(self.image_count) @@ -435,355 +627,3 @@ def save_image(self, image: np.ndarray, image_name: str | None = None): def __exit__(self, exc_type, exc_value, exc_traceback): pass - - -def create_tiles( - images: list[ImageType], - grid_size: tuple[int | None, int | None] | None = None, - single_tile_size: tuple[int, int] | None = None, - tile_scaling: Literal["min", "max", "avg"] = "avg", - tile_padding_color: tuple[int, int, int] | Color = Color.from_hex("#D9D9D9"), - tile_margin: int = 10, - tile_margin_color: tuple[int, int, int] | Color = Color.from_hex("#BFBEBD"), - return_type: Literal["auto", "cv2", "pillow"] = "auto", - titles: list[str | None] | None = None, - titles_anchors: Point | list[Point | None] | None = None, - titles_color: tuple[int, int, int] | Color = Color.from_hex("#262523"), - titles_scale: float | None = None, - titles_thickness: int = 1, - titles_padding: int = 10, - titles_text_font: int = cv2.FONT_HERSHEY_SIMPLEX, - titles_background_color: tuple[int, int, int] | Color = Color.from_hex("#D9D9D9"), - default_title_placement: RelativePosition = "top", -) -> ImageType: - """ - Creates tiles mosaic from input images, automating grid placement and - converting images to common resolution maintaining aspect ratio. It is - also possible to render text titles on tiles, using optional set of - parameters specifying text drawing (see parameters description). - - Automated grid placement will try to maintain square shape of grid - (with size being the nearest integer square root of #images), up to two exceptions: - * if there are up to 3 images - images will be displayed in single row - * if square-grid placement causes last row to be empty - number of rows is trimmed - until last row has at least one image - - Args: - images (List[ImageType]): Images to create tiles. Elements can be either - np.ndarray or PIL.Image, common representation will be agreed by the - function. - grid_size (Optional[Tuple[Optional[int], Optional[int]]]): Expected grid - size in format (n_rows, n_cols). If not given - automated grid placement - will be applied. One may also provide only one out of two elements of the - tuple - then grid will be created with either n_rows or n_cols fixed, - leaving the other dimension to be adjusted by the number of images - single_tile_size (Optional[Tuple[int, int]]): sizeof a single tile element - provided in (width, height) format. If not given - size of tile will be - automatically calculated based on `tile_scaling` parameter. - tile_scaling (Literal["min", "max", "avg"]): If `single_tile_size` is not - given - parameter will be used to calculate tile size - using - min / max / avg size of image provided in `images` list. - tile_padding_color (Union[Tuple[int, int, int], sv.Color]): Color to be used in - images letterbox procedure (while standardising tiles sizes) as a padding. - If tuple provided - should be BGR. - tile_margin (int): size of margin between tiles (in pixels) - tile_margin_color (Union[Tuple[int, int, int], sv.Color]): Color of tile margin. - If tuple provided - should be BGR. - return_type (Literal["auto", "cv2", "pillow"]): Parameter dictates the format of - return image. One may choose specific type ("cv2" or "pillow") to enforce - conversion. "auto" mode takes a majority vote between types of elements in - `images` list - resolving draws in favour of OpenCV format. "auto" can be - safely used when all input images are of the same type. - titles (Optional[List[Optional[str]]]): Optional titles to be added to tiles. - Elements of that list may be empty - then specific tile (in order presented - in `images` parameter) will not be filled with title. It is possible to - provide list of titles shorter than `images` - then remaining titles will - be assumed empty. - titles_anchors (Optional[Union[Point, List[Optional[Point]]]]): Parameter to - specify anchor points for titles. It is possible to specify anchor either - globally or for specific tiles (following order of `images`). - If not given (either globally, or for specific element of the list), - it will be calculated automatically based on `default_title_placement`. - titles_color (Union[Tuple[int, int, int], Color]): Color of titles text. - If tuple provided - should be BGR. - titles_scale (Optional[float]): Scale of titles. If not provided - value will - be calculated using `calculate_optimal_text_scale(...)`. - titles_thickness (int): Thickness of titles text. - titles_padding (int): Size of titles padding. - titles_text_font (int): Font to be used to render titles. Must be integer - constant representing OpenCV font. - (See docs: https://docs.opencv.org/4.x/d6/d6e/group__imgproc__draw.html) - titles_background_color (Union[Tuple[int, int, int], Color]): Color of title - text padding. - default_title_placement (Literal["top", "bottom"]): Parameter specifies title - anchor placement in case if explicit anchor is not provided. - - Returns: - ImageType: Image with all input images located in tails grid. The output type is - determined by `return_type` parameter. - - Raises: - ValueError: In case when input images list is empty, provided `grid_size` is too - small to fit all images, `tile_scaling` mode is invalid. - """ - if len(images) == 0: - raise ValueError("Could not create image tiles from empty list of images.") - if return_type == "auto": - return_type = _negotiate_tiles_format(images=images) - tile_padding_color = unify_to_bgr(color=tile_padding_color) - tile_margin_color = unify_to_bgr(color=tile_margin_color) - images = images_to_cv2(images=images) - if single_tile_size is None: - single_tile_size = _aggregate_images_shape(images=images, mode=tile_scaling) - resized_images = [ - letterbox_image( - image=i, resolution_wh=single_tile_size, color=tile_padding_color - ) - for i in images - ] - grid_size = _establish_grid_size(images=images, grid_size=grid_size) - if len(images) > grid_size[0] * grid_size[1]: - raise ValueError( - f"Could not place {len(images)} in grid with size: {grid_size}." - ) - if titles is not None: - titles = fill(sequence=titles, desired_size=len(images), content=None) - titles_anchors = ( - [titles_anchors] - if not issubclass(type(titles_anchors), list) - else titles_anchors - ) - titles_anchors = fill( - sequence=titles_anchors, desired_size=len(images), content=None - ) - titles_color = unify_to_bgr(color=titles_color) - titles_background_color = unify_to_bgr(color=titles_background_color) - tiles = _generate_tiles( - images=resized_images, - grid_size=grid_size, - single_tile_size=single_tile_size, - tile_padding_color=tile_padding_color, - tile_margin=tile_margin, - tile_margin_color=tile_margin_color, - titles=titles, - titles_anchors=titles_anchors, - titles_color=titles_color, - titles_scale=titles_scale, - titles_thickness=titles_thickness, - titles_padding=titles_padding, - titles_text_font=titles_text_font, - titles_background_color=titles_background_color, - default_title_placement=default_title_placement, - ) - if return_type == "pillow": - tiles = cv2_to_pillow(image=tiles) - return tiles - - -def _negotiate_tiles_format(images: list[ImageType]) -> Literal["cv2", "pillow"]: - number_of_np_arrays = sum(issubclass(type(i), np.ndarray) for i in images) - if number_of_np_arrays >= (len(images) // 2): - return "cv2" - return "pillow" - - -def _calculate_aggregated_images_shape( - images: list[np.ndarray], aggregator: Callable[[list[int]], float] -) -> tuple[int, int]: - height = round(aggregator([i.shape[0] for i in images])) - width = round(aggregator([i.shape[1] for i in images])) - return width, height - - -SHAPE_AGGREGATION_FUN = { - "min": partial(_calculate_aggregated_images_shape, aggregator=np.min), - "max": partial(_calculate_aggregated_images_shape, aggregator=np.max), - "avg": partial(_calculate_aggregated_images_shape, aggregator=np.average), -} - - -def _aggregate_images_shape( - images: list[np.ndarray], mode: Literal["min", "max", "avg"] -) -> tuple[int, int]: - if mode not in SHAPE_AGGREGATION_FUN: - raise ValueError( - f"Could not aggregate images shape - provided unknown mode: {mode}. " - f"Supported modes: {list(SHAPE_AGGREGATION_FUN.keys())}." - ) - return SHAPE_AGGREGATION_FUN[mode](images) - - -def _establish_grid_size( - images: list[np.ndarray], grid_size: tuple[int | None, int | None] | None -) -> tuple[int, int]: - if grid_size is None or all(e is None for e in grid_size): - return _negotiate_grid_size(images=images) - if grid_size[0] is None: - return math.ceil(len(images) / grid_size[1]), grid_size[1] - if grid_size[1] is None: - return grid_size[0], math.ceil(len(images) / grid_size[0]) - return grid_size - - -def _negotiate_grid_size(images: list[np.ndarray]) -> tuple[int, int]: - if len(images) <= MAX_COLUMNS_FOR_SINGLE_ROW_GRID: - return 1, len(images) - nearest_sqrt = math.ceil(np.sqrt(len(images))) - proposed_columns = nearest_sqrt - proposed_rows = nearest_sqrt - while proposed_columns * (proposed_rows - 1) >= len(images): - proposed_rows -= 1 - return proposed_rows, proposed_columns - - -def _generate_tiles( - images: list[np.ndarray], - grid_size: tuple[int, int], - single_tile_size: tuple[int, int], - tile_padding_color: tuple[int, int, int], - tile_margin: int, - tile_margin_color: tuple[int, int, int], - titles: list[str | None] | None, - titles_anchors: list[Point | None], - titles_color: tuple[int, int, int], - titles_scale: float | None, - titles_thickness: int, - titles_padding: int, - titles_text_font: int, - titles_background_color: tuple[int, int, int], - default_title_placement: RelativePosition, -) -> np.ndarray: - images = _draw_texts( - images=images, - titles=titles, - titles_anchors=titles_anchors, - titles_color=titles_color, - titles_scale=titles_scale, - titles_thickness=titles_thickness, - titles_padding=titles_padding, - titles_text_font=titles_text_font, - titles_background_color=titles_background_color, - default_title_placement=default_title_placement, - ) - rows, columns = grid_size - tiles_elements = list(create_batches(sequence=images, batch_size=columns)) - while len(tiles_elements[-1]) < columns: - tiles_elements[-1].append( - _generate_color_image(shape=single_tile_size, color=tile_padding_color) - ) - while len(tiles_elements) < rows: - tiles_elements.append( - [_generate_color_image(shape=single_tile_size, color=tile_padding_color)] - * columns - ) - return _merge_tiles_elements( - tiles_elements=tiles_elements, - grid_size=grid_size, - single_tile_size=single_tile_size, - tile_margin=tile_margin, - tile_margin_color=tile_margin_color, - ) - - -def _draw_texts( - images: list[np.ndarray], - titles: list[str | None] | None, - titles_anchors: list[Point | None], - titles_color: tuple[int, int, int], - titles_scale: float | None, - titles_thickness: int, - titles_padding: int, - titles_text_font: int, - titles_background_color: tuple[int, int, int], - default_title_placement: RelativePosition, -) -> list[np.ndarray]: - if titles is None: - return images - titles_anchors = _prepare_default_titles_anchors( - images=images, - titles_anchors=titles_anchors, - default_title_placement=default_title_placement, - ) - if titles_scale is None: - image_height, image_width = images[0].shape[:2] - titles_scale = calculate_optimal_text_scale( - resolution_wh=(image_width, image_height) - ) - result = [] - for image, text, anchor in zip(images, titles, titles_anchors): - if text is None: - result.append(image) - continue - processed_image = draw_text( - scene=image, - text=text, - text_anchor=anchor, - text_color=Color.from_bgr_tuple(titles_color), - text_scale=titles_scale, - text_thickness=titles_thickness, - text_padding=titles_padding, - text_font=titles_text_font, - background_color=Color.from_bgr_tuple(titles_background_color), - ) - result.append(processed_image) - return result - - -def _prepare_default_titles_anchors( - images: list[np.ndarray], - titles_anchors: list[Point | None], - default_title_placement: RelativePosition, -) -> list[Point]: - result = [] - for image, anchor in zip(images, titles_anchors): - if anchor is not None: - result.append(anchor) - continue - image_height, image_width = image.shape[:2] - if default_title_placement == "top": - default_anchor = Point(x=image_width / 2, y=image_height * 0.1) - else: - default_anchor = Point(x=image_width / 2, y=image_height * 0.9) - result.append(default_anchor) - return result - - -def _merge_tiles_elements( - tiles_elements: list[list[np.ndarray]], - grid_size: tuple[int, int], - single_tile_size: tuple[int, int], - tile_margin: int, - tile_margin_color: tuple[int, int, int], -) -> np.ndarray: - vertical_padding = ( - np.ones((single_tile_size[1], tile_margin, 3)) * tile_margin_color - ) - merged_rows = [ - np.concatenate( - list( - itertools.chain.from_iterable( - zip(row, [vertical_padding] * grid_size[1]) - ) - )[:-1], - axis=1, - ) - for row in tiles_elements - ] - row_width = merged_rows[0].shape[1] - horizontal_padding = ( - np.ones((tile_margin, row_width, 3), dtype=np.uint8) * tile_margin_color - ) - rows_with_paddings = [] - for row in merged_rows: - rows_with_paddings.append(row) - rows_with_paddings.append(horizontal_padding) - return np.concatenate( - rows_with_paddings[:-1], - axis=0, - ).astype(np.uint8) - - -def _generate_color_image( - shape: tuple[int, int], color: tuple[int, int, int] -) -> np.ndarray: - return np.ones((*shape[::-1], 3), dtype=np.uint8) * color diff --git a/supervision/utils/notebook.py b/supervision/utils/notebook.py index 9262f12bc4..3af09ebbec 100644 --- a/supervision/utils/notebook.py +++ b/supervision/utils/notebook.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from PIL import Image -from supervision.annotators.base import ImageType +from supervision.draw.base import ImageType from supervision.utils.conversion import pillow_to_cv2 diff --git a/supervision/utils/video.py b/supervision/utils/video.py index 3b281b4e22..0ece0916da 100644 --- a/supervision/utils/video.py +++ b/supervision/utils/video.py @@ -1,9 +1,11 @@ from __future__ import annotations +import threading import time from collections import deque from collections.abc import Callable, Generator from dataclasses import dataclass +from queue import Queue import cv2 import numpy as np @@ -196,63 +198,126 @@ def process_video( source_path: str, target_path: str, callback: Callable[[np.ndarray, int], np.ndarray], + *, max_frames: int | None = None, + prefetch: int = 32, + writer_buffer: int = 32, show_progress: bool = False, progress_message: str = "Processing video", ) -> None: """ - Process a video file by applying a callback function on each frame - and saving the result to a target video file. + Process video frames asynchronously using a threaded pipeline. + + This function orchestrates a three-stage pipeline to optimize video processing + throughput: + + 1. Reader thread: Continuously reads frames from the source video file and + enqueues them into a bounded queue (`frame_read_queue`). The queue size is + limited by the `prefetch` parameter to control memory usage. + 2. Main thread (Processor): Dequeues frames from `frame_read_queue`, applies the + user-defined `callback` function to process each frame, then enqueues the + processed frames into another bounded queue (`frame_write_queue`) for writing. + The processing happens in the main thread, simplifying use of stateful objects + without synchronization. + 3. Writer thread: Dequeues processed frames from `frame_write_queue` and writes + them sequentially to the output video file. Args: - source_path (str): The path to the source video file. - target_path (str): The path to the target video file. - callback (Callable[[np.ndarray, int], np.ndarray]): A function that takes in - a numpy ndarray representation of a video frame and an - int index of the frame and returns a processed numpy ndarray - representation of the frame. - max_frames (Optional[int]): The maximum number of frames to process. - show_progress (bool): Whether to show a progress bar. - progress_message (str): The message to display in the progress bar. + source_path (str): Path to the input video file. + target_path (str): Path where the processed video will be saved. + callback (Callable[[numpy.ndarray, int], numpy.ndarray]): Function called for + each frame, accepting the frame as a numpy array and its zero-based index, + returning the processed frame. + max_frames (int | None): Optional maximum number of frames to process. + If None, the entire video is processed (default). + prefetch (int): Maximum number of frames buffered by the reader thread. + Controls memory use; default is 32. + writer_buffer (int): Maximum number of frames buffered before writing. + Controls output buffer size; default is 32. + show_progress (bool): Whether to display a tqdm progress bar during processing. + Default is False. + progress_message (str): Description shown in the progress bar. - Examples: + Returns: + None + + Example: ```python + import cv2 import supervision as sv + from rfdetr import RFDETRMedium - def callback(scene: np.ndarray, index: int) -> np.ndarray: - ... + model = RFDETRMedium() + + def callback(frame, frame_index): + return model.predict(frame) process_video( - source_path=, - target_path=, - callback=callback + source_path="source.mp4", + target_path="target.mp4", + callback=frame_callback, ) ``` """ - source_video_info = VideoInfo.from_video_path(video_path=source_path) - video_frames_generator = get_video_frames_generator( - source_path=source_path, end=max_frames + video_info = VideoInfo.from_video_path(video_path=source_path) + total_frames = ( + min(video_info.total_frames, max_frames) + if max_frames is not None + else video_info.total_frames ) - with VideoSink(target_path=target_path, video_info=source_video_info) as sink: - total_frames = ( - min(source_video_info.total_frames, max_frames) - if max_frames is not None - else source_video_info.total_frames + + frame_read_queue: Queue[tuple[int, np.ndarray] | None] = Queue(maxsize=prefetch) + frame_write_queue: Queue[np.ndarray | None] = Queue(maxsize=writer_buffer) + + def reader_thread() -> None: + frame_generator = get_video_frames_generator( + source_path=source_path, + end=max_frames, ) - for index, frame in enumerate( - tqdm( - video_frames_generator, - total=total_frames, - disable=not show_progress, - desc=progress_message, - ) - ): - result_frame = callback(frame, index) - sink.write_frame(frame=result_frame) - else: - for index, frame in enumerate(video_frames_generator): - result_frame = callback(frame, index) - sink.write_frame(frame=result_frame) + for frame_index, frame in enumerate(frame_generator): + frame_read_queue.put((frame_index, frame)) + frame_read_queue.put(None) + + def writer_thread(video_sink: VideoSink) -> None: + while True: + frame = frame_write_queue.get() + if frame is None: + break + video_sink.write_frame(frame=frame) + + reader_worker = threading.Thread(target=reader_thread, daemon=True) + with VideoSink(target_path=target_path, video_info=video_info) as video_sink: + writer_worker = threading.Thread( + target=writer_thread, + args=(video_sink,), + daemon=True, + ) + + reader_worker.start() + writer_worker.start() + + progress_bar = tqdm( + total=total_frames, + disable=not show_progress, + desc=progress_message, + ) + + try: + while True: + read_item = frame_read_queue.get() + if read_item is None: + break + + frame_index, frame = read_item + processed_frame = callback(frame, frame_index) + + frame_write_queue.put(processed_frame) + progress_bar.update(1) + finally: + frame_write_queue.put(None) + reader_worker.join() + writer_worker.join() + progress_bar.close() class FPSMonitor: diff --git a/supervision/validators/__init__.py b/supervision/validators/__init__.py index 97fedabdd9..f051d89d7a 100644 --- a/supervision/validators/__init__.py +++ b/supervision/validators/__init__.py @@ -53,7 +53,7 @@ def validate_confidence(confidence: Any, n: int) -> None: ) -def validate_keypoint_confidence(confidence: Any, n: int, m: int) -> None: +def validate_key_point_confidence(confidence: Any, n: int, m: int) -> None: expected_shape = f"({n, m})" actual_shape = str(getattr(confidence, "shape", None)) @@ -126,7 +126,7 @@ def validate_detections_fields( validate_data(data, n) -def validate_keypoints_fields( +def validate_key_points_fields( xy: Any, class_id: Any, confidence: Any, @@ -136,7 +136,7 @@ def validate_keypoints_fields( m = len(xy[0]) if len(xy) > 0 else 0 validate_xy(xy, n, m) validate_class_id(class_id, n) - validate_keypoint_confidence(confidence, n, m) + validate_key_point_confidence(confidence, n, m) validate_data(data, n) diff --git a/test/annotators/test_utils.py b/test/annotators/test_utils.py index 2fdccec116..3ab0f9b902 100644 --- a/test/annotators/test_utils.py +++ b/test/annotators/test_utils.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from supervision.annotators.utils import ColorLookup, resolve_color_idx +from supervision.annotators.utils import ColorLookup, resolve_color_idx, wrap_text from supervision.detection.core import Detections from test.test_utils import mock_detections @@ -97,7 +97,7 @@ def test_resolve_color_idx( detections: Detections, detection_idx: int, - color_lookup: ColorLookup, + color_lookup: ColorLookup | np.ndarray, expected_result: int | None, exception: Exception, ) -> None: @@ -108,3 +108,67 @@ def test_resolve_color_idx( color_lookup=color_lookup, ) assert result == expected_result + + +@pytest.mark.parametrize( + "text, max_line_length, expected_result, exception", + [ + (None, None, [""], DoesNotRaise()), # text is None + ("", None, [""], DoesNotRaise()), # empty string + (" \t ", 3, [""], DoesNotRaise()), # whitespace-only (spaces + tab) + (12345, None, ["12345"], DoesNotRaise()), # plain integer + (-6789, None, ["-6789"], DoesNotRaise()), # negative integer + (np.int64(1000), None, ["1000"], DoesNotRaise()), # NumPy int64 + ([1, 2, 3], None, ["[1, 2, 3]"], DoesNotRaise()), # list to string + ( + "When you play the game of thrones, you win or you die.\nFear cuts deeper than swords.\nA mind needs books as a sword needs a whetstone.", # noqa: E501 + None, + [ + "When you play the game of thrones, you win or you die.", + "Fear cuts deeper than swords.", + "A mind needs books as a sword needs a whetstone.", + ], + DoesNotRaise(), + ), # Game-of-Thrones quotes, multiline + ("\n", None, [""], DoesNotRaise()), # single newline + ( + "valarmorghulisvalardoharis", + 6, + ["valarm", "orghul", "isvala", "rdohar", "is"], + DoesNotRaise(), + ), # long Valyrian phrase, wrapped + ( + "Winter is coming\nFire and blood", + 10, + [ + "Winter is", + "coming", + "Fire and", + "blood", + ], + DoesNotRaise(), + ), # mix of short/long with newline + ( + "What is dead may never die", + 0, + None, + pytest.raises(ValueError), + ), # width 0 - invalid + ( + "A Lannister always pays his debts", + -1, + None, + pytest.raises(ValueError), + ), # width -1 - invalid + (None, 10, [""], DoesNotRaise()), # text None, width set + ], +) +def test_wrap_text( + text: object, + max_line_length: int | None, + expected_result: list[str], + exception: Exception, +) -> None: + with exception: + result = wrap_text(text=text, max_line_length=max_line_length) + assert result == expected_result diff --git a/test/detection/test_vlm.py b/test/detection/test_vlm.py index 1b4c2f1894..9a0195f780 100644 --- a/test/detection/test_vlm.py +++ b/test/detection/test_vlm.py @@ -6,7 +6,10 @@ import numpy as np import pytest +from supervision.config import CLASS_NAME_DATA_FIELD +from supervision.detection.core import Detections from supervision.detection.vlm import ( + VLM, from_florence_2, from_google_gemini_2_0, from_google_gemini_2_5, @@ -317,6 +320,43 @@ def test_from_paligemma( np.array(["dog"], dtype=str), ), ), # out-of-bounds box + ( + does_not_raise(), + """[ + {'bbox_2d': [10, 20, 110, 120], 'label': 'cat'} + ]""", + (640, 640), + (1280, 720), + None, + ( + np.array([[20.0, 22.5, 220.0, 135.0]]), + None, + np.array(["cat"], dtype=str), + ), + ), # python-style list, single quotes, no fences + ( + does_not_raise(), + """```json + [ + {"bbox_2d": [0, 0, 64, 64], "label": "dog"}, + {"bbox_2d": [10, 20, 110, 120], "label": "cat"}, + {"bbox_2d": [30, 40, 130, 140], "label": + """, + (640, 640), + (640, 640), + None, + ( + np.array( + [ + [0.0, 0.0, 64.0, 64.0], + [10.0, 20.0, 110.0, 120.0], + ], + dtype=float, + ), + None, + np.array(["dog", "cat"], dtype=str), + ), + ), # truncated response, last object unfinished, previous ones recovered ( pytest.raises(ValueError), """```json @@ -327,8 +367,8 @@ def test_from_paligemma( (0, 640), (1280, 720), None, - None, # won't be compared because we expect an exception - ), # zero input width -> ValueError + None, # invalid input_wh + ), ( pytest.raises(ValueError), """```json @@ -339,8 +379,8 @@ def test_from_paligemma( (640, 640), (1280, -100), None, - None, - ), # negative resolution height -> ValueError + None, # invalid resolution_wh + ), ], ) def test_from_qwen_2_5_vl( @@ -1122,3 +1162,110 @@ def test_from_google_gemini_2_5( assert masks is not None assert masks.shape == expected_results[4].shape assert np.array_equal(masks, expected_results[4]) + + +@pytest.mark.parametrize( + "exception, result, resolution_wh, classes, expected_detections", + [ + ( + pytest.raises(ValueError), + "", + (100, 100), + None, + None, + ), # empty text + ( + pytest.raises(ValueError), + "random text", + (100, 100), + None, + None, + ), # random text + ( + does_not_raise(), + "<|ref|>cat<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>", + (1000, 1000), + None, + Detections( + xyxy=np.array([[100.1, 200.2, 300.3, 400.4]]), + class_id=np.array([0]), + data={CLASS_NAME_DATA_FIELD: np.array(["cat"])}, + ), + ), # single box, no classes + ( + does_not_raise(), + "<|ref|>cat<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>", + (1000, 1000), + ["cat", "dog"], + Detections( + xyxy=np.array([[100.1, 200.2, 300.3, 400.4]]), + class_id=np.array([0]), + data={CLASS_NAME_DATA_FIELD: np.array(["cat"])}, + ), + ), # single box, with classes + ( + does_not_raise(), + "<|ref|>person<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>", + (1000, 1000), + ["cat", "dog"], + Detections.empty(), + ), # single box, wrong class + ( + does_not_raise(), + ( + "<|ref|>cat<|/ref|><|det|>[[100, 200, 300, 400]]<|/det|>" + "<|ref|>dog<|/ref|><|det|>[[500, 600, 700, 800]]<|/det|>" + ), + (1000, 1000), + ["cat"], + Detections( + xyxy=np.array([[100.1, 200.2, 300.3, 400.4]]), + class_id=np.array([0]), + data={CLASS_NAME_DATA_FIELD: np.array(["cat"])}, + ), + ), # multiple boxes, one class correct + ( + pytest.raises(ValueError), + "<|ref|>cat<|/ref|>", + (100, 100), + None, + None, + ), # only ref + ( + pytest.raises(ValueError), + "<|det|>[[100, 200, 300, 400]]<|/det|>", + (100, 100), + None, + None, + ), # only det + ], +) +def test_from_deepseek_vl_2( + exception, + result: str, + resolution_wh: tuple[int, int], + classes: list[str] | None, + expected_detections: Detections, +): + with exception: + detections = Detections.from_vlm( + vlm=VLM.DEEPSEEK_VL_2, + result=result, + resolution_wh=resolution_wh, + classes=classes, + ) + + if expected_detections is None: + return + + assert len(detections) == len(expected_detections) + + if len(detections) == 0: + return + + assert np.allclose(detections.xyxy, expected_detections.xyxy, atol=1e-1) + assert np.array_equal(detections.class_id, expected_detections.class_id) + assert np.array_equal( + detections.data[CLASS_NAME_DATA_FIELD], + expected_detections.data[CLASS_NAME_DATA_FIELD], + ) diff --git a/test/detection/tools/test_inference_slicer.py b/test/detection/tools/test_inference_slicer.py index 2185b77f20..7c313841f3 100644 --- a/test/detection/tools/test_inference_slicer.py +++ b/test/detection/tools/test_inference_slicer.py @@ -1,13 +1,10 @@ from __future__ import annotations -from contextlib import ExitStack as DoesNotRaise - import numpy as np import pytest from supervision.detection.core import Detections from supervision.detection.tools.inference_slicer import InferenceSlicer -from supervision.detection.utils.iou_and_nms import OverlapFilter @pytest.fixture @@ -20,54 +17,10 @@ def callback(_: np.ndarray) -> Detections: return callback -@pytest.mark.parametrize( - "slice_wh, overlap_ratio_wh, overlap_wh, expected_overlap, exception", - [ - # Valid case: explicit overlap_wh in pixels - ((128, 128), None, (26, 26), (26, 26), DoesNotRaise()), - # Valid case: overlap_wh in pixels - ((128, 128), None, (20, 20), (20, 20), DoesNotRaise()), - # Invalid case: negative overlap_wh, should raise ValueError - ((128, 128), None, (-10, 20), None, pytest.raises(ValueError)), - # Invalid case: no overlaps defined - ((128, 128), None, None, None, pytest.raises(ValueError)), - # Valid case: overlap_wh = 50 pixels - ((256, 256), None, (50, 50), (50, 50), DoesNotRaise()), - # Valid case: overlap_wh = 60 pixels - ((200, 200), None, (60, 60), (60, 60), DoesNotRaise()), - # Valid case: small overlap_wh values - ((100, 100), None, (0.1, 0.1), (0.1, 0.1), DoesNotRaise()), - # Invalid case: negative overlap_wh values - ((128, 128), None, (-10, -10), None, pytest.raises(ValueError)), - # Invalid case: overlap_wh greater than slice size - ((128, 128), None, (150, 150), (150, 150), DoesNotRaise()), - # Valid case: zero overlap - ((128, 128), None, (0, 0), (0, 0), DoesNotRaise()), - ], -) -def test_inference_slicer_overlap( - mock_callback, - slice_wh: tuple[int, int], - overlap_ratio_wh: tuple[float, float] | None, - overlap_wh: tuple[int, int] | None, - expected_overlap: tuple[int, int] | None, - exception: Exception, -) -> None: - with exception: - slicer = InferenceSlicer( - callback=mock_callback, - slice_wh=slice_wh, - overlap_ratio_wh=overlap_ratio_wh, - overlap_wh=overlap_wh, - overlap_filter=OverlapFilter.NONE, - ) - assert slicer.overlap_wh == expected_overlap - - @pytest.mark.parametrize( "resolution_wh, slice_wh, overlap_wh, expected_offsets", [ - # Case 1: No overlap, exact slices fit within image dimensions + # Case 1: Square image, square slices, no overlap ( (256, 256), (128, 128), @@ -81,7 +34,7 @@ def test_inference_slicer_overlap( ] ), ), - # Case 2: Overlap of 64 pixels in both directions + # Case 2: Square image, square slices, non-zero overlap ( (256, 256), (128, 128), @@ -91,96 +44,154 @@ def test_inference_slicer_overlap( [0, 0, 128, 128], [64, 0, 192, 128], [128, 0, 256, 128], - [192, 0, 256, 128], [0, 64, 128, 192], [64, 64, 192, 192], [128, 64, 256, 192], - [192, 64, 256, 192], [0, 128, 128, 256], [64, 128, 192, 256], [128, 128, 256, 256], - [192, 128, 256, 256], - [0, 192, 128, 256], - [64, 192, 192, 256], - [128, 192, 256, 256], - [192, 192, 256, 256], ] ), ), - # Case 3: Image not perfectly divisible by slice size (no overlap) + # Case 3: Rectangle image (horizontal), square slices, no overlap ( - (300, 300), - (128, 128), + (192, 128), + (64, 64), (0, 0), np.array( [ - [0, 0, 128, 128], - [128, 0, 256, 128], - [256, 0, 300, 128], - [0, 128, 128, 256], - [128, 128, 256, 256], - [256, 128, 300, 256], - [0, 256, 128, 300], - [128, 256, 256, 300], - [256, 256, 300, 300], + [0, 0, 64, 64], + [64, 0, 128, 64], + [128, 0, 192, 64], + [0, 64, 64, 128], + [64, 64, 128, 128], + [128, 64, 192, 128], ] ), ), - # Case 4: Overlap of 32 pixels, image not perfectly divisible by slice size + # Case 4: Rectangle image (horizontal), square slices, non-zero overlap ( - (300, 300), - (128, 128), + (192, 128), + (64, 64), (32, 32), np.array( [ - [0, 0, 128, 128], - [96, 0, 224, 128], - [192, 0, 300, 128], - [288, 0, 300, 128], - [0, 96, 128, 224], - [96, 96, 224, 224], - [192, 96, 300, 224], - [288, 96, 300, 224], - [0, 192, 128, 300], - [96, 192, 224, 300], - [192, 192, 300, 300], - [288, 192, 300, 300], - [0, 288, 128, 300], - [96, 288, 224, 300], - [192, 288, 300, 300], - [288, 288, 300, 300], + [0, 0, 64, 64], + [32, 0, 96, 64], + [64, 0, 128, 64], + [96, 0, 160, 64], + [128, 0, 192, 64], + [0, 32, 64, 96], + [32, 32, 96, 96], + [64, 32, 128, 96], + [96, 32, 160, 96], + [128, 32, 192, 96], + [0, 64, 64, 128], + [32, 64, 96, 128], + [64, 64, 128, 128], + [96, 64, 160, 128], + [128, 64, 192, 128], ] ), ), - # Case 5: Image smaller than slice size (no overlap) + # Case 5: Rectangle image (vertical), square slices, no overlap ( - (100, 100), - (128, 128), + (128, 192), + (64, 64), (0, 0), np.array( [ - [0, 0, 100, 100], + [0, 0, 64, 64], + [64, 0, 128, 64], + [0, 64, 64, 128], + [64, 64, 128, 128], + [0, 128, 64, 192], + [64, 128, 128, 192], + ] + ), + ), + # Case 6: Rectangle image (vertical), square slices, non-zero overlap + ( + (128, 192), + (64, 64), + (32, 32), + np.array( + [ + [0, 0, 64, 64], + [32, 0, 96, 64], + [64, 0, 128, 64], + [0, 32, 64, 96], + [32, 32, 96, 96], + [64, 32, 128, 96], + [0, 64, 64, 128], + [32, 64, 96, 128], + [64, 64, 128, 128], + [0, 96, 64, 160], + [32, 96, 96, 160], + [64, 96, 128, 160], + [0, 128, 64, 192], + [32, 128, 96, 192], + [64, 128, 128, 192], + ] + ), + ), + # Case 7: Square image, rectangular slices (horizontal), no overlap + ( + (160, 160), + (80, 40), + (0, 0), + np.array( + [ + [0, 0, 80, 40], + [80, 0, 160, 40], + [0, 40, 80, 80], + [80, 40, 160, 80], + [0, 80, 80, 120], + [80, 80, 160, 120], + [0, 120, 80, 160], + [80, 120, 160, 160], + ] + ), + ), + # Case 8: Square image, rectangular slices (vertical), non-zero overlap + ( + (160, 160), + (40, 80), + (10, 20), + np.array( + [ + [0, 0, 40, 80], + [30, 0, 70, 80], + [60, 0, 100, 80], + [90, 0, 130, 80], + [120, 0, 160, 80], + [0, 60, 40, 140], + [30, 60, 70, 140], + [60, 60, 100, 140], + [90, 60, 130, 140], + [120, 60, 160, 140], + [0, 80, 40, 160], + [30, 80, 70, 160], + [60, 80, 100, 160], + [90, 80, 130, 160], + [120, 80, 160, 160], ] ), ), - # Case 6: Overlap_wh is greater than the slice size - ((256, 256), (128, 128), (150, 150), np.array([]).reshape(0, 4)), ], ) def test_generate_offset( resolution_wh: tuple[int, int], slice_wh: tuple[int, int], - overlap_wh: tuple[int, int] | None, + overlap_wh: tuple[int, int], expected_offsets: np.ndarray, ) -> None: offsets = InferenceSlicer._generate_offset( resolution_wh=resolution_wh, slice_wh=slice_wh, - overlap_ratio_wh=None, overlap_wh=overlap_wh, ) - # Verify that the generated offsets match the expected offsets assert np.array_equal(offsets, expected_offsets), ( f"Expected {expected_offsets}, got {offsets}" ) diff --git a/test/detection/utils/test_boxes.py b/test/detection/utils/test_boxes.py index 919989287a..66d0d999c8 100644 --- a/test/detection/utils/test_boxes.py +++ b/test/detection/utils/test_boxes.py @@ -5,7 +5,12 @@ import numpy as np import pytest -from supervision.detection.utils.boxes import clip_boxes, move_boxes, scale_boxes +from supervision.detection.utils.boxes import ( + clip_boxes, + denormalize_boxes, + move_boxes, + scale_boxes, +) @pytest.mark.parametrize( @@ -142,3 +147,88 @@ def test_scale_boxes( with exception: result = scale_boxes(xyxy=xyxy, factor=factor) assert np.array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "xyxy, resolution_wh, normalization_factor, expected_result, exception", + [ + ( + np.empty(shape=(0, 4)), + (1280, 720), + 1.0, + np.empty(shape=(0, 4)), + DoesNotRaise(), + ), # empty array + ( + np.array([[0.1, 0.2, 0.5, 0.6]]), + (1280, 720), + 1.0, + np.array([[128.0, 144.0, 640.0, 432.0]]), + DoesNotRaise(), + ), # single box with default normalization + ( + np.array([[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8]]), + (1280, 720), + 1.0, + np.array([[128.0, 144.0, 640.0, 432.0], [384.0, 288.0, 896.0, 576.0]]), + DoesNotRaise(), + ), # two boxes with default normalization + ( + np.array( + [[0.1, 0.2, 0.5, 0.6], [0.3, 0.4, 0.7, 0.8], [0.2, 0.1, 0.6, 0.5]] + ), + (1280, 720), + 1.0, + np.array( + [ + [128.0, 144.0, 640.0, 432.0], + [384.0, 288.0, 896.0, 576.0], + [256.0, 72.0, 768.0, 360.0], + ] + ), + DoesNotRaise(), + ), # three boxes - regression test for issue #1959 + ( + np.array([[10.0, 20.0, 50.0, 60.0]]), + (100, 200), + 100.0, + np.array([[10.0, 40.0, 50.0, 120.0]]), + DoesNotRaise(), + ), # single box with custom normalization factor + ( + np.array([[10.0, 20.0, 50.0, 60.0], [30.0, 40.0, 70.0, 80.0]]), + (100, 200), + 100.0, + np.array([[10.0, 40.0, 50.0, 120.0], [30.0, 80.0, 70.0, 160.0]]), + DoesNotRaise(), + ), # two boxes with custom normalization factor + ( + np.array([[0.0, 0.0, 1.0, 1.0]]), + (1920, 1080), + 1.0, + np.array([[0.0, 0.0, 1920.0, 1080.0]]), + DoesNotRaise(), + ), # full frame box + ( + np.array([[0.5, 0.5, 0.5, 0.5]]), + (640, 480), + 1.0, + np.array([[320.0, 240.0, 320.0, 240.0]]), + DoesNotRaise(), + ), # zero-area box (point) + ], +) +def test_denormalize_boxes( + xyxy: np.ndarray, + resolution_wh: tuple[int, int], + normalization_factor: float, + expected_result: np.ndarray, + exception: Exception, +) -> None: + with exception: + result = denormalize_boxes( + xyxy=xyxy, + resolution_wh=resolution_wh, + normalization_factor=normalization_factor, + ) + assert np.allclose(result, expected_result) diff --git a/test/detection/utils/test_converters.py b/test/detection/utils/test_converters.py index e13b150042..52a3b52004 100644 --- a/test/detection/utils/test_converters.py +++ b/test/detection/utils/test_converters.py @@ -6,6 +6,7 @@ from supervision.detection.utils.converters import ( xcycwh_to_xyxy, xywh_to_xyxy, + xyxy_to_mask, xyxy_to_xcycarh, xyxy_to_xywh, ) @@ -129,3 +130,174 @@ def test_xyxy_to_xcycarh(xyxy: np.ndarray, expected_result: np.ndarray) -> None: def test_xcycwh_to_xyxy(xcycwh: np.ndarray, expected_result: np.ndarray) -> None: result = xcycwh_to_xyxy(xcycwh) np.testing.assert_array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "boxes,resolution_wh,expected", + [ + # 0) Empty input + ( + np.array([], dtype=float).reshape(0, 4), + (5, 4), + np.array([], dtype=bool).reshape(0, 4, 5), + ), + # 1) Single pixel box + ( + np.array([[2, 1, 2, 1]], dtype=float), + (5, 4), + np.array( + [ + [ + [False, False, False, False, False], + [False, False, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 2) Horizontal line, inclusive bounds + ( + np.array([[1, 2, 3, 2]], dtype=float), + (5, 4), + np.array( + [ + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 3) Vertical line, inclusive bounds + ( + np.array([[3, 0, 3, 2]], dtype=float), + (5, 4), + np.array( + [ + [ + [False, False, False, True, False], + [False, False, False, True, False], + [False, False, False, True, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 4) Proper rectangle fill + ( + np.array([[1, 1, 3, 2]], dtype=float), + (5, 4), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 5) Negative coordinates clipped to [0, 0] + ( + np.array([[-2, -1, 1, 1]], dtype=float), + (5, 4), + np.array( + [ + [ + [True, True, False, False, False], + [True, True, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 6) Overflow coordinates clipped to width-1 and height-1 + ( + np.array([[3, 2, 10, 10]], dtype=float), + (5, 4), + np.array( + [ + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + ] + ], + dtype=bool, + ), + ), + # 7) Invalid box where max < min after ints, mask stays empty + ( + np.array([[3, 2, 1, 4]], dtype=float), + (5, 4), + np.array( + [ + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 8) Fractional coordinates are floored by int conversion + # (0.2,0.2)-(2.8,1.9) -> (0,0)-(2,1) + ( + np.array([[0.2, 0.2, 2.8, 1.9]], dtype=float), + (5, 4), + np.array( + [ + [ + [True, True, True, False, False], + [True, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ] + ], + dtype=bool, + ), + ), + # 9) Multiple boxes, separate masks + ( + np.array([[0, 0, 1, 0], [2, 1, 4, 3]], dtype=float), + (5, 4), + np.array( + [ + # Box 0: row 0, cols 0..1 + [ + [True, True, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + # Box 1: rows 1..3, cols 2..4 + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + ], + ], + dtype=bool, + ), + ), + ], +) +def test_xyxy_to_mask(boxes: np.ndarray, resolution_wh, expected: np.ndarray) -> None: + result = xyxy_to_mask(boxes, resolution_wh) + assert result.dtype == np.bool_ + assert result.shape == expected.shape + np.testing.assert_array_equal(result, expected) diff --git a/test/detection/utils/test_iou_and_nms.py b/test/detection/utils/test_iou_and_nms.py index 8039bf2425..ab7586483c 100644 --- a/test/detection/utils/test_iou_and_nms.py +++ b/test/detection/utils/test_iou_and_nms.py @@ -6,11 +6,15 @@ import pytest from supervision.detection.utils.iou_and_nms import ( + OverlapMetric, _group_overlapping_boxes, + box_iou, + box_iou_batch, box_non_max_suppression, mask_non_max_merge, mask_non_max_suppression, ) +from test.test_utils import random_boxes @pytest.mark.parametrize( @@ -631,3 +635,497 @@ def test_mask_non_max_merge( sorted_result = sorted([sorted(group) for group in result]) sorted_expected_result = sorted([sorted(group) for group in expected_result]) assert sorted_result == sorted_expected_result + + +@pytest.mark.parametrize( + "box_true, box_detection, overlap_metric, expected_overlap, exception", + [ + ( + [100.0, 100.0, 200.0, 200.0], + [150.0, 150.0, 250.0, 250.0], + OverlapMetric.IOU, + 0.14285714285714285, + DoesNotRaise(), + ), # partial overlap, IOU + ( + [100.0, 100.0, 200.0, 200.0], + [150.0, 150.0, 250.0, 250.0], + OverlapMetric.IOS, + 0.25, + DoesNotRaise(), + ), # partial overlap, IOS + ( + np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32), + np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32), + OverlapMetric.IOU, + 1.0, + DoesNotRaise(), + ), # identical boxes, both boxes are arrays, IOU + ( + np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32), + np.array([0.0, 0.0, 10.0, 10.0], dtype=np.float32), + OverlapMetric.IOS, + 1.0, + DoesNotRaise(), + ), # identical boxes, both boxes are arrays, IOS + ( + [0.0, 0.0, 10.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + "iou", + 1.0, + DoesNotRaise(), + ), # identical boxes, both boxes are arrays, IOU as lowercase string + ( + [0.0, 0.0, 10.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + "ios", + 1.0, + DoesNotRaise(), + ), # identical boxes, both boxes are arrays, IOS as lowercase string + ( + [0.0, 0.0, 10.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + "IOU", + 1.0, + DoesNotRaise(), + ), # identical boxes, both boxes are arrays, IOU as uppercase string + ( + [0.0, 0.0, 10.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + "IOU", + 1.0, + DoesNotRaise(), + ), # identical boxes, both boxes are arrays, IOS as uppercase string + ( + [0.0, 0.0, 10.0, 10.0], + [20.0, 20.0, 30.0, 30.0], + OverlapMetric.IOU, + 0.0, + DoesNotRaise(), + ), # no overlap, IOU + ( + [0.0, 0.0, 10.0, 10.0], + [20.0, 20.0, 30.0, 30.0], + OverlapMetric.IOS, + 0.0, + DoesNotRaise(), + ), # no overlap, IOS + ( + [0.0, 0.0, 10.0, 10.0], + [10.0, 0.0, 20.0, 10.0], + OverlapMetric.IOU, + 0.0, + DoesNotRaise(), + ), # boxes touch at edge, zero intersection, IOU + ( + [0.0, 0.0, 10.0, 10.0], + [10.0, 0.0, 20.0, 10.0], + OverlapMetric.IOS, + 0.0, + DoesNotRaise(), + ), # boxes touch at edge, zero intersection, IOU + ( + [0.0, 0.0, 10.0, 10.0], + [2.0, 2.0, 8.0, 8.0], + OverlapMetric.IOU, + 0.36, + DoesNotRaise(), + ), # one box inside another, IOU + ( + [0.0, 0.0, 10.0, 10.0], + [2.0, 2.0, 8.0, 8.0], + OverlapMetric.IOS, + 1.0, + DoesNotRaise(), + ), # one box inside another, IOS + ( + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 10.0, 10.0], + OverlapMetric.IOU, + 0.0, + DoesNotRaise(), + ), # degenerate true box with zero area, IOU + ( + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 10.0, 10.0], + OverlapMetric.IOS, + 0.0, + DoesNotRaise(), + ), # degenerate true box with zero area, IOS + ( + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + OverlapMetric.IOU, + 0.0, + DoesNotRaise(), + ), # both boxes fully degenerate, IOU + ( + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + OverlapMetric.IOS, + 0.0, + DoesNotRaise(), + ), # both boxes fully degenerate, IOS + ( + [-5.0, 0.0, 5.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + OverlapMetric.IOU, + 1.0 / 3.0, + DoesNotRaise(), + ), # negative x_min, overlapping boxes, IOU is 1/3 + ( + [-5.0, 0.0, 5.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + OverlapMetric.IOS, + 0.5, + DoesNotRaise(), + ), # negative x_min, overlapping boxes, IOS is 0.5 + ( + [0.0, 0.0, 1.0, 1.0], + [0.5, 0.5, 1.5, 1.5], + OverlapMetric.IOU, + 0.14285714285714285, + DoesNotRaise(), + ), # partial overlap with fractional coordinates, IOU + ( + [0.0, 0.0, 1.0, 1.0], + [0.5, 0.5, 1.5, 1.5], + OverlapMetric.IOS, + 0.25, + DoesNotRaise(), + ), # partial overlap with fractional coordinates, IOS + ], +) +def test_box_iou( + box_true: list[float] | np.ndarray, + box_detection: list[float] | np.ndarray, + overlap_metric: str | OverlapMetric, + expected_overlap: float, + exception: Exception, +) -> None: + with exception: + result = box_iou( + box_true=box_true, + box_detection=box_detection, + overlap_metric=overlap_metric, + ) + assert result == pytest.approx(expected_overlap, rel=1e-6, abs=1e-12) + + +@pytest.mark.parametrize( + "boxes_true, boxes_detection, overlap_metric, expected_overlap, exception", + [ + # both inputs empty + ( + np.empty((0, 4), dtype=np.float32), + np.empty((0, 4), dtype=np.float32), + OverlapMetric.IOU, + np.empty((0, 0), dtype=np.float32), + DoesNotRaise(), + ), + # one true box, no detections + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.empty((0, 4), dtype=np.float32), + OverlapMetric.IOU, + np.empty((1, 0), dtype=np.float32), + DoesNotRaise(), + ), + # no true boxes, one detection + ( + np.empty((0, 4), dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + OverlapMetric.IOU, + np.empty((0, 1), dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 partial overlap, IOU + ( + np.array([[100.0, 100.0, 200.0, 200.0]], dtype=np.float32), + np.array([[150.0, 150.0, 250.0, 250.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.14285715]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 partial overlap, IOS + ( + np.array([[100.0, 100.0, 200.0, 200.0]], dtype=np.float32), + np.array([[150.0, 150.0, 250.0, 250.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.25]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 identical boxes, IOU as lowercase string + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + "iou", + np.array([[1.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 identical boxes, IOS as lowercase string + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + "ios", + np.array([[1.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 identical boxes, IOU as uppercase string + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + "IOU", + np.array([[1.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 identical boxes, IOS as uppercase string + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + "IOS", + np.array([[1.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 no overlap, IOU + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[20.0, 20.0, 30.0, 30.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 no overlap, IOS + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[20.0, 20.0, 30.0, 30.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 touching at edge, zero intersection, IOU + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[10.0, 0.0, 20.0, 10.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 touching at edge, zero intersection, IOS + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[10.0, 0.0, 20.0, 10.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 box inside another, IOU + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[2.0, 2.0, 8.0, 8.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.36]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 box inside another, IOS + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[2.0, 2.0, 8.0, 8.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[1.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 degenerate true box, IOU + ( + np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 degenerate true box, IOS + ( + np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 both boxes degenerate, IOU + ( + np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32), + np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 both boxes degenerate, IOS + ( + np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32), + np.array([[0.0, 0.0, 0.0, 0.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 negative coordinate, partial overlap, IOU + ( + np.array([[-5.0, 0.0, 5.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[1.0 / 3.0]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 negative coordinate, partial overlap, IOS + ( + np.array([[-5.0, 0.0, 5.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.5]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 fractional coordinates, partial overlap, IOU + ( + np.array([[0.0, 0.0, 1.0, 1.0]], dtype=np.float32), + np.array([[0.5, 0.5, 1.5, 1.5]], dtype=np.float32), + OverlapMetric.IOU, + np.array([[0.14285715]], dtype=np.float32), + DoesNotRaise(), + ), + # 1x1 fractional coordinates, partial overlap, IOS + ( + np.array([[0.0, 0.0, 1.0, 1.0]], dtype=np.float32), + np.array([[0.5, 0.5, 1.5, 1.5]], dtype=np.float32), + OverlapMetric.IOS, + np.array([[0.25]], dtype=np.float32), + DoesNotRaise(), + ), + # true batch case, 2x2, IOU + ( + np.array( + [ + [0.0, 0.0, 10.0, 10.0], + [10.0, 10.0, 20.0, 20.0], + ], + dtype=np.float32, + ), + np.array( + [ + [0.0, 0.0, 10.0, 10.0], + [5.0, 5.0, 15.0, 15.0], + ], + dtype=np.float32, + ), + OverlapMetric.IOU, + np.array( + [ + [1.0, 0.14285715], + [0.0, 0.14285715], + ], + dtype=np.float32, + ), + DoesNotRaise(), + ), + # true batch case, 2x2, IOS + ( + np.array( + [ + [0.0, 0.0, 10.0, 10.0], + [10.0, 10.0, 20.0, 20.0], + ], + dtype=np.float32, + ), + np.array( + [ + [0.0, 0.0, 10.0, 10.0], + [5.0, 5.0, 15.0, 15.0], + ], + dtype=np.float32, + ), + OverlapMetric.IOS, + np.array( + [ + [1.0, 0.25], + [0.0, 0.25], + ], + dtype=np.float32, + ), + DoesNotRaise(), + ), + # invalid overlap_metric + ( + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + np.array([[0.0, 0.0, 10.0, 10.0]], dtype=np.float32), + "invalid", + None, + pytest.raises(ValueError), + ), + ], +) +def test_box_iou_batch( + boxes_true: np.ndarray, + boxes_detection: np.ndarray, + overlap_metric: str | OverlapMetric, + expected_overlap: np.ndarray | None, + exception: Exception, +) -> None: + with exception: + result = box_iou_batch( + boxes_true=boxes_true, + boxes_detection=boxes_detection, + overlap_metric=overlap_metric, + ) + + assert isinstance(result, np.ndarray) + assert result.shape == expected_overlap.shape + assert np.allclose( + result, + expected_overlap, + rtol=1e-6, + atol=1e-12, + ) + + +@pytest.mark.parametrize( + "num_true, num_det", + [ + (5, 5), + (5, 10), + (10, 5), + (10, 10), + (20, 30), + (30, 20), + (50, 50), + (100, 100), + ], +) +@pytest.mark.parametrize( + "overlap_metric", + [OverlapMetric.IOU, OverlapMetric.IOS], +) +def test_box_iou_batch_symmetric_large( + num_true: int, + num_det: int, + overlap_metric: OverlapMetric, +) -> None: + boxes_true = random_boxes(num_true) + boxes_det = random_boxes(num_det) + + result_ab = box_iou_batch( + boxes_true=boxes_true, + boxes_detection=boxes_det, + overlap_metric=overlap_metric, + ) + result_ba = box_iou_batch( + boxes_true=boxes_det, + boxes_detection=boxes_true, + overlap_metric=overlap_metric, + ) + + assert result_ab.shape == (num_true, num_det) + assert result_ba.shape == (num_det, num_true) + assert np.allclose( + result_ab, + result_ba.T, + rtol=1e-6, + atol=1e-12, + ) diff --git a/test/detection/utils/test_masks.py b/test/detection/utils/test_masks.py index 2097f6082c..b41f208edb 100644 --- a/test/detection/utils/test_masks.py +++ b/test/detection/utils/test_masks.py @@ -10,6 +10,7 @@ calculate_masks_centroids, contains_holes, contains_multiple_segments, + filter_segments_by_distance, move_masks, ) @@ -500,3 +501,228 @@ def test_contains_multiple_segments( with exception: result = contains_multiple_segments(mask=mask, connectivity=connectivity) assert result == expected_result + + +@pytest.mark.parametrize( + "mask, connectivity, mode, absolute_distance, relative_distance, expected_result, exception", # noqa: E501 + [ + # single component, unchanged + ( + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, + None, + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # two components, edge distance 2, kept with abs=1 + ( + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, + None, + np.array( + [ + [0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # centroid mode, far centroids, dropped with small relative threshold + ( + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ], + dtype=bool, + ), + 8, + "centroid", + None, + 0.3, # diagonal ~8.49, threshold ~2.55, centroid gap ~4.24 + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # centroid mode, larger relative threshold, kept + ( + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ], + dtype=bool, + ), + 8, + "centroid", + None, + 0.6, # diagonal ~8.49, threshold ~5.09, centroid gap ~4.24 + np.array( + [ + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # empty mask + ( + np.zeros((4, 4), dtype=bool), + 4, + "edge", + 2.0, + None, + np.zeros((4, 4), dtype=bool), + DoesNotRaise(), + ), + # full mask + ( + np.ones((4, 4), dtype=bool), + 8, + "centroid", + None, + 0.2, + np.ones((4, 4), dtype=bool), + DoesNotRaise(), + ), + # two components, pixel distance = 2, kept with abs=2 + ( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, # was 1.0 + None, + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 1, 1, 1, 0, 1, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + # two components, pixel distance = 3, dropped with abs=2 + ( + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [0, 1, 1, 1, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + 8, + "edge", + 2.0, # keep threshold below 3 so the right blob is removed + None, + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=bool, + ), + DoesNotRaise(), + ), + ], +) +def test_filter_segments_by_distance_sweep( + mask: npt.NDArray, + connectivity: int, + mode: str, + absolute_distance: float | None, + relative_distance: float | None, + expected_result: npt.NDArray | None, + exception: Exception, +) -> None: + with exception: + result = filter_segments_by_distance( + mask=mask, + connectivity=connectivity, + mode=mode, # type: ignore[arg-type] + absolute_distance=absolute_distance, + relative_distance=relative_distance, + ) + assert np.array_equal(result, expected_result) diff --git a/test/detection/utils/test_vlms.py b/test/detection/utils/test_vlms.py new file mode 100644 index 0000000000..a6fe649b73 --- /dev/null +++ b/test/detection/utils/test_vlms.py @@ -0,0 +1,117 @@ +import pytest + +from supervision.detection.utils.vlms import edit_distance, fuzzy_match_index + + +@pytest.mark.parametrize( + "string_1, string_2, case_sensitive, expected_result", + [ + # identical strings, various cases + ("hello", "hello", True, 0), + ("hello", "hello", False, 0), + # case sensitive vs insensitive + ("Test", "test", True, 1), + ("Test", "test", False, 0), + ("CASE", "case", True, 4), + ("CASE", "case", False, 0), + # completely different + ("abc", "xyz", True, 3), + ("abc", "xyz", False, 3), + # one string empty + ("hello", "", True, 5), + ("", "world", True, 5), + # single character cases + ("a", "b", True, 1), + ("A", "a", True, 1), + ("A", "a", False, 0), + # whitespaces + ("hello world", "helloworld", True, 1), + ("test", " test", True, 1), + # unicode and emoji + ("😊", "😊", True, 0), + ("😊", "😒", True, 1), + # long string vs empty + ("a" * 100, "", True, 100), + ("", "b" * 100, True, 100), + # prefix/suffix + ("prefix", "prefixes", True, 2), + ("suffix", "asuffix", True, 1), + # leading/trailing whitespace + (" hello", "hello", True, 1), + ("hello", "hello ", True, 1), + # long almost-equal string + ( + "The quick brown fox jumps over the lazy dog", + "The quick brown fox jumps over the lazy cog", + True, + 1, + ), + ( + "The quick brown fox jumps over the lazy dog", + "The quick brown fox jumps over the lazy cog", + False, + 1, + ), + # both empty + ("", "", True, 0), + ("", "", False, 0), + # mixed case with symbols + ("123ABC!", "123abc!", True, 3), + ("123ABC!", "123abc!", False, 0), + ], +) +def test_edit_distance(string_1, string_2, case_sensitive, expected_result): + assert ( + edit_distance(string_1, string_2, case_sensitive=case_sensitive) + == expected_result + ) + + +@pytest.mark.parametrize( + "candidates, query, threshold, case_sensitive, expected_result", + [ + # exact match at index 0 + (["cat", "dog", "rat"], "cat", 0, True, 0), + # match at index 2 within threshold + (["cat", "dog", "rat"], "dat", 1, True, 0), + # no match due to high threshold + (["cat", "dog", "rat"], "bat", 0, True, None), + # multiple possible matches, returns first + (["apple", "apply", "appla"], "apple", 1, True, 0), + # case-insensitive match + (["Alpha", "beta", "Gamma"], "alpha", 0, False, 0), + # case-sensitive: no match + (["Alpha", "beta", "Gamma"], "alpha", 0, True, None), + # threshold boundary + (["alpha", "beta", "gamma"], "bata", 1, True, 1), + # no match (all distances too high) + (["one", "two", "three"], "ten", 1, True, None), + # unicode/emoji match + (["😊", "😒", "😁"], "πŸ˜„", 1, True, 0), + (["😊", "😒", "😁"], "😊", 0, True, 0), + # empty candidates + ([], "any", 2, True, None), + # empty query, non-empty candidates + (["", "abc"], "", 0, True, 0), + (["", "abc"], "", 1, True, 0), + (["a", "b", "c"], "", 1, True, 0), + # non-empty query, empty candidate + (["", ""], "a", 1, True, 0), + # all candidates require higher edit than threshold + (["short", "words", "only"], "longerword", 2, True, None), + # repeated candidates + (["a", "a", "a"], "b", 1, True, 0), + ], +) +def test_fuzzy_match_index( + candidates, query, threshold, case_sensitive, expected_result +): + assert ( + fuzzy_match_index( + candidates=candidates, + query=query, + threshold=threshold, + case_sensitive=case_sensitive, + ) + == expected_result + ) diff --git a/test/key_points/__init__.py b/test/key_points/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/key_points/test_core.py b/test/key_points/test_core.py new file mode 100644 index 0000000000..5a8244cd87 --- /dev/null +++ b/test/key_points/test_core.py @@ -0,0 +1,268 @@ +from contextlib import nullcontext as DoesNotRaise + +import numpy as np +import pytest + +from supervision.key_points.core import KeyPoints +from test.test_utils import mock_key_points + +KEY_POINTS = mock_key_points( + xy=[ + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], + [[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]], + ], + confidence=[ + [0.8, 0.2, 0.6, 0.1, 0.5], + [0.7, 0.9, 0.3, 0.4, 0.0], + [0.1, 0.6, 0.8, 0.2, 0.7], + ], + class_id=[0, 1, 2], +) + + +@pytest.mark.parametrize( + "key_points, index, expected_result, exception", + [ + ( + KeyPoints.empty(), + slice(None), + KeyPoints.empty(), + DoesNotRaise(), + ), # slice all key points when key points object empty + ( + KEY_POINTS, + slice(None), + KEY_POINTS, + DoesNotRaise(), + ), # slice all key points when key points object nonempty + ( + KEY_POINTS, + slice(0, 1), + mock_key_points( + xy=[[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]], + confidence=[[0.8, 0.2, 0.6, 0.1, 0.5]], + class_id=[0], + ), + DoesNotRaise(), + ), # select the first skeleton by slice + ( + KEY_POINTS, + slice(0, 2), + mock_key_points( + xy=[ + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], + ], + confidence=[ + [0.8, 0.2, 0.6, 0.1, 0.5], + [0.7, 0.9, 0.3, 0.4, 0.0], + ], + class_id=[0, 1], + ), + DoesNotRaise(), + ), # select the first skeleton by slice + ( + KEY_POINTS, + 0, + mock_key_points( + xy=[[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]], + confidence=[[0.8, 0.2, 0.6, 0.1, 0.5]], + class_id=[0], + ), + DoesNotRaise(), + ), # select the first skeleton by index + ( + KEY_POINTS, + -1, + mock_key_points( + xy=[[[20, 21], [22, 23], [24, 25], [26, 27], [28, 29]]], + confidence=[[0.1, 0.6, 0.8, 0.2, 0.7]], + class_id=[2], + ), + DoesNotRaise(), + ), # select the last skeleton by index + ( + KEY_POINTS, + [0, 1], + mock_key_points( + xy=[ + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], + ], + confidence=[ + [0.8, 0.2, 0.6, 0.1, 0.5], + [0.7, 0.9, 0.3, 0.4, 0.0], + ], + class_id=[0, 1], + ), + DoesNotRaise(), + ), # select the first two skeletons by index; list + ( + KEY_POINTS, + np.array([0, 1]), + mock_key_points( + xy=[ + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], + ], + confidence=[ + [0.8, 0.2, 0.6, 0.1, 0.5], + [0.7, 0.9, 0.3, 0.4, 0.0], + ], + class_id=[0, 1], + ), + DoesNotRaise(), + ), # select the first two skeletons by index; np.array + ( + KEY_POINTS, + [True, True, False], + mock_key_points( + xy=[ + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], + ], + confidence=[ + [0.8, 0.2, 0.6, 0.1, 0.5], + [0.7, 0.9, 0.3, 0.4, 0.0], + ], + class_id=[0, 1], + ), + DoesNotRaise(), + ), # select only skeletons associated with positive filter; list + ( + KEY_POINTS, + np.array([True, True, False]), + mock_key_points( + xy=[ + [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + [[10, 11], [12, 13], [14, 15], [16, 17], [18, 19]], + ], + confidence=[ + [0.8, 0.2, 0.6, 0.1, 0.5], + [0.7, 0.9, 0.3, 0.4, 0.0], + ], + class_id=[0, 1], + ), + DoesNotRaise(), + ), # select only skeletons associated with positive filter; list + ( + KEY_POINTS, + (slice(None), slice(None)), + KEY_POINTS, + DoesNotRaise(), + ), # slice all anchors from all skeletons + ( + KEY_POINTS, + (slice(None), slice(0, 1)), + mock_key_points( + xy=[[[0, 1]], [[10, 11]], [[20, 21]]], + confidence=[[0.8], [0.7], [0.1]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # slice the first anchor from every skeleton + ( + KEY_POINTS, + (slice(None), slice(0, 2)), + mock_key_points( + xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]], + confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # slice the first anchor two anchors from every skeleton + ( + KEY_POINTS, + (slice(None), 0), + mock_key_points( + xy=[[[0, 1]], [[10, 11]], [[20, 21]]], + confidence=[[0.8], [0.7], [0.1]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # select the first anchor from every skeleton by index + ( + KEY_POINTS, + (slice(None), -1), + mock_key_points( + xy=[[[8, 9]], [[18, 19]], [[28, 29]]], + confidence=[[0.5], [0.0], [0.7]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # select the last anchor from every skeleton by index + ( + KEY_POINTS, + (slice(None), [0, 1]), + mock_key_points( + xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]], + confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # select the first two anchors from every skeleton by index; list + ( + KEY_POINTS, + (slice(None), np.array([0, 1])), + mock_key_points( + xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]], + confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # select the first two anchors from every skeleton by index; np.array + ( + KEY_POINTS, + (slice(None), [True, True, False, False, False]), + mock_key_points( + xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]], + confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # select only anchors associated with positive filter; list + ( + KEY_POINTS, + (slice(None), np.array([True, True, False, False, False])), + mock_key_points( + xy=[[[0, 1], [2, 3]], [[10, 11], [12, 13]], [[20, 21], [22, 23]]], + confidence=[[0.8, 0.2], [0.7, 0.9], [0.1, 0.6]], + class_id=[0, 1, 2], + ), + DoesNotRaise(), + ), # select only anchors associated with positive filter; np.array + ( + KEY_POINTS, + (0, 0), + mock_key_points( + xy=[ + [[0, 1]], + ], + confidence=[ + [0.8], + ], + class_id=[0], + ), + DoesNotRaise(), + ), # select the first anchor from the first skeleton by index + ( + KEY_POINTS, + (0, -1), + mock_key_points( + xy=[ + [[8, 9]], + ], + confidence=[ + [0.5], + ], + class_id=[0], + ), + DoesNotRaise(), + ), # select the last anchor from the first skeleton by index + ], +) +def test_key_points_getitem(key_points, index, expected_result, exception): + with exception: + result = key_points[index] + assert result == expected_result diff --git a/test/test_utils.py b/test/test_utils.py index 19fffad5bb..961f8cee1b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,16 +1,16 @@ from __future__ import annotations +import random from typing import Any import numpy as np -import numpy.typing as npt from supervision.detection.core import Detections -from supervision.keypoint.core import KeyPoints +from supervision.key_points.core import KeyPoints def mock_detections( - xyxy: npt.NDArray[np.float32], + xyxy: list[list[float]], mask: list[np.ndarray] | None = None, confidence: list[float] | None = None, class_id: list[int] | None = None, @@ -34,9 +34,9 @@ def convert_data(data: dict[str, list[Any]]): ) -def mock_keypoints( - xy: npt.NDArray[np.float32], - confidence: list[float] | None = None, +def mock_key_points( + xy: list[list[list[float]]], + confidence: list[list[float]] | None = None, class_id: list[int] | None = None, data: dict[str, list[Any]] | None = None, ) -> KeyPoints: @@ -53,5 +53,49 @@ def convert_data(data: dict[str, list[Any]]): ) +def random_boxes( + count: int, + image_size: tuple[int, int] = (1920, 1080), + min_box_size: int = 20, + max_box_size: int = 200, + seed: int | None = None, +) -> np.ndarray: + """ + Generate random bounding boxes within given image dimensions and size constraints. + + Creates `count` bounding boxes randomly positioned and sized, ensuring each + stays within image bounds and has width and height in the specified range. + + Args: + count (`int`): Number of random bounding boxes to generate. + image_size (`tuple[int, int]`): Image size as `(width, height)`. + min_box_size (`int`): Minimum side length (pixels) for generated boxes. + max_box_size (`int`): Maximum side length (pixels) for generated boxes. + seed (`int` or `None`): Optional random seed for reproducibility. + + Returns: + (`numpy.ndarray`): Array of shape `(count, 4)` with bounding boxes as + `(x_min, y_min, x_max, y_max)`. + """ + if seed is not None: + random.seed(seed) + + img_w, img_h = image_size + out = np.zeros((count, 4), dtype=np.float32) + + for i in range(count): + w = random.uniform(min_box_size, max_box_size) + h = random.uniform(min_box_size, max_box_size) + + x_min = random.uniform(0, img_w - w) + y_min = random.uniform(0, img_h - h) + x_max = x_min + w + y_max = y_min + h + + out[i] = (x_min, y_min, x_max, y_max) + + return out + + def assert_almost_equal(actual, expected, tolerance=1e-5): assert abs(actual - expected) < tolerance, f"Expected {expected}, but got {actual}." diff --git a/test/utils/test_conversion.py b/test/utils/test_conversion.py index 65cbd8a1ca..e9fabb0d81 100644 --- a/test/utils/test_conversion.py +++ b/test/utils/test_conversion.py @@ -3,7 +3,7 @@ from supervision.utils.conversion import ( cv2_to_pillow, - ensure_cv2_image_for_processing, + ensure_cv2_image_for_standalone_function, images_to_cv2, pillow_to_cv2, ) @@ -16,7 +16,7 @@ def test_ensure_cv2_image_for_processing_when_pillow_image_submitted( param_a_value = 3 param_b_value = "some" - @ensure_cv2_image_for_processing + @ensure_cv2_image_for_standalone_function def my_custom_processing_function( image: np.ndarray, param_a: int, @@ -55,7 +55,7 @@ def test_ensure_cv2_image_for_processing_when_cv2_image_submitted( param_a_value = 3 param_b_value = "some" - @ensure_cv2_image_for_processing + @ensure_cv2_image_for_standalone_function def my_custom_processing_function( image: np.ndarray, param_a: int, diff --git a/test/utils/test_image.py b/test/utils/test_image.py index 39640330e7..688f938b70 100644 --- a/test/utils/test_image.py +++ b/test/utils/test_image.py @@ -2,8 +2,12 @@ import pytest from PIL import Image, ImageChops -from supervision import Color, Point -from supervision.utils.image import create_tiles, letterbox_image, resize_image +from supervision.utils.image import ( + crop_image, + get_image_resolution_wh, + letterbox_image, + resize_image, +) def test_resize_image_for_opencv_image() -> None: @@ -98,145 +102,59 @@ def test_letterbox_image_for_pillow_image() -> None: ) -def test_create_tiles_with_one_image( - one_image: np.ndarray, single_image_tile: np.ndarray -) -> None: - # when - result = create_tiles(images=[one_image], single_tile_size=(240, 240)) - - # # then - assert np.allclose(result, single_image_tile, atol=5.0) - - -def test_create_tiles_with_one_image_and_enforced_grid( - one_image: np.ndarray, single_image_tile_enforced_grid: np.ndarray -) -> None: - # when - result = create_tiles( - images=[one_image], - grid_size=(None, 3), - single_tile_size=(240, 240), - ) - - # then - assert np.allclose(result, single_image_tile_enforced_grid, atol=5.0) - - -def test_create_tiles_with_two_images( - two_images: list[np.ndarray], two_images_tile: np.ndarray -) -> None: - # when - result = create_tiles(images=two_images, single_tile_size=(240, 240)) - - # then - assert np.allclose(result, two_images_tile, atol=5.0) - - -def test_create_tiles_with_three_images( - three_images: list[np.ndarray], three_images_tile: np.ndarray -) -> None: - # when - result = create_tiles(images=three_images, single_tile_size=(240, 240)) - - # then - assert np.allclose(result, three_images_tile, atol=5.0) - - -def test_create_tiles_with_four_images( - four_images: list[np.ndarray], - four_images_tile: np.ndarray, -) -> None: - # when - result = create_tiles(images=four_images, single_tile_size=(240, 240)) - - # then - assert np.allclose(result, four_images_tile, atol=5.0) - - -def test_create_tiles_with_all_images( - all_images: list[np.ndarray], - all_images_tile: np.ndarray, -) -> None: - # when - result = create_tiles(images=all_images, single_tile_size=(240, 240)) - - # then - assert np.allclose(result, all_images_tile, atol=5.0) - - -def test_create_tiles_with_all_images_and_custom_grid( - all_images: list[np.ndarray], all_images_tile_and_custom_grid: np.ndarray -) -> None: - # when - result = create_tiles( - images=all_images, - grid_size=(3, 3), - single_tile_size=(240, 240), - ) - - # then - assert np.allclose(result, all_images_tile_and_custom_grid, atol=5.0) - - -def test_create_tiles_with_all_images_and_custom_colors( - all_images: list[np.ndarray], all_images_tile_and_custom_colors: np.ndarray -) -> None: - # when - result = create_tiles( - images=all_images, - tile_margin_color=(127, 127, 127), - tile_padding_color=(224, 224, 224), - single_tile_size=(240, 240), - ) - - # then - assert np.allclose(result, all_images_tile_and_custom_colors, atol=5.0) - - -def test_create_tiles_with_all_images_and_titles( - all_images: list[np.ndarray], - all_images_tile_and_custom_colors_and_titles: np.ndarray, -) -> None: - # when - result = create_tiles( - images=all_images, - titles=["Image 1", None, "Image 3", "Image 4"], - single_tile_size=(240, 240), - ) - - # then - assert np.allclose(result, all_images_tile_and_custom_colors_and_titles, atol=5.0) - - -def test_create_tiles_with_all_images_and_titles_with_custom_configs( - all_images: list[np.ndarray], - all_images_tile_and_titles_with_custom_configs: np.ndarray, -) -> None: - # when - result = create_tiles( - images=all_images, - titles=["Image 1", None, "Image 3", "Image 4"], - single_tile_size=(240, 240), - titles_anchors=[ - Point(x=200, y=300), - Point(x=300, y=400), - None, - Point(x=300, y=400), - ], - titles_color=Color.RED, - titles_scale=1.5, - titles_thickness=3, - titles_padding=20, - titles_background_color=Color.BLACK, - default_title_placement="bottom", - ) - - # then - assert np.allclose(result, all_images_tile_and_titles_with_custom_configs, atol=5.0) - - -def test_create_tiles_with_all_images_and_custom_grid_to_small_to_fit_images( - all_images: list[np.ndarray], -) -> None: - with pytest.raises(ValueError): - _ = create_tiles(images=all_images, grid_size=(2, 2)) +@pytest.mark.parametrize( + "image, xyxy, expected_size", + [ + # NumPy RGB + ( + np.zeros((4, 6, 3), dtype=np.uint8), + (2, 1, 5, 3), + (3, 2), # width = 5-2, height = 3-1 + ), + # NumPy grayscale + ( + np.zeros((5, 5), dtype=np.uint8), + (1, 1, 4, 4), + (3, 3), + ), + # Pillow RGB + ( + Image.new("RGB", (6, 4), color=0), + (2, 1, 5, 3), + (3, 2), + ), + # Pillow grayscale + ( + Image.new("L", (5, 5), color=0), + (1, 1, 4, 4), + (3, 3), + ), + ], +) +def test_crop_image(image, xyxy, expected_size): + cropped = crop_image(image=image, xyxy=xyxy) + if isinstance(image, np.ndarray): + assert isinstance(cropped, np.ndarray) + assert cropped.shape[1] == expected_size[0] # width + assert cropped.shape[0] == expected_size[1] # height + else: + assert isinstance(cropped, Image.Image) + assert cropped.size == expected_size + + +@pytest.mark.parametrize( + "image, expected", + [ + # NumPy RGB + (np.zeros((4, 6, 3), dtype=np.uint8), (6, 4)), + # NumPy grayscale + (np.zeros((10, 20), dtype=np.uint8), (20, 10)), + # Pillow RGB + (Image.new("RGB", (6, 4), color=0), (6, 4)), + # Pillow grayscale + (Image.new("L", (20, 10), color=0), (20, 10)), + ], +) +def test_get_image_resolution_wh(image, expected): + resolution = get_image_resolution_wh(image) + assert resolution == expected diff --git a/test/utils/test_internal.py b/test/utils/test_internal.py index 07674f39cb..749d4be3a8 100644 --- a/test/utils/test_internal.py +++ b/test/utils/test_internal.py @@ -145,6 +145,7 @@ def __private_property(self): "metadata", "area", "box_area", + "box_aspect_ratio", }, DoesNotRaise(), ),