diff --git a/challenges/medium/64_weight_dequantization/challenge.html b/challenges/medium/64_weight_dequantization/challenge.html
index aca2164..4ee6f3a 100644
--- a/challenges/medium/64_weight_dequantization/challenge.html
+++ b/challenges/medium/64_weight_dequantization/challenge.html
@@ -50,5 +50,5 @@
Constraints
1 ≤ M, N ≤ 8192
TILE_SIZE ∈ {16, 32, 64, 128}
- Performance is measured with M = 8,192, N = 8,192
+ Performance is measured with M = 8,192, N = 8,192, TILE_SIZE = 128
diff --git a/challenges/medium/64_weight_dequantization/challenge.py b/challenges/medium/64_weight_dequantization/challenge.py
index a210661..db8db37 100644
--- a/challenges/medium/64_weight_dequantization/challenge.py
+++ b/challenges/medium/64_weight_dequantization/challenge.py
@@ -14,24 +14,19 @@ def __init__(self):
def reference_impl(
self, X: torch.Tensor, S: torch.Tensor, Y: torch.Tensor, M: int, N: int, TILE_SIZE: int
):
- # S shape: (ceil(M/TILE_SIZE), ceil(N/TILE_SIZE))
- # We expand S to match X's shape (M, N)
-
- # Expand rows
- S_expanded = S.repeat_interleave(TILE_SIZE, dim=0)
- # Crop if M is not a multiple of TILE_SIZE
- if S_expanded.shape[0] > M:
- S_expanded = S_expanded[:M, :]
+ s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
+ s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
+ assert X.shape == (M, N)
+ assert S.shape == (s_rows, s_cols)
+ assert Y.shape == (M, N)
+ assert X.dtype == torch.float32
+ assert S.dtype == torch.float32
+ assert Y.dtype == torch.float32
- # Expand cols
- S_expanded = S_expanded.repeat_interleave(TILE_SIZE, dim=1)
- # Crop if N is not a multiple of TILE_SIZE
- if S_expanded.shape[1] > N:
- S_expanded = S_expanded[:, :N]
+ S_expanded = S.repeat_interleave(TILE_SIZE, dim=0)[:M, :]
+ S_expanded = S_expanded.repeat_interleave(TILE_SIZE, dim=1)[:, :N]
- # Perform element-wise multiplication
- # Ensure Y is updated in-place
- Y.copy_(X.to(Y.dtype) * S_expanded.to(Y.dtype))
+ Y.copy_(X * S_expanded)
def get_solve_signature(self) -> Dict[str, tuple]:
return {
@@ -65,30 +60,48 @@ def generate_example_test(self) -> Dict[str, Any]:
def generate_functional_test(self) -> List[Dict[str, Any]]:
tests = []
- # Case 1: Perfect Multiple
- M, N = 256, 256
- TILE_SIZE = 128
- tests.append(
- {
- "name": "perfect_multiple",
- "X": torch.randn(M, N, device="cuda", dtype=torch.float32),
- "S": torch.randn(2, 2, device="cuda", dtype=torch.float32),
- "Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
- "M": M,
- "N": N,
- "TILE_SIZE": TILE_SIZE,
- }
- )
-
- # Case 2: Odd sizes (padding needed)
- M, N = 130, 200
- TILE_SIZE = 128
- # Rows: ceil(130/128) = 2. Cols: ceil(200/128) = 2.
+ test_configs = [
+ # Edge cases - small sizes
+ (1, 1, 16),
+ (2, 3, 16),
+ (4, 4, 16),
+ # Power-of-2 sizes
+ (64, 64, 32),
+ (128, 128, 64),
+ (256, 256, 128),
+ (512, 512, 128),
+ # Non-power-of-2 sizes (padding needed)
+ (30, 50, 16),
+ (100, 100, 32),
+ (130, 200, 128),
+ (255, 255, 64),
+ # Realistic sizes
+ (1024, 1024, 128),
+ (2048, 4096, 128),
+ ]
+
+ for M, N, TILE_SIZE in test_configs:
+ s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
+ s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
+ tests.append(
+ {
+ "X": torch.randn(M, N, device="cuda", dtype=torch.float32),
+ "S": torch.randn(s_rows, s_cols, device="cuda", dtype=torch.float32),
+ "Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
+ "M": M,
+ "N": N,
+ "TILE_SIZE": TILE_SIZE,
+ }
+ )
+
+ # Zero input
+ M, N, TILE_SIZE = 64, 64, 32
+ s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
+ s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
tests.append(
{
- "name": "irregular_size",
- "X": torch.randn(M, N, device="cuda", dtype=torch.float32),
- "S": torch.randn(2, 2, device="cuda", dtype=torch.float32),
+ "X": torch.zeros(M, N, device="cuda", dtype=torch.float32),
+ "S": torch.randn(s_rows, s_cols, device="cuda", dtype=torch.float32),
"Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"M": M,
"N": N,
@@ -96,14 +109,14 @@ def generate_functional_test(self) -> List[Dict[str, Any]]:
}
)
- # Case 3: Small Tile Size
- M, N = 64, 64
- TILE_SIZE = 32
+ # Negative values
+ M, N, TILE_SIZE = 128, 128, 64
+ s_rows = (M + TILE_SIZE - 1) // TILE_SIZE
+ s_cols = (N + TILE_SIZE - 1) // TILE_SIZE
tests.append(
{
- "name": "small_tiles",
- "X": torch.randn(M, N, device="cuda", dtype=torch.float32),
- "S": torch.randn(2, 2, device="cuda", dtype=torch.float32),
+ "X": torch.randn(M, N, device="cuda", dtype=torch.float32).sub_(0.5),
+ "S": torch.randn(s_rows, s_cols, device="cuda", dtype=torch.float32).sub_(0.5),
"Y": torch.zeros(M, N, device="cuda", dtype=torch.float32),
"M": M,
"N": N,
diff --git a/scripts/run_challenge.py b/scripts/run_challenge.py
index e39f8ea..31dc2fc 100644
--- a/scripts/run_challenge.py
+++ b/scripts/run_challenge.py
@@ -64,7 +64,7 @@ def submit_solution(
"gpu": gpu,
"mode": "accelerated",
"public": public,
- "challenge_id": challenge_id,
+ "challengeId": challenge_id,
},
}
ws.send(json.dumps(submission))