Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ <h2>Constraints</h2>
<li>1 &le; <code>M</code>, <code>N</code> &le; 8192</li>
<li><code>TILE_SIZE</code> &in; {16, 32, 64, 128}</li>

<li>Performance is measured with <code>M</code> = 8,192, <code>N</code> = 8,192</li>
<li>Performance is measured with <code>M</code> = 8,192, <code>N</code> = 8,192, <code>TILE_SIZE</code> = 128</li>
</ul>
101 changes: 57 additions & 44 deletions challenges/medium/64_weight_dequantization/challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,19 @@ def __init__(self):
def reference_impl(
self, X: torch.Tensor, S: torch.Tensor, Y: torch.Tensor, M: int, N: int, TILE_SIZE: int
):
# S shape: (ceil(M/TILE_SIZE), ceil(N/TILE_SIZE))
# We expand S to match X's shape (M, N)

# Expand rows
S_expanded = S.repeat_interleave(TILE_SIZE, dim=0)
# Crop if M is not a multiple of TILE_SIZE
if S_expanded.shape[0] > M:
S_expanded = S_expanded[:M, :]
s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
assert X.shape == (M, N)
assert S.shape == (s_rows, s_cols)
assert Y.shape == (M, N)
assert X.dtype == torch.float32
assert S.dtype == torch.float32
assert Y.dtype == torch.float32

# Expand cols
S_expanded = S_expanded.repeat_interleave(TILE_SIZE, dim=1)
# Crop if N is not a multiple of TILE_SIZE
if S_expanded.shape[1] > N:
S_expanded = S_expanded[:, :N]
S_expanded = S.repeat_interleave(TILE_SIZE, dim=0)[:M, :]
S_expanded = S_expanded.repeat_interleave(TILE_SIZE, dim=1)[:, :N]

# Perform element-wise multiplication
# Ensure Y is updated in-place
Y.copy_(X.to(Y.dtype) * S_expanded.to(Y.dtype))
Y.copy_(X * S_expanded)

def get_solve_signature(self) -> Dict[str, tuple]:
return {
Expand Down Expand Up @@ -65,45 +60,63 @@ def generate_example_test(self) -> Dict[str, Any]:
def generate_functional_test(self) -> List[Dict[str, Any]]:
tests = []

# Case 1: Perfect Multiple
M, N = 256, 256
TILE_SIZE = 128
tests.append(
{
"name": "perfect_multiple",
"X": torch.randn(M, N, device="cuda", dtype=torch.float32),
"S": torch.randn(2, 2, device="cuda", dtype=torch.float32),
"Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"M": M,
"N": N,
"TILE_SIZE": TILE_SIZE,
}
)

# Case 2: Odd sizes (padding needed)
M, N = 130, 200
TILE_SIZE = 128
# Rows: ceil(130/128) = 2. Cols: ceil(200/128) = 2.
test_configs = [
# Edge cases - small sizes
(1, 1, 16),
(2, 3, 16),
(4, 4, 16),
# Power-of-2 sizes
(64, 64, 32),
(128, 128, 64),
(256, 256, 128),
(512, 512, 128),
# Non-power-of-2 sizes (padding needed)
(30, 50, 16),
(100, 100, 32),
(130, 200, 128),
(255, 255, 64),
# Realistic sizes
(1024, 1024, 128),
(2048, 4096, 128),
]

for M, N, TILE_SIZE in test_configs:
s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
tests.append(
{
"X": torch.randn(M, N, device="cuda", dtype=torch.float32),
"S": torch.randn(s_rows, s_cols, device="cuda", dtype=torch.float32),
"Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"M": M,
"N": N,
"TILE_SIZE": TILE_SIZE,
}
)

# Zero input
M, N, TILE_SIZE = 64, 64, 32
s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
tests.append(
{
"name": "irregular_size",
"X": torch.randn(M, N, device="cuda", dtype=torch.float32),
"S": torch.randn(2, 2, device="cuda", dtype=torch.float32),
"X": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"S": torch.randn(s_rows, s_cols, device="cuda", dtype=torch.float32),
"Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"M": M,
"N": N,
"TILE_SIZE": TILE_SIZE,
}
)

# Case 3: Small Tile Size
M, N = 64, 64
TILE_SIZE = 32
# Negative values
M, N, TILE_SIZE = 128, 128, 64
s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
tests.append(
{
"name": "small_tiles",
"X": torch.randn(M, N, device="cuda", dtype=torch.float32),
"S": torch.randn(2, 2, device="cuda", dtype=torch.float32),
"X": torch.randn(M, N, device="cuda", dtype=torch.float32).sub_(0.5),
"S": torch.randn(s_rows, s_cols, device="cuda", dtype=torch.float32).sub_(0.5),
"Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"M": M,
"N": N,
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_challenge.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def submit_solution(
"gpu": gpu,
"mode": "accelerated",
"public": public,
"challenge_id": challenge_id,
"challengeId": challenge_id,
},
}
ws.send(json.dumps(submission))
Expand Down