-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencode_decode.py
More file actions
350 lines (307 loc) · 11.9 KB
/
encode_decode.py
File metadata and controls
350 lines (307 loc) · 11.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# Copyright 2026 Linum Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
File: encode_decode.py
Description: Encode and decode images or videos through the Linum VAE. Supports both
single files and directories. Videos are processed in temporal chunks to fit GPU memory.
The script saves both the preprocessed original and the VAE reconstruction so users can
compare them directly:
- **Images**: By default, images are resized to the closest training-resolution size bucket
(the VAE was trained on specific resolutions, so this gives the best reconstruction
quality). Use ``--no-resize`` to skip bucket resizing and encode at the image's native
resolution. The saved "original" reflects exactly what the VAE received after resizing.
- **Videos**: Frames are sub-sampled to 24 FPS (videos below 24 FPS are rejected). Spatial
dimensions are floored to the nearest multiple of 8. The saved "original" shows the
sub-sampled, resized frames — exactly what the VAE encoded.
"""
import argparse
import glob
import os
import einops
import torch
from PIL import Image
from image_video_vae.autoencoder import Autoencoder, gen_upsample_shapes
from image_video_vae.io import (
IMAGE_EXTENSIONS,
denormalize_pixels,
get_video_chunk_frames,
preprocess_image,
preprocess_image_native,
preprocess_video,
save_video_as_mp4,
)
from image_video_vae.size_buckets import SIZE_CHOICES, get_best_size_bucket
# --------------------------------
# HIGH-LEVEL ORCHESTRATION
# --------------------------------
def parse_args() -> argparse.Namespace:
"""
Parse command-line arguments.
Returns:
argparse.Namespace:
Parsed arguments.
"""
parser = argparse.ArgumentParser(
description="Encode and decode images/videos through the Linum VAE",
)
parser.add_argument(
"--mode",
type=str,
required=True,
choices=["image", "video"],
help="Processing mode: 'image' or 'video'",
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to image file/directory or video file/directory",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="Path to local checkpoint (.safetensors or .ckpt). "
"Downloads from HuggingFace Hub if omitted.",
)
parser.add_argument(
"--output_dir",
type=str,
default="./test_output",
help="Directory for outputs (default: ./test_output)",
)
size_group = parser.add_mutually_exclusive_group()
size_group.add_argument(
"--size",
type=str,
default=None,
choices=SIZE_CHOICES,
help="Size bucket for resizing images. The VAE was trained on specific "
"resolutions, so by default images are resized to the closest "
"training bucket for best quality. Auto-selected if omitted.",
)
size_group.add_argument(
"--no-resize",
action="store_true",
help="Skip bucket resizing and encode images at native resolution. "
"Both dimensions must be at least 8 pixels. Best reconstruction "
"quality comes from the default bucket-resize behavior.",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducible latent sampling (default: 42)",
)
return parser.parse_args()
@torch.inference_mode()
def main(args: argparse.Namespace) -> None:
"""
Run VAE encode-decode on images or videos.
Args:
args (argparse.Namespace):
Parsed command-line arguments.
"""
autoencoder = Autoencoder.from_pretrained(
checkpoint_path=args.checkpoint,
)
os.makedirs(args.output_dir, exist_ok=True)
if args.mode == "image":
if os.path.isfile(args.input):
image_files = [args.input]
else:
image_files = []
for ext in IMAGE_EXTENSIONS:
image_files.extend(
glob.glob(os.path.join(args.input, "**", ext), recursive=True)
)
image_files = sorted(image_files)
if not image_files:
print(f"No images found in {args.input}")
return
print(f"Found {len(image_files)} images")
for image_path in image_files:
if args.no_resize:
encode_decode_image(
autoencoder=autoencoder,
image_path=image_path,
output_dir=args.output_dir,
no_resize=True,
)
else:
if args.size is not None:
size_bucket = args.size
else:
img = Image.open(image_path)
w, h = img.size
img.close()
size_bucket = get_best_size_bucket(width=w, height=h)
print(f" Auto-selected size bucket: {size_bucket} for {w}x{h}")
encode_decode_image(
autoencoder=autoencoder,
image_path=image_path,
output_dir=args.output_dir,
size_bucket=size_bucket,
)
elif args.mode == "video":
if os.path.isfile(args.input):
video_files = [args.input]
elif os.path.isdir(args.input):
video_files = sorted(
glob.glob(os.path.join(args.input, "**", "*.mp4"), recursive=True)
)
else:
print(f"Video input not found: {args.input}")
return
if not video_files:
print(f"No video files found in {args.input}")
return
print(f"Found {len(video_files)} video(s)")
for video_path in video_files:
print(f"Processing video: {video_path}")
encode_decode_video(
autoencoder=autoencoder,
video_path=video_path,
output_dir=args.output_dir,
)
print(f"\nDone! Outputs saved to: {args.output_dir}")
# --------------------------------
# ENCODE-DECODE
# --------------------------------
def encode_decode_image(
autoencoder: Autoencoder,
image_path: str,
output_dir: str,
size_bucket: str = None,
no_resize: bool = False,
) -> None:
"""
Encode and decode a single image, saving both the original and reconstruction.
When ``no_resize`` is False (default), the image is resized to a training-resolution
bucket and the saved original reflects the resized version. When ``no_resize`` is True,
the image is encoded at its native resolution and ``gen_upsample_shapes`` is used to
ensure the decoder reproduces the exact input dimensions.
Args:
autoencoder (Autoencoder):
Loaded VAE model.
image_path (str):
Path to input image.
output_dir (str):
Directory for output files.
size_bucket (Optional[str]):
Size category for target dimensions. Ignored when ``no_resize`` is True.
no_resize (bool):
If True, skip bucket resizing and encode at native resolution.
"""
if no_resize:
tensor = preprocess_image_native(image_path=image_path)
else:
tensor = preprocess_image(
image_path=image_path,
size_bucket=size_bucket,
)
if tensor is None:
print(f" Skipped (aspect ratio doesn't fit any bucket): {image_path}")
return
# Autoencoder operates on [0, 1] normalized tensors
x = tensor.unsqueeze(0).to(device="cuda:0", dtype=torch.bfloat16)
dist_params = autoencoder.encode(x=x)
z = autoencoder.sample(distribution_params=dist_params)
# For non-8-divisible native resolutions, compute explicit upsample shapes
# so the decoder reproduces the exact input dimensions. For bucket-resized
# images all dims are multiples of 8, so gen_upsample_shapes returns None.
_, _, t, h, w = x.shape
_, _, zt, zh, zw = z.shape
upsample_shapes = gen_upsample_shapes(
latents_resolution=(zt, zh, zw),
target_resolution=(t, h, w),
)
decoded = autoencoder.decode(z=z, upsample_shapes=upsample_shapes)
# Tensor -> I/O: denormalize [0, 1] back to [0, 255] uint8 for saving
decoded = decoded.squeeze(0).float()
decoded = denormalize_pixels(frames=decoded).byte()
recon_array = einops.rearrange(decoded, "c t h w -> t h w c").cpu().numpy()[0]
original = x.squeeze(0).float()
original = denormalize_pixels(frames=original).byte()
original_array = einops.rearrange(original, "c t h w -> t h w c").cpu().numpy()[0]
name = os.path.splitext(os.path.basename(image_path))[0]
if no_resize:
original_out = os.path.join(output_dir, f"{name}_original.jpg")
else:
original_out = os.path.join(output_dir, f"{name}_original_resized.jpg")
Image.fromarray(original_array).save(original_out, format="JPEG", quality=95)
recon_out = os.path.join(output_dir, f"{name}_reconstruction.jpg")
Image.fromarray(recon_array).save(recon_out, format="JPEG", quality=95)
print(f" {name} -> {original_out}, {recon_out}")
def encode_decode_video(
autoencoder: Autoencoder,
video_path: str,
output_dir: str,
) -> None:
"""
Full encode-decode round trip on a video.
Frames are sub-sampled to 24 FPS and spatial dimensions are floored to the nearest
multiple of 8. Both the preprocessed original and reconstruction are saved as MP4.
Args:
autoencoder (Autoencoder):
Loaded VAE model.
video_path (str):
Path to input video.
output_dir (str):
Directory for output files.
"""
chunk_frames = get_video_chunk_frames(video_path=video_path)
print(f" Using {chunk_frames}-frame chunks")
# I/O -> tensor: decode video, subsample to 24 FPS, resize, normalize to [0, 1]
try:
frames, num_frames, sampling_summary = preprocess_video(
video_path=video_path,
chunk_frames=chunk_frames,
)
except ValueError as e:
print(f" WARNING: Skipping {video_path}: {e}")
return
height = frames.shape[2]
width = frames.shape[3]
print(f" {sampling_summary}")
print(f" Preprocessed: {num_frames} frames at {height}x{width}")
# Autoencoder operates on [0, 1] normalized tensors
x = frames.unsqueeze(0).to(device="cuda:0", dtype=torch.bfloat16)
print(" Encoding...")
dist_params = autoencoder.encode_chunked(
x=x,
chunk_frames=chunk_frames,
)
z = autoencoder.sample(distribution_params=dist_params)
print(" Decoding...")
target_chunk_res = (chunk_frames, height, width)
decoded = autoencoder.decode_chunked(
z=z,
target_chunk_resolution=target_chunk_res,
chunk_frames=chunk_frames,
)
# Tensor -> I/O: denormalize [0, 1] back to [0, 255] uint8 and save as MP4
name = os.path.splitext(os.path.basename(video_path))[0]
original_out = os.path.join(output_dir, f"{name}_original.mp4")
save_video_as_mp4(video_tensor=x.float(), output_path=original_out)
print(f" Saved {original_out}")
recon_out = os.path.join(output_dir, f"{name}_reconstruction.mp4")
save_video_as_mp4(video_tensor=decoded, output_path=recon_out)
print(f" Saved {recon_out}")
if __name__ == "__main__":
args = parse_args()
# Set seeds for reproducible sampling
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
main(args=args)