Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
83b5582
Add base class for quantizers
LTMeyer Apr 7, 2025
23c75bf
Add base class for tokenizers
LTMeyer Apr 7, 2025
b250bf4
Change arborescence for tokenizers
LTMeyer Apr 7, 2025
dc75b8c
Add utils functions
LTMeyer Apr 7, 2025
8389823
Add MagVitAE model
LTMeyer Apr 7, 2025
b51e963
Add subsampler module
LTMeyer Apr 7, 2025
57926ab
Add FiniteScale quantizer
LTMeyer Apr 7, 2025
2733ce7
Use quantizer encode instead of quantize
LTMeyer Apr 7, 2025
c190989
Add MagVitAE image tokenizer
LTMeyer Apr 7, 2025
cd1f705
Add test for MagVitAE image tokenizer
LTMeyer Apr 7, 2025
71ea594
Add tests to CI
LTMeyer Apr 7, 2025
2f531e0
Fix github CI
LTMeyer Apr 7, 2025
c22e5fb
Remove unnecessary package listing to fix tests
LTMeyer Apr 7, 2025
7ec31f9
Update aion/utils.py
EiffL Apr 8, 2025
f4c3bff
Update aion/tokenizers/base.py
EiffL Apr 8, 2025
806da4b
Add ruff cache to gitignore
LTMeyer Apr 9, 2025
0b5c6e4
Move tokenizers to dedicated codecs module
LTMeyer Apr 9, 2025
9cd08ba
Rename FiniteScaleQuantizer->FiniteScalarQuantizer
LTMeyer May 12, 2025
d37fb30
Add channel mask as input
LTMeyer May 13, 2025
1cf702b
Add test to ensure previous results consistency
LTMeyer May 13, 2025
8c37cc4
Make the tokenizer a pytorch module
LTMeyer May 14, 2025
aa71c1e
Update test to load only one model checkpoint
LTMeyer May 14, 2025
66b2da2
Upload data to HF
LTMeyer May 15, 2025
6e61347
Fix weight_only loading default value
LTMeyer May 15, 2025
e82942c
Download git lfs files in github actions
LTMeyer May 15, 2025
627130b
Fix tokenizer decode method
LTMeyer May 20, 2025
f2f6d10
Add image codedc decoded batch to test data.
LTMeyer May 20, 2025
88b75da
Merge branch 'main' into add_tokenizers
LTMeyer May 21, 2025
660f0fd
Add missing run keyword to CI
LTMeyer May 21, 2025
9156c0a
Add numpy to dependencies
LTMeyer May 21, 2025
16b2232
Restore HF token in CI
LTMeyer May 21, 2025
fdbe656
Restore lfs checkout in CI
LTMeyer May 21, 2025
8097f35
Update huggingface_hub dependency
LTMeyer May 21, 2025
cb426a8
Investigate why image tokenizer test is failing
LTMeyer May 21, 2025
9823d2f
Bis
LTMeyer May 21, 2025
e3576b0
Ter
LTMeyer May 21, 2025
06f246d
Remove gitattributes to get rid of lfs
LTMeyer May 21, 2025
802495b
Remove test data lfs from repo
LTMeyer May 21, 2025
f4a7b73
Add test data for image tokenizer without lfs
LTMeyer May 21, 2025
9b531d0
Update pyproject.toml
EiffL May 22, 2025
e94d703
Update test.yaml
EiffL May 22, 2025
7a29aa7
Update test_image_tokenizer.py
EiffL May 22, 2025
536e2df
Update test_image_tokenizer.py
EiffL May 22, 2025
682c48c
Update pyproject.toml
EiffL May 22, 2025
e9e77cc
Merge pull request #12 from PolymathicAI/add_tokenizers_tests
EiffL May 22, 2025
a9d63c5
Rename MagViTAEImageCodec to ImageCodec
LTMeyer May 22, 2025
d1b9f58
Prepare migration to lfs
LTMeyer May 22, 2025
e050180
Track tokenizer test data with lfs
LTMeyer May 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tests/test_data/image_codec_decoded_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/image_codec_encoded_batch.pt filter=lfs diff=lfs merge=lfs -text
tests/test_data/image_codec_input_batch.pt filter=lfs diff=lfs merge=lfs -text
10 changes: 8 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ jobs:
uses: pre-commit/action@v3.0.1
with:
extra_args: --all-files
install:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
lfs: true
- uses: astral-sh/setup-uv@v5
with:
enable-cache: true
- uses: actions/setup-python@v5
with:
python-version: "3.13"
python-version: "3.11"
- name: Install AION
run: uv sync --all-extras --dev
- name: Run tests
env:
HF_TOKEN: ${{ secrets.AION_HF_TOKEN }}
run: uv run pytest tests
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ venv.bak/
/site

# mypy
.mypy_cache/
.mypy_cache
.dmypy.json
dmypy.json

# Ruff
.ruff_cache

# Pyre type checker
.pyre/

Expand Down
Empty file added aion/codecs/__init__.py
Empty file.
Empty file added aion/codecs/modules/__init__.py
Empty file.
214 changes: 214 additions & 0 deletions aion/codecs/modules/magvit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


def cast_tuple(t, length=1):
return t if isinstance(t, tuple) else ((t,) * length)


class SameConv2d(torch.nn.Module):
def __init__(self, dim_in, dim_out, kernel_size):
super().__init__()
kernel_size = cast_tuple(kernel_size, 2)
padding = [k // 2 for k in kernel_size]
self.conv = torch.nn.Conv2d(
dim_in, dim_out, kernel_size=kernel_size, padding=padding
)

def forward(self, x: torch.Tensor):
return self.conv(x)


class SqueezeExcite(torch.nn.Module):
# global context network - attention-esque squeeze-excite variant (https://arxiv.org/abs/2012.13375)

def __init__(self, dim, *, dim_out=None, dim_hidden_min=16, init_bias=-10):
super().__init__()
dim_out = dim_out if dim_out is not None else dim

self.to_k = torch.nn.Conv2d(dim, 1, 1)
dim_hidden = max(dim_hidden_min, dim_out // 2)

self.net = torch.nn.Sequential(
torch.nn.Conv2d(dim, dim_hidden, 1),
torch.nn.LeakyReLU(0.1),
torch.nn.Conv2d(dim_hidden, dim_out, 1),
torch.nn.Sigmoid(),
)

torch.nn.init.zeros_(self.net[-2].weight)
torch.nn.init.constant_(self.net[-2].bias, init_bias)

def forward(self, x):
context = self.to_k(x)

context = rearrange(context, "b c h w -> b c (h w)").softmax(dim=-1)
spatial_flattened_input = rearrange(x, "b c h w -> b c (h w)")

out = torch.einsum("b i n, b c n -> b c i", context, spatial_flattened_input)
out = rearrange(out, "... -> ... 1")
gates = self.net(out)

return gates * x


class ResidualUnit(torch.nn.Module):
def __init__(self, dim: int, kernel_size: int | tuple[int, int, int]):
super().__init__()
self.net = torch.nn.Sequential(
SameConv2d(dim, dim, kernel_size),
torch.nn.ELU(),
torch.nn.Conv2d(dim, dim, 1),
torch.nn.ELU(),
SqueezeExcite(dim),
)

def forward(self, x: torch.Tensor):
return self.net(x) + x


class SpatialDownsample2x(torch.nn.Module):
def __init__(
self,
dim: int,
dim_out: int = None,
kernel_size: int = 3,
):
super().__init__()
dim_out = dim_out if dim_out is not None else dim
self.conv = torch.nn.Conv2d(
dim, dim_out, kernel_size, stride=2, padding=kernel_size // 2
)

def forward(self, x: torch.Tensor):
out = self.conv(x)
return out


class SpatialUpsample2x(torch.nn.Module):
def __init__(self, dim: int, dim_out: int = None):
super().__init__()
dim_out = dim_out if dim_out is not None else dim
conv = torch.nn.Conv2d(dim, dim_out * 4, 1)

self.net = torch.nn.Sequential(
conv,
torch.nn.SiLU(),
Rearrange("b (c p1 p2) h w -> b c (h p1) (w p2)", p1=2, p2=2),
)

self.init_conv_(conv)

def init_conv_(self, conv: torch.nn.Module):
o, i, h, w = conv.weight.shape
conv_weight = torch.empty(o // 4, i, h, w)
torch.nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")

conv.weight.data.copy_(conv_weight)
torch.nn.init.zeros_(conv.bias.data)

def forward(self, x: torch.Tensor):
out = self.net(x)
return out


class MagVitAE(torch.nn.Module):
"""MagViTAE implementation from Yu, et al. (2024), adapted for Pytorch.
Code borrowed from https://github.com/lucidrains/magvit2-pytorch, and adapted for images.
"""

def __init__(
self,
n_bands: int = 3,
hidden_dims: int = 512,
residual_conv_kernel_size: int = 3,
n_compressions: int = 2,
num_consecutive: int = 2,
):
super().__init__()

self.encoder_layers = torch.nn.ModuleList([])
self.decoder_layers = torch.nn.ModuleList([])
init_dim = int(hidden_dims / 2**n_compressions)
dim = init_dim

self.conv_in = SameConv2d(n_bands, init_dim, 7)
self.conv_out = SameConv2d(init_dim, n_bands, 3)

# Residual layers
encoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
decoder_layer = ResidualUnit(dim, residual_conv_kernel_size)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)

# Compressions
for i in range(n_compressions):
dim_out = dim * 2
encoder_layer = SpatialDownsample2x(dim, dim_out)
decoder_layer = SpatialUpsample2x(dim_out, dim)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)
dim = dim_out

# Consecutive residual layers
encoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
decoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)

# Add a final non-compress layer
dim_out = dim
encoder_layer = SameConv2d(dim, dim_out, 7)
decoder_layer = SameConv2d(dim_out, dim, 3)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)
dim = dim_out

# Consecutive residual layers
encoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
decoder_layer = torch.nn.Sequential(
*[
ResidualUnit(dim, residual_conv_kernel_size)
for _ in range(num_consecutive)
]
)
self.encoder_layers.append(encoder_layer)
self.decoder_layers.insert(0, decoder_layer)

# add a final norm just before quantization layer
self.encoder_layers.append(
torch.nn.Sequential(
Rearrange("b c ... -> b ... c"),
torch.nn.LayerNorm(dim),
Rearrange("b ... c -> b c ..."),
)
)

def encode(self, x: torch.Tensor):
x = self.conv_in(x)
for layer in self.encoder_layers:
x = layer(x)
return x

def decode(self, x: torch.Tensor):
for layer in self.decoder_layers:
x = layer(x)
x = self.conv_out(x)
return x
60 changes: 60 additions & 0 deletions aion/codecs/modules/subsampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from jaxtyping import Bool, Float


class SubsampledLinear(torch.nn.Module):
def __init__(self, dim_in: int, dim_out: int, subsample_in: bool = True):
"""
Subsampled linear layer for the encoder.
It takes in a zero-padded tensor and a mask.
It projects the tensor into some shared projection space.
It can also be used to reverse out of the space with the mask.

Args:
dim_in : Number of total possible bands.
dim_out : Number of embedding dimensions.
subsample_in : Whether to subsample the input. Defaults to True.
"""
super().__init__()
self.subsample_in = subsample_in
self.dim_in = dim_in # Number of total possible bands
self.dim_out = dim_out # Number of embedding dimensions
temp_linear = torch.nn.Linear(dim_in, dim_out)
self.weight = torch.nn.Parameter(temp_linear.weight)
self.bias = torch.nn.Parameter(temp_linear.bias)

def _subsample_in(self, x, labels: Bool[torch.Tensor, " b c"]):
# Get mask
mask = labels[:, None, None, :].float()
x = x * mask

# Normalize
label_sizes = labels.sum(dim=1, keepdim=True)
scales = ((self.dim_in / label_sizes) ** 0.5).squeeze()

# Apply linear layer
return scales[:, None, None, None] * F.linear(x, self.weight, self.bias)

def _subsample_out(self, x, labels):
# Get mask
mask = labels[:, None, None, :].float()

# Apply linear layer and mask
return F.linear(x, self.weight, self.bias) * mask

def forward(
self, x: Float[torch.Tensor, " b c h w"], labels: Bool[torch.Tensor, " b c"]
) -> Float[torch.Tensor, " b c h w"]:
x = rearrange(x, "b c h w -> b h w c")

if self.subsample_in:
x = self._subsample_in(x, labels)

else:
x = self._subsample_out(x, labels)

x = rearrange(x, "b h w c -> b c h w")

return x
Loading
Loading