Skip to content
Open
8 changes: 5 additions & 3 deletions mlx/backend/cuda/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ CustomKernelFunction cuda_kernel(
}

return array::make_arrays(
std::move(output_shapes),
output_shapes,
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
Expand All @@ -236,7 +236,8 @@ CustomKernelFunction cuda_kernel(
init_value,
std::vector<ScalarArg>{},
false,
shared_memory),
shared_memory,
output_shapes),
std::move(inputs));
};
}
Expand Down Expand Up @@ -270,7 +271,8 @@ std::vector<array> precompiled_cuda_kernel(
init_value,
scalars,
true,
shared_memory),
shared_memory,
output_shapes),
inputs);
}

Expand Down
5 changes: 3 additions & 2 deletions mlx/backend/metal/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ CustomKernelFunction metal_kernel(
}

return array::make_arrays(
std::move(output_shapes),
output_shapes,
std::move(output_dtypes),
std::make_shared<CustomKernel>(
s,
Expand All @@ -319,7 +319,8 @@ CustomKernelFunction metal_kernel(
init_value,
std::vector<ScalarArg>{},
false,
0),
0,
output_shapes),
std::move(inputs));
};
}
Expand Down
14 changes: 12 additions & 2 deletions mlx/fast_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ class CustomKernel : public Primitive {
std::optional<float> init_value,
std::vector<ScalarArg> scalar_arguments,
bool is_precompiled,
int shared_memory)
int shared_memory,
std::vector<Shape> output_shapes = {})
: Primitive(stream),
name_(std::move(name)),
source_(std::move(source)),
Expand All @@ -386,7 +387,8 @@ class CustomKernel : public Primitive {
init_value_(init_value),
scalar_arguments_(std::move(scalar_arguments)),
is_precompiled_(is_precompiled),
shared_memory_(shared_memory) {}
shared_memory_(shared_memory),
output_shapes_(std::move(output_shapes)) {}

void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
Expand All @@ -397,6 +399,13 @@ class CustomKernel : public Primitive {
override;

DEFINE_NAME(CustomKernel);

std::vector<Shape> output_shapes(const std::vector<array>&) override {
if (output_shapes_.empty())
return Primitive::output_shapes({});
return output_shapes_;
}

auto state() const {
return std::make_tuple(
name_,
Expand All @@ -422,6 +431,7 @@ class CustomKernel : public Primitive {
std::vector<ScalarArg> scalar_arguments_;
bool is_precompiled_;
int shared_memory_;
std::vector<Shape> output_shapes_;
};

} // namespace mlx::core::fast
15 changes: 15 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,21 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_GRADS()
DEFINE_NAME(GatherQMM)
bool is_equivalent(const Primitive& other) const override;

// inputs layout: Affine → {x, w, scales, biases, lhs_idx, rhs_idx}
// other → {x, w, scales, lhs_idx, rhs_idx}
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override {
const auto& x = inputs[0];
const auto& w = inputs[1];
const auto& lhs_idx =
(mode_ == QuantizationMode::Affine) ? inputs[4] : inputs[3];
int w_outer = transpose_ ? w.shape(-2) : w.shape(-1) * 32 / bits_;
auto out_shape = lhs_idx.shape();
out_shape.push_back(x.shape(-2));
out_shape.push_back(w_outer);
return {out_shape};
}

auto state() const {
return std::make_tuple(
group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_);
Expand Down
66 changes: 66 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,72 @@ def ones_fun(x):
self.assertEqual(compiled_zero_like(y).shape, y_shape)
self.assertEqual(compiled_ones_like(y).shape, y_shape)

def test_shapeless_compile_custom_kernel(self):
# CustomKernel must implement output_shapes() so shapeless compile can
# reuse the compiled graph without throwing "CustomKernel cannot infer
# output shapes". The kernel here has a fixed output shape (1,) that
# does not depend on the input shape, so output_shapes_ stays correct
# across calls with different input sizes.
if not mx.metal.is_available():
return

kernel = mx.fast.metal_kernel(
name="first_elem",
input_names=["inp"],
output_names=["out"],
source="if (thread_position_in_grid.x == 0) out[0] = inp[0];",
)

def fn(x):
return kernel(
inputs=[x],
grid=(1, 1, 1),
threadgroup=(1, 1, 1),
output_shapes=[(1,)],
output_dtypes=[x.dtype],
stream=mx.gpu,
)[0]

cfn = mx.compile(fn, shapeless=True)

x = mx.array([5.0, 6.0, 7.0, 8.0])
self.assertEqual(cfn(x).item(), 5.0)

# Different input shape — shapeless compile must reuse the graph without
# throwing and return the fixed output shape (1,) with the correct value.
x = mx.array([9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0])
result = cfn(x)
self.assertEqual(result.shape, (1,))
self.assertEqual(result.item(), 9.0)

def test_shapeless_compile_gather_qmm(self):
# GatherQMM must implement output_shapes() so shapeless compile can
# re-trace without throwing "GatherQMM cannot infer output shapes".
K, N, num_experts = 64, 32, 4

w = mx.random.normal((num_experts, N, K))
qw, s, b = mx.quantize(w)
mx.eval(qw, s, b)

# x has shape (num_experts, M, K): the batch dim is indexed by idx,
# which stays fixed so that lhs_indices and rhs_indices (auto-generated
# from w's batch shape) always broadcast. Only M changes between calls.
idx = mx.array([0, 1, 2, 3])
x4 = mx.ones((num_experts, 4, K))
x8 = mx.ones((num_experts, 8, K))

def fn(x):
return mx.gather_qmm(
x, qw, s, b, lhs_indices=idx, rhs_indices=idx, transpose=True
)

cfn = mx.compile(fn, shapeless=True)

self.assertEqual(cfn(x4).shape, fn(x4).shape)

# Different M — must reuse compiled graph without throwing.
self.assertEqual(cfn(x8).shape, fn(x8).shape)

def test_compile_with_constant(self):
# Test float
@partial(mx.compile)
Expand Down
Loading