-
Notifications
You must be signed in to change notification settings - Fork 52
Expand file tree
/
Copy pathwan_loader.py
More file actions
94 lines (83 loc) · 2.7 KB
/
wan_loader.py
File metadata and controls
94 lines (83 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import torch
from pipelines.wan_video import WanVideoPipeline, ModelConfig
from pipelines.wan_video_face_swap import WanVideoPipeline_FaceSwap
def load_wan_pipe(
base_path,
wan_version: str = "2.2",
torch_dtype=torch.bfloat16,
face_swap=False,
use_vace=False,
device="cuda",
):
if wan_version == "2.2" and use_vace:
raise ValueError("Wan2.2 does not support use_vace=True. Please disable this option.")
model_configs = []
if wan_version == "2.1":
if not use_vace:
diffusion_model_files = [
f"diffusion_pytorch_model-0000{i}-of-00006.safetensors" for i in range(1, 7)
]
else:
diffusion_model_files = [
f"diffusion_pytorch_model-0000{i}-of-00007.safetensors" for i in range(1, 8)
]
diffusion_model_paths = [
os.path.join(base_path, fname) for fname in diffusion_model_files
]
model_configs.append(
ModelConfig(
path=diffusion_model_paths,
offload_device="cpu",
skip_download=True,
)
)
else:
diffusion_model_files = [
f"diffusion_pytorch_model-0000{i}-of-00006.safetensors" for i in range(1, 7)
]
high_noise_paths = [
os.path.join(base_path, "high_noise_model", fname) for fname in diffusion_model_files
]
model_configs.append(
ModelConfig(
path=high_noise_paths,
offload_device="cpu",
skip_download=True,
)
)
low_noise_paths = [
os.path.join(base_path, "low_noise_model", fname) for fname in diffusion_model_files
]
model_configs.append(
ModelConfig(
path=low_noise_paths,
offload_device="cpu",
skip_download=True,
)
)
model_configs.extend([
ModelConfig(
path=os.path.join(base_path, "models_t5_umt5-xxl-enc-bf16.pth"),
offload_device="cpu",
skip_download=True,
),
ModelConfig(
path=os.path.join(base_path, f"Wan2.1_VAE.pth"),
offload_device="cpu",
skip_download=True,
),
])
pipe_cls = WanVideoPipeline_FaceSwap if face_swap else WanVideoPipeline
pipe = pipe_cls.from_pretrained(
torch_dtype=torch_dtype,
device=device,
model_configs=model_configs,
tokenizer_config=ModelConfig(
path=os.path.join(base_path, "google/umt5-xxl/"),
offload_device="cpu",
skip_download=True,
),
)
pipe.enable_vram_management()
return pipe