Skip to content

btrude/jet-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Jet PyTorch

This repo contains a PyTorch port of the code from Jet: A Modern Transformer-Based Normalizing Flow which was originally written in Jax and included as part of big_vision.

A checkpoint pretrained on downsampled imagenet 64x64 is provided via the hugging face hub as well as barebones single-GPU training and inference scripts.

Installation

The code in this repo requires a CUDA-capable GPU which can perform operations at bfloat16 precision.

pip install jet-pytorch

The models from the original paper/this reproduction are trained using downsampled imagenet which is available from a few different sources, including Kaggle. The ImageDataset provided in this repository does not load/support labels and does not require any specific directory structure.

Usage

from jet_pytorch import Jet

jet_config = dict(
    patch_size=4,
    patch_dim=48,
    n_patches=256,
    coupling_layers=32,
    block_depth=2,
    block_width=512,
    num_heads=8,
    scale_factor=2.0,
    coupling_types=(
        "channels", "channels",
        "channels", "channels",
        "spatial",
    ),
    spatial_coupling_projs=(
        "checkerboard", "checkerboard-inv",
        "vstripes", "vstripes-inv",
        "hstripes", "hstripes-inv",
    )
)
model = Jet(**jet_config)

This is the default Jet configuration and that which matches the pretrained weights available on the huggingface hub

Download and/or load pretrained imagenet 64x64 weights

from jet_pytorch.util import get_pretrained

weights = get_pretrained()
model.load_state_dict(weights)

Sample from a Jet

from torch.distributions import Normal

batch_size = 16
n_patches = 256
patch_dim = 48
pdf = Normal(0, 1)
z = pdf.sample((batch_size, n_patches, patch_dim))
img, logdet = model.inverse(z)

Training a Jet

from jet_pytorch.train import train

jet_config = dict(...)
train(
    jet_config=jet_config,
    batch_size=64,
    accumulate_steps=16,
    device="cuda:0",
    epochs=50,
    warmup_percentage=0.1,
    max_grad_norm=1.0,
    learning_rate=3e-4,
    weight_decay=1e-5,
    adam_betas=(0.9, 0.95),
    images_path_train="/path/to/train/images",
    images_path_valid="/path/to/validation/images",
    num_workers=8,
    checkpoint_path="jet.pt",
)

The training code favors gradient accumulation over alternatives like gradient checkpointing, allowing this script to run on GPUs with less than 8GB of VRAM. The true batch size is thus equal to batch_size * accumulate_steps. Note that the default configuration assumes at least 24GB of VRAM.

Create visualizations

from jet_pytorch.sample import sample

# Creates visualization using the default Jet config/pretrained weights
sample("path/to/your/images")

# Creates visualization using default Jet config/a local checkpoint
sample(
    "path/to/your/images",
    checkpoint_path="path/to/your/checkpoint.pt",
)

# Creates visualization using custom Jet config/a local checkpoint
jet_config = dict(...)
sample(
    "path/to/your/images",
    jet_config=jet_config,
    checkpoint_path="path/to/your/checkpoint.pt",
)

Visualization results are stored at output/jet.png

About

A Pytorch port of the code from "Jet: A Modern Transformer-Based Normalizing Flow"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages