diff --git a/scripts/setup/cambrian_s_install.sh b/scripts/setup/cambrian_s_install.sh new file mode 100644 index 00000000..dc912bab --- /dev/null +++ b/scripts/setup/cambrian_s_install.sh @@ -0,0 +1,19 @@ +#!/bin/bash +# scripts/setup/cambrian_s_install.sh +# Description: Setup environment for Cambrian-S inference in OpenWorldLib +# Usage: bash scripts/setup/cambrian_s_install.sh + +echo "=== [1/4] Installing the base environment ===" +pip install torch==2.6.9 torchvision torchaudio +pip install git+https://github.com/openai/CLIP.git + +echo "=== [2/4] Installing the OpenWorldLib requirements (transformers_low extra) ===" +pip install -e ".[transformers_low]" + +echo "=== [3/4] Installing Cambrian-S runtime dependencies ===" +pip install sentencepiece decord + +echo "=== [4/4] Installing the flash attention ===" +pip install "flash-attn==2.5.9.post1" --no-build-isolation + +echo "=== Setup completed! ===" diff --git a/src/openworldlib/operators/cambrian_s_operator.py b/src/openworldlib/operators/cambrian_s_operator.py new file mode 100644 index 00000000..ebff8150 --- /dev/null +++ b/src/openworldlib/operators/cambrian_s_operator.py @@ -0,0 +1,112 @@ +from pathlib import Path +from typing import Any + +import torch + +from .base_operator import BaseOperator +from ..reasoning.spatial_reasoning.cambrian_s.constants import DEFAULT_SYSTEM_PROMPT +from ..reasoning.spatial_reasoning.cambrian_s.conversation import ( + build_qwen_chat_prompt, + extract_media_inputs, +) +from ..reasoning.spatial_reasoning.cambrian_s.mm_utils import ( + preprocess_single_image, + preprocess_video_frames, + tokenizer_image_token, +) + + +class CambrianSOperator(BaseOperator): + """ + Lightweight operator placeholder for Cambrian-S. + It tracks interactions and converts OpenWorldLib chat messages into + Cambrian-S prompt tokens plus image/video tensors. + """ + + def __init__(self, operation_types=None, interaction_template=None): + super().__init__(operation_types=operation_types or ["reasoning"]) + self.interaction_template = interaction_template or [] + self.interaction_template_init() + + @classmethod + def from_pretrained(cls, *args, **kwargs) -> "CambrianSOperator": + return cls() + + def check_interaction(self, interaction): + return True + + def get_interaction(self, interaction): + if self.check_interaction(interaction): + self.current_interaction.append(interaction) + + def process_interaction(self, *args, **kwargs): + return self.current_interaction + + def process_perception( + self, + messages: list[dict[str, Any]], + tokenizer, + image_processors: list[Any], + model_config: Any = None, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + video_max_frames: int | None = None, + ) -> dict[str, Any]: + prompt = build_qwen_chat_prompt(messages, system_prompt=system_prompt) + media_inputs = extract_media_inputs(messages) + + if media_inputs and not image_processors: + raise RuntimeError("Cambrian-S received visual inputs, but no image processor is available.") + + image_tensors = [] + image_sizes = [] + if media_inputs: + processor = image_processors[0] + image_count = sum(1 for media in media_inputs if media.get("type") == "image") + image_aspect_ratio = getattr(model_config, "image_aspect_ratio", "pad") + anyres_max_subimages = int(getattr(model_config, "anyres_max_subimages", 1)) + for media in media_inputs: + media_type = media.get("type") + if media_type == "image": + use_anyres = image_count == 1 and image_aspect_ratio == "anyres" + pixel_values, original_size = preprocess_single_image( + media["image"], + processor, + image_aspect_ratio="anyres" if use_anyres else "pad", + anyres_max_subimages=anyres_max_subimages, + ) + elif media_type == "video": + video_input = media["video"] + num_threads = -1 + resolved_video_max_frames = ( + video_max_frames + if video_max_frames is not None + else int(getattr(model_config, "video_max_frames", 32)) + ) + if isinstance(video_input, (str, Path)): + video_name = str(video_input) + if "Ego4D" in video_name or "video_mmmu" in video_name: + num_threads = 1 + pixel_values, original_size = preprocess_video_frames( + video_input, + processor, + max_frames=resolved_video_max_frames, + model_config=model_config, + num_threads=num_threads, + ) + else: + continue + image_tensors.append(pixel_values) + image_sizes.append(original_size) + + input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt").unsqueeze(0) + attention_mask = torch.ones_like(input_ids, dtype=torch.bool) + return { + "prompt": prompt, + "input_ids": input_ids, + "attention_mask": attention_mask, + "images": image_tensors, + "image_sizes": image_sizes, + } + + def delete_last_interaction(self): + super().delete_last_interaction() diff --git a/src/openworldlib/pipelines/cambrian_s/pipeline_cambrian_s.py b/src/openworldlib/pipelines/cambrian_s/pipeline_cambrian_s.py new file mode 100644 index 00000000..20736a9a --- /dev/null +++ b/src/openworldlib/pipelines/cambrian_s/pipeline_cambrian_s.py @@ -0,0 +1,107 @@ +from typing import List, Optional, Sequence, Union + +from PIL import Image as PILImage + +from ...operators.cambrian_s_operator import CambrianSOperator +from ...reasoning.spatial_reasoning.cambrian_s.cambrian_s_reasoning import CambrianSReasoning + + +class CambrianSPipeline: + """ + Pipeline that builds Cambrian-S multimodal inputs and runs Cambrian-S reasoning. + """ + + def __init__(self, reasoning: CambrianSReasoning, operator: CambrianSOperator): + self.reasoning = reasoning + self.operator = operator + + @classmethod + def from_pretrained( + cls, + model_path: str = "nyu-visionx/Cambrian-S-7B", + device: Optional[Union[str, "torch.device"]] = None, + weight_dtype: "torch.dtype" = None, + **kwargs, + ) -> "CambrianSPipeline": + reasoning = CambrianSReasoning.from_pretrained( + model_path=model_path, + device=device, + weight_dtype=weight_dtype, + **kwargs, + ) + operator = CambrianSOperator.from_pretrained() + return cls(reasoning=reasoning, operator=operator) + + def _build_messages( + self, + images: Optional[Union[str, PILImage.Image, Sequence[Union[str, PILImage.Image]]]], + videos: Optional[ + Union[ + str, + list[PILImage.Image], + Sequence[Union[str, list[PILImage.Image]]], + ] + ], + prompt: str, + ): + if images is None: + images = [] + if videos is None: + videos = [] + + if isinstance(images, (str, PILImage.Image)): + images = [images] + if isinstance(videos, str): + videos = [videos] + elif isinstance(videos, list) and videos and isinstance(videos[0], PILImage.Image): + videos = [videos] + + content = [{"type": "image", "image": image} for image in images] + content += [{"type": "video", "video": video} for video in videos] + content.append({"type": "text", "text": prompt}) + return [{"role": "user", "content": content}] + + def __call__( + self, + prompt: str, + images: Optional[Union[str, PILImage.Image, Sequence[Union[str, PILImage.Image]]]] = None, + videos: Optional[ + Union[ + str, + list[PILImage.Image], + Sequence[Union[str, list[PILImage.Image]]], + ] + ] = None, + max_new_tokens: int = 2048, + messages: Optional[list] = None, + generation_kwargs: Optional[dict] = None, + ) -> List[str]: + self.operator.get_interaction(prompt) + self.operator.process_interaction() + + if messages is None: + batched_messages = [self._build_messages(images=images, videos=videos, prompt=prompt)] + else: + if not messages: + raise ValueError("messages must be non-empty.") + batched_messages = [messages] if isinstance(messages[0], dict) else messages + + outputs: List[str] = [] + for sample_messages in batched_messages: + model_config = getattr(getattr(self.reasoning, "model", None), "config", None) + model_inputs = self.operator.process_perception( + sample_messages, + tokenizer=self.reasoning.tokenizer, + image_processors=self.reasoning.image_processors, + model_config=model_config, + ) + outputs.extend( + self.reasoning.inference( + inputs=model_inputs, + max_new_tokens=max_new_tokens, + generation_kwargs=generation_kwargs, + ) + ) + + self.operator.delete_last_interaction() + return outputs diff --git a/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/cambrian_s_reasoning.py b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/cambrian_s_reasoning.py new file mode 100644 index 00000000..84045d3b --- /dev/null +++ b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/cambrian_s_reasoning.py @@ -0,0 +1,137 @@ + + +from typing import Any, List, Optional, Union + +import torch +from transformers import AutoTokenizer + +from ...base_reasoning import BaseReasoning +from .constants import ( + DEFAULT_IMAGE_PATCH_TOKEN, + DEFAULT_IM_END_TOKEN, + DEFAULT_IM_START_TOKEN, +) +from .mm_utils import validate_cambrian_s_environment +from .modeling_cambrian_s import CambrianSForCausalLM + + +class CambrianSReasoning(BaseReasoning): + """ + Cambrian-S: https://arxiv.org/abs/2511.04670 + """ + + def __init__( + self, + model: CambrianSForCausalLM, + tokenizer: Any, + image_processors: list[Any], + device: Union[str, "torch.device"] = "cuda", + ): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.image_processors = image_processors + self.processor = image_processors[0] if image_processors else None + self.device = torch.device(device) + + @classmethod + def from_pretrained( + cls, + model_path: str = "nyu-visionx/Cambrian-S-7B", + device: Optional[Union[str, "torch.device"]] = None, + weight_dtype: "torch.dtype" = None, + attn_implementation: Optional[str] = None, + **kwargs, + ) -> "CambrianSReasoning": + validate_cambrian_s_environment(require_video=False) + + config_override_names = ( + "video_max_frames", + "video_fps", + "video_force_sample", + "add_time_instruction", + "miv_token_len", + "si_token_len", + "image_aspect_ratio", + "anyres_max_subimages", + ) + config_overrides = { + attr_name: kwargs.pop(attr_name) + for attr_name in config_override_names + if attr_name in kwargs + } + + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + if weight_dtype is None: + weight_dtype = torch.float16 + + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + model = CambrianSForCausalLM.from_pretrained( + model_path, + torch_dtype=weight_dtype, + low_cpu_mem_usage=kwargs.pop("low_cpu_mem_usage", True), + attn_implementation=attn_implementation, + **kwargs, + ) + mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) + mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) + if mm_use_im_patch_token: + tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], + special_tokens=True, + ) + model.resize_token_embeddings(len(tokenizer)) + for attr_name, attr_value in config_overrides.items(): + setattr(model.config, attr_name, attr_value) + model = model.to(device) + model.load_vision_towers(device=device, dtype=weight_dtype) + image_processors = [tower.image_processor for tower in model.get_model().get_vision_tower_aux_list()] + return cls(model=model, tokenizer=tokenizer, image_processors=image_processors, device=device) + + def api_init(self, api_key, endpoint): + raise NotImplementedError("API init is not supported for Cambrian-S.") + + def _get_default_device(self): + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + @torch.no_grad() + def inference( + self, + inputs: dict[str, Any], + max_new_tokens: int = 2048, + generation_kwargs: Optional[dict] = None, + ) -> List[str]: + generation_config = {"max_new_tokens": max_new_tokens} + if generation_kwargs: + generation_config.update(generation_kwargs) + + input_ids = inputs["input_ids"].to(self.device) + attention_mask = inputs.get("attention_mask") + if attention_mask is not None: + attention_mask = attention_mask.to(self.device) + + images = inputs.get("images") or [] + model_dtype = next(self.model.parameters()).dtype + images = [image.to(device=self.device, dtype=model_dtype) for image in images] + image_sizes = inputs.get("image_sizes") or [] + + generated_ids = self.model.generate( + inputs=input_ids, + attention_mask=attention_mask, + images=images, + image_sizes=image_sizes, + **generation_config, + ) + if hasattr(generated_ids, "sequences"): + generated_ids = generated_ids.sequences + + return self.tokenizer.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False, + ) diff --git a/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/constants.py b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/constants.py new file mode 100644 index 00000000..5ec9cbd2 --- /dev/null +++ b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/constants.py @@ -0,0 +1,14 @@ +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_IMAGE_PATCH_TOKEN = "" +DEFAULT_IM_START_TOKEN = "" +DEFAULT_IM_END_TOKEN = "" +DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant." + +SUPPORTED_TRANSFORMERS_MIN = "4.39.2" +SUPPORTED_TRANSFORMERS_MAX = "4.45.1" +SUPPORTED_TRANSFORMERS_RANGE = ( + f">={SUPPORTED_TRANSFORMERS_MIN},<={SUPPORTED_TRANSFORMERS_MAX}" +) diff --git a/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/conversation.py b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/conversation.py new file mode 100644 index 00000000..06eb8d38 --- /dev/null +++ b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/conversation.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, Iterable + +from .constants import DEFAULT_IMAGE_TOKEN, DEFAULT_SYSTEM_PROMPT + + +def _iter_content_items(content: Any) -> Iterable[dict[str, Any]]: + if isinstance(content, list): + for item in content: + if isinstance(item, dict): + yield item + + +def render_message_content(content: Any) -> str: + if isinstance(content, str): + return content + + image_prefix_parts: list[str] = [] + text_parts: list[str] = [] + for item in _iter_content_items(content): + item_type = item.get("type") + if item_type in {"image", "video"}: + image_prefix_parts.append(DEFAULT_IMAGE_TOKEN) + elif item_type == "text": + text = str(item.get("text", "")) + if text: + text_parts.append(text) + + return "".join(image_prefix_parts) + "".join(text_parts) + + +def extract_media_inputs(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + media_inputs: list[dict[str, Any]] = [] + for message in messages: + for item in _iter_content_items(message.get("content")): + if item.get("type") in {"image", "video"}: + media_inputs.append(item) + return media_inputs + + +def normalize_messages( + messages: list[dict[str, Any]], + system_prompt: str = DEFAULT_SYSTEM_PROMPT, +) -> tuple[str, list[dict[str, str]]]: + normalized: list[dict[str, str]] = [] + resolved_system_prompt = system_prompt + + for message in messages: + role = str(message.get("role", "user")) + + text = render_message_content(message.get("content", "")) + if role == "system": + resolved_system_prompt = text or system_prompt + elif role == "assistant": + normalized.append({"role": "assistant", "text": text}) + else: + normalized.append({"role": "user", "text": text}) + + return resolved_system_prompt, normalized + + +def build_qwen_chat_prompt( + messages: list[dict[str, Any]], + system_prompt: str = DEFAULT_SYSTEM_PROMPT, +) -> str: + resolved_system_prompt, normalized_messages = normalize_messages( + messages, + system_prompt=system_prompt, + ) + prompt = f"<|im_start|>system\n{resolved_system_prompt}<|im_end|>\n" + role_prefixes = { + "user": "<|im_start|>user", + "assistant": "<|im_start|>assistant", + } + for message in normalized_messages: + role_prefix = role_prefixes[message["role"]] + if message["text"]: + prompt += f"{role_prefix}\n{message['text']}<|im_end|>\n" + else: + prompt += f"{role_prefix}\n" + prompt += "<|im_start|>assistant\n" + return prompt diff --git a/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/mm_utils.py b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/mm_utils.py new file mode 100644 index 00000000..a1391c9a --- /dev/null +++ b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/mm_utils.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import importlib.util +import math +import re +from importlib.metadata import PackageNotFoundError, version as package_version +from pathlib import Path +from typing import Any, Callable, Sequence + +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image + +from .constants import ( + IMAGE_TOKEN_INDEX, + SUPPORTED_TRANSFORMERS_MAX, + SUPPORTED_TRANSFORMERS_MIN, + SUPPORTED_TRANSFORMERS_RANGE, +) +from .siglip_vision import SigLipImageProcessor, SigLipVisionTower + + +def parse_version(version: str) -> tuple[int, ...]: + parts = re.findall(r"\d+", version) + return tuple(int(part) for part in parts[:3]) + + +def module_exists(module_name: str) -> bool: + return importlib.util.find_spec(module_name) is not None + + +def get_package_version(package_name: str) -> str | None: + try: + return package_version(package_name) + except PackageNotFoundError: + return None + + +def validate_cambrian_s_environment( + require_video: bool = False, + module_checker: Callable[[str], bool] | None = None, + version_getter: Callable[[str], str | None] | None = None, +) -> None: + module_checker = module_checker or module_exists + version_getter = version_getter or get_package_version + + required_modules = { + "torch": "torch", + "transformers": "transformers", + "tokenizers": "tokenizers", + "sentencepiece": "sentencepiece", + } + if require_video: + required_modules["decord"] = "decord" + + missing = [ + package_name + for package_name, module_name in required_modules.items() + if not module_checker(module_name) + ] + if missing: + raise RuntimeError( + "Cambrian-S requires missing dependencies: " + + ", ".join(missing) + + ". Use the repository's transformers_low environment and install the missing packages." + ) + + transformers_version = version_getter("transformers") + if not transformers_version: + raise RuntimeError( + "Cambrian-S requires the transformers package, but no installed version was detected." + ) + + min_version = parse_version(SUPPORTED_TRANSFORMERS_MIN) + max_version = parse_version(SUPPORTED_TRANSFORMERS_MAX) + current_version = parse_version(transformers_version) + if current_version < min_version or current_version > max_version: + raise RuntimeError( + "Cambrian-S v1 is validated for transformers" + f" {SUPPORTED_TRANSFORMERS_RANGE}; found {transformers_version}. " + "Use the repository's transformers_low environment." + ) + + +SiglipVisionTower = SigLipVisionTower + + +def expand_to_square( + image: Image.Image, + background_color: Sequence[float] = (127, 127, 127), +) -> Image.Image: + width, height = image.size + if width == height: + return image + + fill_color = tuple(int(value) for value in background_color) + side = max(width, height) + canvas = Image.new(image.mode, (side, side), fill_color) + paste_x = (side - width) // 2 + paste_y = (side - height) // 2 + canvas.paste(image, (paste_x, paste_y)) + return canvas + + +def select_best_resolution( + original_size: Sequence[int], + possible_resolutions: Sequence[tuple[int, int]], +) -> tuple[int, int]: + original_width, original_height = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for width, height in possible_resolutions: + scale = min(width / max(original_width, 1), height / max(original_height, 1)) + downscaled_width = int(original_width * scale) + downscaled_height = int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if ( + effective_resolution > max_effective_resolution + or ( + effective_resolution == max_effective_resolution + and wasted_resolution < min_wasted_resolution + ) + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (width, height) + + if best_fit is None: + raise ValueError("Cambrian-S could not determine a valid anyres resolution.") + return best_fit + + +def resize_and_pad_image( + image: Image.Image, + target_resolution: tuple[int, int], + background_color: Sequence[int] = (0, 0, 0), +) -> Image.Image: + original_width, original_height = image.size + target_width, target_height = target_resolution + + scale_w = target_width / max(original_width, 1) + scale_h = target_height / max(original_height, 1) + if scale_w < scale_h: + new_width = target_width + new_height = min(math.ceil(original_height * scale_w), target_height) + else: + new_height = target_height + new_width = min(math.ceil(original_width * scale_h), target_width) + + resized_image = image.resize((new_width, new_height)) + padded_image = Image.new("RGB", (target_width, target_height), tuple(int(value) for value in background_color)) + paste_x = (target_width - new_width) // 2 + paste_y = (target_height - new_height) // 2 + padded_image.paste(resized_image, (paste_x, paste_y)) + return padded_image + + +def divide_to_patches(image: Image.Image, patch_size: int) -> list[Image.Image]: + patches: list[Image.Image] = [] + width, height = image.size + for offset_y in range(0, height, patch_size): + for offset_x in range(0, width, patch_size): + patches.append(image.crop((offset_x, offset_y, offset_x + patch_size, offset_y + patch_size))) + return patches + + +def unpad_image(features: torch.Tensor, original_size: Sequence[int]) -> torch.Tensor: + original_w, original_h = original_size + current_h, current_w = features.shape[1:3] + + original_aspect_ratio = original_w / max(original_h, 1) + current_aspect_ratio = current_w / max(current_h, 1) + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_w / max(original_w, 1) + new_height = max(int(original_h * scale_factor), 1) + padding = max((current_h - new_height) // 2, 0) + if padding > 0: + return features[:, padding : current_h - padding, :, :] + return features + + scale_factor = current_h / max(original_h, 1) + new_width = max(int(original_w * scale_factor), 1) + padding = max((current_w - new_width) // 2, 0) + if padding > 0: + return features[:, :, padding : current_w - padding, :] + return features + + +def _load_pil_image(image: str | Path | Image.Image) -> Image.Image: + if isinstance(image, Image.Image): + return image.convert("RGB") + return Image.open(image).convert("RGB") + + +def _processor_background_color(processor: Any) -> tuple[int, int, int]: + return tuple(int(value * 255) for value in getattr(processor, "image_mean", (0.5, 0.5, 0.5))) + + +def _processor_target_resolution(processor: Any) -> int: + crop_size = getattr(processor, "crop_size", None) + if isinstance(crop_size, dict): + return int(crop_size.get("height") or crop_size.get("width") or 384) + + size = getattr(processor, "size", None) + if isinstance(size, dict): + return int( + size.get("height") + or size.get("width") + or size.get("shortest_edge") + or 384 + ) + if isinstance(size, (tuple, list)): + return int(size[0]) + return 384 + + +def preprocess_single_image( + image: str | Path | Image.Image, + processor: SigLipImageProcessor | Any, + image_aspect_ratio: str = "pad", + anyres_max_subimages: int = 1, +) -> tuple[torch.Tensor, tuple[int, ...]]: + pil_image = _load_pil_image(image) + original_size = pil_image.size + target_resolution = _processor_target_resolution(processor) + background_color = _processor_background_color(processor) + + if image_aspect_ratio == "anyres": + snapshot_image = expand_to_square(pil_image, background_color=background_color).resize( + (target_resolution, target_resolution) + ) + possible_resolutions = [ + (int(grid_width * target_resolution), int(grid_height * target_resolution)) + for grid_width in range(1, anyres_max_subimages + 1) + for grid_height in range(1, anyres_max_subimages + 1) + if (grid_width * grid_height) <= anyres_max_subimages + ] + best_resolution = select_best_resolution(pil_image.size, possible_resolutions) + anyres_image = resize_and_pad_image( + pil_image, + best_resolution, + background_color=background_color, + ) + patches = divide_to_patches(anyres_image, target_resolution) + image_patches = [snapshot_image] + patches + pixel_values = torch.stack( + [ + processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] + for image_patch in image_patches + ], + dim=0, + ) + anyres_grid = ( + best_resolution[1] // target_resolution, + best_resolution[0] // target_resolution, + ) + return pixel_values, (*original_size, *anyres_grid) + + square_image = expand_to_square(pil_image, background_color=background_color).resize( + (target_resolution, target_resolution) + ) + pixel_values = processor.preprocess(square_image, return_tensors="pt")["pixel_values"] + return pixel_values, original_size + + +def _get_model_video_attr(model_config: Any, attr_name: str, default: Any) -> Any: + if model_config is None: + return default + return getattr(model_config, attr_name, default) + + +def process_video_with_decord( + video_file: str | Path, + model_config: Any = None, + max_frames: int | None = None, + num_threads: int = -1, +) -> tuple[np.ndarray, float, str, int]: + validate_cambrian_s_environment(require_video=True) + from decord import VideoReader, cpu + + if num_threads < 1: + reader = VideoReader(str(video_file), ctx=cpu(0)) + else: + reader = VideoReader(str(video_file), ctx=cpu(0), num_threads=num_threads) + + total_frame_num = len(reader) + video_time = total_frame_num / reader.get_avg_fps() + avg_fps = round(reader.get_avg_fps() / _get_model_video_attr(model_config, "video_fps", 1)) + frame_idx = [frame for frame in range(0, total_frame_num, avg_fps)] + frame_time = [frame / avg_fps for frame in frame_idx] + + video_max_frames = ( + max_frames + if max_frames is not None + else _get_model_video_attr(model_config, "video_max_frames", 32) + ) + video_force_sample = _get_model_video_attr(model_config, "video_force_sample", False) + if video_max_frames > 0: + if len(frame_idx) > video_max_frames or video_force_sample: + uniform_sampled_frames = np.linspace( + 0, + total_frame_num - 1, + video_max_frames, + dtype=int, + ) + frame_idx = uniform_sampled_frames.tolist() + frame_time = [frame / reader.get_avg_fps() for frame in frame_idx] + + video = reader.get_batch(frame_idx).asnumpy() + frame_time_str = ",".join([f"{timestamp:.2f}s" for timestamp in frame_time]) + num_frames_to_sample = len(frame_idx) + reader.seek(0) + return video, video_time, frame_time_str, num_frames_to_sample + + +def sample_video_frames( + video_path: str | Path, + max_frames: int = 8, + model_config: Any = None, + num_threads: int = -1, +) -> list[Image.Image]: + if model_config is None: + model_config = type( + "CambrianVideoConfig", + (), + { + "video_fps": 1, + "video_max_frames": max_frames, + "video_force_sample": False, + }, + )() + video, _, _, _ = process_video_with_decord( + video_path, + model_config=model_config, + num_threads=num_threads, + ) + return [Image.fromarray(frame).convert("RGB") for frame in video] + + +def preprocess_video_frames( + frames: Sequence[Image.Image] | str | Path, + processor: SigLipImageProcessor | Any, + max_frames: int | None = None, + model_config: Any = None, + num_threads: int = -1, +) -> tuple[torch.Tensor, tuple[int, int, int]]: + if isinstance(frames, (str, Path)): + raw_video, _, _, _ = process_video_with_decord( + frames, + model_config=model_config, + max_frames=max_frames, + num_threads=num_threads, + ) + frames = [Image.fromarray(frame).convert("RGB") for frame in raw_video] + else: + frames = [_load_pil_image(frame) for frame in frames] + + if not frames: + raise ValueError("Cambrian-S received an empty video input.") + + original_size = frames[0].size + square_frames = [ + expand_to_square(frame, background_color=_processor_background_color(processor)) + for frame in frames + ] + pixel_values = processor.preprocess(square_frames, return_tensors="pt")["pixel_values"] + return pixel_values, (original_size[0], original_size[1], len(frames)) + + +def tokenizer_image_token( + prompt: str, + tokenizer: Any, + image_token_index: int = IMAGE_TOKEN_INDEX, + return_tensors: str | None = None, +) -> list[int] | torch.Tensor: + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + input_ids: list[int] = [] + offset = 0 + if prompt_chunks and prompt_chunks[0] and prompt_chunks[0][0] == getattr(tokenizer, "bos_token_id", None): + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for chunk_index, chunk_ids in enumerate(prompt_chunks): + input_ids.extend(chunk_ids[offset:]) + if chunk_index < len(prompt_chunks) - 1: + input_ids.append(image_token_index) + + if return_tensors is None: + return input_ids + if return_tensors != "pt": + raise ValueError(f"Unsupported tensor type: {return_tensors}") + return torch.tensor(input_ids, dtype=torch.long) + + +def resize_patch_grid(features: torch.Tensor, target_side: int) -> torch.Tensor: + token_count = features.shape[1] + current_side = int(token_count ** 0.5) + if current_side * current_side != token_count and (token_count - 1) > 0: + maybe_square = int((token_count - 1) ** 0.5) + if maybe_square * maybe_square == (token_count - 1): + features = features[:, 1:, :] + token_count = features.shape[1] + current_side = int(token_count ** 0.5) + + if current_side == target_side: + return features + + batch, _, channels = features.shape + grid = features.view(batch, current_side, current_side, channels).permute(0, 3, 1, 2) + grid = F.interpolate(grid.float(), size=(target_side, target_side), mode="bilinear", align_corners=False) + return grid.permute(0, 2, 3, 1).reshape(batch, target_side * target_side, channels).to(features.dtype) + + +__all__ = [ + "SigLipImageProcessor", + "SigLipVisionTower", + "SiglipVisionTower", + "divide_to_patches", + "expand_to_square", + "process_video_with_decord", + "preprocess_single_image", + "preprocess_video_frames", + "resize_and_pad_image", + "resize_patch_grid", + "sample_video_frames", + "select_best_resolution", + "tokenizer_image_token", + "unpad_image", + "validate_cambrian_s_environment", +] diff --git a/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/modeling_cambrian_s.py b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/modeling_cambrian_s.py new file mode 100644 index 00000000..35ed72f0 --- /dev/null +++ b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/modeling_cambrian_s.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import math +import re +from typing import Any + +import torch +from torch import nn +from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, Qwen2Model + +from .constants import IMAGE_TOKEN_INDEX +from .mm_utils import ( + SigLipVisionTower, + resize_patch_grid, + unpad_image, + validate_cambrian_s_environment, +) + + +def _build_vision_projector(config): + projector_type = getattr(config, "mm_projector_type", "linear") + if projector_type == "linear": + return nn.Linear(config.mm_hidden_size, config.hidden_size) + + mlp_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) + if mlp_match: + depth = int(mlp_match.group(1)) + modules: list[nn.Module] = [nn.Linear(config.mm_hidden_size, config.hidden_size)] + for _ in range(1, depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.hidden_size)) + return nn.Sequential(*modules) + + raise ValueError(f"Unsupported Cambrian-S projector type: {projector_type}") + + +class CambrianSConfig(Qwen2Config): + model_type = "cambrian_qwen" + + +class CambrianSModel(Qwen2Model): + config_class = CambrianSConfig + + def __init__(self, config: CambrianSConfig): + super().__init__(config) + + self.mm_projector = _build_vision_projector(config) + embed_std = 1 / math.sqrt(config.hidden_size) + self.image_newline = nn.Parameter( + torch.randn(config.hidden_size, dtype=self.embed_tokens.weight.dtype) * embed_std + ) + vision_tower_names = list( + getattr(config, "mm_vision_tower_aux_list", None) + or getattr(config, "vision_tower_aux_list", None) + or [] + ) + self.vision_tower_aux_list = [ + SigLipVisionTower(vision_tower_name, delay_load=True) + for vision_tower_name in vision_tower_names + ] + + def get_vision_tower_aux_list(self) -> list[SigLipVisionTower]: + return self.vision_tower_aux_list + + def load_vision_towers( + self, + device: str | torch.device, + dtype: torch.dtype | None = None, + ) -> None: + for tower in self.vision_tower_aux_list: + tower.load_model() + tower.to(device=device, dtype=dtype or self.embed_tokens.weight.dtype) + + def _project_features(self, features: torch.Tensor) -> torch.Tensor: + projector_dtype = next(self.mm_projector.parameters()).dtype + projected = self.mm_projector(features.to(projector_dtype)) + return projected.to(self.embed_tokens.weight.dtype) + + def _use_image_newline_token(self) -> bool: + return not hasattr(self.config, "mm_use_im_newline_token") or bool( + self.config.mm_use_im_newline_token + ) + + def _append_newline_token(self, features: torch.Tensor) -> torch.Tensor: + if not self._use_image_newline_token(): + return features + hidden_size = features.shape[-1] + newline = self.image_newline.to(features.dtype)[None, None, None, :].expand( + features.shape[0], + features.shape[1], + 1, + hidden_size, + ) + return torch.cat([features, newline], dim=2) + + def _format_image_features( + self, + image_features: torch.Tensor, + original_size: tuple[int, int], + ) -> torch.Tensor: + target_side = int(getattr(self.config, "si_token_len", image_features.shape[1]) ** 0.5) + image_features = resize_patch_grid(image_features, target_side) + batch_size, _, hidden_size = image_features.shape + grid = image_features.view(batch_size, target_side, target_side, hidden_size) + grid = unpad_image(grid, original_size) + return self._append_newline_token(grid).flatten(1, 2) + + def _format_anyres_image_features( + self, + image_features: torch.Tensor, + image_size: tuple[int, int, int, int], + ) -> torch.Tensor: + target_side = int(getattr(self.config, "si_token_len", image_features.shape[1]) ** 0.5) + image_features = resize_patch_grid(image_features, target_side) + _, _, hidden_size = image_features.shape + grid = image_features.view(image_features.shape[0], target_side, target_side, hidden_size) + + snapshot_features = grid[0].unsqueeze(0) + patch_rows, patch_cols = image_size[2:] + anyres_features = grid[1:].unflatten(0, (patch_rows, patch_cols)) + anyres_features = anyres_features.permute(0, 2, 1, 3, 4).flatten(2, 3).flatten(0, 1).unsqueeze(0) + + original_size = image_size[:2] + snapshot_features = unpad_image(snapshot_features, original_size) + anyres_features = unpad_image(anyres_features, original_size) + + snapshot_features = self._append_newline_token(snapshot_features).flatten(1, 2) + anyres_features = self._append_newline_token(anyres_features).flatten(1, 2) + return torch.cat([snapshot_features, anyres_features], dim=1) + + def _format_video_features( + self, + video_features: torch.Tensor, + video_size: tuple[int, int, int], + ) -> torch.Tensor: + target_side = int(max(getattr(self.config, "miv_token_len", video_features.shape[1]), 1) ** 0.5) + video_features = resize_patch_grid(video_features, target_side) + _, _, hidden_size = video_features.shape + grid = video_features.view(video_features.shape[0], target_side, target_side, hidden_size) + grid = unpad_image(grid, video_size[:2]) + return self._append_newline_token(grid).flatten(1, 2).flatten(0, 1).unsqueeze(0) + + def prepare_media_embeddings( + self, + images: list[torch.Tensor], + image_sizes: list[tuple[int, ...]], + ) -> list[torch.Tensor]: + if not images: + return [] + if not self.vision_tower_aux_list: + raise RuntimeError("Cambrian-S model does not define a vision tower.") + if len(images) != len(image_sizes): + raise ValueError( + "Cambrian-S expected image tensors and image_sizes to align, " + f"but received {len(images)} tensors and {len(image_sizes)} size entries." + ) + + tower = self.vision_tower_aux_list[0] + media_counts = [media.shape[0] for media in images] + stacked_images = torch.cat(images, dim=0).to(device=tower.device, dtype=tower.dtype) + vision_features = tower(stacked_images) + projected_features = self._project_features(vision_features) + split_features = list(torch.split(projected_features, media_counts, dim=0)) + + media_embeddings: list[torch.Tensor] = [] + for features, media_size in zip(split_features, image_sizes): + if len(media_size) == 2: + media_embeddings.append(self._format_image_features(features, media_size)) + elif len(media_size) == 3: + media_embeddings.append(self._format_video_features(features, media_size)) + elif len(media_size) == 4: + media_embeddings.append(self._format_anyres_image_features(features, media_size)) + else: + raise ValueError(f"Unsupported Cambrian-S media size: {media_size}") + return media_embeddings + + +class CambrianSForCausalLM(Qwen2ForCausalLM): + config_class = CambrianSConfig + + def __init__(self, config: CambrianSConfig): + Qwen2ForCausalLM.__init__(self, config) + config.model_type = "cambrian_qwen" + config.rope_scaling = None + + self.model = CambrianSModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.post_init() + + def get_model(self) -> CambrianSModel: + return self.model + + def load_vision_towers( + self, + device: str | torch.device, + dtype: torch.dtype | None = None, + ) -> None: + self.get_model().load_vision_towers(device=device, dtype=dtype) + + def prepare_inputs_labels_for_multimodal_for_generation( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor | None, + attention_mask: torch.Tensor | None, + past_key_values: Any, + labels: torch.Tensor | None, + images: list[torch.Tensor] | None, + image_sizes: list[tuple[int, ...]] | None = None, + ): + if not images or input_ids.shape[1] == 1: + return input_ids, position_ids, attention_mask, past_key_values, None, labels + + if input_ids.shape[0] != 1: + raise ValueError("Cambrian-S v1 only supports single-sample generation.") + + media_embeddings = self.get_model().prepare_media_embeddings(images, image_sizes or []) + + input_ids_for_embed = torch.where(input_ids == IMAGE_TOKEN_INDEX, 0, input_ids) + input_embeds = self.get_model().embed_tokens(input_ids_for_embed) + image_positions = input_ids[0].eq(IMAGE_TOKEN_INDEX).nonzero(as_tuple=False).flatten().tolist() + + if len(image_positions) != len(media_embeddings): + raise ValueError( + "Cambrian-S found a mismatch between visual placeholders and prepared media embeddings: " + f"{len(image_positions)} placeholders vs {len(media_embeddings)} media items." + ) + + pieces: list[torch.Tensor] = [] + start = 0 + for image_position, media_embedding in zip(image_positions, media_embeddings): + pieces.append(input_embeds[:, start:image_position]) + pieces.append(media_embedding.to(dtype=input_embeds.dtype, device=input_embeds.device)) + start = image_position + 1 + pieces.append(input_embeds[:, start:]) + + new_input_embeds = torch.cat(pieces, dim=1) + attention_mask = torch.ones( + new_input_embeds.shape[:2], + device=new_input_embeds.device, + dtype=torch.bool, + ) + return None, None, attention_mask, past_key_values, new_input_embeds, labels + + @torch.no_grad() + def generate( + self, + inputs: torch.Tensor | None = None, + images: list[torch.Tensor] | None = None, + image_sizes: list[tuple[int, ...]] | None = None, + **kwargs, + ): + position_ids = kwargs.pop("position_ids", None) + attention_mask = kwargs.pop("attention_mask", None) + if "inputs_embeds" in kwargs: + raise NotImplementedError("Cambrian-S does not accept external inputs_embeds.") + + if images: + ( + inputs, + position_ids, + attention_mask, + _, + inputs_embeds, + _, + ) = self.prepare_inputs_labels_for_multimodal_for_generation( + inputs, + position_ids, + attention_mask, + None, + None, + images, + image_sizes=image_sizes, + ) + else: + if inputs is None: + raise ValueError("Cambrian-S generate requires token inputs when no images are provided.") + inputs_embeds = self.get_model().embed_tokens(inputs) + + return super().generate( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.Tensor, + past_key_values: Any = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ): + images = kwargs.pop("images", None) + image_sizes = kwargs.pop("image_sizes", None) + prepared = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + **kwargs, + ) + if images is not None: + prepared["images"] = images + if image_sizes is not None: + prepared["image_sizes"] = image_sizes + return prepared + + +CambrianQwenConfig = CambrianSConfig +CambrianQwenForCausalLM = CambrianSForCausalLM + + +def register_cambrian_s_autoclasses() -> None: + try: + AutoConfig.register("cambrian_qwen", CambrianSConfig) + except ValueError: + pass + + try: + AutoModelForCausalLM.register(CambrianSConfig, CambrianSForCausalLM) + except ValueError: + pass + + +register_cambrian_s_autoclasses() + +__all__ = [ + "CambrianSConfig", + "CambrianSForCausalLM", + "CambrianSModel", + "CambrianQwenConfig", + "CambrianQwenForCausalLM", + "register_cambrian_s_autoclasses", + "validate_cambrian_s_environment", +] diff --git a/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/siglip_vision.py b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/siglip_vision.py new file mode 100644 index 00000000..a419dc94 --- /dev/null +++ b/src/openworldlib/reasoning/spatial_reasoning/cambrian_s/siglip_vision.py @@ -0,0 +1,477 @@ +from __future__ import annotations + +from dataclasses import dataclass +from functools import partial, reduce +from typing import Dict, Optional, Tuple, Union + +import torch +from PIL import Image +from torch import nn +from transformers import PretrainedConfig +from transformers.activations import ACT2FN +from transformers.image_processing_utils import BatchFeature, get_size_dict +from transformers.image_transforms import ( + convert_to_rgb, + normalize, + rescale, + resize, + to_channel_dimension_format, +) +from transformers.image_utils import ( + ChannelDimension, + PILImageResampling, + to_numpy_array, +) +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput + + +class SigLipImageProcessor: + def __init__( + self, + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + image_std: tuple[float, float, float] = (0.5, 0.5, 0.5), + size: tuple[int, int] = (384, 384), + crop_size: Dict[str, int] | None = None, + resample=PILImageResampling.BICUBIC, + rescale_factor: float = 1 / 255, + data_format: ChannelDimension = ChannelDimension.FIRST, + ): + crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} + self.crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") + self.image_mean = image_mean + self.image_std = image_std + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.data_format = data_format + + def preprocess(self, images, return_tensors: str): + if isinstance(images, Image.Image): + images = [images] + else: + images = [to_numpy_array(image) for image in images] + + transforms = [ + convert_to_rgb, + to_numpy_array, + partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), + partial(rescale, scale=self.rescale_factor, data_format=self.data_format), + partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), + partial( + to_channel_dimension_format, + channel_dim=self.data_format, + input_channel_dim=self.data_format, + ), + ] + images = reduce(lambda value, fn: [*map(fn, value)], transforms, images) + return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) + + +class SigLipVisionConfig(PretrainedConfig): + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size: int = 1152, + image_mean: tuple[float, float, float] = (0.5, 0.5, 0.5), + intermediate_size: int = 4304, + num_hidden_layers: int = 27, + num_attention_heads: int = 16, + num_channels: int = 3, + image_size: int = 384, + patch_size: int = 14, + hidden_act: str = "gelu_pytorch_tanh", + layer_norm_eps: float = 1e-6, + attention_dropout: float = 0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.image_mean = image_mean + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.image_size = image_size + self.patch_size = patch_size + self.hidden_act = hidden_act + self.layer_norm_eps = layer_norm_eps + self.attention_dropout = attention_dropout + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs) -> "SigLipVisionConfig": + cls._set_token_in_kwargs(kwargs) + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + return cls.from_dict(config_dict, **kwargs) + + +@dataclass +class SigLipVisionModelOutput(ModelOutput): + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class SigLipVisionEmbeddings(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding=0, + ) + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_patches).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + return embeddings + self.position_embedding(self.position_ids) + + +class SigLipAttention(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got embed_dim={self.embed_dim}, num_heads={self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + batch_size, query_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2) + + attention_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + if attention_mask is not None: + attention_weights = attention_weights + attention_mask + + attention_weights = nn.functional.softmax(attention_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attention_weights = nn.functional.dropout(attention_weights, p=self.dropout, training=self.training) + attention_output = torch.matmul(attention_weights, value_states) + + attention_output = attention_output.transpose(1, 2).contiguous() + attention_output = attention_output.reshape(batch_size, query_len, self.embed_dim) + attention_output = self.out_proj(attention_output) + if not output_attentions: + attention_weights = None + return attention_output, attention_weights + + +class SigLipMLP(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + return self.fc2(hidden_states) + + +class SigLipEncoderLayer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + output_attentions: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, attn_weights + + +class SigLipEncoder(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + inputs_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + hidden_states = inputs_embeds + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + hidden_states, attn_weights = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + if output_attentions: + all_attentions = all_attentions + (attn_weights,) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(value for value in [hidden_states, encoder_states, all_attentions] if value is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SigLipMultiheadAttentionPoolingHead(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = nn.MultiheadAttention( + config.hidden_size, + config.num_attention_heads, + batch_first=True, + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + return hidden_state[:, 0] + + +class SigLipVisionTransformer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embeddings = SigLipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.head = SigLipMultiheadAttentionPoolingHead(config) + + def forward( + self, + pixel_values: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values) + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = self.post_layernorm(encoder_outputs[0]) + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SigLipPreTrainedModel(PreTrainedModel): + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + return None + + +class SigLipVisionModel(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["SigLipEncoderLayer"] + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + self.vision_model = SigLipVisionTransformer(config) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SigLipVisionTower(nn.Module): + def __init__(self, vision_tower_name: str, delay_load: bool = True): + super().__init__() + self.vision_tower_name = vision_tower_name.split("-interp")[0] + self.image_processor = SigLipImageProcessor() + self.config = SigLipVisionConfig() + self.is_loaded = False + self.vision_tower: SigLipVisionModel | None = None + if not delay_load: + self.load_model() + + def load_model(self, device_map=None) -> None: + if self.is_loaded: + return + + self.vision_tower = SigLipVisionModel.from_pretrained( + self.vision_tower_name, + device_map=device_map, + ) + del self.vision_tower.vision_model.encoder.layers[-1:] + self.vision_tower.vision_model.head = nn.Identity() + self.vision_tower.requires_grad_(False) + self.is_loaded = True + + def to(self, device, dtype: torch.dtype | None = None): + if not self.is_loaded: + self.load_model() + if self.vision_tower is None: + raise RuntimeError("SigLip vision tower is not loaded.") + + kwargs = {"device": device} + if dtype is not None: + kwargs["dtype"] = dtype + self.vision_tower = self.vision_tower.to(**kwargs) + return self + + def forward(self, images: torch.Tensor) -> torch.Tensor: + if not self.is_loaded or self.vision_tower is None: + raise RuntimeError("SigLip vision tower is not loaded.") + + outputs = self.vision_tower( + images.to(device=self.device, dtype=self.dtype), + output_hidden_states=True, + ) + image_features = outputs.hidden_states[-1].to(images.dtype) + if image_features.shape[-2] != self.num_patches: + raise RuntimeError( + "Unexpected SigLIP token count: " + f"expected {self.num_patches}, found {image_features.shape[-2]}." + ) + return image_features + + @property + def dtype(self) -> torch.dtype: + if self.vision_tower is None: + return torch.float32 + return next(self.vision_tower.parameters()).dtype + + @property + def device(self) -> torch.device: + if self.vision_tower is None: + return torch.device("cpu") + return next(self.vision_tower.parameters()).device + + @property + def hidden_size(self) -> int: + return self.config.hidden_size + + @property + def num_patches(self) -> int: + return (self.config.image_size // self.config.patch_size) ** 2 + + @property + def num_patches_per_side(self) -> int: + return self.config.image_size // self.config.patch_size + + @property + def image_size(self) -> int: + return self.config.image_size + + +__all__ = [ + "SigLipImageProcessor", + "SigLipVisionConfig", + "SigLipVisionModel", + "SigLipVisionTower", +] diff --git a/test/test_cambrian_s.py b/test/test_cambrian_s.py new file mode 100644 index 00000000..50e7d3cf --- /dev/null +++ b/test/test_cambrian_s.py @@ -0,0 +1,50 @@ +import torch +from PIL import Image + +from openworldlib.pipelines.cambrian_s.pipeline_cambrian_s import CambrianSPipeline + + +MODEL_PATH = "nyu-visionx/Cambrian-S-0.5B" +DEVICE = "cuda" +WEIGHT_DTYPE = torch.float16 + +IMAGE_PATH = "./data/test_case/test_image_case1/ref_image.png" +VIDEO_PATH = "./data/test_case/test_video_case1/talking_man.mp4" + + +def test_cambrian_s_pipeline_pil_image(): + pipe = CambrianSPipeline.from_pretrained( + model_path=MODEL_PATH, + device=DEVICE, + weight_dtype=WEIGHT_DTYPE, + ) + pil_image = Image.open(IMAGE_PATH).convert("RGB") + instruction = "Describe the scene." + output = pipe( + prompt=instruction, + images=pil_image, + max_new_tokens=64, + ) + assert isinstance(output, list) and len(output) == 1 + print("[PIL.Image] output:", output[0]) + + +def test_cambrian_s_pipeline_video(): + pipe = CambrianSPipeline.from_pretrained( + model_path=MODEL_PATH, + device=DEVICE, + weight_dtype=WEIGHT_DTYPE, + ) + instruction = "Summarize the video content." + output = pipe( + prompt=instruction, + videos=VIDEO_PATH, + max_new_tokens=64, + ) + assert isinstance(output, list) and len(output) == 1 + print("[video] output:", output[0]) + + +if __name__ == "__main__": + test_cambrian_s_pipeline_pil_image() + test_cambrian_s_pipeline_video()