-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
141 lines (105 loc) · 5.33 KB
/
test.py
File metadata and controls
141 lines (105 loc) · 5.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import math
from Compress import *
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
config = Qwen3Config()
def test_final_compressed_tokens():
"""Test the FinalCompressedTokens class functionality"""
print("Testing FinalCompressedTokens Class")
print("=" * 50)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Test parameters
batch_size = 2
seq_len = 50
window_size = 10
r = 0.6
M = 0
# Create test inputs - these should be flattened representations
q_w = torch.randn(batch_size, window_size, config.num_attention_heads, config.head_dim, device=device).transpose(1, 2)
km_cmp = torch.randn(batch_size, seq_len, config.num_key_value_heads, config.head_dim, device=device).transpose(1, 2)
print("Input shapes:")
print(f" q_w: {q_w.shape}")
print(f" km_cmp: {km_cmp.shape}")
print()
# Create FinalCompressedTokens instance
final_compression = FinalCompressedTokens(config, q_w, km_cmp, r, M)
print(f"Final compression parameters: {sum(p.numel() for p in final_compression.parameters()):,}")
# Create test memory tensors
xm_cmp = torch.randn(batch_size, seq_len, config.hidden_size, device=device)
x_m = torch.randn(batch_size, 2* seq_len , config.hidden_size, device=device)
output = final_compression(x_m, xm_cmp)
print(f"Output shape: {output.shape}")
compression_ratio = 1 - (output.shape[1]-window_size) / (2*seq_len)
print(f"Compression ratio: {compression_ratio:.2%}")
print(f"Memory tensor shape: {xm_cmp.shape}")
print()
# Test forward pass
with torch.no_grad():
output = final_compression(xm_cmp, xm_cmp)
print("Forward pass successful!")
print(f"Output shape: {output.shape}")
# Verify output shape matches input shape
assert output.shape == xm_cmp.shape, f"Output shape {output.shape} doesn't match input shape {xm_cmp.shape}"
print(f"✓ Shape verification passed")
# Test with different memory sizes
test_cases = [
(1, 32, 8, "Small sequence"),
(2, 100, 20, "Medium sequence"),
(1, 200, 50, "Large sequence"),
]
print("\nTesting different sequence sizes:")
for batch_size, seq_len, window_size, description in test_cases:
print(f"\nTest: {description}")
print(f"Input: batch={batch_size}, seq_len={seq_len}, window={window_size}")
q_w = torch.randn(batch_size, config.num_attention_heads, window_size, config.head_dim, device=device)
km_cmp = torch.randn(batch_size, config.num_key_value_heads, seq_len, config.head_dim, device=device)
xm_cmp = torch.randn(batch_size, seq_len, config.hidden_size, device=device)
x_m = torch.randn(batch_size, 2*seq_len, config.hidden_size, device=device)
final_compression = FinalCompressedTokens(config, q_w, km_cmp, r, M)
output = final_compression(x_m, xm_cmp)
print(f" Input shape: {xm_cmp.shape}")
print(f" Output shape: {output.shape}")
compression_ratio = 1 - (output.shape[1]-window_size) / (2*seq_len)
print(f"Compression ratio: {compression_ratio:.2%}")
print(f"\n✓ All FinalCompressedTokens tests passed!")
def test_final_compressed_tokens_gradients():
"""Test that gradients flow properly through FinalCompressedTokens"""
print("\n" + "=" * 50)
print("Testing FinalCompressedTokens Gradient Flow")
print("=" * 50)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Qwen3Config()
batch_size = 1
seq_len = 32
window_size = 8
r = 0.6
M = 5
# Create inputs with gradients - these should be flattened representations
q_w = torch.randn(batch_size, config.num_attention_heads, window_size, config.head_dim, device=device)
km_cmp = torch.randn(batch_size, config.num_key_value_heads, seq_len - window_size, config.head_dim, device=device)
xm_cmp = torch.randn(batch_size, seq_len - window_size, config.hidden_size, device=device, requires_grad=True)
x_m = torch.randn(batch_size, 2*seq_len, config.hidden_size, device=device, requires_grad=True)
print(f"Input requires grad: {xm_cmp.requires_grad}")
final_compression = FinalCompressedTokens(config, q_w, km_cmp, r, M)
final_compression.train() # Enable training mode
output = final_compression(x_m, xm_cmp)
# Compute a simple loss
loss = output.mean()
loss.backward()
print(f"Output shape: {output.shape}")
print(f"Loss: {loss.item():.6f}")
if xm_cmp.grad is not None:
print(f"Input grad norm: {xm_cmp.grad.norm().item():.6f}")
else:
print("Input grad: None")
# Check that some parameters have gradients
param_grads = [p.grad is not None for p in final_compression.parameters() if p.requires_grad]
print(f"Parameters with gradients: {sum(param_grads)}/{len(param_grads)}")
print(f"✓ Gradient test passed")
if __name__ == "__main__":
test_final_compressed_tokens()
test_final_compressed_tokens_gradients()
print("\n" + "=" * 50)
print("🎉 All FinalCompressedTokens tests completed!")
print("=" * 50)