Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion docs/src/ruby/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ Key methods:
- Persistence: ``save_weights``, ``load_weights``

Important pattern: register trainable members with ``self.<name> = ...`` so
they are tracked by module state and optimizer updates.
they are tracked by module state and optimizer updates. Direct ivar assignment
(``@name = ...``) bypasses module-state registration and those members will be
missed by traversal helpers (for example ``children``/``parameters``),
``load_weights``, and quantization updates.

See implementation:

Expand Down Expand Up @@ -127,6 +130,10 @@ Graph IR / ONNX / WebGPU entry points in ``MLX::ONNX``:
- ``MLX::ONNX::WebGPUHarness.export_onnx_webgpu_harness``
- ``MLX::ONNX::WebGPUHarness.smoke_test_onnx_webgpu_harness``

``save_safetensors`` requires a native MLX build with
``MLX_BUILD_SAFETENSORS=ON``. If safetensors support is unavailable, use
``savez``/``savez_compressed`` or a Ruby safetensors serializer fallback.

.. _distributed:

Distributed
Expand Down
2 changes: 1 addition & 1 deletion ext/mlx/extconf.rb
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def patch_makefile_include_dirs!(makefile_path, include_dirs)
"-DMLX_BUILD_PYTHON_STUBS=OFF",
"-DMLX_BUILD_METAL=ON",
"-DMLX_BUILD_GGUF=OFF",
"-DMLX_BUILD_SAFETENSORS=OFF",
"-DMLX_BUILD_SAFETENSORS=ON",
"-DBUILD_SHARED_LIBS=ON"
]

Expand Down
99 changes: 97 additions & 2 deletions lib/mlx/core.rb
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,8 @@ class << self
alias_method :native_vmap, :vmap if method_defined?(:vmap) && !method_defined?(:native_vmap)
alias_method :native_export_to_dot,
:export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
alias_method :native_array, :array if method_defined?(:array) && !method_defined?(:native_array)
alias_method :native_mean, :mean if method_defined?(:mean) && !method_defined?(:native_mean)

%i[savez savez_compressed].each do |method_name|
if method_defined?(method_name) && instance_method(method_name).owner == self
Expand All @@ -344,6 +346,24 @@ class << self

ARRAY_LEAF = :__mlx_array_leaf__

def array(value, positional_dtype = nil, dtype: nil)
ensure_native!
target_dtype = resolve_array_dtype(positional_dtype, dtype)
native_array(value, target_dtype)
end

def mean(array, axis = nil, positional_keepdims = nil, keepdims: nil)
ensure_native!
keepdims_v = resolve_keepdims_argument(positional_keepdims, keepdims)
reduced = reduce_mean(array, axis)
return reduced unless keepdims_v

normalize_reduction_axes(array, axis).each do |axis_index|
reduced = expand_dims(reduced, axis_index)
end
reduced
end

def load(file, format = nil, return_metadata = false)
ensure_native!
format_name = (format || infer_format(file)).to_s
Expand Down Expand Up @@ -561,6 +581,71 @@ def from_dlpack(dlpack_value)

private

def resolve_array_dtype(positional_dtype, keyword_dtype)
return keyword_dtype if positional_dtype.nil?
return positional_dtype if keyword_dtype.nil?

if dtype_name_for_compare(positional_dtype) != dtype_name_for_compare(keyword_dtype)
raise ArgumentError,
"array received conflicting dtype arguments (positional=#{positional_dtype.inspect}, keyword=#{keyword_dtype.inspect})"
end

positional_dtype
end

def dtype_name_for_compare(dtype)
return nil if dtype.nil?

if dtype.respond_to?(:name)
dtype.name.to_s
else
dtype.to_s
end
end

def resolve_keepdims_argument(positional_keepdims, keyword_keepdims)
if !positional_keepdims.nil? && !keyword_keepdims.nil? && !!positional_keepdims != !!keyword_keepdims
raise ArgumentError,
"mean received conflicting keepdims arguments (positional=#{positional_keepdims.inspect}, keyword=#{keyword_keepdims.inspect})"
end
return !!keyword_keepdims unless keyword_keepdims.nil?
return !!positional_keepdims unless positional_keepdims.nil?

false
end

def reduce_mean(array, axis)
if axis.is_a?(::Array)
normalize_reduction_axes(array, axis).reverse_each.reduce(array) do |acc, axis_index|
native_mean(acc, axis_index)
end
else
native_mean(array, axis)
end
end

def normalize_reduction_axes(array, axis)
ndim = array.ndim
return (0...ndim).to_a if axis.nil?

raw_axes = axis.is_a?(::Array) ? axis : [axis]
axes = raw_axes.map { |entry| normalize_axis_index(entry, ndim) }.sort
raise ArgumentError, "axis contains duplicate values: #{raw_axes.inspect}" if axes.uniq.length != axes.length

axes
end

def normalize_axis_index(axis, ndim)
raise TypeError, "axis entries must be Integer" unless axis.is_a?(::Integer)

out = axis
out += ndim if out.negative?
if out.negative? || out >= ndim
raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{ndim}"
end
out
end

def infer_format(file)
path = file_path(file)
ext = File.extname(path).delete_prefix(".")
Expand Down Expand Up @@ -1118,8 +1203,8 @@ def cos
MLX::Core.cos(self)
end

def mean(axis = nil)
MLX::Core.mean(self, axis)
def mean(axis = nil, keepdims_positional = nil, keepdims: nil)
MLX::Core.mean(self, axis, keepdims_positional, keepdims: keepdims)
end

def sum(axis = nil)
Expand Down Expand Up @@ -1523,6 +1608,16 @@ def __rfloordiv__(other)
MLX::Core.floor_divide(other, self)
end

def coerce(other)
if other.is_a?(MLX::Core::Array)
[other, self]
elsif other.is_a?(::Numeric)
[MLX::Core.array(other, dtype), self]
else
raise TypeError, "#{other.class} can't be coerced into MLX::Core::Array"
end
end

def __getitem__(index)
self[index]
end
Expand Down
4 changes: 4 additions & 0 deletions lib/mlx/nn/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def update_modules_impl(dst, modules, strict)
current_value = dst[k]
if current_value.is_a?(Module) && new_value.is_a?(Module)
dst[k] = new_value
elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) || new_value.is_a?(Array))
update_modules_impl(current_value, new_value, strict)
elsif current_value.is_a?(Hash) || current_value.is_a?(Array)
update_modules_impl(current_value, new_value, strict)
elsif strict && new_value != {}
Expand All @@ -337,6 +339,8 @@ def update_modules_impl(dst, modules, strict)
current_value = dst[i]
if current_value.is_a?(Module) && new_value.is_a?(Module)
dst[i] = new_value
elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) || new_value.is_a?(Array))
update_modules_impl(current_value, new_value, strict)
elsif current_value.is_a?(Hash) || current_value.is_a?(Array)
update_modules_impl(current_value, new_value, strict)
elsif strict && new_value != {}
Expand Down
172 changes: 172 additions & 0 deletions prd/2026_02_25_mlx_ruby_gap_fixes_prd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# MLX-Ruby Gap Fixes PRD

- Date: 2026-02-25
- Owner: Codex + @skryl
- Status: Completed (amended 2026-02-25 for safetensors default-on follow-up)
- Scope Type: Upstream `mlx-ruby` compatibility and ergonomics

## 1) Problem Statement

`mlx-ruby-lm` integration surfaced API/interop gaps in `mlx-ruby` that force repetitive workarounds, reduce parity with Python MLX, and increase model-porting cost.

Most impactful categories:

1. Core API ergonomics (`array(dtype:)`, reduction `keepdims`)
2. Ruby numeric coercion with `MLX::Core::Array`
3. `MLX::NN::Module` traversal/update behavior under nested structures
4. ONNX lowering parity (`GreaterEqual`)

## 2) Goals

1. Remove high-friction workarounds required across model codebases.
2. Preserve backward compatibility where possible.
3. Add regression tests for every accepted fix.
4. Keep docs explicit about unsupported/optional features.

## 3) Non-Goals

1. Rewriting `mlx-ruby-lm` model architectures in this PRD.
2. Full MoE feature parity beyond scoped API fixes (for example complete `SwitchGLU` implementation in this pass unless separately approved).
3. Changing CI packaging defaults unless explicitly approved.

## 4) Scope Matrix (From Reported Issues)

## 4.1 In Scope (this PRD)

1. Fix `MLX::Core.array(values, dtype: ...)` keyword compatibility.
2. Add `keepdims` support to `mean` (and align reductions where practical).
3. Add Ruby `coerce` support for `MLX::Core::Array` so `Float * array`, `Float + array`, etc. work.
4. Confirm/patch `update_modules_impl` recursion for `Module -> Hash/Array` replacement path.
5. Add ONNX lowering coverage for `GreaterEqual` (Issue 14).
6. Strengthen docs on module child registration (`self.x = ...` vs `@x = ...`).
7. Enable `MLX_BUILD_SAFETENSORS=ON` in this repo's native build defaults and require safetensors roundtrip tests to pass.

## 4.2 Out of Scope (tracked, not implemented in this PRD unless requested)

1. Add `SwitchGLU` layer implementation.
2. Dropout constructor API expansion (`Dropout.new(p: ...)`) unless explicitly requested for compatibility.
3. Integer dtype support in `random_uniform` unless upstream MLX behavior and API contract are aligned for deterministic semantics.

## 5) Detailed Requirements

1. `array(dtype:)`
- Accept both positional dtype and keyword dtype.
- Reject conflicting dtype values with clear error.
- Keep existing behavior for `array(values)` and `array(values, dtype_obj)`.

2. `mean(..., keepdims:)`
- Support `MLX::Core.mean(array, axis=nil, keepdims=false)` in Ruby API.
- Support `MLX::Core::Array#mean(axis=nil, keepdims=nil)`.
- Maintain existing no-axis behavior.

3. Ruby coercion
- Implement `MLX::Core::Array#coerce(other)` so numeric-left operators work:
- `1.5 * arr`
- `1.5 + arr`
- `2 - arr`
- `2 / arr`
- Add tests for `Float` and `Integer` on left-hand side.

4. Module update recursion
- Ensure `NN.update_modules_impl` recursively handles:
- current `Module` + replacement `Hash`
- current `Module` + replacement `Array`
- Preserve existing behavior for map/map and array/array recursion.

5. ONNX `GreaterEqual`
- Add lowering support and integration/contract tests.
- Confirm exported graph compatibility report no longer flags missing op in covered paths.

6. Documentation
- Add explicit guidance in NN docs/README: assigning child modules must go through `self.child = ...` to register in state traversal.
- Document safetensors compile-time optionality and fallback path.

## 6) Phased Plan and Checklist

## Phase 0: Baseline + Repro

- [x] Add or identify failing regression tests for each in-scope item.
- [x] Confirm currently failing behavior before code changes.
- [x] Update this PRD with confirmed repro status.

## Phase 1: Core API Fixes

- [x] Implement `array(dtype:)` keyword support in Ruby/native boundary.
- [x] Implement `mean(..., keepdims:)` support in Ruby/native boundary.
- [x] Add targeted unit/parity tests for both.

## Phase 2: Ruby Coercion + NN Traversal

- [x] Implement `Array#coerce` for numeric LHS ops.
- [x] Add coercion tests for add/sub/mul/div.
- [x] Patch/confirm `update_modules_impl` recursion behavior.
- [x] Add regression tests for module/hash/array update recursion.

## Phase 3: ONNX + Docs

- [x] Add `GreaterEqual` lowering support and tests.
- [x] Add docs for `self.x =` child registration requirement.
- [x] Add docs for safetensors optional build feature and fallback.

## Phase 4: Validation + Completion

- [x] Run targeted tests for touched files/features.
- [x] Run broader suite covering core/nn/onnx touched areas.
- [x] Safetensors default-on follow-up: flip build flag and verify native safetensors roundtrip behavior in parity tests.
- [x] Update PRD status to `Completed` only when all checklist items are done.

## 6.1) Baseline Repro Notes (2026-02-25)

Observed before fixes:

1. `MLX::Core.array([1,2,3], dtype: MLX::Core.int32)` raised `ArgumentError`.
2. `MLX::Core.mean(x, axis, keepdims: true)` raised wrong-arity `ArgumentError`.
3. Numeric-left ops (`1.5 + array`) raised `TypeError` (missing `coerce`).
4. `update_modules` raised `ArgumentError: Received invalid type: Hash.` for
`current_value=Module` + `new_value=Hash` recursion paths.

Regression tests added:

1. `test/core/core_api_gap_regression_test.rb`
2. `test/nn/module_update_modules_recursion_test.rb`

## 7) Test Strategy

Minimum per-change targeted tests:

1. Core API:
- `test/parity/phase5_core_ops_test.rb` (or dedicated new parity file)
- new unit tests for `array(dtype:)` keyword and `mean keepdims`
2. Coercion:
- new `test/parity` or `test/core` coverage for numeric-left operations
3. NN traversal:
- existing quantization/state traversal tests plus new recursion regression test
4. ONNX:
- ONNX binding/integration tests covering `GreaterEqual` export and runtime path

Completion sweep:

1. Run all targeted tests for modified files.
2. Run a broad cross-cutting suite including core + nn + onnx touched domains before marking complete.

## 8) Risks and Mitigations

1. Native binding signature changes can break call compatibility.
- Mitigation: keep positional args valid; add explicit keyword parsing tests.
2. `coerce` may affect operator dispatch in edge cases.
- Mitigation: limit support to numeric and raise clear `TypeError` otherwise.
3. ONNX lowering changes can affect compatibility reports.
- Mitigation: add both positive tests and unsupported-op boundary tests.

## 9) Acceptance Criteria

1. All in-scope checklist items are checked.
2. New/updated tests for each implemented fix are green.
3. No regressions in touched core/nn/onnx paths.
4. PRD status changed from `Draft` to `Completed` only when all above are true.

## 10) Open Decisions

1. Should `Dropout.new(p: 0.5)` keyword support be included now or deferred?
2. Should integer `random_uniform` support be emulated in Ruby or left as explicit unsupported behavior?
3. Should safetensors default build flags be changed in CI/release pipelines in this effort? Resolved for this repo on 2026-02-25: native build default switched to `MLX_BUILD_SAFETENSORS=ON`.
2 changes: 1 addition & 1 deletion submodules/mlx-onnx
Submodule mlx-onnx updated 79 files
+87 −0 .github/workflows/tests.yml
+1 −1 CMakeLists.txt
+83 −176 README.md
+66 −0 docs/cpp-interface.md
+233 −0 docs/native-architecture.md
+73 −0 docs/python-interface.md
+96 −0 docs/supported-mlx-ops.md
+157 −0 prd/2026-02-23-per-op-mlx-to-onnx-parity.md
+9 −1 pyproject.toml
+72 −0 python/examples/export_tiny_mlp_onnx.py
+61 −4 python/tests/test_examples.py
+57 −0 python/tests/unit/_contract_harness.py
+751 −0 python/tests/unit/_op_cases.py
+111 −0 python/tests/unit/_op_harness.py
+354 −0 python/tests/unit/test_legacy_gap_closure.py
+485 −0 python/tests/unit/test_lowering_rewrite_contracts.py
+5 −0 python/tests/unit/test_op_abs.py
+5 −0 python/tests/unit/test_op_add.py
+5 −0 python/tests/unit/test_op_add_mm.py
+36 −0 python/tests/unit/test_op_arange.py
+57 −0 python/tests/unit/test_op_argreduce.py
+5 −0 python/tests/unit/test_op_as_strided.py
+15 −0 python/tests/unit/test_op_as_type.py
+5 −0 python/tests/unit/test_op_broadcast.py
+15 −0 python/tests/unit/test_op_concatenate.py
+5 −0 python/tests/unit/test_op_convolution.py
+5 −0 python/tests/unit/test_op_convolution_transpose.py
+15 −0 python/tests/unit/test_op_cos.py
+5 −0 python/tests/unit/test_op_divide.py
+15 −0 python/tests/unit/test_op_equal.py
+15 −0 python/tests/unit/test_op_erf.py
+5 −0 python/tests/unit/test_op_erf_inv.py
+5 −0 python/tests/unit/test_op_exp.py
+38 −0 python/tests/unit/test_op_expand_dims.py
+35 −0 python/tests/unit/test_op_flatten.py
+15 −0 python/tests/unit/test_op_floor.py
+5 −0 python/tests/unit/test_op_full.py
+30 −0 python/tests/unit/test_op_gather.py
+5 −0 python/tests/unit/test_op_gather_axis.py
+5 −0 python/tests/unit/test_op_greater.py
+5 −0 python/tests/unit/test_op_greater_equal.py
+5 −0 python/tests/unit/test_op_layer_norm.py
+15 −0 python/tests/unit/test_op_less.py
+5 −0 python/tests/unit/test_op_log.py
+5 −0 python/tests/unit/test_op_log_sum_exp.py
+5 −0 python/tests/unit/test_op_matmul.py
+5 −0 python/tests/unit/test_op_maximum.py
+5 −0 python/tests/unit/test_op_minimum.py
+5 −0 python/tests/unit/test_op_multiply.py
+5 −0 python/tests/unit/test_op_negative.py
+5 −0 python/tests/unit/test_op_pad.py
+5 −0 python/tests/unit/test_op_power.py
+5 −0 python/tests/unit/test_op_random_bits.py
+88 −0 python/tests/unit/test_op_reduce.py
+5 −0 python/tests/unit/test_op_relu.py
+40 −0 python/tests/unit/test_op_reshape.py
+5 −0 python/tests/unit/test_op_rope.py
+5 −0 python/tests/unit/test_op_scan.py
+15 −0 python/tests/unit/test_op_scatter_axis.py
+24 −0 python/tests/unit/test_op_select.py
+5 −0 python/tests/unit/test_op_sigmoid.py
+15 −0 python/tests/unit/test_op_sin.py
+57 −0 python/tests/unit/test_op_slice.py
+5 −0 python/tests/unit/test_op_slice_update.py
+15 −0 python/tests/unit/test_op_softmax.py
+52 −0 python/tests/unit/test_op_split.py
+5 −0 python/tests/unit/test_op_sqrt.py
+5 −0 python/tests/unit/test_op_square.py
+5 −0 python/tests/unit/test_op_squeeze.py
+5 −0 python/tests/unit/test_op_subtract.py
+5 −0 python/tests/unit/test_op_tanh.py
+5 −0 python/tests/unit/test_op_transpose.py
+5 −0 python/tests/unit/test_op_unflatten.py
+54 −0 skills/phased-prd-red-green/SKILL.md
+4 −0 skills/phased-prd-red-green/agents/openai.yaml
+45 −0 skills/phased-prd-red-green/references/prd_red_green_template.md
+1 −1 src/lowering.cpp
+2 −1 src/mappings.cpp
+48 −0 tasks.py
Loading
Loading