From 74e6027a487ae9acf86e9e8e4c1e1c9dabc44ffd Mon Sep 17 00:00:00 2001 From: qflen Date: Mon, 11 May 2026 10:22:12 +0200 Subject: [PATCH 1/3] Detect int32 shape-product overflow at MLX compute-shape boundaries Issue #3327 reports that shapes whose per-dim values fit in int32 but whose product exceeds 2^31 silently produced wrapped results. `reshape(big, (-1,))` returned a negative inferred dim, `zeros((2^30, 2)).flatten()` returned shape (-2147483648,) and size 18446744071562067968, `take(big, ...)` failed via an internal flatten with the same wrap, and `conv_general` with output > 2^31 elements either requested an 18 EB allocation on M3 Max or silently wrote to wrapped offsets in the Metal kernel on M5 (`y[-1]` read back zeros). PR #3425 kept `ShapeElem = int32_t` and added a clear diagnostic at the Python binding for per-dim overflow. This patch extends the same approach to the internal C++ compute-shape boundaries that produce a Shape from int64 arithmetic, and to the Metal conv kernel offsets where the product of valid per-dim values silently wrapped. - mlx/utils.h: new `check_shape_dim(int64_t, op)` helper using PR #3425's error message format. - Compute-shape sites narrow through the helper: `Flatten` and `Reshape` `output_shape`, `unflatten` infer path, `indices_or_default` (accumulator widened to int64). Backend- agnostic -- applies to CPU, Metal, and CUDA. - mlx/backend/metal/conv.cpp: guard the four dispatcher sites where `int implicit_M = out.size() / O` truncates size_t or `int implicit_M = N * oS[0] * oS[1][*oS[2]]` wraps. Widen `inp_large` / `out_large` heuristics to int64 to remove signed-overflow UB on the dispatch predicate. - mlx/backend/metal/kernels/steel/conv/kernels/{steel_conv.h, steel_conv_3d.h, steel_conv_general.h}: promote per-thread output pointer arithmetic to size_t. With M < 2^31 but M * O > 2^31, `c_row * (N * groups) + c_col` overflowed even after the dispatcher accepted the shape -- last batches wrote to wrapped offsets. This is the substance of PRs #3294 / #3320, now exercised by an end-to-end test. - mlx/backend/cuda/conv/{gemm_conv,gemm_grouped_conv}.cu: same size_t->int truncation pattern as the Metal sites. Apply the identical guard. CUDA validation pending CI -- no toolchain on the authoring machine. Adds two regression tests in tests/gpu_tests.cpp. The kernel-offset test (varying per-batch input, allclose vs CPU reference) fails on `y[-1]` without the steel_conv_general.h patch -- verified by stash/restore. The shape-boundary test exercises each fix path; the eval branch is guarded by max_buffer_length so it skips on small-GPU devices. Closes #3327. Resolves the cross-dim overflow path that #3425 diagnosed but deferred (related #2681). --- mlx/backend/cuda/conv/gemm_conv.cu | 7 +- mlx/backend/cuda/conv/gemm_grouped_conv.cu | 8 +- mlx/backend/metal/conv.cpp | 36 ++++--- .../kernels/steel/conv/kernels/steel_conv.h | 2 +- .../steel/conv/kernels/steel_conv_3d.h | 2 +- .../steel/conv/kernels/steel_conv_general.h | 10 +- mlx/ops.cpp | 9 +- mlx/primitives.cpp | 7 +- mlx/utils.h | 20 ++++ tests/gpu_tests.cpp | 96 +++++++++++++++++++ 10 files changed, 170 insertions(+), 27 deletions(-) diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu index fff2445297..129c1d7f52 100644 --- a/mlx/backend/cuda/conv/gemm_conv.cu +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" +#include "mlx/utils.h" #include @@ -136,8 +137,10 @@ void gemm_conv_nd( ConvParams& params, Stream s) { // Get gemm shapes. - int mat_M = out.size() / params.O; // N * H_out * W_out - int mat_K = wt.size() / params.O; // C * H_wt * W_wt + int mat_M = check_shape_dim( + static_cast(out.size() / params.O), "conv"); // N * H_out * W_out + int mat_K = check_shape_dim( + static_cast(wt.size() / params.O), "conv"); // C * H_wt * W_wt int mat_N = params.O; // O // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu index f2688b3096..a445a4ea28 100644 --- a/mlx/backend/cuda/conv/gemm_grouped_conv.cu +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/gemms/cublas_gemm.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" +#include "mlx/utils.h" #include @@ -141,8 +142,11 @@ void gemm_grouped_conv_nd( // Get gemm shapes. int C_per_group = params.C / params.groups; int O_per_group = params.O / params.groups; - int mat_M = out.size() / params.O; // N * H_out * W_out - int mat_K = wt.size() / params.O; // C_per_group * H_wt * W_wt + int mat_M = check_shape_dim( + static_cast(out.size() / params.O), "conv"); // N * H_out * W_out + int mat_K = check_shape_dim( + static_cast(wt.size() / params.O), + "conv"); // C_per_group * H_wt * W_wt int mat_N = O_per_group; // O_per_group // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index 5d032779d3..f9a2e6c0de 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -39,8 +39,10 @@ void explicit_gemm_conv_ND_gpu( array& out, const MLXConvParams& conv_params) { // Get gemm shapes - int implicit_M = out.size() / conv_params.O; - int implicit_K = wt.size() / conv_params.O; + int implicit_M = + check_shape_dim(static_cast(out.size() / conv_params.O), "conv"); + int implicit_K = + check_shape_dim(static_cast(wt.size() / conv_params.O), "conv"); int implicit_N = conv_params.O; // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K}; @@ -113,8 +115,10 @@ void explicit_gemm_conv_group_ND_gpu( const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; // Get gemm shapes - const int implicit_M = out.size() / conv_params.O; - const int implicit_K = wt.size() / conv_params.O; + const int implicit_M = + check_shape_dim(static_cast(out.size() / conv_params.O), "conv"); + const int implicit_K = + check_shape_dim(static_cast(wt.size() / conv_params.O), "conv"); const int implicit_N = O_per_group; int kernel_size = 1; @@ -200,7 +204,10 @@ void implicit_gemm_conv_2D_gpu( const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size - const int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + const int implicit_M = check_shape_dim( + static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1], + "conv"); const int implicit_N = O_per_group; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * C_per_group; @@ -329,7 +336,10 @@ void implicit_gemm_conv_2D_general_gpu( array& out, const MLXConvParams<2>& conv_params) { // Deduce implicit gemm size - int implicit_M = conv_params.N * conv_params.oS[0] * conv_params.oS[1]; + int implicit_M = check_shape_dim( + static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1], + "conv"); int implicit_N = conv_params.O; int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.C; @@ -512,8 +522,10 @@ void implicit_gemm_conv_3D_gpu( const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size - const int implicit_M = - conv_params.N * conv_params.oS[0] * conv_params.oS[1] * conv_params.oS[2]; + const int implicit_M = check_shape_dim( + static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1] * conv_params.oS[2], + "conv"); const int implicit_N = O_per_group; const int implicit_K = conv_params.wS[0] * conv_params.wS[1] * conv_params.wS[2] * C_per_group; @@ -1001,11 +1013,11 @@ void dispatch_conv_2D_gpu( } // Direct to winograd conv - bool inp_large = - (conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 4096; + bool inp_large = (static_cast(conv_params.N) * conv_params.iS[0] * + conv_params.iS[1]) >= 4096; bool channels_large = (conv_params.C + conv_params.O) >= 256; - bool out_large = - (conv_params.N * conv_params.oS[0] * conv_params.oS[1]) >= 256; + bool out_large = (static_cast(conv_params.N) * conv_params.oS[0] * + conv_params.oS[1]) >= 256; if (!conv_params.flip && is_stride_one && is_kdil_one && is_idil_one && conv_params.wS[0] == 3 && conv_params.wS[1] == 3 && conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large && diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h index 850ec15be6..6bc78dd43e 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h @@ -135,7 +135,7 @@ implicit_gemm_conv_2d( C += tid.z * N; B += c_col * K; - C += c_row * (N * params->groups) + c_col; + C += static_cast(c_row) * size_t(N * params->groups) + size_t(c_col); const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h index d2fbac0fc7..0b2a0d68dc 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h @@ -94,7 +94,7 @@ implicit_gemm_conv_3d( C += tid.z * N; B += c_col * K; - C += c_row * (N * params->groups) + c_col; + C += static_cast(c_row) * size_t(N * params->groups) + size_t(c_col); const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 1241f77357..2af042ddd8 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -200,14 +200,18 @@ implicit_gemm_conv_2d_general( (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - int offset_cm = n * params->out_strides[0] + - oh * params->out_strides[1] + ow * params->out_strides[2]; + size_t offset_cm = static_cast(n) * + static_cast(params->out_strides[0]) + + static_cast(oh) * + static_cast(params->out_strides[1]) + + static_cast(ow) * + static_cast(params->out_strides[2]); STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { // Get accumulated result and associated offset in C thread const auto& accum = mma_op.Ctile.frag_at(i, j); - int offset = offset_cm + (j * mma_t::TN_stride); + size_t offset = offset_cm + (j * mma_t::TN_stride); constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index defcc2f6e0..4c182e37a6 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -67,8 +67,10 @@ array indices_or_default( } Shape shape(x.shape().begin(), x.shape().end() - 2); - int total = - std::reduce(shape.begin(), shape.end(), 1, std::multiplies()); + int total = check_shape_dim( + std::reduce( + shape.begin(), shape.end(), int64_t{1}, std::multiplies{}), + "gather"); return reshape(arange(total, uint32, s), std::move(shape), s); } @@ -433,7 +435,8 @@ array unflatten( } } if (infer_idx >= 0) { - shape[infer_idx] = a.shape(ax) / size; + shape[infer_idx] = + check_shape_dim(static_cast(a.shape(ax) / size), "unflatten"); size *= shape[infer_idx]; } if (size != a.shape(ax)) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index f3acec574b..42fda6bc5b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2136,12 +2136,12 @@ bool Flatten::is_equivalent(const Primitive& other) const { Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) { Shape shape = input.shape(); - auto flat_size = input.shape(start_axis); + int64_t flat_size = input.shape(start_axis); for (int ax = start_axis + 1; ax <= end_axis; ++ax) { flat_size *= input.shape(ax); } shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1); - shape[start_axis] = flat_size; + shape[start_axis] = check_shape_dim(flat_size, "flatten"); return shape; } @@ -3913,7 +3913,8 @@ Shape Reshape::output_shape(const array& input, Shape shape) { // Infer the shape if (size > 0 && infer_idx >= 0) { - shape[infer_idx] = input.size() / size; + shape[infer_idx] = + check_shape_dim(static_cast(input.size() / size), "reshape"); size *= shape[infer_idx]; } else if (infer_idx >= 0) { throw std::invalid_argument( diff --git a/mlx/utils.h b/mlx/utils.h index 7835a97028..4763272f51 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -3,6 +3,10 @@ #pragma once #include +#include +#include +#include +#include #include #include "mlx/api.h" @@ -96,6 +100,22 @@ MLX_API Dtype result_type(const std::vector& arrays); MLX_API Shape broadcast_shapes(const Shape& s1, const Shape& s2); +inline ShapeElem check_shape_dim(int64_t dim, std::string_view op = "") { + constexpr int64_t lo = std::numeric_limits::min(); + constexpr int64_t hi = std::numeric_limits::max(); + if (dim < lo || dim > hi) { + std::ostringstream msg; + if (!op.empty()) { + msg << "[" << op << "] "; + } + msg << "Shape dimension " << dim << " is outside the supported range [" + << lo << ", " << hi + << "]. MLX currently uses 32-bit integers for shape dimensions."; + throw std::overflow_error(msg.str()); + } + return static_cast(dim); +} + /** * Returns the axis normalized to be in the range [0, ndim). */ diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index 58cca348e5..1075ff5118 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -3,6 +3,7 @@ #include #include "doctest/doctest.h" +#include "mlx/backend/gpu/device_info.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -477,6 +478,101 @@ TEST_CASE("test gpu validation") { eval(scatter_max(array(1), {}, array(2), std::vector{})); } +TEST_CASE("test gpu int32 shape overflow errors") { + // (2^30, 2).flatten() — product 2^31 doesn't fit in ShapeElem. + // Issue #2681 reported wrapped shape (-2147483648,) and a + // 2^64 - X reported size. The lazy graph is never evaluated. + auto a = zeros({1 << 30, 2}); + CHECK_THROWS_AS(flatten(a), std::overflow_error); + + // conv_general output > 2^31 elements with each per-dim < 2^31. + // Total elements 524290 * 64 * 64 = 2,147,491,840. + int n = static_cast((int64_t{1} << 31) / (64 * 64) + 2); + auto x = ones({n, 8, 8, 1}, float16); + auto w = ones({1, 1, 1, 1}, float16); + auto y = conv_general( + /* input = */ x, + /* weight = */ w, + /* stride = */ {1, 1}, + /* padding_lo = */ {0, 0}, + /* padding_hi = */ {0, 0}, + /* kernel_dilation = */ {1, 1}, + /* input_dilation = */ {9, 9}, + /* groups = */ 1, + /* flip = */ false); + CHECK_EQ(y.shape(), Shape{n, 64, 64, 1}); + + // reshape with inferred dim that won't fit in ShapeElem — issue #3327. + CHECK_THROWS_AS(reshape(y, {-1}), std::overflow_error); + + // take(a, idx) routes through an internal flatten — overflows on flatten. + auto idx = array({0u}, uint32); + CHECK_THROWS_AS(take(y, idx), std::overflow_error); + + // The conv dispatcher refuses to compute a >2^31-element output. eval + // allocates the ~4 GB float16 output before the dispatcher check fires, + // so skip on small-GPU devices. + size_t needed = size_t(n) * 64 * 64 * sizeof(float16_t); + auto max_buf = std::get(gpu::device_info().at("max_buffer_length")); + if (max_buf >= needed) { + CHECK_THROWS_AS(eval(y), std::overflow_error); + } +} + +TEST_CASE("test gpu conv2d large output offset") { + // Regression for the kernel-offset half of #3327 (originally PR #3294). + // Output shape (batch, 64, 64, O) with batch * 64 * 64 * O > 2^31 but + // each per-dim and `batch * 64 * 64` fit in int32 — so the dispatcher + // accepts the work but each thread's output offset `c_row * O + c_col` + // exceeds int32 max. Before the size_t promotion in + // steel_conv_general.h, threads wrote to wrapped offsets and the last + // batches read back zeros. + constexpr int H = 64; + constexpr int W = 64; + constexpr int O = 17; + const int per_batch_output = H * W * O; + const int batch_size = + static_cast((int64_t{1} << 31) / per_batch_output + 2); + + // Skip if the output array (~4.3 GB fp16) won't fit on this device. + size_t needed = size_t(batch_size) * H * W * O * sizeof(float16_t); + auto max_buf = std::get(gpu::device_info().at("max_buffer_length")); + if (max_buf < needed) { + return; + } + + auto batch_values = + astype(remainder(arange(batch_size, int32), array(251)), float16); + batch_values = reshape(batch_values, {batch_size, 1, 1, 1}); + auto x = multiply(ones({batch_size, H, W, 1}, float16), batch_values); + auto channel_values = + divide(arange(1.0, double(O + 1), float16), array(8.0f, float16)); + auto w = reshape(channel_values, {O, 1, 1, 1}); + + auto y = conv2d(x, w); + + // Expected y[i, h, w, j] = (i % 251) * ((j+1)/8). Spot check first and + // last batches; the last batch covers offsets past int32 max. + auto expected_first = multiply( + slice(x, {0, 0, 0, 0}, {1, H, W, 1}), + reshape(channel_values, {1, 1, 1, O})); + auto expected_last = multiply( + slice(x, {batch_size - 1, 0, 0, 0}, {batch_size, H, W, 1}), + reshape(channel_values, {1, 1, 1, O})); + CHECK(allclose( + slice(y, {0, 0, 0, 0}, {1, H, W, O}), + expected_first, + /* rtol = */ 1e-3, + /* atol = */ 1e-3) + .item()); + CHECK(allclose( + slice(y, {batch_size - 1, 0, 0, 0}, {batch_size, H, W, O}), + expected_last, + /* rtol = */ 1e-3, + /* atol = */ 1e-3) + .item()); +} + TEST_CASE("test memory info") { // Test cache limits { From a5c5211329aaa38b7865d4189c73889fbdbe6092 Mon Sep 17 00:00:00 2001 From: qflen Date: Tue, 12 May 2026 03:57:57 +0200 Subject: [PATCH 2/3] Address review feedback --- mlx/backend/cuda/conv/gemm_conv.cu | 6 ++---- mlx/backend/cuda/conv/gemm_grouped_conv.cu | 8 +++----- mlx/backend/metal/conv.cpp | 18 +++++++----------- .../kernels/steel/conv/kernels/steel_conv.h | 2 +- .../kernels/steel/conv/kernels/steel_conv_3d.h | 2 +- .../steel/conv/kernels/steel_conv_general.h | 8 ++------ mlx/ops.cpp | 5 ++--- mlx/primitives.cpp | 5 ++--- mlx/utils.h | 12 +++++++----- 9 files changed, 27 insertions(+), 39 deletions(-) diff --git a/mlx/backend/cuda/conv/gemm_conv.cu b/mlx/backend/cuda/conv/gemm_conv.cu index 129c1d7f52..6fc9528289 100644 --- a/mlx/backend/cuda/conv/gemm_conv.cu +++ b/mlx/backend/cuda/conv/gemm_conv.cu @@ -137,10 +137,8 @@ void gemm_conv_nd( ConvParams& params, Stream s) { // Get gemm shapes. - int mat_M = check_shape_dim( - static_cast(out.size() / params.O), "conv"); // N * H_out * W_out - int mat_K = check_shape_dim( - static_cast(wt.size() / params.O), "conv"); // C * H_wt * W_wt + int mat_M = safe_cast(out.size() / params.O, "conv"); // N * H_out * W_out + int mat_K = safe_cast(wt.size() / params.O, "conv"); // C * H_wt * W_wt int mat_N = params.O; // O // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. diff --git a/mlx/backend/cuda/conv/gemm_grouped_conv.cu b/mlx/backend/cuda/conv/gemm_grouped_conv.cu index a445a4ea28..4060d744d7 100644 --- a/mlx/backend/cuda/conv/gemm_grouped_conv.cu +++ b/mlx/backend/cuda/conv/gemm_grouped_conv.cu @@ -142,11 +142,9 @@ void gemm_grouped_conv_nd( // Get gemm shapes. int C_per_group = params.C / params.groups; int O_per_group = params.O / params.groups; - int mat_M = check_shape_dim( - static_cast(out.size() / params.O), "conv"); // N * H_out * W_out - int mat_K = check_shape_dim( - static_cast(wt.size() / params.O), - "conv"); // C_per_group * H_wt * W_wt + int mat_M = safe_cast(out.size() / params.O, "conv"); // N * H_out * W_out + int mat_K = + safe_cast(wt.size() / params.O, "conv"); // C_per_group * H_wt * W_wt int mat_N = O_per_group; // O_per_group // Unfold input to (N * H_out * W_out, C * H_wt * W_wt) for gemm. diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index f9a2e6c0de..ce6f718448 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -39,10 +39,8 @@ void explicit_gemm_conv_ND_gpu( array& out, const MLXConvParams& conv_params) { // Get gemm shapes - int implicit_M = - check_shape_dim(static_cast(out.size() / conv_params.O), "conv"); - int implicit_K = - check_shape_dim(static_cast(wt.size() / conv_params.O), "conv"); + int implicit_M = safe_cast(out.size() / conv_params.O, "conv"); + int implicit_K = safe_cast(wt.size() / conv_params.O, "conv"); int implicit_N = conv_params.O; // Prepare unfolding array Shape unfolded_shape{implicit_M, implicit_K}; @@ -115,10 +113,8 @@ void explicit_gemm_conv_group_ND_gpu( const int C_per_group = conv_params.C / conv_params.groups; const int O_per_group = conv_params.O / conv_params.groups; // Get gemm shapes - const int implicit_M = - check_shape_dim(static_cast(out.size() / conv_params.O), "conv"); - const int implicit_K = - check_shape_dim(static_cast(wt.size() / conv_params.O), "conv"); + const int implicit_M = safe_cast(out.size() / conv_params.O, "conv"); + const int implicit_K = safe_cast(wt.size() / conv_params.O, "conv"); const int implicit_N = O_per_group; int kernel_size = 1; @@ -204,7 +200,7 @@ void implicit_gemm_conv_2D_gpu( const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size - const int implicit_M = check_shape_dim( + const int implicit_M = safe_cast( static_cast(conv_params.N) * conv_params.oS[0] * conv_params.oS[1], "conv"); @@ -336,7 +332,7 @@ void implicit_gemm_conv_2D_general_gpu( array& out, const MLXConvParams<2>& conv_params) { // Deduce implicit gemm size - int implicit_M = check_shape_dim( + int implicit_M = safe_cast( static_cast(conv_params.N) * conv_params.oS[0] * conv_params.oS[1], "conv"); @@ -522,7 +518,7 @@ void implicit_gemm_conv_3D_gpu( const int O_per_group = conv_params.O / conv_params.groups; // Deduce implicit gemm size - const int implicit_M = check_shape_dim( + const int implicit_M = safe_cast( static_cast(conv_params.N) * conv_params.oS[0] * conv_params.oS[1] * conv_params.oS[2], "conv"); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h index 6bc78dd43e..f559596b73 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h @@ -135,7 +135,7 @@ implicit_gemm_conv_2d( C += tid.z * N; B += c_col * K; - C += static_cast(c_row) * size_t(N * params->groups) + size_t(c_col); + C += static_cast(c_row) * N * params->groups + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h index 0b2a0d68dc..f2ccc1c03d 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_3d.h @@ -94,7 +94,7 @@ implicit_gemm_conv_3d( C += tid.z * N; B += c_col * K; - C += static_cast(c_row) * size_t(N * params->groups) + size_t(c_col); + C += static_cast(c_row) * N * params->groups + c_col; const int2 offsets_a(0, c_row); const int2 offsets_b(0, c_col); diff --git a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h index 2af042ddd8..38250f9b81 100644 --- a/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +++ b/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -200,12 +200,8 @@ implicit_gemm_conv_2d_general( (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - size_t offset_cm = static_cast(n) * - static_cast(params->out_strides[0]) + - static_cast(oh) * - static_cast(params->out_strides[1]) + - static_cast(ow) * - static_cast(params->out_strides[2]); + size_t offset_cm = static_cast(n) * params->out_strides[0] + + oh * params->out_strides[1] + ow * params->out_strides[2]; STEEL_PRAGMA_UNROLL for (int j = 0; j < mma_t::TN; j++) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4c182e37a6..6ad41e2e38 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -67,7 +67,7 @@ array indices_or_default( } Shape shape(x.shape().begin(), x.shape().end() - 2); - int total = check_shape_dim( + int total = safe_cast( std::reduce( shape.begin(), shape.end(), int64_t{1}, std::multiplies{}), "gather"); @@ -435,8 +435,7 @@ array unflatten( } } if (infer_idx >= 0) { - shape[infer_idx] = - check_shape_dim(static_cast(a.shape(ax) / size), "unflatten"); + shape[infer_idx] = safe_cast(a.shape(ax) / size, "unflatten"); size *= shape[infer_idx]; } if (size != a.shape(ax)) { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 42fda6bc5b..62460a3d1b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -2141,7 +2141,7 @@ Shape Flatten::output_shape(const array& input, int start_axis, int end_axis) { flat_size *= input.shape(ax); } shape.erase(shape.begin() + start_axis + 1, shape.begin() + end_axis + 1); - shape[start_axis] = check_shape_dim(flat_size, "flatten"); + shape[start_axis] = safe_cast(flat_size, "flatten"); return shape; } @@ -3913,8 +3913,7 @@ Shape Reshape::output_shape(const array& input, Shape shape) { // Infer the shape if (size > 0 && infer_idx >= 0) { - shape[infer_idx] = - check_shape_dim(static_cast(input.size() / size), "reshape"); + shape[infer_idx] = safe_cast(input.size() / size, "reshape"); size *= shape[infer_idx]; } else if (infer_idx >= 0) { throw std::invalid_argument( diff --git a/mlx/utils.h b/mlx/utils.h index 4763272f51..d8b4c7ac99 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -100,20 +100,22 @@ MLX_API Dtype result_type(const std::vector& arrays); MLX_API Shape broadcast_shapes(const Shape& s1, const Shape& s2); -inline ShapeElem check_shape_dim(int64_t dim, std::string_view op = "") { +template +inline ShapeElem safe_cast(T dim, std::string_view op = "") { constexpr int64_t lo = std::numeric_limits::min(); constexpr int64_t hi = std::numeric_limits::max(); - if (dim < lo || dim > hi) { + auto v = static_cast(dim); + if (v < lo || v > hi) { std::ostringstream msg; if (!op.empty()) { msg << "[" << op << "] "; } - msg << "Shape dimension " << dim << " is outside the supported range [" - << lo << ", " << hi + msg << "Shape dimension " << v << " is outside the supported range [" << lo + << ", " << hi << "]. MLX currently uses 32-bit integers for shape dimensions."; throw std::overflow_error(msg.str()); } - return static_cast(dim); + return static_cast(v); } /** From 1bb7bb6b26ee4f6380b7f5e6ea637444f923e971 Mon Sep 17 00:00:00 2001 From: qflen Date: Tue, 12 May 2026 04:19:29 +0200 Subject: [PATCH 3/3] Drop 4GB allocation tests --- tests/gpu_tests.cpp | 64 --------------------------------------------- 1 file changed, 64 deletions(-) diff --git a/tests/gpu_tests.cpp b/tests/gpu_tests.cpp index 1075ff5118..52b5a3f3a6 100644 --- a/tests/gpu_tests.cpp +++ b/tests/gpu_tests.cpp @@ -3,7 +3,6 @@ #include #include "doctest/doctest.h" -#include "mlx/backend/gpu/device_info.h" #include "mlx/mlx.h" using namespace mlx::core; @@ -508,69 +507,6 @@ TEST_CASE("test gpu int32 shape overflow errors") { // take(a, idx) routes through an internal flatten — overflows on flatten. auto idx = array({0u}, uint32); CHECK_THROWS_AS(take(y, idx), std::overflow_error); - - // The conv dispatcher refuses to compute a >2^31-element output. eval - // allocates the ~4 GB float16 output before the dispatcher check fires, - // so skip on small-GPU devices. - size_t needed = size_t(n) * 64 * 64 * sizeof(float16_t); - auto max_buf = std::get(gpu::device_info().at("max_buffer_length")); - if (max_buf >= needed) { - CHECK_THROWS_AS(eval(y), std::overflow_error); - } -} - -TEST_CASE("test gpu conv2d large output offset") { - // Regression for the kernel-offset half of #3327 (originally PR #3294). - // Output shape (batch, 64, 64, O) with batch * 64 * 64 * O > 2^31 but - // each per-dim and `batch * 64 * 64` fit in int32 — so the dispatcher - // accepts the work but each thread's output offset `c_row * O + c_col` - // exceeds int32 max. Before the size_t promotion in - // steel_conv_general.h, threads wrote to wrapped offsets and the last - // batches read back zeros. - constexpr int H = 64; - constexpr int W = 64; - constexpr int O = 17; - const int per_batch_output = H * W * O; - const int batch_size = - static_cast((int64_t{1} << 31) / per_batch_output + 2); - - // Skip if the output array (~4.3 GB fp16) won't fit on this device. - size_t needed = size_t(batch_size) * H * W * O * sizeof(float16_t); - auto max_buf = std::get(gpu::device_info().at("max_buffer_length")); - if (max_buf < needed) { - return; - } - - auto batch_values = - astype(remainder(arange(batch_size, int32), array(251)), float16); - batch_values = reshape(batch_values, {batch_size, 1, 1, 1}); - auto x = multiply(ones({batch_size, H, W, 1}, float16), batch_values); - auto channel_values = - divide(arange(1.0, double(O + 1), float16), array(8.0f, float16)); - auto w = reshape(channel_values, {O, 1, 1, 1}); - - auto y = conv2d(x, w); - - // Expected y[i, h, w, j] = (i % 251) * ((j+1)/8). Spot check first and - // last batches; the last batch covers offsets past int32 max. - auto expected_first = multiply( - slice(x, {0, 0, 0, 0}, {1, H, W, 1}), - reshape(channel_values, {1, 1, 1, O})); - auto expected_last = multiply( - slice(x, {batch_size - 1, 0, 0, 0}, {batch_size, H, W, 1}), - reshape(channel_values, {1, 1, 1, O})); - CHECK(allclose( - slice(y, {0, 0, 0, 0}, {1, H, W, O}), - expected_first, - /* rtol = */ 1e-3, - /* atol = */ 1e-3) - .item()); - CHECK(allclose( - slice(y, {batch_size - 1, 0, 0, 0}, {batch_size, H, W, O}), - expected_last, - /* rtol = */ 1e-3, - /* atol = */ 1e-3) - .item()); } TEST_CASE("test memory info") {