Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -788,18 +788,18 @@ MLX Ruby has full Metal support through the upstream MLX runtime. On Apple silic
The table below is from:

```bash
bundle exec rake benchmark WARMUP=50 ITERATIONS=1000
BENCHMARK_DEVICES=gpu bundle exec rake benchmark WARMUP=50 ITERATIONS=5000
```

Ratios are shown per column (`py_cpu/gpu`, `rb_cpu/gpu`, `rb/py_cpu`, `rb/py_gpu`). Parity columns come from harness checks (`input_shape`, `input_digest`, `output_shape`, `reference_output_digest`).

| model | py_cpu_s | py_gpu_s | py_cpu/gpu | rb_cpu_s | rb_gpu_s | rb_cpu/gpu | rb/py_cpu | rb/py_gpu | in_shape (cpu/gpu) | in_content (cpu/gpu) | out_shape (cpu/gpu) | out_content (cpu/gpu) |
| --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: | ---: | :---: | :---: | :---: | :---: |
| transformer | 0.034 | 0.008 | 4.48x | 0.029 | 0.008 | 3.67x | 0.84x | 1.03x | /✓ | /✓ | /✓ | /✓ |
| cnn | 0.004 | 0.000 | 9.54x | 0.004 | 0.001 | 4.68x | 1.05x | 2.13x | /✓ | /✓ | /✓ | /✓ |
| mlp | 0.000 | 0.000 | 1.66x | 0.000 | 0.000 | 1.28x | 0.98x | 1.27x | /✓ | /✓ | /✓ | /✓ |
| rnn | 0.006 | 0.004 | 1.39x | 0.007 | 0.007 | 0.91x | 1.16x | 1.79x | /✓ | /✓ | /✓ | /✓ |
| karpathy_gpt2 | 0.058 | 0.011 | 5.06x | 0.060 | 0.016 | 3.84x | 1.03x | 1.36x | /✓ | /✓ | /✓ | /✓ |
| transformer | n/a | 0.006 | n/a | n/a | 0.006 | n/a | n/a | 1.11x | -/✓ | -/✓ | -/✓ | -/✓ |
| cnn | n/a | 0.001 | n/a | n/a | 0.001 | n/a | n/a | 1.15x | -/✓ | -/✓ | -/✓ | -/✓ |
| mlp | n/a | 0.000 | n/a | n/a | 0.000 | n/a | n/a | 1.09x | -/✓ | -/✓ | -/✓ | -/✓ |
| rnn | n/a | 0.004 | n/a | n/a | 0.006 | n/a | n/a | 1.62x | -/✓ | -/✓ | -/✓ | -/✓ |
| karpathy_gpt2 | n/a | 0.013 | n/a | n/a | 0.015 | n/a | n/a | 1.17x | -/✓ | -/✓ | -/✓ | -/✓ |

### Build docs

Expand Down
32 changes: 23 additions & 9 deletions examples/benchmark/cnn_example.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,40 @@ def initialize(batch_size:, dtype:)

@conv1 = MLX::NN::Conv2d.new(CNN_CHANNELS, 16, 3, stride: 1, padding: 1, bias: true)
@conv2 = MLX::NN::Conv2d.new(16, 32, 3, stride: 1, padding: 1, bias: true)
@relu = MLX::NN::ReLU.new
@pool = MLX::NN::MaxPool2d.new(2, stride: 2, padding: 0)
@linear = MLX::NN::Linear.new(@flattened_features, CNN_CLASSES)
BenchmarkDigest.assign_deterministic_parameters!([@conv1, @conv2, @linear])

conv1 = @conv1
conv2 = @conv2
pool = @pool
linear = @linear
input = @input
batch_size = @batch_size
flattened_features = @flattened_features
@run_step = lambda do
y = conv1.call(input)
y = MLX::NN.relu(y)
y = pool.call(y)
y = conv2.call(y)
y = MLX::NN.relu(y)
y = pool.call(y)
y = MLX::Core.reshape(y, [batch_size, flattened_features])
linear.call(y)
end

@input_shape = @input.shape
@input_digest = BenchmarkDigest.digest_array(@input)
@reference_output_digest = BenchmarkDigest.digest_array(run_step)
@path_signature = "forward_only_eval_output"
end

def run_step
y = @conv1.call(@input)
y = @relu.call(y)
y = @pool.call(y)
y = @conv2.call(y)
y = @relu.call(y)
y = @pool.call(y)
y = MLX::Core.reshape(y, [@batch_size, @flattened_features])
@linear.call(y)
@run_step.call
end

def run_step_proc
@run_step
end

def verification_input_digest
Expand Down
3 changes: 1 addition & 2 deletions examples/benchmark/python/cnn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from benchmark_digest import assign_deterministic_parameters
from benchmark_digest import deterministic_tensor
from benchmark_digest import digest_array
from mlx.nn.layers.activations import ReLU
from mlx.nn.layers.activations import relu
from mlx.nn.layers.convolution import Conv2d
from mlx.nn.layers.linear import Linear
from mlx.nn.layers.pooling import MaxPool2d
Expand Down Expand Up @@ -48,7 +48,6 @@ def main():

conv1 = Conv2d(CNN_CHANNELS, 16, 3, stride=1, padding=1)
conv2 = Conv2d(16, 32, 3, stride=1, padding=1)
relu = ReLU()
pool = MaxPool2d(2, stride=2)
linear = Linear(flattened, CNN_CLASSES)
assign_deterministic_parameters([conv1, conv2, linear])
Expand Down
10 changes: 9 additions & 1 deletion examples/benchmark/rnn_example.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,22 @@ def initialize(batch_size:, sequence_length:, dims:, dtype:)
@rnn = MLX::NN::RNN.new(dims, hidden_size)
BenchmarkDigest.assign_deterministic_parameters!(@rnn)

rnn = @rnn
input = @input
@run_step = lambda { rnn.call(input) }

@input_shape = @input.shape
@input_digest = BenchmarkDigest.digest_array(@input)
@reference_output_digest = BenchmarkDigest.digest_array(run_step)
@path_signature = "forward_only_eval_output"
end

def run_step
@rnn.call(@input)
@run_step.call
end

def run_step_proc
@run_step
end

def verification_input_digest
Expand Down
13 changes: 12 additions & 1 deletion examples/benchmark/transformer_example.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,23 @@ def initialize(batch_size:, sequence_length:, target_sequence_length:, dims:, nu
)
BenchmarkDigest.assign_deterministic_parameters!(@model)

model = @model
src = @src
tgt = @tgt
src_mask = @src_mask
tgt_mask = @tgt_mask
@run_step = lambda { model.call(src, tgt, src_mask, tgt_mask, nil) }

@reference_output_digest = BenchmarkDigest.digest_array(run_step)
@path_signature = "forward_only_eval_output"
end

def run_step
@model.call(@src, @tgt, @src_mask, @tgt_mask, nil)
@run_step.call
end

def run_step_proc
@run_step
end

def verification_input_digest
Expand Down
142 changes: 90 additions & 52 deletions lib/mlx/core.rb
Original file line number Diff line number Diff line change
Expand Up @@ -419,39 +419,53 @@ def compile(fun, inputs = nil, outputs = nil, shapeless = false)
flat_inputs = []
input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false)
key = structure_cache_key(input_spec)
rebuilt_once = false

entry = cache[key]
unless entry
output_spec = nil
lifted = lambda do |*flat_vars|
rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
unless cursor == flat_vars.length
raise RuntimeError, "internal input reconstruction mismatch"
begin
entry = cache[key]
unless valid_transform_cache_entry?(entry)
cache.delete(key)
entry = nil
end
unless entry
output_spec = nil
lifted = lambda do |*flat_vars|
rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
unless cursor == flat_vars.length
raise RuntimeError, "internal input reconstruction mismatch"
end

call_args = rebuilt[0]
call_kwargs = rebuilt[1]
raw_output = fun.call(*call_args, **call_kwargs)

flat_output = []
output_spec = flatten_tree_spec(raw_output, flat_output, false)
flat_output
end

call_args = rebuilt[0]
call_kwargs = rebuilt[1]
raw_output = fun.call(*call_args, **call_kwargs)

flat_output = []
output_spec = flatten_tree_spec(raw_output, flat_output, false)
flat_output
compiled = native_compile(lifted, inputs, outputs, shapeless)
entry = { fn: compiled, output_spec: -> { output_spec } }
cache[key] = entry
end

compiled = native_compile(lifted, inputs, outputs, shapeless)
entry = { fn: compiled, output_spec: -> { output_spec } }
cache[key] = entry
end

flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "compiled output")
spec = entry[:output_spec].call
raise RuntimeError, "missing output structure from compiled function" if spec.nil?
flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "compiled output")
spec = entry[:output_spec].call
raise RuntimeError, "missing output structure from compiled function" if spec.nil?

rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
unless cursor == flat_output.length
raise RuntimeError, "internal output reconstruction mismatch"
rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
unless cursor == flat_output.length
raise RuntimeError, "internal output reconstruction mismatch"
end
rebuilt
rescue RuntimeError => e
if !rebuilt_once && invalid_transform_callable_error?(e)
rebuilt_once = true
cache.delete(key)
retry
end
raise
end
rebuilt
end
end

Expand All @@ -463,39 +477,53 @@ def checkpoint(fun)
flat_inputs = []
input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false)
key = structure_cache_key(input_spec)
rebuilt_once = false

entry = cache[key]
unless entry
output_spec = nil
lifted = lambda do |*flat_vars|
rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
unless cursor == flat_vars.length
raise RuntimeError, "internal input reconstruction mismatch"
begin
entry = cache[key]
unless valid_transform_cache_entry?(entry)
cache.delete(key)
entry = nil
end
unless entry
output_spec = nil
lifted = lambda do |*flat_vars|
rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
unless cursor == flat_vars.length
raise RuntimeError, "internal input reconstruction mismatch"
end

call_args = rebuilt[0]
call_kwargs = rebuilt[1]
raw_output = fun.call(*call_args, **call_kwargs)

flat_output = []
output_spec = flatten_tree_spec(raw_output, flat_output, false)
flat_output
end

call_args = rebuilt[0]
call_kwargs = rebuilt[1]
raw_output = fun.call(*call_args, **call_kwargs)

flat_output = []
output_spec = flatten_tree_spec(raw_output, flat_output, false)
flat_output
checkpointed = native_checkpoint(lifted)
entry = { fn: checkpointed, output_spec: -> { output_spec } }
cache[key] = entry
end

checkpointed = native_checkpoint(lifted)
entry = { fn: checkpointed, output_spec: -> { output_spec } }
cache[key] = entry
end

flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "checkpoint output")
spec = entry[:output_spec].call
raise RuntimeError, "missing output structure from checkpoint function" if spec.nil?
flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "checkpoint output")
spec = entry[:output_spec].call
raise RuntimeError, "missing output structure from checkpoint function" if spec.nil?

rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
unless cursor == flat_output.length
raise RuntimeError, "internal output reconstruction mismatch"
rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
unless cursor == flat_output.length
raise RuntimeError, "internal output reconstruction mismatch"
end
rebuilt
rescue RuntimeError => e
if !rebuilt_once && invalid_transform_callable_error?(e)
rebuilt_once = true
cache.delete(key)
retry
end
raise
end
rebuilt
end
end

Expand Down Expand Up @@ -877,6 +905,16 @@ def normalize_raw_grads(raw)
normalize_array_sequence(raw, "gradient")
end

def valid_transform_cache_entry?(entry)
entry.is_a?(::Hash) &&
entry[:fn].respond_to?(:call) &&
entry[:output_spec].respond_to?(:call)
end

def invalid_transform_callable_error?(error)
error.message.match?(/undefined method [`']call['"]/)
end

def normalize_array_sequence(raw, context)
return [raw] if raw.is_a?(MLX::Core::Array)

Expand Down
48 changes: 46 additions & 2 deletions lib/mlx/nn/layers/activations.rb
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,31 @@ module MLX
module NN
class << self
def sigmoid(x)
MLX::Core.sigmoid(x)
compiled = compiled_sigmoid
return MLX::Core.sigmoid(x) unless compiled.respond_to?(:call)

compiled.call(x)
rescue RuntimeError => e
if invalid_compiled_activation_call?(e)
@compiled_sigmoid = false
MLX::Core.sigmoid(x)
else
raise
end
end

def relu(x)
MLX::Core.maximum(x, 0.0)
compiled = compiled_relu
return MLX::Core.maximum(x, 0.0) unless compiled.respond_to?(:call)

compiled.call(x)
rescue RuntimeError => e
if invalid_compiled_activation_call?(e)
@compiled_relu = false
MLX::Core.maximum(x, 0.0)
else
raise
end
end

def relu2(x)
Expand Down Expand Up @@ -130,6 +150,30 @@ def softmin(x, axis: -1)
def tanh(x)
MLX::Core.tanh(x)
end

private

def compiled_sigmoid
return nil if @compiled_sigmoid == false

@compiled_sigmoid ||= compile_activation_unary(->(v) { MLX::Core.sigmoid(v) })
end

def compiled_relu
return nil if @compiled_relu == false

@compiled_relu ||= compile_activation_unary(->(v) { MLX::Core.maximum(v, 0.0) })
end

def compile_activation_unary(fun)
MLX::Core.compile(fun, nil, nil, true)
rescue StandardError
nil
end

def invalid_compiled_activation_call?(error)
error.message.match?(/undefined method [`']call['"]/)
end
end

class GLU < Module
Expand Down
Loading
Loading