diff --git a/docs/src/ruby/api_reference.rst b/docs/src/ruby/api_reference.rst index 6d49b1e7..3de4486d 100644 --- a/docs/src/ruby/api_reference.rst +++ b/docs/src/ruby/api_reference.rst @@ -60,7 +60,10 @@ Key methods: - Persistence: ``save_weights``, ``load_weights`` Important pattern: register trainable members with ``self. = ...`` so -they are tracked by module state and optimizer updates. +they are tracked by module state and optimizer updates. Direct ivar assignment +(``@name = ...``) bypasses module-state registration and those members will be +missed by traversal helpers (for example ``children``/``parameters``), +``load_weights``, and quantization updates. See implementation: @@ -127,6 +130,10 @@ Graph IR / ONNX / WebGPU entry points in ``MLX::ONNX``: - ``MLX::ONNX::WebGPUHarness.export_onnx_webgpu_harness`` - ``MLX::ONNX::WebGPUHarness.smoke_test_onnx_webgpu_harness`` +``save_safetensors`` requires a native MLX build with +``MLX_BUILD_SAFETENSORS=ON``. If safetensors support is unavailable, use +``savez``/``savez_compressed`` or a Ruby safetensors serializer fallback. + .. _distributed: Distributed diff --git a/ext/mlx/extconf.rb b/ext/mlx/extconf.rb index b9736c8d..cf8681df 100644 --- a/ext/mlx/extconf.rb +++ b/ext/mlx/extconf.rb @@ -232,7 +232,7 @@ def patch_makefile_include_dirs!(makefile_path, include_dirs) "-DMLX_BUILD_PYTHON_STUBS=OFF", "-DMLX_BUILD_METAL=ON", "-DMLX_BUILD_GGUF=OFF", - "-DMLX_BUILD_SAFETENSORS=OFF", + "-DMLX_BUILD_SAFETENSORS=ON", "-DBUILD_SHARED_LIBS=ON" ] diff --git a/lib/mlx/core.rb b/lib/mlx/core.rb index ed3e8b18..eae6c914 100644 --- a/lib/mlx/core.rb +++ b/lib/mlx/core.rb @@ -335,6 +335,8 @@ class << self alias_method :native_vmap, :vmap if method_defined?(:vmap) && !method_defined?(:native_vmap) alias_method :native_export_to_dot, :export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot) + alias_method :native_array, :array if method_defined?(:array) && !method_defined?(:native_array) + alias_method :native_mean, :mean if method_defined?(:mean) && !method_defined?(:native_mean) %i[savez savez_compressed].each do |method_name| if method_defined?(method_name) && instance_method(method_name).owner == self @@ -344,6 +346,24 @@ class << self ARRAY_LEAF = :__mlx_array_leaf__ + def array(value, positional_dtype = nil, dtype: nil) + ensure_native! + target_dtype = resolve_array_dtype(positional_dtype, dtype) + native_array(value, target_dtype) + end + + def mean(array, axis = nil, positional_keepdims = nil, keepdims: nil) + ensure_native! + keepdims_v = resolve_keepdims_argument(positional_keepdims, keepdims) + reduced = reduce_mean(array, axis) + return reduced unless keepdims_v + + normalize_reduction_axes(array, axis).each do |axis_index| + reduced = expand_dims(reduced, axis_index) + end + reduced + end + def load(file, format = nil, return_metadata = false) ensure_native! format_name = (format || infer_format(file)).to_s @@ -561,6 +581,71 @@ def from_dlpack(dlpack_value) private + def resolve_array_dtype(positional_dtype, keyword_dtype) + return keyword_dtype if positional_dtype.nil? + return positional_dtype if keyword_dtype.nil? + + if dtype_name_for_compare(positional_dtype) != dtype_name_for_compare(keyword_dtype) + raise ArgumentError, + "array received conflicting dtype arguments (positional=#{positional_dtype.inspect}, keyword=#{keyword_dtype.inspect})" + end + + positional_dtype + end + + def dtype_name_for_compare(dtype) + return nil if dtype.nil? + + if dtype.respond_to?(:name) + dtype.name.to_s + else + dtype.to_s + end + end + + def resolve_keepdims_argument(positional_keepdims, keyword_keepdims) + if !positional_keepdims.nil? && !keyword_keepdims.nil? && !!positional_keepdims != !!keyword_keepdims + raise ArgumentError, + "mean received conflicting keepdims arguments (positional=#{positional_keepdims.inspect}, keyword=#{keyword_keepdims.inspect})" + end + return !!keyword_keepdims unless keyword_keepdims.nil? + return !!positional_keepdims unless positional_keepdims.nil? + + false + end + + def reduce_mean(array, axis) + if axis.is_a?(::Array) + normalize_reduction_axes(array, axis).reverse_each.reduce(array) do |acc, axis_index| + native_mean(acc, axis_index) + end + else + native_mean(array, axis) + end + end + + def normalize_reduction_axes(array, axis) + ndim = array.ndim + return (0...ndim).to_a if axis.nil? + + raw_axes = axis.is_a?(::Array) ? axis : [axis] + axes = raw_axes.map { |entry| normalize_axis_index(entry, ndim) }.sort + raise ArgumentError, "axis contains duplicate values: #{raw_axes.inspect}" if axes.uniq.length != axes.length + + axes + end + + def normalize_axis_index(axis, ndim) + raise TypeError, "axis entries must be Integer" unless axis.is_a?(::Integer) + + out = axis + out += ndim if out.negative? + if out.negative? || out >= ndim + raise ArgumentError, "axis #{axis} is out of bounds for array of dimension #{ndim}" + end + out + end + def infer_format(file) path = file_path(file) ext = File.extname(path).delete_prefix(".") @@ -1118,8 +1203,8 @@ def cos MLX::Core.cos(self) end - def mean(axis = nil) - MLX::Core.mean(self, axis) + def mean(axis = nil, keepdims_positional = nil, keepdims: nil) + MLX::Core.mean(self, axis, keepdims_positional, keepdims: keepdims) end def sum(axis = nil) @@ -1523,6 +1608,16 @@ def __rfloordiv__(other) MLX::Core.floor_divide(other, self) end + def coerce(other) + if other.is_a?(MLX::Core::Array) + [other, self] + elsif other.is_a?(::Numeric) + [MLX::Core.array(other, dtype), self] + else + raise TypeError, "#{other.class} can't be coerced into MLX::Core::Array" + end + end + def __getitem__(index) self[index] end diff --git a/lib/mlx/nn/base.rb b/lib/mlx/nn/base.rb index d67abe62..5f130fa6 100644 --- a/lib/mlx/nn/base.rb +++ b/lib/mlx/nn/base.rb @@ -323,6 +323,8 @@ def update_modules_impl(dst, modules, strict) current_value = dst[k] if current_value.is_a?(Module) && new_value.is_a?(Module) dst[k] = new_value + elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) || new_value.is_a?(Array)) + update_modules_impl(current_value, new_value, strict) elsif current_value.is_a?(Hash) || current_value.is_a?(Array) update_modules_impl(current_value, new_value, strict) elsif strict && new_value != {} @@ -337,6 +339,8 @@ def update_modules_impl(dst, modules, strict) current_value = dst[i] if current_value.is_a?(Module) && new_value.is_a?(Module) dst[i] = new_value + elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) || new_value.is_a?(Array)) + update_modules_impl(current_value, new_value, strict) elsif current_value.is_a?(Hash) || current_value.is_a?(Array) update_modules_impl(current_value, new_value, strict) elsif strict && new_value != {} diff --git a/prd/2026_02_25_mlx_ruby_gap_fixes_prd.md b/prd/2026_02_25_mlx_ruby_gap_fixes_prd.md new file mode 100644 index 00000000..a27facdd --- /dev/null +++ b/prd/2026_02_25_mlx_ruby_gap_fixes_prd.md @@ -0,0 +1,172 @@ +# MLX-Ruby Gap Fixes PRD + +- Date: 2026-02-25 +- Owner: Codex + @skryl +- Status: Completed (amended 2026-02-25 for safetensors default-on follow-up) +- Scope Type: Upstream `mlx-ruby` compatibility and ergonomics + +## 1) Problem Statement + +`mlx-ruby-lm` integration surfaced API/interop gaps in `mlx-ruby` that force repetitive workarounds, reduce parity with Python MLX, and increase model-porting cost. + +Most impactful categories: + +1. Core API ergonomics (`array(dtype:)`, reduction `keepdims`) +2. Ruby numeric coercion with `MLX::Core::Array` +3. `MLX::NN::Module` traversal/update behavior under nested structures +4. ONNX lowering parity (`GreaterEqual`) + +## 2) Goals + +1. Remove high-friction workarounds required across model codebases. +2. Preserve backward compatibility where possible. +3. Add regression tests for every accepted fix. +4. Keep docs explicit about unsupported/optional features. + +## 3) Non-Goals + +1. Rewriting `mlx-ruby-lm` model architectures in this PRD. +2. Full MoE feature parity beyond scoped API fixes (for example complete `SwitchGLU` implementation in this pass unless separately approved). +3. Changing CI packaging defaults unless explicitly approved. + +## 4) Scope Matrix (From Reported Issues) + +## 4.1 In Scope (this PRD) + +1. Fix `MLX::Core.array(values, dtype: ...)` keyword compatibility. +2. Add `keepdims` support to `mean` (and align reductions where practical). +3. Add Ruby `coerce` support for `MLX::Core::Array` so `Float * array`, `Float + array`, etc. work. +4. Confirm/patch `update_modules_impl` recursion for `Module -> Hash/Array` replacement path. +5. Add ONNX lowering coverage for `GreaterEqual` (Issue 14). +6. Strengthen docs on module child registration (`self.x = ...` vs `@x = ...`). +7. Enable `MLX_BUILD_SAFETENSORS=ON` in this repo's native build defaults and require safetensors roundtrip tests to pass. + +## 4.2 Out of Scope (tracked, not implemented in this PRD unless requested) + +1. Add `SwitchGLU` layer implementation. +2. Dropout constructor API expansion (`Dropout.new(p: ...)`) unless explicitly requested for compatibility. +3. Integer dtype support in `random_uniform` unless upstream MLX behavior and API contract are aligned for deterministic semantics. + +## 5) Detailed Requirements + +1. `array(dtype:)` + - Accept both positional dtype and keyword dtype. + - Reject conflicting dtype values with clear error. + - Keep existing behavior for `array(values)` and `array(values, dtype_obj)`. + +2. `mean(..., keepdims:)` + - Support `MLX::Core.mean(array, axis=nil, keepdims=false)` in Ruby API. + - Support `MLX::Core::Array#mean(axis=nil, keepdims=nil)`. + - Maintain existing no-axis behavior. + +3. Ruby coercion + - Implement `MLX::Core::Array#coerce(other)` so numeric-left operators work: + - `1.5 * arr` + - `1.5 + arr` + - `2 - arr` + - `2 / arr` + - Add tests for `Float` and `Integer` on left-hand side. + +4. Module update recursion + - Ensure `NN.update_modules_impl` recursively handles: + - current `Module` + replacement `Hash` + - current `Module` + replacement `Array` + - Preserve existing behavior for map/map and array/array recursion. + +5. ONNX `GreaterEqual` + - Add lowering support and integration/contract tests. + - Confirm exported graph compatibility report no longer flags missing op in covered paths. + +6. Documentation + - Add explicit guidance in NN docs/README: assigning child modules must go through `self.child = ...` to register in state traversal. + - Document safetensors compile-time optionality and fallback path. + +## 6) Phased Plan and Checklist + +## Phase 0: Baseline + Repro + +- [x] Add or identify failing regression tests for each in-scope item. +- [x] Confirm currently failing behavior before code changes. +- [x] Update this PRD with confirmed repro status. + +## Phase 1: Core API Fixes + +- [x] Implement `array(dtype:)` keyword support in Ruby/native boundary. +- [x] Implement `mean(..., keepdims:)` support in Ruby/native boundary. +- [x] Add targeted unit/parity tests for both. + +## Phase 2: Ruby Coercion + NN Traversal + +- [x] Implement `Array#coerce` for numeric LHS ops. +- [x] Add coercion tests for add/sub/mul/div. +- [x] Patch/confirm `update_modules_impl` recursion behavior. +- [x] Add regression tests for module/hash/array update recursion. + +## Phase 3: ONNX + Docs + +- [x] Add `GreaterEqual` lowering support and tests. +- [x] Add docs for `self.x =` child registration requirement. +- [x] Add docs for safetensors optional build feature and fallback. + +## Phase 4: Validation + Completion + +- [x] Run targeted tests for touched files/features. +- [x] Run broader suite covering core/nn/onnx touched areas. +- [x] Safetensors default-on follow-up: flip build flag and verify native safetensors roundtrip behavior in parity tests. +- [x] Update PRD status to `Completed` only when all checklist items are done. + +## 6.1) Baseline Repro Notes (2026-02-25) + +Observed before fixes: + +1. `MLX::Core.array([1,2,3], dtype: MLX::Core.int32)` raised `ArgumentError`. +2. `MLX::Core.mean(x, axis, keepdims: true)` raised wrong-arity `ArgumentError`. +3. Numeric-left ops (`1.5 + array`) raised `TypeError` (missing `coerce`). +4. `update_modules` raised `ArgumentError: Received invalid type: Hash.` for + `current_value=Module` + `new_value=Hash` recursion paths. + +Regression tests added: + +1. `test/core/core_api_gap_regression_test.rb` +2. `test/nn/module_update_modules_recursion_test.rb` + +## 7) Test Strategy + +Minimum per-change targeted tests: + +1. Core API: + - `test/parity/phase5_core_ops_test.rb` (or dedicated new parity file) + - new unit tests for `array(dtype:)` keyword and `mean keepdims` +2. Coercion: + - new `test/parity` or `test/core` coverage for numeric-left operations +3. NN traversal: + - existing quantization/state traversal tests plus new recursion regression test +4. ONNX: + - ONNX binding/integration tests covering `GreaterEqual` export and runtime path + +Completion sweep: + +1. Run all targeted tests for modified files. +2. Run a broad cross-cutting suite including core + nn + onnx touched domains before marking complete. + +## 8) Risks and Mitigations + +1. Native binding signature changes can break call compatibility. + - Mitigation: keep positional args valid; add explicit keyword parsing tests. +2. `coerce` may affect operator dispatch in edge cases. + - Mitigation: limit support to numeric and raise clear `TypeError` otherwise. +3. ONNX lowering changes can affect compatibility reports. + - Mitigation: add both positive tests and unsupported-op boundary tests. + +## 9) Acceptance Criteria + +1. All in-scope checklist items are checked. +2. New/updated tests for each implemented fix are green. +3. No regressions in touched core/nn/onnx paths. +4. PRD status changed from `Draft` to `Completed` only when all above are true. + +## 10) Open Decisions + +1. Should `Dropout.new(p: 0.5)` keyword support be included now or deferred? +2. Should integer `random_uniform` support be emulated in Ruby or left as explicit unsupported behavior? +3. Should safetensors default build flags be changed in CI/release pipelines in this effort? Resolved for this repo on 2026-02-25: native build default switched to `MLX_BUILD_SAFETENSORS=ON`. diff --git a/submodules/mlx-onnx b/submodules/mlx-onnx index 7802b305..128d7de2 160000 --- a/submodules/mlx-onnx +++ b/submodules/mlx-onnx @@ -1 +1 @@ -Subproject commit 7802b3059234084ba39faa30d09efaa86652afc3 +Subproject commit 128d7de232f9888a16cd58d82bb2a50ab2be21be diff --git a/test/core/core_api_gap_regression_test.rb b/test/core/core_api_gap_regression_test.rb new file mode 100644 index 00000000..494e1aac --- /dev/null +++ b/test/core/core_api_gap_regression_test.rb @@ -0,0 +1,72 @@ +# frozen_string_literal: true + +require_relative "../support/test_helper" + +$LOAD_PATH.unshift(File.join(RUBY_ROOT, "lib")) +require "mlx" + +class CoreApiGapRegressionTest < Minitest::Test + def setup + TestSupport.build_native_extension! + end + + def teardown + $LOAD_PATH.delete(File.join(RUBY_ROOT, "lib")) + end + + def test_array_accepts_dtype_keyword_argument + array = MLX::Core.array([1, 2, 3], dtype: MLX::Core.int32) + + assert_equal :int32, array.dtype.name + assert_equal [1, 2, 3], array.to_a + end + + def test_array_rejects_conflicting_dtype_positional_and_keyword_arguments + error = assert_raises(ArgumentError) do + MLX::Core.array([1, 2, 3], MLX::Core.float32, dtype: MLX::Core.int32) + end + assert_match(/conflicting dtype/i, error.message) + end + + def test_mean_supports_keepdims_keyword + matrix = MLX::Core.array([[1.0, 2.0], [3.0, 4.0]], MLX::Core.float32) + + by_row = MLX::Core.mean(matrix, 1, keepdims: true) + assert_equal [2, 1], by_row.shape + assert_nested_close [[1.5], [3.5]], by_row.to_a + + global = MLX::Core.mean(matrix, keepdims: true) + assert_equal [1, 1], global.shape + assert_nested_close [[2.5]], global.to_a + end + + def test_numeric_left_hand_scalar_ops_work_with_arrays + array = MLX::Core.array([1.0, 2.0], MLX::Core.float32) + + assert_nested_close [2.5, 3.5], (1.5 + array).to_a + assert_nested_close [1.5, 3.0], (1.5 * array).to_a + assert_nested_close [1.0, 0.0], (2 - array).to_a + assert_nested_close [2.0, 1.0], (2 / array).to_a + end + + private + + def assert_nested_close(expected, actual, atol = 1e-5) + assert_equal shape_signature(expected), shape_signature(actual) + flatten(expected).zip(flatten(actual)).each do |exp, got| + assert_in_delta exp, got, atol + end + end + + def flatten(value) + return [value] unless value.is_a?(Array) + + value.flat_map { |item| flatten(item) } + end + + def shape_signature(value) + return :scalar unless value.is_a?(Array) + + [value.length, *(value.map { |item| shape_signature(item) })] + end +end diff --git a/test/nn/module_update_modules_recursion_test.rb b/test/nn/module_update_modules_recursion_test.rb new file mode 100644 index 00000000..a2bfe0cf --- /dev/null +++ b/test/nn/module_update_modules_recursion_test.rb @@ -0,0 +1,58 @@ +# frozen_string_literal: true + +require_relative "../support/test_helper" + +$LOAD_PATH.unshift(File.join(RUBY_ROOT, "lib")) +require "mlx" + +class ModuleUpdateModulesRecursionTest < Minitest::Test + def setup + TestSupport.build_native_extension! + end + + def teardown + $LOAD_PATH.delete(File.join(RUBY_ROOT, "lib")) + end + + def test_update_modules_recurses_when_current_value_is_module_and_new_value_is_hash + root = build_tree + replacement = build_leaf(9.0) + + root.update_modules( + { "child" => { "inner" => replacement } }, + strict: true + ) + + assert_same replacement, root.child.inner + end + + def test_update_modules_recurses_for_array_of_modules_when_new_values_are_hashes + root = build_tree + replacement = build_leaf(11.0) + + root.update_modules( + { "items" => [{ "inner" => replacement }] }, + strict: true + ) + + assert_same replacement, root.items[0].inner + end + + private + + def build_leaf(value) + leaf = MLX::NN::Module.new + leaf.weight = MLX::Core.array([value], MLX::Core.float32) + leaf + end + + def build_tree + root = MLX::NN::Module.new + root.child = MLX::NN::Module.new + root.child.inner = build_leaf(1.0) + + root.items = [MLX::NN::Module.new] + root.items[0].inner = build_leaf(2.0) + root + end +end diff --git a/test/parity/phase171_module_load_save_weights_test.rb b/test/parity/phase171_module_load_save_weights_test.rb index b72e43e1..3452d07b 100644 --- a/test/parity/phase171_module_load_save_weights_test.rb +++ b/test/parity/phase171_module_load_save_weights_test.rb @@ -45,16 +45,18 @@ def test_save_and_load_npz_weights_roundtrip mod = WeightsModule.new TestSupport.mktmpdir("mlx-ruby-weights") do |dir| - path = File.join(dir, "weights.npz") - mod.save_weights(path) - - other = WeightsModule.new - other.weight = MLX::Core.array([[0.0, 0.0], [0.0, 0.0]], MLX::Core.float32) - other.bias = MLX::Core.array([0.0, 0.0], MLX::Core.float32) - other.load_weights(path, strict: true) - - assert_nested_close mod.weight.to_a, other.weight.to_a - assert_nested_close mod.bias.to_a, other.bias.to_a + ["weights.npz", "weights.safetensors"].each do |filename| + path = File.join(dir, filename) + mod.save_weights(path) + + other = WeightsModule.new + other.weight = MLX::Core.array([[0.0, 0.0], [0.0, 0.0]], MLX::Core.float32) + other.bias = MLX::Core.array([0.0, 0.0], MLX::Core.float32) + other.load_weights(path, strict: true) + + assert_nested_close mod.weight.to_a, other.weight.to_a + assert_nested_close mod.bias.to_a, other.bias.to_a + end end end diff --git a/test/parity/phase250_load_save_edge_parity_test.rb b/test/parity/phase250_load_save_edge_parity_test.rb index e663e976..9fefdacd 100644 --- a/test/parity/phase250_load_save_edge_parity_test.rb +++ b/test/parity/phase250_load_save_edge_parity_test.rb @@ -49,12 +49,8 @@ def test_non_contiguous_roundtrip_and_optional_container_formats transposed = MLX::Core.swapaxes(MLX::Core.reshape(MLX::Core.arange(0, 4, 1, MLX::Core.int32), [2, 2]), 0, 1) safetensors = File.join(dir, "a.safetensors") - begin - MLX::Core.save_safetensors(safetensors, {"a" => transposed}) - assert MLX::Core.array_equal(transposed, MLX::Core.load(safetensors)["a"]) - rescue RuntimeError => e - assert_match(/SAFETENSORS|safetensors/i, e.message) - end + MLX::Core.save_safetensors(safetensors, {"a" => transposed}) + assert MLX::Core.array_equal(transposed, MLX::Core.load(safetensors)["a"]) gguf = File.join(dir, "a.gguf") begin diff --git a/test/parity/phase50_io_formats_test.rb b/test/parity/phase50_io_formats_test.rb index bbafa41a..bcfe9a61 100644 --- a/test/parity/phase50_io_formats_test.rb +++ b/test/parity/phase50_io_formats_test.rb @@ -19,14 +19,10 @@ def test_save_safetensors_roundtrip_or_feature_error TestSupport.mktmpdir do |dir| path = File.join(dir, "weights.safetensors") - begin - MLX::Core.save_safetensors(path, { "x" => x }, { "note" => "ok" }) - arrays, metadata = MLX::Core.load(path, "safetensors", true) - assert MLX::Core.array_equal(x, arrays["x"]) - assert_equal "ok", metadata["note"] if metadata.key?("note") - rescue RuntimeError => e - assert_match(/SAFETENSORS|safetensors/i, e.message) - end + MLX::Core.save_safetensors(path, { "x" => x }, { "note" => "ok" }) + arrays, metadata = MLX::Core.load(path, "safetensors", true) + assert MLX::Core.array_equal(x, arrays["x"]) + assert_equal "ok", metadata["note"] end end diff --git a/test/parity/phase81_build_stability_contract_test.rb b/test/parity/phase81_build_stability_contract_test.rb index 0792045f..02524384 100644 --- a/test/parity/phase81_build_stability_contract_test.rb +++ b/test/parity/phase81_build_stability_contract_test.rb @@ -17,7 +17,7 @@ def test_build_stability_contract checks = payload.fetch("checks") assert_equal true, checks.fetch("gguf_disabled") - assert_equal true, checks.fetch("safetensors_disabled") + assert_equal true, checks.fetch("safetensors_enabled") assert_equal true, checks.fetch("configure_retry_present") assert_equal true, checks.fetch("retry_cleans_build_root") end diff --git a/test/support/parity/check_build_stability.rb b/test/support/parity/check_build_stability.rb index 804cdd34..f8e11714 100644 --- a/test/support/parity/check_build_stability.rb +++ b/test/support/parity/check_build_stability.rb @@ -14,7 +14,7 @@ checks = { "gguf_disabled" => source.include?("-DMLX_BUILD_GGUF=OFF"), - "safetensors_disabled" => source.include?("-DMLX_BUILD_SAFETENSORS=OFF"), + "safetensors_enabled" => source.include?("-DMLX_BUILD_SAFETENSORS=ON"), "configure_retry_present" => source.include?("initial CMake configure failed"), "retry_cleans_build_root" => source.include?("FileUtils.rm_rf(build_root)") }