diff --git a/README.md b/README.md index be35fcb..2b77fae 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ A collection of plugins for [Lightly Studio](https://github.com/lightly-ai/light | Plugin | Description | Maintainer | Install | |---|---|---|---| | [BBox auto propagation nano tracker](plugins/bbox_auto_propagation_nano_tracker/)|Auto bbox propagation using nano tracker|Lightly| `pip install git+https://github.com/lightly-ai/lightly-studio-plugins.git#subdirectory=plugins/bbox_auto_propagation_nano_tracker/`| +| [SAM3 Segmentation](plugins/sam3_segmentation/)|Automatic instance segmentation using SAM3 with a text prompt. Requires HuggingFace access to `facebook/sam3`.|Lightly| `pip install git+https://github.com/lightly-ai/lightly-studio-plugins.git#subdirectory=plugins/sam3_segmentation/`| | [LightlyTrain object detection inference](plugins/lightly_train_object_detection_inference/)|LightlyTrain inference operator for object detection auto-labeling|Lightly| `pip install git+https://github.com/lightly-ai/lightly-studio-plugins.git#subdirectory=plugins/lightly_train_object_detection_inference/`| ## Adding a New Plugin diff --git a/plugins.toml b/plugins.toml index 554d198..5d9f5df 100644 --- a/plugins.toml +++ b/plugins.toml @@ -6,6 +6,14 @@ source = "local:plugins/bbox_auto_propagation_nano_tracker" maintainer = "lightly" tags = ["auto-labeling", "tracking"] +[[plugins]] +id = "lightly_plugins_sam3_segmentation" +name = "SAM3 Segmentation" +description = "Automatic instance segmentation using SAM3 with a text prompt" +source = "local:plugins/sam3_segmentation" +maintainer = "lightly" +tags = ["auto-labeling", "segmentation"] + [[plugins]] id = "lightly_plugins_lightly_train_object_detection_inference" name = "LightlyTrain object detection inference" diff --git a/plugins/sam3_segmentation/LICENSE b/plugins/sam3_segmentation/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/plugins/sam3_segmentation/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/plugins/sam3_segmentation/Makefile b/plugins/sam3_segmentation/Makefile new file mode 100644 index 0000000..394c786 --- /dev/null +++ b/plugins/sam3_segmentation/Makefile @@ -0,0 +1,13 @@ +.PHONY: install format type-check + +install: + uv venv + uv pip install -r ../../dev-requirements.txt + uv pip install -e . + +format: + uv run ruff format . + uv run ruff check --fix . + +type-check: + uv run mypy --config-file ../../mypy.ini . diff --git a/plugins/sam3_segmentation/README.md b/plugins/sam3_segmentation/README.md new file mode 100644 index 0000000..76e6241 --- /dev/null +++ b/plugins/sam3_segmentation/README.md @@ -0,0 +1,42 @@ +# SAM3 Segmentation Plugin + +Automatic instance segmentation using [SAM3](https://huggingface.co/facebook/sam3) with a text prompt. Runs on image collections in Lightly Studio. + +## Setup + +### 1. Request access to the model + +Visit [facebook/sam3](https://huggingface.co/facebook/sam3) on HuggingFace and request access. + +### 2. Authenticate with HuggingFace + +```bash +hf auth login +``` + +Paste your HuggingFace token when prompted. Generate one at https://huggingface.co/settings/tokens (needs read access). + +### 3. Install the plugin + +```bash +uv pip install "git+https://github.com/lightly-ai/lightly-studio-plugins.git#subdirectory=plugins/sam3_segmentation/" +``` + +### 4. GPU (optional) + +By default the plugin runs on CUDA if available. To use a CUDA GPU, reinstall PyTorch with the appropriate CUDA build: + +```bash +uv pip install torch --index-url https://download.pytorch.org/whl/cu121 +``` + +If CUDA is not available, the plugin will run on CPU automatically. + +## Parameters + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `model_id` | string | `"facebook/sam3"` | HuggingFace model ID — `facebook/sam3` or `facebook/sam3.1` | +| `prompt` | string | `"person"` | Text describing what to segment (e.g. `"car"`, `"dog"`) | +| `confidence_threshold` | float | `0.5` | Minimum score to keep a prediction | +| `collection_name` | string | `"SAM3_auto_label"` | Target annotation collection for generated segmentations. Override this to store the results in a different collection. | diff --git a/plugins/sam3_segmentation/pyproject.toml b/plugins/sam3_segmentation/pyproject.toml new file mode 100644 index 0000000..c27c4fe --- /dev/null +++ b/plugins/sam3_segmentation/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "lightly_plugins_sam3_segmentation" +version = "0.1.0" +description = "SAM3 instance segmentation plugin for Lightly Studio" +requires-python = ">=3.10" +dependencies = [ + "lightly_studio>=0.4.13", + "torch", + "transformers>=4.57.2", + "Pillow", + "numpy", + "huggingface-hub>=0.34", + "sqlmodel", + "labelformat", +] + +[project.entry-points."lightly_studio.plugins"] +sam3_segmentation = "lightly_plugins_sam3_segmentation.operator:SAM3SegmentationOperator" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/lightly_plugins_sam3_segmentation"] diff --git a/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/__init__.py b/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/operator.py b/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/operator.py new file mode 100644 index 0000000..7cda75a --- /dev/null +++ b/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/operator.py @@ -0,0 +1,262 @@ +"""SAM3 instance segmentation plugin for Lightly Studio.""" + +from __future__ import annotations + +import dataclasses +import logging +from dataclasses import dataclass +from typing import Any, cast +from uuid import UUID + +import PIL.Image +import torch +from sqlmodel import Session +from transformers import Sam3Model, Sam3Processor + +from lightly_studio.models.annotation.annotation_base import ( + AnnotationCreate, + AnnotationType, +) +from lightly_studio.models.annotation_label import AnnotationLabelCreate +from lightly_studio.plugins.base_operator import BaseOperator, OperatorResult +from lightly_studio.plugins.operator_context import ExecutionContext, OperatorScope +from lightly_studio.resolvers.image_filter import ImageFilter +from lightly_studio.resolvers.sample_resolver.sample_filter import SampleFilter +from lightly_studio.plugins.parameter import ( + BaseParameter, + FloatParameter, + StringParameter, +) +from lightly_studio.resolvers import ( + annotation_label_resolver, + annotation_resolver, + collection_resolver, + image_resolver, +) + +from lightly_plugins_sam3_segmentation.utils import prepare_segmentation_entries + +logger = logging.getLogger(__name__) + +_DEFAULT_MODEL_ID = "facebook/sam3" +_SEGMENTATION_ANNOTATION_TYPE: AnnotationType = cast( + AnnotationType, + getattr(AnnotationType, "SEGMENTATION_MASK", None) + or getattr(AnnotationType, "INSTANCE_SEGMENTATION"), +) + + +def _get_or_create_label(session: Session, dataset_id: UUID, label_name: str) -> UUID: + label = annotation_label_resolver.get_by_label_name( + session=session, dataset_id=dataset_id, label_name=label_name + ) + if label is None: + label = annotation_label_resolver.create( + session=session, + label=AnnotationLabelCreate( + dataset_id=dataset_id, annotation_label_name=label_name + ), + ) + return label.annotation_label_id + + +@dataclass +class SAM3SegmentationOperator(BaseOperator): + """Instance segmentation using SAM3 driven by a text prompt.""" + + name: str = "SAM3 Segmentation" + description: str = ( + "Automatic instance segmentation using SAM3 (facebook/sam3). " + "Requires HuggingFace access — authenticate with `hf auth login` first." + ) + _model: Any = dataclasses.field(default=None, init=False, repr=False) + _processor: Any = dataclasses.field(default=None, init=False, repr=False) + _model_device: str = dataclasses.field(default="", init=False, repr=False) + _loaded_model_id: str = dataclasses.field(default="", init=False, repr=False) + + @property + def parameters(self) -> list[BaseParameter]: + return [ + StringParameter( + name="model_id", + required=True, + default=_DEFAULT_MODEL_ID, + description="HuggingFace model ID (e.g. 'facebook/sam3' or 'facebook/sam3.1')", + ), + StringParameter( + name="prompt", + required=True, + default="person", + description="Text prompt describing what to segment (e.g. 'person', 'car')", + ), + FloatParameter( + name="confidence_threshold", + required=False, + default=0.5, + description="Minimum confidence score for keeping a prediction", + ), + StringParameter( + name="collection_name", + required=True, + default="SAM3_auto_label", + description="The target annotation collection name.", + ), + ] + + @property + def supported_scopes(self) -> list[OperatorScope]: + return [OperatorScope.IMAGE] + + def _load_model(self, model_id: str, device: str) -> None: + if ( + self._model is not None + and self._model_device == device + and self._loaded_model_id == model_id + ): + return + + logger.info("Loading SAM3 model (%s) on device: %s", model_id, device) + # `facebook/sam3` may resolve to the video processor through AutoProcessor in + # recent `transformers` builds. Load the image SAM3 classes explicitly so + # text-prompted image segmentation stays on the correct code path. + self._model = Sam3Model.from_pretrained(model_id).to(device).eval() # type: ignore[arg-type] + self._processor = Sam3Processor.from_pretrained(model_id) + self._model_device = device + self._loaded_model_id = model_id + + def _build_runtime_error_result(self, exc: Exception) -> OperatorResult: + logger.exception("SAM3 segmentation failed: %s", exc) + return OperatorResult( + success=False, + message=( + "SAM3 segmentation failed. Verify HuggingFace access for the selected " + "model, run `hf auth login`, and check the logs for details." + ), + ) + + def execute( + self, + *, + session: Session, + context: ExecutionContext, + parameters: dict[str, Any], + ) -> OperatorResult: + model_id: str = parameters.get("model_id", _DEFAULT_MODEL_ID) + prompt_value = parameters.get("prompt") + if prompt_value is None: + return OperatorResult( + success=False, + message="Please provide a prompt.", + ) + prompt: str = prompt_value + confidence_threshold: float = parameters.get("confidence_threshold", 0.5) + device = "cuda" if torch.cuda.is_available() else "cpu" + collection_name_value = parameters.get("collection_name") + if collection_name_value is None: + return OperatorResult( + success=False, + message="Please provide a collection name.", + ) + collection_name: str = collection_name_value + + collection = collection_resolver.get_by_id( + session=session, collection_id=context.collection_id + ) + if collection is None: + return OperatorResult(success=False, message="Collection not found.") + + context_filter: ImageFilter | None = None + if isinstance(context.context_filter, SampleFilter): + context_filter = ImageFilter(sample_filter=context.context_filter) + elif isinstance(context.context_filter, ImageFilter): + context_filter = context.context_filter + + result = image_resolver.get_all_by_collection_id( + session=session, collection_id=context.collection_id, filters=context_filter + ) + + samples = list(result.samples) + if not samples: + return OperatorResult( + success=True, + message="No samples found for current view.", + ) + + try: + self._load_model(model_id, device) + except Exception as exc: + return self._build_runtime_error_result(exc) + + raw_detections: list[tuple[Any, Any]] = [] # (sample, entry) + for sample in samples: + try: + with PIL.Image.open(sample.file_path_abs) as opened_image: + image = opened_image.convert("RGB") + except Exception: + logger.warning( + "Could not open image: %s — skipping.", sample.file_path_abs + ) + continue + + try: + inputs = self._processor(images=image, text=prompt, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = self._model(**inputs) + + post_results = self._processor.post_process_instance_segmentation( + outputs, + threshold=confidence_threshold, + target_sizes=[(sample.height, sample.width)], + ) + detections = post_results[0] + except Exception as exc: + return self._build_runtime_error_result(exc) + + entries = prepare_segmentation_entries( + boxes=detections["boxes"], + masks=detections["masks"], + scores=detections["scores"], + image_size=(sample.width, sample.height), + ) + for entry in entries: + raw_detections.append((sample, entry)) + + if not raw_detections: + return OperatorResult( + success=True, + message="Segmentation complete. No annotations created.", + ) + + label_id = _get_or_create_label( + session=session, dataset_id=collection.dataset_id, label_name=prompt + ) + + annotation_creates: list[AnnotationCreate] = [] + for sample, entry in raw_detections: + x, y, w, h = entry["box"] + annotation_creates.append( + AnnotationCreate( + annotation_label_id=label_id, + annotation_type=_SEGMENTATION_ANNOTATION_TYPE, + parent_sample_id=sample.sample_id, + confidence=entry["score"], + x=x, + y=y, + width=w, + height=h, + segmentation_mask=entry["rle"], + ) + ) + + annotation_resolver.create_many( + session=session, + parent_collection_id=context.collection_id, + annotations=annotation_creates, + collection_name=collection_name, + ) + return OperatorResult( + success=True, + message=f"Segmentation complete. Created {len(annotation_creates)} annotations.", + ) diff --git a/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/utils.py b/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/utils.py new file mode 100644 index 0000000..87b630f --- /dev/null +++ b/plugins/sam3_segmentation/src/lightly_plugins_sam3_segmentation/utils.py @@ -0,0 +1,60 @@ +"""Utilities for post-processing SAM3 model outputs.""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation +from labelformat.model.bounding_box import BoundingBox + + +def _clamp_xyxy_to_xywh( + box: Any, + width: int, + height: int, +) -> tuple[int, int, int, int]: + """Convert an xyxy box to (x, y, w, h), clamped to image bounds.""" + x1 = max(0, min(int(box[0]), width - 1)) + y1 = max(0, min(int(box[1]), height - 1)) + x2 = max(x1 + 1, min(int(box[2]), width)) + y2 = max(y1 + 1, min(int(box[3]), height)) + return x1, y1, x2 - x1, y2 - y1 + + +def prepare_segmentation_entries( + boxes: Any, + masks: Any, + scores: Any, + image_size: tuple[int, int], +) -> list[dict[str, Any]]: + """Convert SAM3 post-processed outputs to annotation-ready entries. + + Args: + boxes: Tensor (N, 4) absolute-pixel xyxy coordinates. + masks: Tensor (N, H, W) boolean masks. + scores: Tensor (N,) confidence scores. + image_size: (width, height) of the source image. + + Returns: + List of dicts with keys 'box' (x, y, w, h), 'score' (float), + and 'rle' (list[int] row-wise run-length encoding). + """ + img_w, img_h = image_size + entries = [] + for box, mask, score in zip(boxes, masks, scores): + x, y, w, h = _clamp_xyxy_to_xywh(box, img_w, img_h) + binary_mask: NDArray[np.int_] = mask.cpu().numpy().astype(np.int_) + bounding_box = BoundingBox( + xmin=float(x), + ymin=float(y), + xmax=float(x + w), + ymax=float(y + h), + ) + seg = BinaryMaskSegmentation.from_binary_mask(binary_mask, bounding_box) + entries.append( + {"box": (x, y, w, h), "score": float(score), "rle": seg.get_rle()} + ) + return entries