From d59c1b4040630aa4ad27fd3e220a82700e2b5ce4 Mon Sep 17 00:00:00 2001 From: Alex Skryl Date: Wed, 18 Feb 2026 19:18:03 -0600 Subject: [PATCH] performance improvements --- README.md | 12 +- examples/benchmark/cnn_example.rb | 32 +- examples/benchmark/python/cnn_example.py | 3 +- examples/benchmark/rnn_example.rb | 10 +- examples/benchmark/transformer_example.rb | 13 +- lib/mlx/core.rb | 142 +- lib/mlx/nn/layers/activations.rb | 48 +- lib/mlx/nn/layers/convolution.rb | 38 +- lib/mlx/nn/layers/dropout.rb | 6 +- lib/mlx/nn/layers/linear.rb | 33 +- lib/mlx/nn/layers/normalization.rb | 4 +- lib/mlx/nn/layers/pooling.rb | 36 +- lib/mlx/nn/layers/recurrent.rb | 112 +- lib/mlx/nn/layers/transformer.rb | 124 +- rfp/2026_02_13_dsl_implementation.md | 1592 +++++++++++++++++ ..._benchmark_performance_remediation_plan.md | 142 ++ tasks/benchmark_task.rb | 9 +- .../phase190_activations_parity_test.rb | 27 + .../phase195_pooling_layers_parity_test.rb | 16 + 19 files changed, 2234 insertions(+), 165 deletions(-) create mode 100644 rfp/2026_02_13_dsl_implementation.md create mode 100644 rfp/2026_02_17_benchmark_performance_remediation_plan.md diff --git a/README.md b/README.md index 606ca9cf..b76ec846 100644 --- a/README.md +++ b/README.md @@ -788,18 +788,18 @@ MLX Ruby has full Metal support through the upstream MLX runtime. On Apple silic The table below is from: ```bash -bundle exec rake benchmark WARMUP=50 ITERATIONS=1000 +BENCHMARK_DEVICES=gpu bundle exec rake benchmark WARMUP=50 ITERATIONS=5000 ``` Ratios are shown per column (`py_cpu/gpu`, `rb_cpu/gpu`, `rb/py_cpu`, `rb/py_gpu`). Parity columns come from harness checks (`input_shape`, `input_digest`, `output_shape`, `reference_output_digest`). | model | py_cpu_s | py_gpu_s | py_cpu/gpu | rb_cpu_s | rb_gpu_s | rb_cpu/gpu | rb/py_cpu | rb/py_gpu | in_shape (cpu/gpu) | in_content (cpu/gpu) | out_shape (cpu/gpu) | out_content (cpu/gpu) | | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | :---: | :---: | :---: | :---: | -| transformer | 0.034 | 0.008 | 4.48x | 0.029 | 0.008 | 3.67x | 0.84x | 1.03x | ✓/✓ | ✓/✓ | ✓/✓ | ✓/✓ | -| cnn | 0.004 | 0.000 | 9.54x | 0.004 | 0.001 | 4.68x | 1.05x | 2.13x | ✓/✓ | ✓/✓ | ✓/✓ | ✓/✓ | -| mlp | 0.000 | 0.000 | 1.66x | 0.000 | 0.000 | 1.28x | 0.98x | 1.27x | ✓/✓ | ✓/✓ | ✓/✓ | ✓/✓ | -| rnn | 0.006 | 0.004 | 1.39x | 0.007 | 0.007 | 0.91x | 1.16x | 1.79x | ✓/✓ | ✓/✓ | ✓/✓ | ✓/✓ | -| karpathy_gpt2 | 0.058 | 0.011 | 5.06x | 0.060 | 0.016 | 3.84x | 1.03x | 1.36x | ✓/✓ | ✓/✓ | ✓/✓ | ✓/✓ | +| transformer | n/a | 0.006 | n/a | n/a | 0.006 | n/a | n/a | 1.11x | -/✓ | -/✓ | -/✓ | -/✓ | +| cnn | n/a | 0.001 | n/a | n/a | 0.001 | n/a | n/a | 1.15x | -/✓ | -/✓ | -/✓ | -/✓ | +| mlp | n/a | 0.000 | n/a | n/a | 0.000 | n/a | n/a | 1.09x | -/✓ | -/✓ | -/✓ | -/✓ | +| rnn | n/a | 0.004 | n/a | n/a | 0.006 | n/a | n/a | 1.62x | -/✓ | -/✓ | -/✓ | -/✓ | +| karpathy_gpt2 | n/a | 0.013 | n/a | n/a | 0.015 | n/a | n/a | 1.17x | -/✓ | -/✓ | -/✓ | -/✓ | ### Build docs diff --git a/examples/benchmark/cnn_example.rb b/examples/benchmark/cnn_example.rb index 6a1f95a9..b1ce64ba 100644 --- a/examples/benchmark/cnn_example.rb +++ b/examples/benchmark/cnn_example.rb @@ -21,11 +21,28 @@ def initialize(batch_size:, dtype:) @conv1 = MLX::NN::Conv2d.new(CNN_CHANNELS, 16, 3, stride: 1, padding: 1, bias: true) @conv2 = MLX::NN::Conv2d.new(16, 32, 3, stride: 1, padding: 1, bias: true) - @relu = MLX::NN::ReLU.new @pool = MLX::NN::MaxPool2d.new(2, stride: 2, padding: 0) @linear = MLX::NN::Linear.new(@flattened_features, CNN_CLASSES) BenchmarkDigest.assign_deterministic_parameters!([@conv1, @conv2, @linear]) + conv1 = @conv1 + conv2 = @conv2 + pool = @pool + linear = @linear + input = @input + batch_size = @batch_size + flattened_features = @flattened_features + @run_step = lambda do + y = conv1.call(input) + y = MLX::NN.relu(y) + y = pool.call(y) + y = conv2.call(y) + y = MLX::NN.relu(y) + y = pool.call(y) + y = MLX::Core.reshape(y, [batch_size, flattened_features]) + linear.call(y) + end + @input_shape = @input.shape @input_digest = BenchmarkDigest.digest_array(@input) @reference_output_digest = BenchmarkDigest.digest_array(run_step) @@ -33,14 +50,11 @@ def initialize(batch_size:, dtype:) end def run_step - y = @conv1.call(@input) - y = @relu.call(y) - y = @pool.call(y) - y = @conv2.call(y) - y = @relu.call(y) - y = @pool.call(y) - y = MLX::Core.reshape(y, [@batch_size, @flattened_features]) - @linear.call(y) + @run_step.call + end + + def run_step_proc + @run_step end def verification_input_digest diff --git a/examples/benchmark/python/cnn_example.py b/examples/benchmark/python/cnn_example.py index 0d550d94..eef9093f 100644 --- a/examples/benchmark/python/cnn_example.py +++ b/examples/benchmark/python/cnn_example.py @@ -6,7 +6,7 @@ from benchmark_digest import assign_deterministic_parameters from benchmark_digest import deterministic_tensor from benchmark_digest import digest_array -from mlx.nn.layers.activations import ReLU +from mlx.nn.layers.activations import relu from mlx.nn.layers.convolution import Conv2d from mlx.nn.layers.linear import Linear from mlx.nn.layers.pooling import MaxPool2d @@ -48,7 +48,6 @@ def main(): conv1 = Conv2d(CNN_CHANNELS, 16, 3, stride=1, padding=1) conv2 = Conv2d(16, 32, 3, stride=1, padding=1) - relu = ReLU() pool = MaxPool2d(2, stride=2) linear = Linear(flattened, CNN_CLASSES) assign_deterministic_parameters([conv1, conv2, linear]) diff --git a/examples/benchmark/rnn_example.rb b/examples/benchmark/rnn_example.rb index 477148fe..fafeaa5c 100644 --- a/examples/benchmark/rnn_example.rb +++ b/examples/benchmark/rnn_example.rb @@ -18,6 +18,10 @@ def initialize(batch_size:, sequence_length:, dims:, dtype:) @rnn = MLX::NN::RNN.new(dims, hidden_size) BenchmarkDigest.assign_deterministic_parameters!(@rnn) + rnn = @rnn + input = @input + @run_step = lambda { rnn.call(input) } + @input_shape = @input.shape @input_digest = BenchmarkDigest.digest_array(@input) @reference_output_digest = BenchmarkDigest.digest_array(run_step) @@ -25,7 +29,11 @@ def initialize(batch_size:, sequence_length:, dims:, dtype:) end def run_step - @rnn.call(@input) + @run_step.call + end + + def run_step_proc + @run_step end def verification_input_digest diff --git a/examples/benchmark/transformer_example.rb b/examples/benchmark/transformer_example.rb index 612579b2..c8a3962d 100644 --- a/examples/benchmark/transformer_example.rb +++ b/examples/benchmark/transformer_example.rb @@ -36,12 +36,23 @@ def initialize(batch_size:, sequence_length:, target_sequence_length:, dims:, nu ) BenchmarkDigest.assign_deterministic_parameters!(@model) + model = @model + src = @src + tgt = @tgt + src_mask = @src_mask + tgt_mask = @tgt_mask + @run_step = lambda { model.call(src, tgt, src_mask, tgt_mask, nil) } + @reference_output_digest = BenchmarkDigest.digest_array(run_step) @path_signature = "forward_only_eval_output" end def run_step - @model.call(@src, @tgt, @src_mask, @tgt_mask, nil) + @run_step.call + end + + def run_step_proc + @run_step end def verification_input_digest diff --git a/lib/mlx/core.rb b/lib/mlx/core.rb index 2668695e..bd2f4745 100644 --- a/lib/mlx/core.rb +++ b/lib/mlx/core.rb @@ -419,39 +419,53 @@ def compile(fun, inputs = nil, outputs = nil, shapeless = false) flat_inputs = [] input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false) key = structure_cache_key(input_spec) + rebuilt_once = false - entry = cache[key] - unless entry - output_spec = nil - lifted = lambda do |*flat_vars| - rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0) - unless cursor == flat_vars.length - raise RuntimeError, "internal input reconstruction mismatch" + begin + entry = cache[key] + unless valid_transform_cache_entry?(entry) + cache.delete(key) + entry = nil + end + unless entry + output_spec = nil + lifted = lambda do |*flat_vars| + rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0) + unless cursor == flat_vars.length + raise RuntimeError, "internal input reconstruction mismatch" + end + + call_args = rebuilt[0] + call_kwargs = rebuilt[1] + raw_output = fun.call(*call_args, **call_kwargs) + + flat_output = [] + output_spec = flatten_tree_spec(raw_output, flat_output, false) + flat_output end - call_args = rebuilt[0] - call_kwargs = rebuilt[1] - raw_output = fun.call(*call_args, **call_kwargs) - - flat_output = [] - output_spec = flatten_tree_spec(raw_output, flat_output, false) - flat_output + compiled = native_compile(lifted, inputs, outputs, shapeless) + entry = { fn: compiled, output_spec: -> { output_spec } } + cache[key] = entry end - compiled = native_compile(lifted, inputs, outputs, shapeless) - entry = { fn: compiled, output_spec: -> { output_spec } } - cache[key] = entry - end - - flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "compiled output") - spec = entry[:output_spec].call - raise RuntimeError, "missing output structure from compiled function" if spec.nil? + flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "compiled output") + spec = entry[:output_spec].call + raise RuntimeError, "missing output structure from compiled function" if spec.nil? - rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0) - unless cursor == flat_output.length - raise RuntimeError, "internal output reconstruction mismatch" + rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0) + unless cursor == flat_output.length + raise RuntimeError, "internal output reconstruction mismatch" + end + rebuilt + rescue RuntimeError => e + if !rebuilt_once && invalid_transform_callable_error?(e) + rebuilt_once = true + cache.delete(key) + retry + end + raise end - rebuilt end end @@ -463,39 +477,53 @@ def checkpoint(fun) flat_inputs = [] input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false) key = structure_cache_key(input_spec) + rebuilt_once = false - entry = cache[key] - unless entry - output_spec = nil - lifted = lambda do |*flat_vars| - rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0) - unless cursor == flat_vars.length - raise RuntimeError, "internal input reconstruction mismatch" + begin + entry = cache[key] + unless valid_transform_cache_entry?(entry) + cache.delete(key) + entry = nil + end + unless entry + output_spec = nil + lifted = lambda do |*flat_vars| + rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0) + unless cursor == flat_vars.length + raise RuntimeError, "internal input reconstruction mismatch" + end + + call_args = rebuilt[0] + call_kwargs = rebuilt[1] + raw_output = fun.call(*call_args, **call_kwargs) + + flat_output = [] + output_spec = flatten_tree_spec(raw_output, flat_output, false) + flat_output end - call_args = rebuilt[0] - call_kwargs = rebuilt[1] - raw_output = fun.call(*call_args, **call_kwargs) - - flat_output = [] - output_spec = flatten_tree_spec(raw_output, flat_output, false) - flat_output + checkpointed = native_checkpoint(lifted) + entry = { fn: checkpointed, output_spec: -> { output_spec } } + cache[key] = entry end - checkpointed = native_checkpoint(lifted) - entry = { fn: checkpointed, output_spec: -> { output_spec } } - cache[key] = entry - end - - flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "checkpoint output") - spec = entry[:output_spec].call - raise RuntimeError, "missing output structure from checkpoint function" if spec.nil? + flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "checkpoint output") + spec = entry[:output_spec].call + raise RuntimeError, "missing output structure from checkpoint function" if spec.nil? - rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0) - unless cursor == flat_output.length - raise RuntimeError, "internal output reconstruction mismatch" + rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0) + unless cursor == flat_output.length + raise RuntimeError, "internal output reconstruction mismatch" + end + rebuilt + rescue RuntimeError => e + if !rebuilt_once && invalid_transform_callable_error?(e) + rebuilt_once = true + cache.delete(key) + retry + end + raise end - rebuilt end end @@ -877,6 +905,16 @@ def normalize_raw_grads(raw) normalize_array_sequence(raw, "gradient") end + def valid_transform_cache_entry?(entry) + entry.is_a?(::Hash) && + entry[:fn].respond_to?(:call) && + entry[:output_spec].respond_to?(:call) + end + + def invalid_transform_callable_error?(error) + error.message.match?(/undefined method [`']call['"]/) + end + def normalize_array_sequence(raw, context) return [raw] if raw.is_a?(MLX::Core::Array) diff --git a/lib/mlx/nn/layers/activations.rb b/lib/mlx/nn/layers/activations.rb index c473d7b2..cbc25cf4 100644 --- a/lib/mlx/nn/layers/activations.rb +++ b/lib/mlx/nn/layers/activations.rb @@ -4,11 +4,31 @@ module MLX module NN class << self def sigmoid(x) - MLX::Core.sigmoid(x) + compiled = compiled_sigmoid + return MLX::Core.sigmoid(x) unless compiled.respond_to?(:call) + + compiled.call(x) + rescue RuntimeError => e + if invalid_compiled_activation_call?(e) + @compiled_sigmoid = false + MLX::Core.sigmoid(x) + else + raise + end end def relu(x) - MLX::Core.maximum(x, 0.0) + compiled = compiled_relu + return MLX::Core.maximum(x, 0.0) unless compiled.respond_to?(:call) + + compiled.call(x) + rescue RuntimeError => e + if invalid_compiled_activation_call?(e) + @compiled_relu = false + MLX::Core.maximum(x, 0.0) + else + raise + end end def relu2(x) @@ -130,6 +150,30 @@ def softmin(x, axis: -1) def tanh(x) MLX::Core.tanh(x) end + + private + + def compiled_sigmoid + return nil if @compiled_sigmoid == false + + @compiled_sigmoid ||= compile_activation_unary(->(v) { MLX::Core.sigmoid(v) }) + end + + def compiled_relu + return nil if @compiled_relu == false + + @compiled_relu ||= compile_activation_unary(->(v) { MLX::Core.maximum(v, 0.0) }) + end + + def compile_activation_unary(fun) + MLX::Core.compile(fun, nil, nil, true) + rescue StandardError + nil + end + + def invalid_compiled_activation_call?(error) + error.message.match?(/undefined method [`']call['"]/) + end end class GLU < Module diff --git a/lib/mlx/nn/layers/convolution.rb b/lib/mlx/nn/layers/convolution.rb index c2a7659e..79045aed 100644 --- a/lib/mlx/nn/layers/convolution.rb +++ b/lib/mlx/nn/layers/convolution.rb @@ -31,8 +31,10 @@ def initialize( end def call(x) - y = MLX::Core.conv1d(x, weight, @stride, @padding, @dilation, @groups) - state.key?("bias") ? MLX::Core.add(y, bias) : y + weight_param = @state["weight"] + bias_param = @state["bias"] + y = MLX::Core.conv1d(x, weight_param, @stride, @padding, @dilation, @groups) + bias_param.nil? ? y : (y + bias_param) end end @@ -65,12 +67,17 @@ def initialize( @stride = stride @padding = padding @dilation = dilation + @stride_arg = collapse_uniform(@stride) + @padding_arg = collapse_uniform(@padding) + @dilation_arg = collapse_uniform(@dilation) @groups = groups end def call(x) - y = MLX::Core.conv2d(x, weight, @stride, @padding, @dilation, @groups) - state.key?("bias") ? MLX::Core.add(y, bias) : y + weight_param = @state["weight"] + bias_param = @state["bias"] + y = MLX::Core.conv2d(x, weight_param, @stride_arg, @padding_arg, @dilation_arg, @groups) + bias_param.nil? ? y : (y + bias_param) end private @@ -78,6 +85,13 @@ def call(x) def pair(value) value.is_a?(Integer) ? [value, value] : value end + + def collapse_uniform(value) + return value unless value.is_a?(Array) + return value if value.empty? + + value.all? { |entry| entry == value[0] } ? value[0] : value + end end class Conv3d < Module @@ -103,11 +117,16 @@ def initialize( @stride = stride @padding = padding @dilation = dilation + @stride_arg = collapse_uniform(@stride) + @padding_arg = collapse_uniform(@padding) + @dilation_arg = collapse_uniform(@dilation) end def call(x) - y = MLX::Core.conv3d(x, weight, @stride, @padding, @dilation) - state.key?("bias") ? MLX::Core.add(y, bias) : y + weight_param = @state["weight"] + bias_param = @state["bias"] + y = MLX::Core.conv3d(x, weight_param, @stride_arg, @padding_arg, @dilation_arg) + bias_param.nil? ? y : (y + bias_param) end private @@ -115,6 +134,13 @@ def call(x) def triple(value) value.is_a?(Integer) ? [value, value, value] : value end + + def collapse_uniform(value) + return value unless value.is_a?(Array) + return value if value.empty? + + value.all? { |entry| entry == value[0] } ? value[0] : value + end end end end diff --git a/lib/mlx/nn/layers/dropout.rb b/lib/mlx/nn/layers/dropout.rb index c3723bab..b798a5fa 100644 --- a/lib/mlx/nn/layers/dropout.rb +++ b/lib/mlx/nn/layers/dropout.rb @@ -13,7 +13,7 @@ def initialize(p = 0.5) end def call(x) - return x if @p_keep == 1.0 || !training + return x if @p_keep == 1.0 || !@training mask = MLX::Core.bernoulli(@p_keep, x.shape) MLX::Core.multiply(MLX::Core.multiply(mask, x), 1.0 / @p_keep) @@ -35,7 +35,7 @@ def call(x) raise ArgumentError, "Received input with #{x.ndim} dimensions. Expected 3 or 4 dimensions." end - return x if @p_keep == 1.0 || !training + return x if @p_keep == 1.0 || !@training mask_shape = x.shape.dup mask_shape[-2] = 1 @@ -60,7 +60,7 @@ def call(x) raise ArgumentError, "Received input with #{x.ndim} dimensions. Expected 4 or 5 dimensions." end - return x if @p_keep == 1.0 || !training + return x if @p_keep == 1.0 || !@training mask_shape = x.shape.dup mask_shape[-2] = 1 diff --git a/lib/mlx/nn/layers/linear.rb b/lib/mlx/nn/layers/linear.rb index 8e3a9d26..89461a13 100644 --- a/lib/mlx/nn/layers/linear.rb +++ b/lib/mlx/nn/layers/linear.rb @@ -20,16 +20,35 @@ def initialize(input_dims, output_dims, bias: true) scale = Math.sqrt(1.0 / input_dims) self.weight = MLX::Core.uniform([output_dims, input_dims], -scale, scale) self.bias = MLX::Core.uniform([output_dims], -scale, scale) if bias + @has_bias = bias + @cached_weight_t = nil + @cached_weight_id = nil end def call(x) - if state.key?("bias") - MLX::Core.addmm(bias, x, weight.T) + weight_param = @state["weight"] + weight_t = cached_weight_transpose(weight_param) + if @has_bias + bias_param = @state["bias"] + MLX::Core.addmm(bias_param, x, weight_t) else - MLX::Core.matmul(x, weight.T) + MLX::Core.matmul(x, weight_t) end end + private + + def cached_weight_transpose(weight_param) + weight_id = weight_param.object_id + if @cached_weight_t.nil? || @cached_weight_id != weight_id + @cached_weight_t = weight_param.T + @cached_weight_id = weight_id + end + @cached_weight_t + end + + public + def to_quantized(group_size: nil, bits: nil, mode: "affine", quantize_input: false) if quantize_input unless %w[nvfp4 mxfp8].include?(mode.to_s) @@ -53,14 +72,16 @@ def initialize(input1_dims, input2_dims, output_dims, bias: true) end def call(x1, x2) - out_dims, in2_dims, in1_dims = weight.shape + weight_param = @state["weight"] + bias_param = @state["bias"] + out_dims, in2_dims, in1_dims = weight_param.shape x_shape = x1.shape[0...-1] batch = x1.size / in1_dims x1_2d = MLX::Core.reshape(x1, [batch, in1_dims]) x2_3d = MLX::Core.reshape(x2, [batch, 1, in2_dims]) - w = MLX::Core.reshape(weight, [out_dims * in2_dims, in1_dims]) + w = MLX::Core.reshape(weight_param, [out_dims * in2_dims, in1_dims]) y = MLX::Core.matmul(x1_2d, w.T) y = MLX::Core.reshape(y, [batch, out_dims, in2_dims]) y = MLX::Core.swapaxes(y, -2, -1) @@ -69,7 +90,7 @@ def call(x1, x2) out_shape = x_shape.empty? ? [out_dims] : x_shape + [out_dims] y = MLX::Core.reshape(y, out_shape) - y = MLX::Core.add(y, bias) if state.key?("bias") + y = y + bias_param unless bias_param.nil? y end end diff --git a/lib/mlx/nn/layers/normalization.rb b/lib/mlx/nn/layers/normalization.rb index 0a7b8f6a..1e329799 100644 --- a/lib/mlx/nn/layers/normalization.rb +++ b/lib/mlx/nn/layers/normalization.rb @@ -58,8 +58,8 @@ def initialize(dims, eps: 1e-5, affine: true, bias: true) end def call(x) - w = state.key?("weight") ? weight : nil - b = state.key?("bias") ? bias : nil + w = @state["weight"] + b = @state["bias"] MLX::Core.layer_norm(x, w, b, @eps) end end diff --git a/lib/mlx/nn/layers/pooling.rb b/lib/mlx/nn/layers/pooling.rb index 61565a32..1eb4349d 100644 --- a/lib/mlx/nn/layers/pooling.rb +++ b/lib/mlx/nn/layers/pooling.rb @@ -10,6 +10,7 @@ def initialize(pooling_symbol, kernel_size, stride, padding, padding_value) @stride = stride @padding = padding @padding_value = padding_value + @axes = (-@kernel_size.length - 1...-1).to_a end def call(x) @@ -37,15 +38,14 @@ def value_or_list(value, n, message) end def reduce_windows(windows) + if @pooling_symbol == :max + return MLX::Core.max(windows, @axes) + end + result = windows - window_dims = @kernel_size.length - window_dims.times do + @kernel_size.length.times do axis = result.ndim - 2 - result = if @pooling_symbol == :max - MLX::Core.max(result, axis) - else - MLX::Core.mean(result, axis) - end + result = MLX::Core.mean(result, axis) end result end @@ -63,6 +63,10 @@ def sliding_windows(x, window_shape, window_strides) end shape = x.shape + if spatial_dims.zip(window_shape, window_strides).all? { |size, window, stride| window == stride && (size % window).zero? } + return non_overlapping_sliding_windows(x, shape, window_shape) + end + strides = Array.new(shape.length) running = 1 (shape.length - 1).downto(0) do |i| @@ -88,6 +92,24 @@ def sliding_windows(x, window_shape, window_strides) MLX::Core.as_strided(x, final_shape, final_strides) end + + def non_overlapping_sliding_windows(x, shape, window_shape) + new_shape = [shape[0]] + shape[1...-1].zip(window_shape).each do |size, window| + new_shape << (size / window) + new_shape << window + end + new_shape << shape[-1] + + last_axis = new_shape.length - 1 + axis_order = [0] + axis_order.concat((1...last_axis).step(2).to_a) + axis_order.concat((2...last_axis).step(2).to_a) + axis_order << last_axis + + reshaped = MLX::Core.reshape(x, new_shape) + MLX::Core.transpose(reshaped, axis_order) + end end class Pool1dBase < PoolBase diff --git a/lib/mlx/nn/layers/recurrent.rb b/lib/mlx/nn/layers/recurrent.rb index f9c86514..cc0a013f 100644 --- a/lib/mlx/nn/layers/recurrent.rb +++ b/lib/mlx/nn/layers/recurrent.rb @@ -7,6 +7,8 @@ def initialize(input_size, hidden_size, bias: true, nonlinearity: nil) super() @nonlinearity = nonlinearity || lambda { |z| MLX::NN.tanh(z) } + @fast_tanh = nonlinearity.nil? + @compiled_hidden_step = build_compiled_hidden_step if @fast_tanh unless @nonlinearity.respond_to?(:call) raise ArgumentError, "Nonlinearity must be callable. Current value: #{nonlinearity}." end @@ -19,20 +21,61 @@ def initialize(input_size, hidden_size, bias: true, nonlinearity: nil) end def call(x, hidden = nil) - x = MLX::Core.matmul(x, self.Wxh.T) - x = MLX::Core.add(x, self.bias) unless self.bias.nil? + wxh = @state["Wxh"] + whh = @state["Whh"] + bias = @state["bias"] + wxh_t = wxh.T + whh_t = whh.T + + x = if bias.nil? + MLX::Core.matmul(x, wxh_t) + else + MLX::Core.addmm(bias, x, wxh_t) + end - all_hidden = [] sequence_axis = x.ndim - 2 - x.shape[sequence_axis].times do |idx| + sequence_length = x.shape[sequence_axis] + all_hidden = Array.new(sequence_length) + idx = 0 + + while idx < sequence_length step = MLX::Core.take(x, idx, sequence_axis) - step = MLX::Core.add(step, MLX::Core.matmul(hidden, self.Whh.T)) unless hidden.nil? - hidden = @nonlinearity.call(step) - all_hidden << hidden + hidden = if hidden.nil? + if @fast_tanh + MLX::Core.tanh(step) + else + @nonlinearity.call(step) + end + elsif @fast_tanh + compiled_hidden_step(step, hidden, whh_t) + else + @nonlinearity.call(MLX::Core.addmm(step, hidden, whh_t)) + end + all_hidden[idx] = hidden + idx += 1 end MLX::Core.stack(all_hidden, -2) end + + private + + def build_compiled_hidden_step + MLX::Core.compile(lambda do |step, hidden, whh_t| + MLX::Core.tanh(MLX::Core.addmm(step, hidden, whh_t)) + end) + rescue StandardError + nil + end + + def compiled_hidden_step(step, hidden, whh_t) + return MLX::Core.tanh(MLX::Core.addmm(step, hidden, whh_t)) if @compiled_hidden_step.nil? + + @compiled_hidden_step.call(step, hidden, whh_t) + rescue StandardError + @compiled_hidden_step = nil + MLX::Core.tanh(MLX::Core.addmm(step, hidden, whh_t)) + end end class GRU < Module @@ -48,21 +91,33 @@ def initialize(input_size, hidden_size, bias: true) end def call(x, hidden = nil) - x = MLX::Core.matmul(x, self.Wx.T) - x = MLX::Core.add(x, self.b) unless self.b.nil? + wx = @state["Wx"] + wh = @state["Wh"] + b = @state["b"] + bhn = @state["bhn"] + wx_t = wx.T + wh_t = wh.T + + x = if b.nil? + MLX::Core.matmul(x, wx_t) + else + MLX::Core.addmm(b, x, wx_t) + end x_rz, x_n = MLX::Core.split(x, [2 * @hidden_size], x.ndim - 1) - all_hidden = [] sequence_axis = x.ndim - 2 + sequence_length = x.shape[sequence_axis] + all_hidden = Array.new(sequence_length) + idx = 0 - x.shape[sequence_axis].times do |idx| + while idx < sequence_length rz = MLX::Core.take(x_rz, idx, sequence_axis) h_proj_n = nil unless hidden.nil? - h_proj = MLX::Core.matmul(hidden, self.Wh.T) + h_proj = MLX::Core.matmul(hidden, wh_t) h_proj_rz, h_proj_n = MLX::Core.split(h_proj, [2 * @hidden_size], h_proj.ndim - 1) - h_proj_n = MLX::Core.add(h_proj_n, self.bhn) unless self.bhn.nil? + h_proj_n = MLX::Core.add(h_proj_n, bhn) unless bhn.nil? rz = MLX::Core.add(rz, h_proj_rz) end @@ -81,7 +136,8 @@ def call(x, hidden = nil) ) end - all_hidden << hidden + all_hidden[idx] = hidden + idx += 1 end MLX::Core.stack(all_hidden, -2) @@ -100,16 +156,27 @@ def initialize(input_size, hidden_size, bias: true) end def call(x, hidden = nil, cell = nil) - x = MLX::Core.matmul(x, self.Wx.T) - x = MLX::Core.add(x, self.bias) unless self.bias.nil? + wx = @state["Wx"] + wh = @state["Wh"] + bias = @state["bias"] + wx_t = wx.T + wh_t = wh.T + + x = if bias.nil? + MLX::Core.matmul(x, wx_t) + else + MLX::Core.addmm(bias, x, wx_t) + end - all_hidden = [] - all_cell = [] sequence_axis = x.ndim - 2 + sequence_length = x.shape[sequence_axis] + all_hidden = Array.new(sequence_length) + all_cell = Array.new(sequence_length) + idx = 0 - x.shape[sequence_axis].times do |idx| + while idx < sequence_length ifgo = MLX::Core.take(x, idx, sequence_axis) - ifgo = MLX::Core.add(ifgo, MLX::Core.matmul(hidden, self.Wh.T)) unless hidden.nil? + ifgo = MLX::Core.addmm(ifgo, hidden, wh_t) unless hidden.nil? i, f, g, o = MLX::Core.split(ifgo, 4, ifgo.ndim - 1) i = MLX::Core.sigmoid(i) @@ -124,8 +191,9 @@ def call(x, hidden = nil, cell = nil) end hidden = MLX::Core.multiply(o, MLX::Core.tanh(cell)) - all_cell << cell - all_hidden << hidden + all_cell[idx] = cell + all_hidden[idx] = hidden + idx += 1 end [MLX::Core.stack(all_hidden, -2), MLX::Core.stack(all_cell, -2)] diff --git a/lib/mlx/nn/layers/transformer.rb b/lib/mlx/nn/layers/transformer.rb index 93a3a955..ce6580b5 100644 --- a/lib/mlx/nn/layers/transformer.rb +++ b/lib/mlx/nn/layers/transformer.rb @@ -29,6 +29,10 @@ def initialize( value_output_dims ||= dims @num_heads = num_heads + @head_dims = dims / num_heads + @scale = Math.sqrt(1.0 / @head_dims) + @heads_layout = [0, 2, 1, 3].freeze + @split_shape = [@num_heads, @head_dims].freeze self.query_proj = Linear.new(query_input_dims, dims, bias: bias) self.key_proj = Linear.new(key_input_dims, dims, bias: bias) self.value_proj = Linear.new(value_input_dims, value_dims, bias: bias) @@ -36,9 +40,17 @@ def initialize( end def call(queries, keys, values, mask = nil) - queries, q_was_2d = maybe_batch(queries) - keys, = maybe_batch(keys) - values, = maybe_batch(values) + q_was_2d = queries.ndim == 2 + if q_was_2d + queries, = maybe_batch(queries) + keys, = maybe_batch(keys) + values, = maybe_batch(values) + end + + query_proj = @state["query_proj"] + key_proj = @state["key_proj"] + value_proj = @state["value_proj"] + out_proj = @state["out_proj"] queries = query_proj.call(queries) keys = key_proj.call(keys) @@ -48,10 +60,9 @@ def call(queries, keys, values, mask = nil) keys = split_heads(keys) values = split_heads(values) - scale = Math.sqrt(1.0 / queries.shape[-1]) - output = MLX::Core.scaled_dot_product_attention(queries, keys, values, scale, mask) - output = MLX::Core.transpose(output, [0, 2, 1, 3]) - output = output.flatten(-2, -1) + output = MLX::Core.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = MLX::Core.transpose(output, @heads_layout) + output = MLX::Core.flatten(output, -2, -1) output = out_proj.call(output) q_was_2d ? MLX::Core.squeeze(output, 0) : output end @@ -67,10 +78,8 @@ def self.create_additive_causal_mask(n, dtype = MLX::Core.float32) private def split_heads(x) - batch, length, dims = x.shape - head_dim = dims / @num_heads - x = MLX::Core.reshape(x, [batch, length, @num_heads, head_dim]) - MLX::Core.transpose(x, [0, 2, 1, 3]) + x = MLX::Core.unflatten(x, -1, @split_shape) + MLX::Core.transpose(x, @heads_layout) end def maybe_batch(x) @@ -93,7 +102,6 @@ def initialize( ) super() mlp_dims ||= dims * 4 - activation ||= lambda { |x| MLX::NN.relu(x) } self.attention = MultiHeadAttention.new(dims, num_heads) self.ln1 = LayerNorm.new(dims) @@ -103,32 +111,42 @@ def initialize( self.dropout1 = Dropout.new(dropout) self.dropout2 = Dropout.new(dropout) @activation = activation + @dropout1_identity = dropout == 0.0 + @dropout2_identity = dropout == 0.0 @norm_first = norm_first end def call(x, mask) + attention = @state["attention"] + ln1 = @state["ln1"] + ln2 = @state["ln2"] + linear1 = @state["linear1"] + linear2 = @state["linear2"] + dropout1 = @dropout1_identity ? nil : @state["dropout1"] + dropout2 = @dropout2_identity ? nil : @state["dropout2"] + if @norm_first y = ln1.call(x) y = attention.call(y, y, y, mask) - y = dropout1.call(y) - x = MLX::Core.add(x, y) + y = dropout1.call(y) unless dropout1.nil? + x = x + y y = ln2.call(x) y = linear1.call(y) - y = @activation.call(y) - y = dropout2.call(y) + y = @activation.nil? ? MLX::Core.maximum(y, 0.0) : @activation.call(y) + y = dropout2.call(y) unless dropout2.nil? y = linear2.call(y) - y = MLX::Core.add(x, y) + y = x + y else y = attention.call(x, x, x, mask) - y = dropout1.call(y) - x = ln1.call(MLX::Core.add(x, y)) + y = dropout1.call(y) unless dropout1.nil? + x = ln1.call(x + y) y = linear1.call(x) - y = @activation.call(y) - y = dropout2.call(y) + y = @activation.nil? ? MLX::Core.maximum(y, 0.0) : @activation.call(y) + y = dropout2.call(y) unless dropout2.nil? y = linear2.call(y) - y = ln2.call(MLX::Core.add(x, y)) + y = ln2.call(x + y) end y @@ -147,7 +165,6 @@ def initialize( checkpoint: false ) super() - activation ||= lambda { |x| MLX::NN.relu(x) } self.layers = Array.new(num_layers) do TransformerEncoderLayer.new( dims, @@ -168,8 +185,11 @@ def initialize( end def call(x, mask) - @layer_fns.each do |layer_fn| - x = layer_fn.call(x, mask) + ln = @state["ln"] + idx = 0 + while idx < @layer_fns.length + x = @layer_fns[idx].call(x, mask) + idx += 1 end ln.call(x) end @@ -186,7 +206,6 @@ def initialize( ) super() mlp_dims ||= dims * 4 - activation ||= lambda { |x| MLX::NN.relu(x) } self.self_attention = MultiHeadAttention.new(dims, num_heads) self.cross_attention = MultiHeadAttention.new(dims, num_heads) @@ -199,41 +218,55 @@ def initialize( self.dropout2 = Dropout.new(dropout) self.dropout3 = Dropout.new(dropout) @activation = activation + @dropout1_identity = dropout == 0.0 + @dropout2_identity = dropout == 0.0 + @dropout3_identity = dropout == 0.0 @norm_first = norm_first end def call(x, memory, x_mask, memory_mask) + self_attention = @state["self_attention"] + cross_attention = @state["cross_attention"] + ln1 = @state["ln1"] + ln2 = @state["ln2"] + ln3 = @state["ln3"] + linear1 = @state["linear1"] + linear2 = @state["linear2"] + dropout1 = @dropout1_identity ? nil : @state["dropout1"] + dropout2 = @dropout2_identity ? nil : @state["dropout2"] + dropout3 = @dropout3_identity ? nil : @state["dropout3"] + if @norm_first y = ln1.call(x) y = self_attention.call(y, y, y, x_mask) - y = dropout1.call(y) - x = MLX::Core.add(x, y) + y = dropout1.call(y) unless dropout1.nil? + x = x + y y = ln2.call(x) y = cross_attention.call(y, memory, memory, memory_mask) - y = dropout2.call(y) - x = MLX::Core.add(x, y) + y = dropout2.call(y) unless dropout2.nil? + x = x + y y = ln3.call(x) y = linear1.call(y) - y = @activation.call(y) - y = dropout3.call(y) + y = @activation.nil? ? MLX::Core.maximum(y, 0.0) : @activation.call(y) + y = dropout3.call(y) unless dropout3.nil? y = linear2.call(y) - y = MLX::Core.add(x, y) + y = x + y else y = self_attention.call(x, x, x, x_mask) - y = dropout1.call(y) - x = ln1.call(MLX::Core.add(x, y)) + y = dropout1.call(y) unless dropout1.nil? + x = ln1.call(x + y) y = cross_attention.call(y, memory, memory, memory_mask) - y = dropout2.call(y) - x = ln2.call(MLX::Core.add(x, y)) + y = dropout2.call(y) unless dropout2.nil? + x = ln2.call(x + y) y = linear1.call(x) - y = @activation.call(y) - y = dropout3.call(y) + y = @activation.nil? ? MLX::Core.maximum(y, 0.0) : @activation.call(y) + y = dropout3.call(y) unless dropout3.nil? y = linear2.call(y) - y = ln3.call(MLX::Core.add(x, y)) + y = ln3.call(x + y) end y @@ -252,7 +285,6 @@ def initialize( checkpoint: false ) super() - activation ||= lambda { |x| MLX::NN.relu(x) } self.layers = Array.new(num_layers) do TransformerDecoderLayer.new( dims, @@ -275,8 +307,11 @@ def initialize( end def call(x, memory, x_mask, memory_mask) - @layer_fns.each do |layer_fn| - x = layer_fn.call(x, memory, x_mask, memory_mask) + ln = @state["ln"] + idx = 0 + while idx < @layer_fns.length + x = @layer_fns[idx].call(x, memory, x_mask, memory_mask) + idx += 1 end ln.call(x) end @@ -298,7 +333,6 @@ def initialize( ) super() - activation ||= lambda { |x| MLX::NN.relu(x) } self.encoder = custom_encoder || TransformerEncoder.new( num_encoder_layers, dims, @@ -322,6 +356,8 @@ def initialize( end def call(src, tgt, src_mask, tgt_mask, memory_mask) + encoder = @state["encoder"] + decoder = @state["decoder"] memory = encoder.call(src, src_mask) decoder.call(tgt, memory, tgt_mask, memory_mask) end diff --git a/rfp/2026_02_13_dsl_implementation.md b/rfp/2026_02_13_dsl_implementation.md new file mode 100644 index 00000000..3f2d93cc --- /dev/null +++ b/rfp/2026_02_13_dsl_implementation.md @@ -0,0 +1,1592 @@ +# Ruby DSL Phased Implementation Plan + +## Goal + +Add a Ruby-native DSL for MLX bindings under `lib/mlx/dsl` that preserves compatibility with existing `MLX::Core`, `MLX::NN`, and `MLX::Optimizers` behavior while improving ergonomics for model definition and training. + +This plan is intentionally separate from parity-report artifacts. The DSL is a Ruby-only extension and should not be counted in Python package-parity file checks. + +## Design Principles + +1. Keep existing APIs stable. +2. Build DSL features as sugar over existing primitives. +3. Keep all DSL state interoperable with `MLX::NN::Module#parameters`, `#trainable_parameters`, `#update`, and optimizer flows. +4. Implement with red/green testing: + - Red: add failing tests for new behavior. + - Green: implement minimal code to pass. + - Refactor: tighten internals without changing behavior. + +## Phase 1: Foundation (`lib/mlx/dsl`) + +### Deliverables + +1. Add: + - `lib/mlx/dsl.rb` + - `lib/mlx/dsl/model.rb` + - `lib/mlx/dsl/model_mixin.rb` + - `lib/mlx/dsl/builder.rb` + - `lib/mlx/dsl/train_step.rb` +2. Load DSL from `lib/mlx.rb`. +3. Add initial tests in `test/dsl/dsl_test.rb`. + +### Ergonomics Target + +```ruby +class Classifier < MLX::DSL::Model + layer :net do + sequential do + linear 784, 256 + relu + dropout 0.1 + linear 256, 10 + end + end + + def call(x) = net.call(x) +end +``` + +## Phase 2: Model Declaration DSL + +### Deliverables + +1. `option` class macro with required/default behavior. +2. `layer` and `network` macros for submodule declaration. +3. `param` and `buffer` macros for array declaration and initialization. +4. Ensure declarations write through `self.=` so module state tracking remains consistent. + +### Ergonomics Target + +```ruby +class Affine < MLX::DSL::Model + option :in_dim + option :out_dim + option :use_bias, default: true + + param :weight, shape: -> { [out_dim, in_dim] }, init: ->(shape) { MLX::Core.normal(shape, 0.0, 0.02) } + buffer :scale, shape: -> { [out_dim] }, init: ->(shape) { MLX::Core.ones(shape, MLX::Core.float32) } + + def call(x) + y = MLX::Core.matmul(x, weight.T) + y = MLX::Core.add(y, bias) if use_bias + MLX::Core.multiply(y, scale) + end +end +``` + +## Phase 3: Mixin Support for Existing Modules + +### Deliverables + +1. `MLX::DSL::ModelMixin` for classes that already inherit `MLX::NN::Module`. +2. Shared declaration behavior between `Model` and `ModelMixin`. + +### Ergonomics Target + +```ruby +class ResidualBlock < MLX::NN::Module + include MLX::DSL::ModelMixin + + option :dims, default: 256 + + layer(:proj) { linear dims, dims, bias: false } + layer(:norm) { layer_norm dims } + + def call(x) + MLX::Core.add(x, norm.call(proj.call(x))) + end +end +``` + +## Phase 4: Builder/Composition Ergonomics + +### Deliverables + +1. Builder methods for common layers: + - `linear`, `relu`, `dropout`, `conv2d`, `layer_norm`, etc. +2. `sequential` block collector. +3. Baseline branch-friendly structure for later graph helpers. + +### Ergonomics Target + +```ruby +layer :encoder do + sequential do + linear 512, 512 + relu + dropout 0.1 + linear 512, 256 + end +end +``` + +## Phase 5: Training Ergonomics + +### Deliverables + +1. `train_step` helper that wraps: + - `MLX::NN.value_and_grad` + - optional `MLX::Optimizers.clip_grad_norm` + - `optimizer.update` +2. Hook support for training lifecycle events: + - `:before_step` + - `:after_backward` + - `:after_step` +3. Mode helpers: + - `train_mode { ... }` + - `eval_mode { ... }` + +### Ergonomics Target + +```ruby +step = model.train_step(optimizer: optimizer, clip_grad_norm: 1.0) do |x:, y:| + logits = model.call(x) + MLX::NN.cross_entropy(logits, y, reduction: "mean") +end + +step.on(:after_step) do |ctx| + puts "step=#{ctx[:step]} loss=#{ctx[:loss].item}" +end + +loss = step.call(x: batch_x, y: batch_y) +``` + +## Phase 6: Optimizer Group Ergonomics + +### Deliverables + +1. `optimizer_groups` builder on model instances. +2. `group(matcher)` declarations producing `MLX::Optimizers::MultiOptimizer`. +3. Path matcher support (`Regexp`, `String`, `Proc`). + +### Ergonomics Target + +```ruby +opt = model.optimizer_groups do + group(/^encoder\./) { MLX::Optimizers::AdamW.new(learning_rate: 1e-4) } + group(nil) { MLX::Optimizers::SGD.new(learning_rate: 5e-3) } +end +``` + +## Phase 7: Parameter Selection Helpers + +### Deliverables + +1. `freeze_paths!(matcher)` and `unfreeze_paths!(matcher)` helpers. +2. Match by full flattened parameter path. + +### Ergonomics Target + +```ruby +model.freeze_paths!(/^encoder\./) +model.unfreeze_paths!(/^head\./) +``` + +## Phase 8: Parity Tooling Compatibility + +### Deliverables + +1. Update `test/parity/scripts/generate_package_inventory.rb` ignore list so: + - `lib/mlx/dsl.rb` + - `lib/mlx/dsl/**/*.rb` + are excluded from package file parity diffs. +2. Keep existing parity checks unchanged for Python-equivalent surfaces. + +## Phase 9: Documentation and Examples + +### Deliverables + +1. Add DSL section to `README.md`. +2. Provide concise end-to-end examples: + - MLP classification model + - Mix-in based module + - `train_step` with hooks + - optimizer groups + +## Phase 10: V2 Extensions (Post-V1) + +### Deliverables + +1. Graph helpers: + - `residual` + - `branch` + - `concat` +2. Checkpoint helpers for model + optimizer state. +3. Lightweight trainer wrapper. + +## Phase 11: Trainer UX Enhancements + +### Deliverables + +1. Hook shorthand methods (in addition to `on`): + - `TrainStep`: `before_step`, `after_backward`, `after_step` + - `Trainer`: `before_fit`, `before_epoch`, `after_batch`, `after_epoch`, `checkpoint`, `after_fit` +2. Custom monitor metrics for fit reports: + - `monitor:` label in report output + - `metric:` callable receiving epoch context +3. Checkpoint path templating: + - `%{epoch}`, `%{monitor}`, `%{monitor_name}`, `%{epoch_loss}`, `%{improved}` +4. Early stopping controls: + - `patience:` + - `min_delta:` + - report fields `epochs_ran` and `stopped_early` + +### Ergonomics Target + +```ruby +trainer = model.trainer(optimizer: optimizer) do |x:, y:| + logits = model.call(x) + MLX::NN.cross_entropy(logits, y, reduction: "mean") +end + +trainer.before_epoch { |ctx| puts "epoch=#{ctx[:epoch]}" } +trainer.after_batch { |ctx| puts "batch=#{ctx[:batch_index]} loss=#{ctx[:loss_value]}" } + +report = trainer.fit_report( + dataset, + epochs: 20, + monitor: :peak_loss, + monitor_mode: :max, + metric: ->(ctx) { ctx.fetch(:epoch_losses).max }, + checkpoint_path: "ckpts/ep-%{epoch}-m-%{monitor}.bin", + save_best: true, + patience: 2, + min_delta: 1e-4 +) +``` + +## Phase 12: Data Ergonomics + +### Deliverables + +1. Streaming-friendly training datasets: + - accept `Enumerable` without forcing `to_a` + - accept dataset factories (`Proc`) that return per-epoch enumerables +2. Optional validation dataset loop: + - `validation_data:` and `validation_reduce:` + - include `val_loss` and `validation_batches` in epoch reports +3. Native monitoring for validation: + - `monitor: :val_loss` without requiring a custom metric proc + +### Ergonomics Target + +```ruby +train_data = ->(epoch:) { shuffled_batches_for(epoch) } +val_data = ->(epoch:) { heldout_batches_for(epoch) } + +report = trainer.fit_report( + train_data, + epochs: 10, + reduce: :mean, + validation_data: val_data, + validation_reduce: :mean, + monitor: :val_loss, + monitor_mode: :min, + save_best: true, + checkpoint_path: "checkpoints/epoch-%{epoch}-val-%{monitor}.bin" +) +``` + +## Phase 13: Enumerable and Batch Pipeline Ergonomics + +### Deliverables + +1. Safer multi-epoch behavior for single-pass datasets: + - `strict_data_reuse:` option to detect exhausted non-rewindable datasets across epochs + - clear error pointing users to dataset factories +2. Batch transforms/collation hooks: + - `train_transform:` + - `validation_transform:` + - supports transforms that receive `batch`, `epoch`, `batch_index`, `kind`, and `trainer` +3. Hardened dataset factory signatures: + - support positional, keyword, and mixed signatures (e.g. `->(epoch, kind:)`) + - explicit errors when required parameters are unsupported +4. Optional per-batch loss retention: + - `keep_losses:` control for long-running jobs + - preserves epoch-level reporting while avoiding unbounded `losses` arrays + +### Ergonomics Target + +```ruby +train_data = ->(epoch, kind:) { stream_train_batches(epoch, kind: kind) } +val_data = ->(epoch:) { stream_validation_batches(epoch) } + +report = trainer.fit_report( + train_data, + epochs: 20, + strict_data_reuse: true, + train_transform: ->(batch, epoch:, batch_index:) { collate_train(batch, epoch, batch_index) }, + validation_data: val_data, + validation_transform: ->(batch, epoch:) { collate_val(batch, epoch) }, + monitor: :val_loss, + keep_losses: false +) +``` + +## Phase 14: Validation Lifecycle Hooks + +### Deliverables + +1. Add validation hook events on `MLX::DSL::Trainer`: + - `before_validation` + - `after_validation_batch` + - `after_validation` +2. Expose hook shorthand methods matching the above events. +3. Include validation hook context fields: + - `epoch` + - `batch_index` (for per-batch hook) + - `loss` and `loss_value` + - reduced `val_loss` for `after_validation` + +### Ergonomics Target + +```ruby +trainer.before_validation { |ctx| puts "epoch=#{ctx[:epoch]} val:start" } +trainer.after_validation_batch { |ctx| puts "val_batch=#{ctx[:batch_index]} loss=#{ctx[:loss_value]}" } +trainer.after_validation { |ctx| puts "val_loss=#{ctx[:val_loss]}" } +``` + +## Phase 15: Native Integration Coverage For Data Ergonomics + +### Deliverables + +1. Add integration tests to `test/dsl/dsl_test.rb` for: + - `train_transform` + - `validation_transform` + - `strict_data_reuse` + - `keep_losses: false` +2. Ensure new assertions run against real `MLX::Core::Array` and optimizer flows. +3. Keep tests deterministic and fast (small toy datasets). + +## Phase 16: Runnable DSL Examples + +### Deliverables + +1. Add `examples/dsl/` scripts for: + - streaming per-epoch dataset factory + - validation monitoring (`monitor: :val_loss`) + - long-running memory-friendly reporting (`keep_losses: false`) +2. Keep examples executable via `bundle exec ruby examples/dsl/.rb`. +3. Reference examples from `README.md`. + +## Phase 17: No-Native Load Resilience + +### Deliverables + +1. Remove eager native-dependent defaults in DSL class macros where possible. +2. Ensure requiring DSL files does not raise when native extension is unavailable. +3. Prefer lazy dtype/default resolution at runtime instead of class definition time. + +## Phase 18: Test Helper Rebuild Policy + +### Deliverables + +1. Reduce unnecessary forced rebuilds in `test/test_helper.rb`. +2. When a loadable native bundle already exists, avoid rebuild attempts that require unavailable source trees. +3. Preserve explicit rebuild behavior when `MLX_RUBY_FORCE_REBUILD=1`. + +## Red/Green Execution Plan + +1. Add tests for each DSL behavior in `test/dsl/dsl_test.rb` (red). +2. Implement minimum code in `lib/mlx/dsl/**` to pass tests (green). +3. Refactor and tighten internals while maintaining passing tests. +4. Run targeted suite first, then broader suite. + +## Ergonomics Execution Track (Phases 19+) + +This track focuses on reducing friction when building and running models with Ruby-native DSL behavior and dynamic data pipelines. + +### Phase 19: DSL Declaration Safety and Error Quality + +#### Problem + +The DSL currently accepts some invalid declarations and surfaces unclear initialization errors for unknown options. + +#### Deliverables + +1. Reject non-module results from `layer` / `network` declarations. +2. Raise explicit unknown option errors for declared DSL options instead of falling through to generic Ruby argument errors. +3. Keep compatibility for classes that intentionally handle extra kwargs in their own initializer. + +#### Red (tests first) + +1. Add failing test for unknown DSL option handling in `test/dsl/dsl_test.rb`. +2. Add failing test proving layer declarations must materialize `MLX::NN::Module`. + +#### Green (minimum implementation) + +1. Tighten option extraction / initializer validation in `lib/mlx/dsl/model_mixin.rb`. +2. Validate layer declaration materialization type in `lib/mlx/dsl/model_mixin.rb`. + +#### Exit Criteria + +1. New tests pass. +2. Existing DSL declaration tests remain green. + +### Phase 20: Dataset Factory Invocation Reliability + +#### Problem + +Dataset factories with zero-arity blocks are mis-invoked, and non-rewind errors can be mislabeled as rewind failures. + +#### Deliverables + +1. Correct dataset factory invocation for: + - `-> { ... }` + - `->(epoch) { ... }` + - `->(epoch:) { ... }` + - mixed signatures +2. Limit rewind-specific error wrapping to rewind failures only. + +#### Red (tests first) + +1. Add failing trainer unit test for zero-arity dataset factory support. + +#### Green (minimum implementation) + +1. Fix factory invocation branch in `lib/mlx/dsl/trainer.rb`. +2. Narrow error handling in dataset/rewind path in `lib/mlx/dsl/trainer.rb`. + +#### Exit Criteria + +1. New trainer tests pass. +2. Existing dataset factory and strict-data-reuse tests stay green. + +### Phase 21: Variadic Composition Through Sequential + +#### Problem + +`MLX::NN::Sequential` currently forces a single positional input, reducing composability for multi-arg modules and Ruby-style dynamic wiring. + +#### Deliverables + +1. Support first-layer invocation with `*args, **kwargs`. +2. Support forwarding intermediary payloads as: + - positional args when payload is `Array` + - kwargs when payload is `Hash` + - single positional argument otherwise +3. Preserve existing single-tensor flow behavior. + +#### Red (tests first) + +1. Add failing unit test in `test/dsl/dsl_graph_unit_test.rb` covering multi-arg + kwargs + array forwarding across sequential layers. + +#### Green (minimum implementation) + +1. Update `MLX::NN::Sequential#call` in `lib/mlx/nn/layers/containers.rb`. + +#### Exit Criteria + +1. New composition test passes. +2. Existing graph/builder tests remain green. + +### Phase 22: Compile-Aware Training Step + +#### Deliverables + +1. Add optional `compile:` mode for `train_step` and trainer internals. +2. Add explicit sync policy controls (`:none`, `:step`, `:epoch`) to avoid implicit runtime behavior. +3. Add focused parity/integration tests for compile + hooks + checkpoint interactions. + +### Phase 23: Checkpoint Format Interoperability + +#### Deliverables + +1. Introduce native checkpoint formats (`.npz` / `.safetensors`) for model parameters. +2. Preserve metadata and optional optimizer state with versioned schema. +3. Keep current marshal path as compatibility fallback during migration window. + +### Phase 24: Builder Surface + Batch Collation Ergonomics + +#### Deliverables + +1. Expand DSL builder coverage to match practical `MLX::NN` usage (recurrent/transformer/transpose-conv paths). +2. Add first-class batch collation helpers for common train/eval datasets. +3. Add end-to-end examples demonstrating lower boilerplate and clearer training loops. + +### Phase 25: Docs + Example Refresh For New Ergonomics + +#### Deliverables + +1. Update README DSL section with: + - `compile:` / `sync:` usage on `train_step` and `trainer` + - `collate:` / `validation_collate:` examples + - native checkpoint usage (`.npz` / `.safetensors`) and metadata behavior +2. Add runnable example scripts for: + - built-in collate schemas and mapping-based collation + - compile/sync controls and native checkpoint roundtrip +3. Keep examples executable via `bundle exec ruby examples/dsl/.rb`. + +### Phase 26: Context-Aware Collation Callables + +#### Problem + +Collation Procs are currently batch-only, which limits Ruby-style dynamic shaping based on epoch and trainer runtime context. + +#### Deliverables + +1. Allow `collate:` and `validation_collate:` callables to accept dynamic signatures: + - positional batch input + - optional keyword context (`epoch`, `batch_index`, `kind`, `trainer`) +2. Allow mapping-collate `Proc` selectors to use the same dynamic signature support. +3. Preserve existing simple `->(batch) { ... }` behavior. + +#### Red (tests first) + +1. Add failing trainer unit test for context-aware `collate` callable signatures. +2. Add failing trainer unit test for context-aware mapping selector Procs. +3. Add failing trainer unit test for context-aware `validation_collate`. + +#### Green (minimum implementation) + +1. Thread epoch/batch context through trainer collation call sites. +2. Add signature-aware callable invocation helper in `lib/mlx/dsl/trainer.rb`. +3. Route mapping selector Procs through the same callable helper. + +#### Exit Criteria + +1. New context-aware collation tests pass. +2. Existing collation + transform tests remain green. + +### Phase 27: Keyword-Normalized Hash Batch Dispatch + +#### Problem + +Ruby keyword lambdas (for `train_step` and validation loss blocks) fail when dataset hashes use string keys, even when keys are semantically correct. + +#### Deliverables + +1. Normalize top-level hash keys to symbols before keyword dispatch in train and validation execution paths. +2. Raise clear errors when normalization causes duplicate keyword collisions (for example `"x"` and `:x` in the same batch). +3. Keep array and scalar batch dispatch behavior unchanged. + +#### Red (tests first) + +1. Add failing trainer unit test proving train batches with string keys are accepted for keyword step call signatures. +2. Add failing trainer unit test proving validation batches with string keys are accepted for keyword loss signatures. +3. Add failing trainer unit test asserting duplicate normalized keys raise explicit error. + +#### Green (minimum implementation) + +1. Normalize hash keys in `__dsl_run_batch` and `__dsl_run_validation_batch`. +2. Add duplicate-key detection and error reporting helper in `lib/mlx/dsl/trainer.rb`. + +#### Exit Criteria + +1. New key-normalization tests pass. +2. Existing trainer behavior stays green across full DSL test suite. + +### Phase 28: Inline Callable Layer Ergonomics + +#### Problem + +Users currently need ad-hoc module classes for simple one-off tensor transforms in builder graphs, which adds boilerplate and slows experimentation. + +#### Deliverables + +1. Add a dedicated DSL callable-layer wrapper module for inline Ruby lambdas/procs. +2. Add builder helpers for inline callable layers: + - `fn(callable = nil, &block)` + - `lambda_layer` alias +3. Preserve full variadic call forwarding (`*args, **kwargs`) through callable layers. + +#### Red (tests first) + +1. Add failing graph unit test for `Builder#fn` returning a DSL module wrapper. +2. Add failing graph unit test for variadic arg/kwarg forwarding through callable layers. +3. Add failing graph unit test that missing callable/block raises a clear argument error. + +#### Green (minimum implementation) + +1. Add `MLX::DSL::Callable` module in `lib/mlx/dsl/graph_modules.rb`. +2. Add `fn` and `lambda_layer` builder methods in `lib/mlx/dsl/builder.rb`. + +#### Exit Criteria + +1. New callable-layer tests pass. +2. Existing graph/builder/trainer DSL tests remain green. + +### Phase 29: Composition Input Normalization + +#### Problem + +Composition helpers (`sequential`, `branch`, `concat`, `sum`, `residual`) currently accept heterogeneous inputs without normalization, which can leak raw classes/callables into layer stacks and break module tracking. + +#### Deliverables + +1. Normalize composition entries to `MLX::NN::Module` instances: + - instantiate module classes + - wrap callables with `MLX::DSL::Callable` +2. Raise clear `TypeError` for invalid composition entries at build time. +3. Preserve existing module-instance behavior. + +#### Red (tests first) + +1. Add failing graph unit tests proving `sequential` and `branch` normalize class + callable entries into module instances. +2. Add failing graph unit test proving invalid composition entries are rejected early. + +#### Green (minimum implementation) + +1. Add module-entry normalization path in `lib/mlx/dsl/builder.rb`. +2. Update module collection fallback to include non-nil block returns so normalization applies consistently. + +#### Exit Criteria + +1. New normalization tests pass. +2. Existing graph and trainer DSL suites remain green. + +### Phase 30: Checkpoint Path Directory Ergonomics + +#### Problem + +Saving checkpoints to nested output paths currently fails when parent directories do not already exist, adding avoidable filesystem setup boilerplate. + +#### Deliverables + +1. Automatically create parent directories for marshal checkpoint writes. +2. Automatically create parent directories for native checkpoint writes (`.npz` / `.safetensors`). +3. Preserve existing checkpoint payload/schema behavior. + +#### Red (tests first) + +1. Add failing DSL integration test proving marshal checkpoint save succeeds for non-existent nested directories. +2. Add failing DSL integration test proving native checkpoint save succeeds for non-existent nested directories. + +#### Green (minimum implementation) + +1. Add checkpoint parent-directory helper in `lib/mlx/dsl/model_mixin.rb`. +2. Invoke helper from both marshal and native save paths. + +#### Exit Criteria + +1. New checkpoint directory tests pass. +2. Existing checkpoint and trainer tests remain green. + +### Phase 31: Polymorphic `Builder#layer` Input Ergonomics + +#### Problem + +`Builder#layer` currently assumes class-only inputs, which makes it inconsistent with the rest of the composition DSL and prevents direct reuse of module instances/callables. + +#### Deliverables + +1. Support `layer` inputs as: + - `MLX::NN::Module` instance + - `MLX::NN::Module` class (with constructor args/kwargs) + - callable (`Proc`/`lambda`) wrapped as `MLX::DSL::Callable` + - block form (`layer { |...| ... }`) as callable layer +2. Add clear errors for: + - missing entry/block + - ambiguous entry + block usage + - constructor args passed to instance/callable entries +3. Keep existing class-constructor behavior intact. + +#### Red (tests first) + +1. Add failing graph unit tests covering instance/class/callable/block `layer` forms. +2. Add failing graph unit tests for invalid argument and missing-entry error paths. + +#### Green (minimum implementation) + +1. Expand `Builder#layer` dispatch logic in `lib/mlx/dsl/builder.rb`. +2. Add focused helper for constructor-argument validation on non-class entries. + +#### Exit Criteria + +1. New `layer` polymorphism tests pass. +2. Existing graph/builder/trainer DSL suites remain green. + +### Phase 32: Declarative Layer Factory Arguments + +#### Problem + +Model declaration macros require extra lambda boilerplate when layer factories need constructor arguments derived from DSL options. + +#### Deliverables + +1. Extend `layer` / `network` declaration macros to accept factory constructor args/kwargs: + - `layer :proj, MLX::NN::Linear, -> { dims }, -> { dims }, bias: false` +2. Resolve callable constructor args/kwargs in model context (same semantics as other DSL callables). +3. Keep clear rejection for ambiguous factory+block declarations. + +#### Red (tests first) + +1. Add failing DSL integration test for class factory with dynamic args/kwargs. +2. Add failing DSL integration test for callable factory with dynamic args/kwargs. +3. Add failing DSL integration test for factory+block ambiguity. + +#### Green (minimum implementation) + +1. Expand `ClassMethods#layer` / `#network` signatures and stored declaration payload. +2. Update layer materialization path in `lib/mlx/dsl/model_mixin.rb` to apply resolved factory args/kwargs. + +#### Exit Criteria + +1. New declaration factory-arg tests pass. +2. Existing DSL suite remains green. + +### Phase 33: Batch Failure Diagnostics With Epoch/Index Context + +#### Problem + +When train/validation batch execution raises, errors currently surface without trainer location context, making dynamic pipeline failures slow to debug. + +#### Deliverables + +1. Include `kind`, `epoch`, and `batch_index` context in train batch execution errors. +2. Include `kind`, `epoch`, and `batch_index` context in validation batch execution errors. +3. Preserve original exception class and original message details. + +#### Red (tests first) + +1. Add failing trainer unit test for train-step failure message context. +2. Add failing trainer unit test for validation loss failure message context. + +#### Green (minimum implementation) + +1. Thread batch location metadata into internal train/validation batch runner calls. +2. Add shared error re-raise helper in `lib/mlx/dsl/trainer.rb`. + +#### Exit Criteria + +1. New batch-diagnostic tests pass. +2. Existing DSL suite remains green. + +### Phase 34: Validation Loop Limit Control + +#### Problem + +Trainer exposes `limit:` for train batches but not validation batches, forcing custom dataset wrappers for quick validation sampling. + +#### Deliverables + +1. Add `validation_limit:` option to `Trainer#fit` / `fit_report`. +2. Apply per-epoch validation loop cap without changing existing reducer/checkpoint semantics. +3. Preserve default behavior when `validation_limit` is omitted. + +#### Red (tests first) + +1. Add failing trainer unit test proving validation batches can be capped and metrics are computed from the capped set. + +#### Green (minimum implementation) + +1. Extend `fit` keyword signature with `validation_limit`. +2. Add break condition in validation epoch iteration path. + +#### Exit Criteria + +1. New validation-limit test passes. +2. Existing DSL suite remains green. + +### Phase 35: Extensionless Native Checkpoint Load Autodetection + +#### Problem + +Native checkpoints saved with explicit `format:` and extensionless base names require callers to remember the generated extension when loading. + +#### Deliverables + +1. For `load_checkpoint(path, format: nil)` where `path` has no extension and does not exist: + - autodetect `#{path}.npz` + - fallback autodetect `#{path}.safetensors` +2. Preserve existing explicit `format:` behavior and marshal path semantics. +3. Keep native payload/metadata loading unchanged. + +#### Red (tests first) + +1. Add failing DSL integration test proving extensionless load path auto-detects an existing native `.npz` checkpoint. + +#### Green (minimum implementation) + +1. Add load-path resolution helper in `lib/mlx/dsl/model_mixin.rb`. +2. Route `load_checkpoint` through resolved path before format dispatch. + +#### Exit Criteria + +1. New autodetect integration test passes. +2. Existing DSL suite remains green. + +### Phase 36: Resumeable Trainer Runs From Checkpoints + +#### Problem + +Long-running experiments need restart safety, but trainer state (`epoch`, `best_metric`, `stale_epochs`) is currently not restored into `fit` / `fit_report`. + +#### Deliverables + +1. Add `resume_from:` option to `Trainer#fit` and `fit_report`. +2. Restore checkpoint-backed trainer state: + - next start epoch (`epoch + 1`) + - `best_metric` + - `stale_epochs` + - monitor consistency via `monitor_name` +3. Persist resume-relevant metadata during checkpoint saves for future resumes. +4. Include resume fields in report payload for observability: + - `resume_from` + - `resumed_from_epoch` + - `start_epoch` + +#### Red (tests first) + +1. Add failing trainer unit test for successful continuation from checkpoint metadata. +2. Add failing trainer unit test proving early-stopping stale counter is restored on resume. +3. Add failing trainer unit test rejecting monitor mismatch between requested monitor and checkpoint monitor metadata. + +#### Green (minimum implementation) + +1. Extend trainer fit signature and epoch loop to begin from resumed epoch. +2. Add resume-state loader helper with backward-compatible `load_checkpoint` keyword negotiation. +3. Add additional checkpoint metadata fields (`stale_epochs`, `best_metric`, `next_epoch`) in trainer checkpoint writes. + +#### Exit Criteria + +1. Resume tests pass. +2. Full DSL test suite remains green. + +### Phase 37: Inline Resume Payload Ergonomics + +#### Problem + +`resume_from:` currently assumes filesystem-backed checkpoints; dynamic Ruby workflows often already hold parsed checkpoint payloads in memory. + +#### Deliverables + +1. Allow `resume_from:` to accept an inline checkpoint payload `Hash`. +2. Bypass model `load_checkpoint` when inline payload is provided. +3. Preserve existing path-based resume behavior unchanged. + +#### Red (tests first) + +1. Add failing trainer unit test proving `resume_from: { ... }` resumes epochs/metrics without invoking `load_checkpoint`. + +#### Green (minimum implementation) + +1. Branch resume source handling in `lib/mlx/dsl/trainer.rb` for hash payloads vs path inputs. +2. Keep report resume metadata consistent (`resume_from` nil for inline payloads, `resumed_from_epoch` retained). + +#### Exit Criteria + +1. Inline resume payload tests pass. +2. Full DSL suite remains green. + +### Phase 38: Callable Resume Loader Support + +#### Problem + +Dynamic orchestration layers often need to compute resume state at runtime (for example, choose latest checkpoint per monitor), which path-only resume wiring does not express cleanly. + +#### Deliverables + +1. Allow `resume_from:` to accept a callable loader. +2. Support dynamic callable signatures for loader context: + - `trainer` + - `model` + - `optimizer` + - `monitor_name` +3. Allow loader return values as: + - inline checkpoint payload hash + - checkpoint path + - `nil` (no resume) + +#### Red (tests first) + +1. Add failing trainer unit test proving callable loader receives context and can return an inline payload without invoking `load_checkpoint`. + +#### Green (minimum implementation) + +1. Add callable resume loader invocation helper in `lib/mlx/dsl/trainer.rb`. +2. Route resume-source normalization through callable/hash/path branches. + +#### Exit Criteria + +1. Callable resume loader tests pass. +2. Full DSL suite remains green. + +### Phase 39: Resume Progress Telemetry In Fit Reports + +#### Problem + +`epochs_ran` alone is ambiguous in resumed runs because it reports only epochs executed in the current invocation. + +#### Deliverables + +1. Add explicit fit-report progress fields: + - `epochs_target` (requested total epoch target) + - `epochs_completed` (effective total progress including resumed offset) +2. Preserve existing `epochs_ran` semantics for backward compatibility. + +#### Red (tests first) + +1. Add failing trainer unit assertions for new progress fields in fresh runs. +2. Add failing trainer unit assertions for new progress fields in resumed runs. + +#### Green (minimum implementation) + +1. Track `total_epochs` in `Trainer#fit`. +2. Emit `epochs_target` / `epochs_completed` in report payload. + +#### Exit Criteria + +1. New report telemetry tests pass. +2. Full DSL suite remains green. + +### Phase 40: Checkpoint Template `next_epoch` Placeholder + +#### Problem + +Checkpoint naming templates support `%{epoch}` but resume flows often need forward-looking names aligned with the next epoch index. + +#### Deliverables + +1. Add `%{next_epoch}` template token support to trainer checkpoint path rendering. +2. Keep existing checkpoint template tokens and error behavior unchanged. + +#### Red (tests first) + +1. Add failing trainer unit test proving `%{next_epoch}` renders as `epoch + 1`. + +#### Green (minimum implementation) + +1. Extend `__dsl_checkpoint_path` interpolation map with `next_epoch`. + +#### Exit Criteria + +1. New template test passes. +2. Full DSL suite remains green. + +### Phase 41: First-Class Dataset Pipeline DSL + +#### Problem + +Training data preparation still relies on ad-hoc Ruby enumerator wiring, which makes common transforms repetitive and less composable. + +#### Deliverables + +1. Add a dataset pipeline wrapper under `MLX::DSL::Data` with chainable transforms: + - `map` + - `filter` + - `batch` + - `take` + - `repeat` +2. Keep pipeline output compatible with existing trainer dataset expectations (`#each`, rewind/factory behavior). +3. Preserve lazy semantics by default to avoid eager memory blowups. + +#### Red (tests first) + +1. Add failing unit tests for transform chaining and stable iteration semantics across epochs. +2. Add failing trainer integration test proving pipeline output works with `fit`/`fit_report`. + +#### Green (minimum implementation) + +1. Add pipeline wrapper implementation in `lib/mlx/dsl/data_pipeline.rb`. +2. Load from `lib/mlx/dsl.rb`. +3. Integrate with trainer data paths without changing existing dataset APIs. + +#### Exit Criteria + +1. Pipeline tests pass. +2. Existing trainer/data ergonomics tests remain green. + +### Phase 42: Collate Registry and Schema Composition + +#### Problem + +Collate logic is powerful but repetitive across train/validation flows when the same mapping/callable schemas are reused. + +#### Deliverables + +1. Add trainer-level collate registry: + - `register_collate(name, spec = nil, &block)` + - support named reuse in `collate:` / `validation_collate:` +2. Support schema composition helpers for named collates (for example, extending base mapping selectors). +3. Keep built-in schemas (`:x`, `:xy`) backward compatible. + +#### Red (tests first) + +1. Add failing trainer unit tests for registering and resolving named collate schemas. +2. Add failing tests for train/validation parity using shared named collates. + +#### Green (minimum implementation) + +1. Add collate registry storage and lookup path in `lib/mlx/dsl/trainer.rb`. +2. Route collate normalization through registry before existing dispatch. + +#### Exit Criteria + +1. Named collate registry tests pass. +2. Existing collate behavior remains green. + +### Phase 43: Ordered and Conditional Hook Middleware + +#### Problem + +Hooks currently execute in registration order only, with no built-in scheduling controls (`every N`, once-only, or priority ordering). + +#### Deliverables + +1. Extend hook registration to support options: + - `priority:` + - `every:` + - `once:` + - optional `if:` predicate +2. Ensure deterministic ordering for hooks with mixed priorities. +3. Preserve existing `on` and shorthand hook APIs without options. + +#### Red (tests first) + +1. Add failing trainer/train-step unit tests for hook ordering by priority. +2. Add failing tests for `every:` and `once:` scheduling semantics. +3. Add failing tests for conditional hook predicates. + +#### Green (minimum implementation) + +1. Add hook wrapper normalization and ordered execution in `lib/mlx/dsl/train_step.rb` and `lib/mlx/dsl/trainer.rb`. +2. Keep no-option hooks on current behavior path. + +#### Exit Criteria + +1. Hook ordering/scheduling tests pass. +2. Existing hook consumers remain green. + +### Phase 44: Model Introspection and Debug Ergonomics + +#### Problem + +As DSL graphs get more dynamic, users need fast visibility into module composition, parameter paths, and trainable counts without custom scripts. + +#### Deliverables + +1. Add model introspection helpers: + - `summary` + - `parameter_count` + - `trainable_parameter_count` + - `parameter_paths(matcher: nil)` +2. Ensure output reflects DSL-built graphs (including callable/composed modules). +3. Provide machine-friendly summary payloads (hash) and human-readable formatting. + +#### Red (tests first) + +1. Add failing DSL tests for parameter/path counts on composed models. +2. Add failing tests for matcher-filtered path reporting. + +#### Green (minimum implementation) + +1. Add introspection helpers in `lib/mlx/dsl/model_mixin.rb`. +2. Reuse existing tree flatten utilities for consistent path semantics. + +#### Exit Criteria + +1. Introspection tests pass. +2. Existing freeze/unfreeze and optimizer-group behavior stays green. + +### Phase 45: Reproducible Run Bundle and Resume Metadata Standardization + +#### Problem + +Experiment reproducibility remains fragmented across ad-hoc metadata, making restart/debug workflows less reliable over long-running iterations. + +#### Deliverables + +1. Add trainer run bundle export (JSON) containing: + - fit report + - trainer config (`monitor`, reducer, limits, resume source kind) + - checkpoint metadata snapshot +2. Add versioned metadata schema key for trainer-resume fields. +3. Add helper for loading run bundle metadata into `resume_from` callable/hash flows. + +#### Red (tests first) + +1. Add failing integration tests for run bundle export payload shape. +2. Add failing tests for resume compatibility using exported metadata. + +#### Green (minimum implementation) + +1. Add run bundle serialization helper in `lib/mlx/dsl/trainer.rb`. +2. Document schema version and compatibility expectations. + +#### Exit Criteria + +1. Run bundle tests pass. +2. Resume and checkpoint compatibility tests remain green. + +### Phase 46: Multi-Base Collate Schema Composition + +#### Problem + +Named collate composition currently supports only a single `extends:` base, which makes shared train/eval schema layering repetitive in real projects. + +#### Deliverables + +1. Allow `register_collate(..., extends: [...])` with ordered base names. +2. Compose base schemas in-order before applying the overlay schema. +3. Preserve clear unknown-base errors for each missing base name. + +#### Red (tests first) + +1. Add failing trainer unit test proving multi-base `extends` composition order. +2. Add failing trainer unit test proving unknown base names raise explicit errors in multi-base mode. + +#### Green (minimum implementation) + +1. Update `Trainer#register_collate` to normalize `extends` as single-name or array. +2. Compose ordered base schemas before overlay merge in `lib/mlx/dsl/trainer.rb`. + +#### Exit Criteria + +1. Multi-base collate tests pass. +2. Existing named collate behavior remains green. + +### Phase 47: Dynamic Per-Epoch Loop Limits + +#### Problem + +`limit:` and `validation_limit:` are static values today, forcing external wrappers for common Ruby dynamic scheduling patterns. + +#### Deliverables + +1. Allow `limit:` to accept callables resolved per epoch. +2. Allow `validation_limit:` to accept callables resolved per epoch. +3. Support callable signatures with runtime context (`epoch`, `kind`, `trainer`). +4. Validate negative/invalid callable returns with clear errors. + +#### Red (tests first) + +1. Add failing trainer unit test for callable train limits varying by epoch. +2. Add failing trainer unit test for callable validation limits varying by epoch. + +#### Green (minimum implementation) + +1. Add loop-limit resolution helper in `lib/mlx/dsl/trainer.rb`. +2. Resolve per-epoch limit values before train and validation iteration loops. + +#### Exit Criteria + +1. Callable limit tests pass. +2. Existing integer limit behavior remains unchanged. + +### Phase 48: Callable Checkpoint Path Builders + +#### Problem + +Checkpoint naming currently depends on string templates only, which underuses Ruby’s dynamic DSL style for contextual path generation. + +#### Deliverables + +1. Allow `checkpoint_path:` to accept a callable path builder. +2. Provide path-builder context (`epoch`, `next_epoch`, monitor fields, `trainer`, `model`, `optimizer`). +3. Validate callable return values as string-compatible paths. +4. Preserve existing `%{...}` template support. + +#### Red (tests first) + +1. Add failing trainer unit test proving callable checkpoint paths receive full runtime context. + +#### Green (minimum implementation) + +1. Extend checkpoint-path resolver in `lib/mlx/dsl/trainer.rb` for callable dispatch. +2. Keep template interpolation fallback for string paths. + +#### Exit Criteria + +1. Callable checkpoint path tests pass. +2. Existing template behavior remains green. + +### Phase 49: Run Bundle Resume Source Autodetection + +#### Problem + +Resume flows currently require manual conversion from run bundles to checkpoint metadata payloads, adding avoidable boilerplate. + +#### Deliverables + +1. Accept run-bundle hashes directly in `resume_from:`. +2. Accept run-bundle JSON paths directly in `resume_from:`. +3. Route detected run bundles through `resume_payload_from_bundle` automatically. +4. Preserve existing path-based checkpoint resume behavior for non-bundle paths. + +#### Red (tests first) + +1. Add failing trainer unit test for run-bundle hash resume. +2. Add failing trainer unit test for run-bundle path resume. + +#### Green (minimum implementation) + +1. Add run-bundle source detection helpers in `lib/mlx/dsl/trainer.rb`. +2. Normalize resume source before checkpoint-loading fallback. + +#### Exit Criteria + +1. Run-bundle resume tests pass. +2. Existing inline/hash/callable/path resume behavior remains green. + +### Phase 50: Index-Aware Dataset Pipeline Transforms + +#### Problem + +`Data::Pipeline#map` and `#filter` currently receive only the batch item, limiting Ruby-idiomatic index-aware transform logic. + +#### Deliverables + +1. Add index-aware callable invocation for pipeline `map`. +2. Add index-aware callable invocation for pipeline `filter`. +3. Support positional and keyword signatures with `item`, `index`, and `pipeline` context. + +#### Red (tests first) + +1. Add failing pipeline unit test for positional `(item, index)` mapping. +2. Add failing pipeline unit test for keyword `index:` filtering. + +#### Green (minimum implementation) + +1. Add signature-aware callable dispatcher for pipeline transforms in `lib/mlx/dsl/data_pipeline.rb`. +2. Route `map`/`filter` through the new dispatcher while preserving lazy iteration semantics. + +#### Exit Criteria + +1. Index-aware pipeline tests pass. +2. Existing pipeline chaining/laziness tests remain green. + +### Phase 51: Trainer Fit Presets and Defaults Registry + +#### Problem + +`fit` / `fit_report` calls repeat large keyword argument sets (`monitor`, `reduce`, checkpoint policy, limits, resume policy), creating copy/paste boilerplate across scripts. + +#### Deliverables + +1. Add trainer-level preset registry: + - `register_fit_preset(name, **defaults)` + - `fit_with(name, dataset, **overrides)` + - `fit_report_with(name, dataset, **overrides)` +2. Add immutable trainer defaults helper: + - `with_fit_defaults(**defaults)` returning a configured trainer wrapper/clone. +3. Merge precedence: + - explicit call overrides > preset defaults > trainer defaults > current method defaults. +4. Keep existing `fit` / `fit_report` APIs unchanged. + +#### Red (tests first) + +1. Add failing trainer unit tests for preset registration and `fit_with` / `fit_report_with` execution. +2. Add failing tests for precedence/merge semantics across trainer defaults, preset defaults, and call overrides. +3. Add failing tests proving existing direct `fit` usage remains unaffected. + +#### Green (minimum implementation) + +1. Add preset/default storage and merge normalization in `lib/mlx/dsl/trainer.rb`. +2. Route `fit_with` / `fit_report_with` through existing `fit` execution path. + +#### Exit Criteria + +1. Preset/default tests pass. +2. Existing trainer behavior remains green. + +### Phase 52: Declarative Batch Schema and Auto-Collate + +#### Problem + +Users repeatedly write equivalent `collate`/`validation_collate` mappings for common `(x, y)` and nested hash batches. + +#### Deliverables + +1. Add declarative batch schema API: + - `batch_schema(spec)` at trainer level + - optional split-specific schemas (`train_schema`, `validation_schema`). +2. Add auto-collate mode: + - `collate: :auto` / `validation_collate: :auto` using declared schema or inferred defaults. +3. Inference behavior for common batch shapes: + - hash with `x`/`y` keys (symbol or string) + - two-item arrays -> `{x:, y:}` +4. Keep explicit collate specs taking precedence over schema/auto behavior. + +#### Red (tests first) + +1. Add failing trainer unit tests for schema-driven auto-collation. +2. Add failing tests for split-specific schema overrides. +3. Add failing tests for precedence when explicit `collate` is provided. + +#### Green (minimum implementation) + +1. Add schema storage and auto-collate resolver in `lib/mlx/dsl/trainer.rb`. +2. Reuse existing collate mapping and callable dispatch internals where possible. + +#### Exit Criteria + +1. Auto-collate/schema tests pass. +2. Existing manual collate behavior remains green. + +### Phase 53: Reusable Dataflow Specs (Collate + Transform + Limits) + +#### Problem + +Even with collate reuse, users still repeat the same train/validation loop wiring (`collate`, transforms, limits, reducers) across runs. + +#### Deliverables + +1. Add composable dataflow profiles: + - `register_dataflow(name, train: {...}, validation: {...})` + - `use_dataflow(name, **overrides)` on fit calls. +2. Dataflow fields should support existing dynamic callables: + - `collate`, `transform`, `limit`, `reduce` / `validation_reduce`. +3. Support profile inheritance/composition: + - `extends:` for dataflow profiles (same merge semantics as fit presets). +4. Keep direct keyword usage fully backward compatible. + +#### Red (tests first) + +1. Add failing trainer unit tests for applying named dataflow profiles. +2. Add failing tests for profile inheritance and override precedence. +3. Add failing tests proving direct per-call kwargs override profile values. + +#### Green (minimum implementation) + +1. Add dataflow registry + profile merge in `lib/mlx/dsl/trainer.rb`. +2. Resolve profile-derived fit kwargs before main fit execution. + +#### Exit Criteria + +1. Dataflow profile tests pass. +2. Existing dataset/collate/transform behavior remains green. + +### Phase 54: Stack/Repeat Builder Macros + +#### Problem + +Model declaration blocks repeat similar layer sequences manually (e.g., `linear + relu + dropout` N times), adding noise and error-prone copy/paste. + +#### Deliverables + +1. Add builder repetition helpers: + - `repeat_layers(count) { |i| ... }` + - `stack(count, layer_class = nil, *args, **kwargs, &block)` for common repeated patterns. +2. Ensure repeated entries are normalized through existing module/callable resolution paths. +3. Support index-aware blocks for dynamic dimensions (`i`-dependent construction). +4. Preserve existing `sequential`/`branch` semantics and output module tracking. + +#### Red (tests first) + +1. Add failing graph unit tests for repeated layer construction and index-aware block behavior. +2. Add failing tests proving module tracking and parameter paths remain correct for repeated stacks. + +#### Green (minimum implementation) + +1. Implement repeat/stack helpers in `lib/mlx/dsl/builder.rb`. +2. Reuse existing composition normalization internals for consistency. + +#### Exit Criteria + +1. Stack/repeat tests pass. +2. Existing graph/builder tests remain green. + +### Phase 55: Hook and Metric Packs + +#### Problem + +Hook instrumentation and monitor metric setup are frequently duplicated between experiments (`logging`, `checkpoint telemetry`, `early-stop traces`). + +#### Deliverables + +1. Add reusable hook packs: + - `register_hook_pack(name) { ... }` + - `use_hook_pack(name, **options)` +2. Add reusable metric packs: + - `register_metric(name, callable = nil, &block)` + - reference metric by name in `fit_report(monitor:, metric:)`. +3. Allow hook/metric packs to receive runtime context and user options. +4. Preserve direct inline `on` hooks and callable `metric:` behavior. + +#### Red (tests first) + +1. Add failing trainer unit tests for applying named hook packs. +2. Add failing tests for named metric registration and monitor integration. +3. Add failing tests for per-use options/context propagation. + +#### Green (minimum implementation) + +1. Add hook/metric registries in `lib/mlx/dsl/trainer.rb`. +2. Resolve named packs/metrics through existing emit and monitor execution paths. + +#### Exit Criteria + +1. Hook/metric pack tests pass. +2. Existing hook and metric behavior remains green. + +### Phase 56: Task-Level Training API (`fit_task`) + +#### Problem + +Training setup still repeats task boilerplate (`loss`, `monitor`, default collate shape, metric wiring) for common workflows like classification and regression. + +#### Deliverables + +1. Add task-level fit entrypoints: + - `fit_task(task, dataset, **kwargs)` + - `fit_task_report(task, dataset, **kwargs)` +2. Add built-in task presets: + - `:classification` + - `:regression` + - `:language_modeling` +3. Task presets should provide sane defaults for: + - loss callable + - `monitor` / `monitor_mode` + - common collate schema assumptions +4. Preserve current `fit` / `fit_report` and custom loss-block workflows unchanged. + +#### Red (tests first) + +1. Add failing trainer unit tests for built-in classification task wiring. +2. Add failing tests for task defaults being overrideable per call. +3. Add failing tests proving non-task fit flows remain unchanged. + +#### Green (minimum implementation) + +1. Add task registry/resolution in `lib/mlx/dsl/trainer.rb`. +2. Route `fit_task` APIs through existing fit execution path via normalized kwargs/loss callable. + +#### Exit Criteria + +1. Task-level fit tests pass. +2. Existing trainer APIs remain green. + +### Phase 57: Signature and KeyPath Auto-Binding for Batches + +#### Problem + +Users still write repetitive `collate` mappings to bridge dataset batch shapes into loss/train-step signatures (`x:`, `y:`, nested targets). + +#### Deliverables + +1. Add auto-binding option for train and validation: + - `bind:` + - `validation_bind:` +2. Support binding modes: + - argument-name inference from loss/train-step signature + - explicit key-path mappings (for example, `{ x: [:input, :x], y: [:target, 0] }`) +3. Keep duplicate-key and missing-key diagnostics explicit and context-rich. +4. Preserve explicit `collate` precedence over auto-binding. + +#### Red (tests first) + +1. Add failing trainer unit tests for signature-inferred binding on hash batches. +2. Add failing tests for nested key-path binding. +3. Add failing tests for precedence and error diagnostics. + +#### Green (minimum implementation) + +1. Add bind normalization and extraction helpers in `lib/mlx/dsl/trainer.rb`. +2. Integrate binding into batch preparation before train/validation dispatch. + +#### Exit Criteria + +1. Auto-binding tests pass. +2. Existing collate and batch dispatch flows remain green. + +### Phase 58: Unified Experiment DSL (`experiment do ... end`) + +#### Problem + +Experiment scripts still spread setup across model/trainer/optimizer/dataflow/checkpoint blocks, requiring repeated orchestration scaffolding. + +#### Deliverables + +1. Add top-level experiment DSL helper: + - `MLX::DSL.experiment(name = nil) { ... }` +2. Support declarative sections: + - `model` + - `optimizer` + - `trainer` + - `dataflow` / datasets + - `artifacts` / resume settings +3. Return a runnable experiment object with: + - `run` + - `report` + - `save_run_bundle` +4. Keep all existing lower-level APIs available and unchanged. + +#### Red (tests first) + +1. Add failing integration test for end-to-end experiment declaration and execution. +2. Add failing tests for section override precedence and explicit object injection. + +#### Green (minimum implementation) + +1. Add experiment builder/runtime under `lib/mlx/dsl/`. +2. Reuse existing trainer/model helpers instead of duplicating training logic. + +#### Exit Criteria + +1. Experiment DSL tests pass. +2. Existing DSL entrypoints remain green. + +### Phase 59: Dataset Split Plan DSL + +#### Problem + +Train/validation/test split wiring still requires repeated ad-hoc lambdas and transform plumbing across scripts. + +#### Deliverables + +1. Add split-plan DSL: + - `splits do ... end` + - `train`, `validation`, `test` declarations +2. Support shared and split-specific transforms/collate/limits. +3. Provide reusable split plan objects consumable by trainer fit/report APIs. +4. Preserve compatibility with raw enumerable and factory datasets. + +#### Red (tests first) + +1. Add failing trainer unit tests for split plan train/validation consumption. +2. Add failing tests for shared transform inheritance plus split overrides. +3. Add failing tests for backward compatibility with current dataset inputs. + +#### Green (minimum implementation) + +1. Add split plan object and resolver in `lib/mlx/dsl/`. +2. Integrate plan expansion into trainer call paths as optional sugar. + +#### Exit Criteria + +1. Split-plan tests pass. +2. Existing dataset and transform behaviors remain green. + +### Phase 60: Artifact Policy DSL (Checkpoints and Run Bundles) + +#### Problem + +Checkpoint/run-bundle lifecycle policies (`latest`, `best`, retention count, naming, resume target) are still configured manually per script. + +#### Deliverables + +1. Add declarative artifact policy API: + - checkpoint strategy (`save_latest`, `save_best`, `save_every`) + - retention (`keep_last_n`) + - resume strategy (`:latest`, `:best`, `:path`, callable) +2. Add run-bundle export policy toggles and output conventions. +3. Keep policy evaluation deterministic and report-visible. +4. Preserve current direct checkpoint/run-bundle APIs. + +#### Red (tests first) + +1. Add failing trainer unit/integration tests for retention and strategy behavior. +2. Add failing tests for resume strategy resolution from policy state. +3. Add failing tests for policy metadata appearing in reports/bundles. + +#### Green (minimum implementation) + +1. Add artifact policy object and enforcement hooks in `lib/mlx/dsl/trainer.rb`. +2. Route checkpoint/bundle paths through policy resolver while keeping existing explicit options. + +#### Exit Criteria + +1. Artifact policy tests pass. +2. Existing checkpoint and run-bundle behavior remains green. + +## Immediate Implementation Scope + +Implemented in this execution stream: Phases 19-60. +Next planning queue: Phase 61+. + +Implementation approach for each phase remains strict red/green sequencing: + +1. Red tests for each phase. +2. Minimum green code changes. +3. Targeted suite run, then broader DSL regression run. diff --git a/rfp/2026_02_17_benchmark_performance_remediation_plan.md b/rfp/2026_02_17_benchmark_performance_remediation_plan.md new file mode 100644 index 00000000..b24aa8f9 --- /dev/null +++ b/rfp/2026_02_17_benchmark_performance_remediation_plan.md @@ -0,0 +1,142 @@ +# Benchmark Performance Remediation Plan (CNN + RNN) + +## Goal + +Close the Ruby vs Python performance gap for benchmark `cnn` and `rnn` while preserving output parity and benchmark path equivalence. + +Current benchmark gap (GPU, `WARMUP=50`, `ITERATIONS=1000`): + +- `cnn`: Ruby `0.929 ms` vs Python `0.436 ms` (`2.13x` slower) +- `rnn`: Ruby `7.468 ms` vs Python `4.178 ms` (`1.79x` slower) + +## Issue Map + +1. End-to-end benchmark gap in CNN/RNN. +2. CNN gap is mostly execution-path cost, not graph-build cost. +3. Ruby pooling implementation is slower than upstream path. +4. Ruby RNN has high graph-build and execution overhead. +5. Ruby recurrent kernels use less efficient op patterns than upstream. +6. Ruby activations are eager-only while Python uses compiled activation helpers. + +## Phased Plan + +### Phase 0: Baseline and Targets (Issues 1, 2, 4) + +- Lock benchmark protocol and environment for reproducibility. +- Capture baseline: + - full benchmark matrix (`benchmark`) + - per-op microbench (`conv`, `relu`, `pool`, `linear`, `rnn`) + - split timings (`build_only`, `build_plus_eval`) +- Define pass thresholds: + - GPU `cnn` `rb/py <= 1.30x` + - GPU `rnn` `rb/py <= 1.30x` + +Deliverable: + +- Committed baseline report and target table in `rfp/`. + +### Phase 1: Pooling Fast Path Port (Issue 3) + +- Port upstream pooling optimizations to Ruby: + - non-overlapping window fast path + - pooled reduction over all window axes in one operation +- Keep parity behavior identical. + +Deliverable: + +- Refactored `lib/mlx/nn/layers/pooling.rb` + test coverage. + +Exit criteria: + +- `maxpool2d` microbench improves by at least `25%`. +- CNN benchmark improves from Phase 0 baseline. + +### Phase 2: Recurrent Op-Path Refactor (Issue 5) + +- Refactor `RNN`, `GRU`, `LSTM` to use more fused operations where possible (`addmm` paths). +- Reduce per-step op count and avoid avoidable `take`/extra intermediate materialization patterns. + +Deliverable: + +- Refactored `lib/mlx/nn/layers/recurrent.rb` + parity tests. + +Exit criteria: + +- `rnn_full` microbench improves by at least `20%`. +- No parity regressions in benchmark checks. + +### Phase 3: Ruby Graph-Build Overhead Reduction (Issue 4) + +- Reduce Ruby-side loop and object churn in hot model paths. +- Cache static metadata and eliminate repeated Ruby work in inner loops where safe. + +Deliverable: + +- Profiling-backed changes with before/after timing report. + +Exit criteria: + +- `rnn_build_only` improves by at least `30%`. +- `cnn_build_only` improves where measurable. + +### Phase 4: Activation Compile Parity (Issue 6) + +- Add compiled activation path in Ruby analogous to upstream activation usage. +- Preserve eager fallback behavior. + +Deliverable: + +- Activation runtime improvements + regression tests. + +Exit criteria: + +- Activation microbench near Python parity. +- Measurable CNN improvement versus Phase 0. + +### Phase 5: End-to-End Validation and Guardrails (Issue 1 Closure) + +- Re-run full benchmark matrix with Phase 0 protocol. +- Publish before/after summary. +- Update `README.md` benchmark table. +- Add non-flaky perf guardrails in CI (warning-first, then fail on sustained regression). + +Deliverable: + +- Final benchmark report + README update + CI threshold checks. + +Exit criteria: + +- CNN and RNN meet target ratios or have documented residual blockers with quantified delta. + +## Execution Strategy + +- Ship by phase in small PRs to isolate regressions. +- Require parity check pass (`input_shape`, `input_digest`, `output_shape`, `reference_output_digest`) for each phase. +- Treat benchmark speedup claims as valid only when measured with fixed protocol from Phase 0. + +## Progress Update (2026-02-17) + +Completed so far: + +- Phase 1 (partial): pooling non-overlapping sliding window fast path and pooled-axis reduction for max. +- Phase 2 (partial): recurrent refactor (`addmm` paths, transpose caching, while-loop/preallocated hidden buffers, direct state fetches). +- Phase 3 (partial): reduced Ruby dispatch overhead in hot paths (`Conv*`, `Linear`, `Bilinear`, benchmark runner proc fast path). +- Additional RNN optimization: compiled hidden-state update fast path for default tanh recurrence with safe eager fallback. + +Validation: + +- Parity tests passing for updated areas: + - `test/parity/phase187_linear_layers_parity_test.rb` + - `test/parity/phase193_convolution_layers_parity_test.rb` + - `test/parity/phase195_pooling_layers_parity_test.rb` + - `test/parity/phase197_recurrent_layers_parity_test.rb` + - `test/parity/phase190_activations_parity_test.rb` + +Latest GPU benchmark samples (`WARMUP=50`, `ITERATIONS=1000`): + +- `rnn`: Ruby `3.687 ms`, Python `4.219 ms` (Ruby faster in current run) +- `cnn`: Ruby `0.647 ms`, Python `0.397 ms` (Ruby slower; residual gap remains) + +Current blocker: + +- CNN still trails despite op-path and dispatch optimizations. Residual gap appears concentrated in `conv + relu + pool` execution cost under Ruby wrapper calls; needs deeper profiling/fusion strategy in Phase 3/4 follow-up. diff --git a/tasks/benchmark_task.rb b/tasks/benchmark_task.rb index 27375677..17014149 100644 --- a/tasks/benchmark_task.rb +++ b/tasks/benchmark_task.rb @@ -212,11 +212,16 @@ def benchmark_ruby_loop(example) finish = nil output = nil label = example.label + runner = if example.respond_to?(:run_step_proc) + example.run_step_proc + else + -> { example.run_step } + end warmup_every = log_interval(@warmup) iter_every = log_interval(@iterations) @warmup.times do |idx| - output = example.run_step + output = runner.call MLX::Core.eval(output) if (idx + 1) == @warmup || ((idx + 1) % warmup_every).zero? puts "[ruby/#{label}] warmup #{idx + 1}/#{@warmup}" @@ -225,7 +230,7 @@ def benchmark_ruby_loop(example) start = Process.clock_gettime(Process::CLOCK_MONOTONIC) @iterations.times do |idx| - output = example.run_step + output = runner.call MLX::Core.eval(output) if (idx + 1) == @iterations || ((idx + 1) % iter_every).zero? puts "[ruby/#{label}] iter #{idx + 1}/#{@iterations}" diff --git a/test/parity/phase190_activations_parity_test.rb b/test/parity/phase190_activations_parity_test.rb index 1129edc1..a9704edf 100644 --- a/test/parity/phase190_activations_parity_test.rb +++ b/test/parity/phase190_activations_parity_test.rb @@ -58,6 +58,33 @@ def test_softplus_sigmoid_family assert_nested_close mish_expected.to_a, MLX::NN.mish(x).to_a end + def test_relu_and_sigmoid_fallback_when_compiled_activation_cache_is_invalid + x = MLX::Core.array([-1.0, 0.0, 1.0], MLX::Core.float32) + sentinel = Object.new + ivars = %i[@compiled_relu @compiled_sigmoid] + backups = {} + + ivars.each do |ivar| + backups[ivar] = if MLX::NN.instance_variable_defined?(ivar) + MLX::NN.instance_variable_get(ivar) + else + sentinel + end + MLX::NN.instance_variable_set(ivar, MLX::Core) + end + + assert_nested_close [0.0, 0.0, 1.0], MLX::NN.relu(x).to_a + assert_nested_close [0.26894143, 0.5, 0.7310586], MLX::NN.sigmoid(x).to_a + ensure + ivars.each do |ivar| + if backups[ivar].equal?(sentinel) + MLX::NN.remove_instance_variable(ivar) if MLX::NN.instance_variable_defined?(ivar) + else + MLX::NN.instance_variable_set(ivar, backups[ivar]) + end + end + end + def test_glu_prelu_and_gelu_module_behavior x_glu = MLX::Core.array([[1.0, 2.0, -1.0, -2.0]], MLX::Core.float32) glu = MLX::NN::GLU.new diff --git a/test/parity/phase195_pooling_layers_parity_test.rb b/test/parity/phase195_pooling_layers_parity_test.rb index bee09bda..cc7a2846 100644 --- a/test/parity/phase195_pooling_layers_parity_test.rb +++ b/test/parity/phase195_pooling_layers_parity_test.rb @@ -34,6 +34,22 @@ def test_pool2d_max_and_avg assert_nested_close [[[[2.5]]]], avg_pool.call(x).to_a end + def test_pool2d_non_overlapping_windows + x = MLX::Core.array( + [[[[1.0], [2.0], [3.0], [4.0]], + [[5.0], [6.0], [7.0], [8.0]], + [[9.0], [10.0], [11.0], [12.0]], + [[13.0], [14.0], [15.0], [16.0]]]], + MLX::Core.float32 + ) + + max_pool = MLX::NN::MaxPool2d.new([2, 2], stride: [2, 2]) + avg_pool = MLX::NN::AvgPool2d.new([2, 2], stride: [2, 2]) + + assert_nested_close [[[[6.0], [8.0]], [[14.0], [16.0]]]], max_pool.call(x).to_a + assert_nested_close [[[[3.5], [5.5]], [[11.5], [13.5]]]], avg_pool.call(x).to_a + end + def test_pool3d_max_and_avg x = MLX::Core.array( [[[[[1.0], [2.0]], [[3.0], [4.0]]], [[[5.0], [6.0]], [[7.0], [8.0]]]]],