This guide explains how to modify an existing image ViT to work with AutoGaze. We use our SigLIP implementation as a running example, comparing it against the original SigLIP from HuggingFace Transformers (v4.51.1).
AutoGaze predicts which patches to attend to (the "gaze") for each video frame. The ViT then only processes the gazed patches of all frames rather than all patches of a single image as usual. This requires two changes to a standard image ViT:
- Patch Embedding: Instead of embedding all patches, only embed the gazed patches selected by AutoGaze for each frame.
- Attention Mask: Since we are retargeting an image ViT into a video ViT, we need to construct an attention mask to control how tokens from different frames interact. Three types are supported: block-causal (default), causal, and bidirectional.
The rest of the ViT (encoder layers, attention, MLP, layer norms, etc.) stays unchanged.
An example of using SigLIP+AutoGaze can be found in QUICK_START.md.
AutoGaze outputs a gazing_info dict that is passed to the ViT. It contains:
| Key | Shape | Description |
|---|---|---|
gazing_pos |
(B, N) |
Indices of the gazed patches across all frames, where N = sum(num_gazing_each_frame). |
num_gazing_each_frame |
(T,) |
Number of gazed patches per frame (including padding). |
if_padded_gazing |
(B, N) |
Boolean mask indicating which positions are padding (not real gazes). |
In a standard ViT, patch embedding converts all image patches into tokens. With AutoGaze, we only keep the patches at positions specified by gazing_info['gazing_pos'].
Add a method to select only the gazed patches from the full sequence of patches (and their corresponding position embeddings). In our SigLIP implementation, this is mask_with_gazing:
def mask_with_gazing(self, sequence, gazing_info):
"""
Select only the gazed patches from the full sequence.
Padded gazing positions are mapped to a dummy token (index 0).
"""
gazing_pos = gazing_info['gazing_pos'].clone()
if_padded_gazing = gazing_info['if_padded_gazing'].clone()
B = sequence.shape[0]
# Map padded gazing positions to a dummy token
gazing_pos[if_padded_gazing] = 0
# Gather only the gazed tokens
sequence_gazed = sequence[torch.arange(B)[:, None], gazing_pos]
return sequence_gazedThen in the embedding forward pass, after computing all patches and position embeddings, apply this selection:
# Compute all patches and position embeddings (across frames and scales)
# patches: (B, T*num_patches, patch_dim)
# pos_embeddings: (B, T*num_patches, embed_dim)
# Select only the gazed patches
patches = self.mask_with_gazing(patches, gazing_info)
pos_embeddings = self.mask_with_gazing(pos_embeddings, gazing_info)
# Then embed and add position embeddings as usual
embeddings = linear(patches) + pos_embeddingsKey point: The input pixel_values shape changes from (B, C, H, W) to (B, T, C, H, W) since AutoGaze operates on video frames. Patches from all frames are flattened into a single sequence before gazing selection.
See
SiglipVisionEmbeddingsinmodeling_siglip.pyfor the complete implementation.
AutoGaze supports multi-scale patches (e.g., 32+64+112+224), where the video is resized to each scale and patchified independently. The patches and position embeddings from all scales are concatenated before gazing selection. This is handled in get_gazed_patches_and_other_embeddings in our SigLIP. If you only use a single scale, this simply reduces to the standard single-resolution patchification.
Since we are repurposing an image ViT to process multiple video frames as a single sequence, we need an attention mask to control how tokens from different frames interact. We support three attention types (configured via attn_type):
attn_type |
Inter-frame attention | Intra-frame attention | Description |
|---|---|---|---|
block_causal |
Causal (past frames only) | Bidirectional | Tokens attend to all tokens in the same frame and all tokens from previous frames. Recommended default. |
causal |
Causal | Causal | Strictly causal — each token attends only to preceding tokens in the flattened sequence. |
bidirectional |
Bidirectional | Bidirectional | Full attention across all tokens (all frames see each other). |
In all cases, padded gazing tokens are masked out so they are not attended to.
Not all attention types work with all backends:
| Backend | Supported attn_type |
|---|---|
flash_attention_2 |
causal, bidirectional |
sdpa, eager, flex_attention |
block_causal |
The reason is that flash_attention_2 natively supports causal masking (via the is_causal flag) and simple padding masks, but does not accept arbitrary 2D attention masks needed for block-causal attention. Conversely, the other backends construct an explicit (B, num_heads, N, N) additive mask, which can express block-causal patterns but would be redundant for the simpler causal/bidirectional cases that flash attention handles more efficiently.
Add a method to construct the appropriate attention mask. In our SigLIP, this is get_causal_mask in SiglipVisionTransformer. Here is a simplified version showing the block-causal case:
def get_causal_mask(self, num_tokens_each_frame, batch_size, num_heads,
token_mask=None, dtype=torch.float32):
T = len(num_tokens_each_frame)
N = num_tokens_each_frame.sum()
# Start with a causal (lower-triangular) mask
mask = torch.tril(torch.ones(batch_size, N, N, dtype=dtype))
# Allow full bidirectional attention within each frame
for t in range(T):
start = num_tokens_each_frame[:t].sum()
end = num_tokens_each_frame[:t+1].sum()
mask[:, start:end, start:end] = 1
# Zero out columns for padded tokens
if token_mask is not None:
mask = mask * (~token_mask.unsqueeze(1)).to(dtype)
# Convert to additive mask (0 for attend, -inf for ignore)
mask = torch.where(mask == 1, 0, -torch.inf).to(dtype)
mask = mask.unsqueeze(1).expand(-1, num_heads, -1, -1)
return maskThen in the transformer forward pass, construct and pass this mask to the encoder:
encoder_attn_mask = self.get_causal_mask(
gazing_info['num_gazing_each_frame'],
batch_size=B,
num_heads=self.config.num_attention_heads,
token_mask=gazing_info['if_padded_gazing'],
dtype=pixel_values.dtype,
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states, attention_mask=encoder_attn_mask)See
SiglipVisionTransformer.get_causal_maskinmodeling_siglip.pyfor the complete implementation covering all three attention types and backends.
Add these fields to your vision config (see configuration_siglip.py):
scales(str): Multi-scale resolutions separated by+, e.g.,'32+64+112+224'. Use'224'for single-scale.attn_type(str): Attention type —'block_causal'(default; causal across frames, bidirectional within each frame),'causal'(strictly causal), or'bidirectional'(full attention). See Step 2 for details and backend compatibility.frame_independent_encoding(bool): IfTrue, tokens from different frames cannot attend to each other (only intra-frame attention).
Add gazing_info: Optional[dict] = None to the forward methods of both the embedding module and the transformer module.
| Component | Original ViT | With AutoGaze |
|---|---|---|
| Input shape | (B, C, H, W) |
(B, T, C, H, W) |
| Patch embedding | Embeds all patches | Embeds only gazed patches via mask_with_gazing |
| Attention mask | None (full attention) | Block-causal / causal / bidirectional mask from get_causal_mask |
| Encoder / MLP / LayerNorm | No change | No change |
| Config | Standard | + scales, attn_type, frame_independent_encoding |
After adding AutoGaze to a ViT, it's conceptually trivial to use it in an MLLM--all you need is to send the ViT features of gazed patches into the LLM.
As an example, we open-sourced NVILA-HD-Video, an MLLM with AutoGaze. See instructions in the VILA repo for using NVILA-HD-Video.