diff --git a/LIMITATIONS.md b/coreml/LIMITATIONS.md similarity index 100% rename from LIMITATIONS.md rename to coreml/LIMITATIONS.md diff --git a/gomlx/additional_ops_test.go b/coreml/gomlx/additional_ops_test.go similarity index 100% rename from gomlx/additional_ops_test.go rename to coreml/gomlx/additional_ops_test.go diff --git a/gomlx/backend.go b/coreml/gomlx/backend.go similarity index 97% rename from gomlx/backend.go rename to coreml/gomlx/backend.go index 46f2f58..f7c1540 100644 --- a/gomlx/backend.go +++ b/coreml/gomlx/backend.go @@ -11,8 +11,8 @@ import ( "strings" "sync" - "github.com/gomlx/go-coreml/model" - "github.com/gomlx/go-coreml/runtime" + "github.com/gomlx/go-darwinml/coreml/model" + "github.com/gomlx/go-darwinml/coreml/runtime" "github.com/gomlx/gomlx/backends" "github.com/pkg/errors" ) diff --git a/gomlx/benchmark_test.go b/coreml/gomlx/benchmark_test.go similarity index 100% rename from gomlx/benchmark_test.go rename to coreml/gomlx/benchmark_test.go diff --git a/gomlx/buffer.go b/coreml/gomlx/buffer.go similarity index 100% rename from gomlx/buffer.go rename to coreml/gomlx/buffer.go diff --git a/gomlx/builder.go b/coreml/gomlx/builder.go similarity index 99% rename from gomlx/builder.go rename to coreml/gomlx/builder.go index 75cbbf2..c1a0a83 100644 --- a/gomlx/builder.go +++ b/coreml/gomlx/builder.go @@ -8,7 +8,7 @@ import ( "fmt" "reflect" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/model" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/backends/notimplemented" "github.com/gomlx/gomlx/pkg/core/dtypes" diff --git a/gomlx/capabilities.go b/coreml/gomlx/capabilities.go similarity index 100% rename from gomlx/capabilities.go rename to coreml/gomlx/capabilities.go diff --git a/gomlx/coreml_test.go b/coreml/gomlx/coreml_test.go similarity index 100% rename from gomlx/coreml_test.go rename to coreml/gomlx/coreml_test.go diff --git a/gomlx/doc.go b/coreml/gomlx/doc.go similarity index 100% rename from gomlx/doc.go rename to coreml/gomlx/doc.go diff --git a/gomlx/dynamic_update_slice_test.go b/coreml/gomlx/dynamic_update_slice_test.go similarity index 100% rename from gomlx/dynamic_update_slice_test.go rename to coreml/gomlx/dynamic_update_slice_test.go diff --git a/gomlx/executable.go b/coreml/gomlx/executable.go similarity index 99% rename from gomlx/executable.go rename to coreml/gomlx/executable.go index 39178e7..7a2f0ee 100644 --- a/gomlx/executable.go +++ b/coreml/gomlx/executable.go @@ -6,7 +6,7 @@ import ( "reflect" "sync" - "github.com/gomlx/go-coreml/runtime" + "github.com/gomlx/go-darwinml/coreml/runtime" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/pkg/core/shapes" "github.com/pkg/errors" diff --git a/gomlx/function.go b/coreml/gomlx/function.go similarity index 99% rename from gomlx/function.go rename to coreml/gomlx/function.go index 766dc31..e1000f3 100644 --- a/gomlx/function.go +++ b/coreml/gomlx/function.go @@ -9,7 +9,7 @@ import ( "math" "slices" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/model" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/backends/notimplemented" "github.com/gomlx/gomlx/backends/shapeinference" diff --git a/gomlx/gather_test.go b/coreml/gomlx/gather_test.go similarity index 100% rename from gomlx/gather_test.go rename to coreml/gomlx/gather_test.go diff --git a/gomlx/go.mod b/coreml/gomlx/go.mod similarity index 59% rename from gomlx/go.mod rename to coreml/gomlx/go.mod index ebd5ae9..cc71b46 100644 --- a/gomlx/go.mod +++ b/coreml/gomlx/go.mod @@ -1,10 +1,10 @@ -module github.com/gomlx/go-coreml/gomlx +module github.com/gomlx/go-darwinml/coreml/gomlx go 1.26.0 require ( - github.com/gomlx/go-coreml v0.0.0-20260218230850-b757a40d32e7 - github.com/gomlx/gomlx v0.26.1-0.20260220075116-8da82ca8aaad + github.com/gomlx/go-darwinml v0.0.0-20260218230850-b757a40d32e7 + github.com/gomlx/gomlx v0.27.0 github.com/pkg/errors v0.9.1 ) @@ -15,7 +15,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa // indirect google.golang.org/protobuf v1.36.11 // indirect - k8s.io/klog/v2 v2.130.1 // indirect + k8s.io/klog/v2 v2.140.0 // indirect ) -replace github.com/gomlx/go-coreml => ../ +replace github.com/gomlx/go-darwinml => ../../ diff --git a/gomlx/go.sum b/coreml/gomlx/go.sum similarity index 85% rename from gomlx/go.sum rename to coreml/gomlx/go.sum index a9f4d82..8514d35 100644 --- a/gomlx/go.sum +++ b/coreml/gomlx/go.sum @@ -8,8 +8,11 @@ github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw= github.com/gofrs/flock v0.13.0/go.mod h1:jxeyy9R1auM5S6JYDBhDt+E2TCo7DkratH4Pgi8P+Z0= github.com/gomlx/go-xla v0.1.5-0.20260219173412-338774b2e7a7 h1:lZSXhh4/1962rXKwEvnhkRPYXPFe0zWn57JzZaQgEOQ= github.com/gomlx/go-xla v0.1.5-0.20260219173412-338774b2e7a7/go.mod h1:gwv58mA3ih33rgue4CuZ/ERxj5P2ZKKSFaD0s6Rt090= +github.com/gomlx/go-xla v0.2.0 h1:vPRgGjKUaN4Pq58ZswWtiKNGUkqcbUj95YmvELZNvTA= github.com/gomlx/gomlx v0.26.1-0.20260220075116-8da82ca8aaad h1:5sVrSlklIfVDaAm4GhfHHIkmeNeWgifI0T5OeoFP//c= github.com/gomlx/gomlx v0.26.1-0.20260220075116-8da82ca8aaad/go.mod h1:qh8htE4nkodsshIvgIiqkwW2IIzbOkJBT3QXI4i93Z8= +github.com/gomlx/gomlx v0.27.0 h1:tkmzft4WR+6oncPUT4jCHQYtfW2r6QkfFxUEBy7RmxA= +github.com/gomlx/gomlx v0.27.0/go.mod h1:9fOMnTb7YMs/6zYVR6diliq25x9oSjMUPMlHAbcJWd4= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -26,11 +29,15 @@ golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDME golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= +k8s.io/klog/v2 v2.140.0 h1:Tf+J3AH7xnUzZyVVXhTgGhEKnFqye14aadWv7bzXdzc= +k8s.io/klog/v2 v2.140.0/go.mod h1:o+/RWfJ6PwpnFn7OyAG3QnO47BFsymfEfrz6XyYSSp0= diff --git a/gomlx/int_ops_test.go b/coreml/gomlx/int_ops_test.go similarity index 99% rename from gomlx/int_ops_test.go rename to coreml/gomlx/int_ops_test.go index c39b9b4..cfdf7b5 100644 --- a/gomlx/int_ops_test.go +++ b/coreml/gomlx/int_ops_test.go @@ -5,7 +5,7 @@ package coreml import ( "testing" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/model" "github.com/gomlx/gomlx/pkg/core/dtypes" "github.com/gomlx/gomlx/pkg/core/graph" "github.com/gomlx/gomlx/pkg/core/shapes" diff --git a/gomlx/integration_test.go b/coreml/gomlx/integration_test.go similarity index 100% rename from gomlx/integration_test.go rename to coreml/gomlx/integration_test.go diff --git a/gomlx/iota_test.go b/coreml/gomlx/iota_test.go similarity index 100% rename from gomlx/iota_test.go rename to coreml/gomlx/iota_test.go diff --git a/gomlx/nocgo_darwin.go b/coreml/gomlx/nocgo_darwin.go similarity index 100% rename from gomlx/nocgo_darwin.go rename to coreml/gomlx/nocgo_darwin.go diff --git a/gomlx/register_darwin.go b/coreml/gomlx/register_darwin.go similarity index 100% rename from gomlx/register_darwin.go rename to coreml/gomlx/register_darwin.go diff --git a/gomlx/stub_other.go b/coreml/gomlx/stub_other.go similarity index 100% rename from gomlx/stub_other.go rename to coreml/gomlx/stub_other.go diff --git a/internal/bridge/bridge.go b/coreml/internal/bridge/bridge.go similarity index 100% rename from internal/bridge/bridge.go rename to coreml/internal/bridge/bridge.go diff --git a/internal/bridge/bridge.h b/coreml/internal/bridge/bridge.h similarity index 100% rename from internal/bridge/bridge.h rename to coreml/internal/bridge/bridge.h diff --git a/internal/bridge/bridge.m b/coreml/internal/bridge/bridge.m similarity index 100% rename from internal/bridge/bridge.m rename to coreml/internal/bridge/bridge.m diff --git a/internal/bridge/bridge_test.go b/coreml/internal/bridge/bridge_test.go similarity index 100% rename from internal/bridge/bridge_test.go rename to coreml/internal/bridge/bridge_test.go diff --git a/internal/bridge/scalar_test.go b/coreml/internal/bridge/scalar_test.go similarity index 100% rename from internal/bridge/scalar_test.go rename to coreml/internal/bridge/scalar_test.go diff --git a/model/builder.go b/coreml/model/builder.go similarity index 99% rename from model/builder.go rename to coreml/model/builder.go index 858c38a..23ce413 100644 --- a/model/builder.go +++ b/coreml/model/builder.go @@ -17,7 +17,7 @@ package model import ( "fmt" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) // DType represents a data type for tensors. diff --git a/model/builder_test.go b/coreml/model/builder_test.go similarity index 98% rename from model/builder_test.go rename to coreml/model/builder_test.go index 9688b92..9e7a1e3 100644 --- a/model/builder_test.go +++ b/coreml/model/builder_test.go @@ -3,7 +3,7 @@ package model import ( "testing" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" "google.golang.org/protobuf/proto" ) diff --git a/model/example_concat_test.go b/coreml/model/example_concat_test.go similarity index 98% rename from model/example_concat_test.go rename to coreml/model/example_concat_test.go index 753aa4c..1a0fb5c 100644 --- a/model/example_concat_test.go +++ b/coreml/model/example_concat_test.go @@ -3,7 +3,7 @@ package model_test import ( "fmt" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/model" ) // ExampleBuilder_Concat demonstrates concatenating multiple tensors along an axis. diff --git a/model/example_conv_test.go b/coreml/model/example_conv_test.go similarity index 100% rename from model/example_conv_test.go rename to coreml/model/example_conv_test.go diff --git a/model/example_einsum_test.go b/coreml/model/example_einsum_test.go similarity index 98% rename from model/example_einsum_test.go rename to coreml/model/example_einsum_test.go index 24ed902..88e936b 100644 --- a/model/example_einsum_test.go +++ b/coreml/model/example_einsum_test.go @@ -3,7 +3,7 @@ package model_test import ( "fmt" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/model" ) // ExampleBuilder_Einsum demonstrates how to use the Einsum operation for diff --git a/model/ops.go b/coreml/model/ops.go similarity index 100% rename from model/ops.go rename to coreml/model/ops.go diff --git a/model/ops_clamp_cast_test.go b/coreml/model/ops_clamp_cast_test.go similarity index 100% rename from model/ops_clamp_cast_test.go rename to coreml/model/ops_clamp_cast_test.go diff --git a/model/ops_concat_test.go b/coreml/model/ops_concat_test.go similarity index 98% rename from model/ops_concat_test.go rename to coreml/model/ops_concat_test.go index fca561e..693b7c9 100644 --- a/model/ops_concat_test.go +++ b/coreml/model/ops_concat_test.go @@ -3,7 +3,7 @@ package model import ( "testing" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) func TestConcat(t *testing.T) { diff --git a/model/ops_control_flow.go b/coreml/model/ops_control_flow.go similarity index 99% rename from model/ops_control_flow.go rename to coreml/model/ops_control_flow.go index 114fa58..2a1d76a 100644 --- a/model/ops_control_flow.go +++ b/coreml/model/ops_control_flow.go @@ -3,7 +3,7 @@ package model import ( "fmt" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) // BlockBuilder builds a nested block within an operation. diff --git a/model/ops_control_flow_test.go b/coreml/model/ops_control_flow_test.go similarity index 99% rename from model/ops_control_flow_test.go rename to coreml/model/ops_control_flow_test.go index 20c25e3..139e535 100644 --- a/model/ops_control_flow_test.go +++ b/coreml/model/ops_control_flow_test.go @@ -3,7 +3,7 @@ package model import ( "testing" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) func TestBlockBuilder_BasicOps(t *testing.T) { diff --git a/model/ops_conv_test.go b/coreml/model/ops_conv_test.go similarity index 100% rename from model/ops_conv_test.go rename to coreml/model/ops_conv_test.go diff --git a/model/ops_einsum_test.go b/coreml/model/ops_einsum_test.go similarity index 99% rename from model/ops_einsum_test.go rename to coreml/model/ops_einsum_test.go index a23e1e0..78585f9 100644 --- a/model/ops_einsum_test.go +++ b/coreml/model/ops_einsum_test.go @@ -3,7 +3,7 @@ package model import ( "testing" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) func TestEinsumRank4(t *testing.T) { diff --git a/model/ops_logical_test.go b/coreml/model/ops_logical_test.go similarity index 100% rename from model/ops_logical_test.go rename to coreml/model/ops_logical_test.go diff --git a/model/ops_pooling_test.go b/coreml/model/ops_pooling_test.go similarity index 100% rename from model/ops_pooling_test.go rename to coreml/model/ops_pooling_test.go diff --git a/model/ops_range_test.go b/coreml/model/ops_range_test.go similarity index 97% rename from model/ops_range_test.go rename to coreml/model/ops_range_test.go index 9d359d6..71af37e 100644 --- a/model/ops_range_test.go +++ b/coreml/model/ops_range_test.go @@ -3,7 +3,7 @@ package model import ( "testing" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) func TestRange1D(t *testing.T) { diff --git a/model/ops_sort_test.go b/coreml/model/ops_sort_test.go similarity index 100% rename from model/ops_sort_test.go rename to coreml/model/ops_sort_test.go diff --git a/model/optimize.go b/coreml/model/optimize.go similarity index 99% rename from model/optimize.go rename to coreml/model/optimize.go index 7d1a2fa..aabc435 100644 --- a/model/optimize.go +++ b/coreml/model/optimize.go @@ -3,7 +3,7 @@ package model import ( "sort" - "github.com/gomlx/go-coreml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" ) // optimizeHighRankOps applies optimization passes to eliminate operations diff --git a/model/optimize_test.go b/coreml/model/optimize_test.go similarity index 100% rename from model/optimize_test.go rename to coreml/model/optimize_test.go diff --git a/model/serialize.go b/coreml/model/serialize.go similarity index 98% rename from model/serialize.go rename to coreml/model/serialize.go index 51947e1..d7d12fe 100644 --- a/model/serialize.go +++ b/coreml/model/serialize.go @@ -6,8 +6,8 @@ import ( "os" "path/filepath" - "github.com/gomlx/go-coreml/proto/coreml/milspec" - "github.com/gomlx/go-coreml/proto/coreml/spec" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/spec" "github.com/google/uuid" "google.golang.org/protobuf/proto" ) diff --git a/model/serialize_blob.go b/coreml/model/serialize_blob.go similarity index 98% rename from model/serialize_blob.go rename to coreml/model/serialize_blob.go index 7464fc4..c9ca192 100644 --- a/model/serialize_blob.go +++ b/coreml/model/serialize_blob.go @@ -9,9 +9,9 @@ import ( "path/filepath" "github.com/google/uuid" - "github.com/gomlx/go-coreml/blob" - "github.com/gomlx/go-coreml/proto/coreml/milspec" - "github.com/gomlx/go-coreml/proto/coreml/spec" + "github.com/gomlx/go-darwinml/blob" + "github.com/gomlx/go-darwinml/proto/coreml/milspec" + "github.com/gomlx/go-darwinml/proto/coreml/spec" "google.golang.org/protobuf/proto" ) diff --git a/model/serialize_blob_test.go b/coreml/model/serialize_blob_test.go similarity index 99% rename from model/serialize_blob_test.go rename to coreml/model/serialize_blob_test.go index c8311ce..c5b7134 100644 --- a/model/serialize_blob_test.go +++ b/coreml/model/serialize_blob_test.go @@ -5,7 +5,7 @@ import ( "path/filepath" "testing" - "github.com/gomlx/go-coreml/blob" + "github.com/gomlx/go-darwinml/blob" ) func TestSaveMLPackageWithBlobs(t *testing.T) { diff --git a/model/serialize_test.go b/coreml/model/serialize_test.go similarity index 100% rename from model/serialize_test.go rename to coreml/model/serialize_test.go diff --git a/runtime/blob_e2e_test.go b/coreml/runtime/blob_e2e_test.go similarity index 97% rename from runtime/blob_e2e_test.go rename to coreml/runtime/blob_e2e_test.go index a9e14c5..6a13309 100644 --- a/runtime/blob_e2e_test.go +++ b/coreml/runtime/blob_e2e_test.go @@ -6,8 +6,8 @@ import ( "path/filepath" "testing" - "github.com/gomlx/go-coreml/blob" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/blob" + "github.com/gomlx/go-darwinml/coreml/model" ) func TestBlobStorageE2E(t *testing.T) { diff --git a/runtime/runtime.go b/coreml/runtime/runtime.go similarity index 98% rename from runtime/runtime.go rename to coreml/runtime/runtime.go index 01790ff..ab86d23 100644 --- a/runtime/runtime.go +++ b/coreml/runtime/runtime.go @@ -27,8 +27,8 @@ import ( "sync" "unsafe" - "github.com/gomlx/go-coreml/internal/bridge" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/internal/bridge" + "github.com/gomlx/go-darwinml/coreml/model" ) // Runtime manages CoreML model compilation and execution. diff --git a/runtime/runtime_test.go b/coreml/runtime/runtime_test.go similarity index 97% rename from runtime/runtime_test.go rename to coreml/runtime/runtime_test.go index 748275e..587336c 100644 --- a/runtime/runtime_test.go +++ b/coreml/runtime/runtime_test.go @@ -3,8 +3,8 @@ package runtime import ( "testing" - "github.com/gomlx/go-coreml/internal/bridge" - "github.com/gomlx/go-coreml/model" + "github.com/gomlx/go-darwinml/coreml/internal/bridge" + "github.com/gomlx/go-darwinml/coreml/model" ) // skipIfNotMacOS skips the test if not running on macOS. diff --git a/doc.go b/doc.go index d93a316..1a5085a 100644 --- a/doc.go +++ b/doc.go @@ -1,56 +1,21 @@ -// Package gocoreml provides Go bindings to Apple's CoreML framework. +// Package godarwinml provides Go backends for machine learning on Apple Silicon. // -// This package enables running machine learning models on Apple's Neural Engine (ANE), -// Metal GPU, and CPU. It is designed to be used as a backend for GoMLX, providing -// high-performance inference on Apple Silicon. +// This project provides multiple GoMLX backends targeting Apple hardware: // -// # Architecture +// - coreml: CoreML backend supporting Apple Neural Engine (ANE), Metal GPU, and CPU +// - mpsgraph: MPSGraph backend for direct Metal GPU computation // -// The package is organized into several sub-packages: +// Each backend implements the GoMLX backends.Backend interface and can be used +// interchangeably for inference on Apple Silicon. // -// - internal/bridge: Low-level cgo bindings to CoreML via Objective-C++ -// - model: High-level model building and management -// - runtime: Model loading and execution -// - ops: MIL (Model Intermediate Language) operation builders +// # Shared Packages // -// # Usage -// -// This package is primarily intended to be used through the GoMLX CoreML backend. -// Direct usage is also possible for loading and running pre-built CoreML models: -// -// import "github.com/gomlx/go-coreml/runtime" -// -// model, err := runtime.LoadModel("path/to/model.mlmodelc") -// if err != nil { -// log.Fatal(err) -// } -// defer model.Close() -// -// // Create input tensor -// input := runtime.NewTensor([]int64{1, 3, 224, 224}, runtime.Float32) -// // ... fill input data ... -// -// // Run inference -// output, err := model.Predict(map[string]*runtime.Tensor{"input": input}) -// if err != nil { -// log.Fatal(err) -// } +// - blob: Weight blob storage format +// - proto: CoreML protobuf specifications // // # Requirements // // - macOS 12.0+ (Monterey or later) // - Xcode Command Line Tools -// - Go 1.21+ -// -// # Compute Units -// -// CoreML can run on different compute units: -// -// - All: Let CoreML decide (default, usually best performance) -// - CPU Only: Force CPU-only execution -// - CPU and GPU: Use CPU and Metal GPU -// - CPU and ANE: Use CPU and Apple Neural Engine -// -// The Neural Engine (ANE) provides the best performance and power efficiency -// for supported operations on Apple Silicon. -package gocoreml +// - Go 1.25+ +package godarwinml diff --git a/go.mod b/go.mod index c62f652..b565b2e 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/gomlx/go-coreml +module github.com/gomlx/go-darwinml go 1.25.5 diff --git a/mpsgraph/gomlx/backend.go b/mpsgraph/gomlx/backend.go index f2df974..030692b 100644 --- a/mpsgraph/gomlx/backend.go +++ b/mpsgraph/gomlx/backend.go @@ -9,7 +9,7 @@ package mpsgraph import ( "sync" - "github.com/gomlx/go-coreml/mpsgraph/gomlx/internal/bridge" + "github.com/gomlx/go-darwinml/mpsgraph/gomlx/internal/bridge" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/pkg/core/shapes" "github.com/pkg/errors" diff --git a/mpsgraph/gomlx/builder.go b/mpsgraph/gomlx/builder.go index 82cb4e3..682cfb9 100644 --- a/mpsgraph/gomlx/builder.go +++ b/mpsgraph/gomlx/builder.go @@ -5,7 +5,7 @@ package mpsgraph import ( - "github.com/gomlx/go-coreml/mpsgraph/gomlx/internal/bridge" + "github.com/gomlx/go-darwinml/mpsgraph/gomlx/internal/bridge" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/backends/notimplemented" "github.com/gomlx/gomlx/pkg/core/shapes" diff --git a/mpsgraph/gomlx/executable.go b/mpsgraph/gomlx/executable.go index c6b0825..0a6840e 100644 --- a/mpsgraph/gomlx/executable.go +++ b/mpsgraph/gomlx/executable.go @@ -9,7 +9,7 @@ import ( "sync" "unsafe" - "github.com/gomlx/go-coreml/mpsgraph/gomlx/internal/bridge" + "github.com/gomlx/go-darwinml/mpsgraph/gomlx/internal/bridge" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/pkg/core/shapes" "github.com/pkg/errors" diff --git a/mpsgraph/gomlx/function.go b/mpsgraph/gomlx/function.go index c320130..3caebdd 100644 --- a/mpsgraph/gomlx/function.go +++ b/mpsgraph/gomlx/function.go @@ -10,7 +10,7 @@ import ( "runtime" "unsafe" - "github.com/gomlx/go-coreml/mpsgraph/gomlx/internal/bridge" + "github.com/gomlx/go-darwinml/mpsgraph/gomlx/internal/bridge" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/backends/notimplemented" "github.com/gomlx/gomlx/backends/shapeinference" @@ -2672,6 +2672,7 @@ func (f *Function) FusedScaledDotProductAttention( axesLayout backends.AxesLayout, scale float64, causal bool, + options *backends.ScaledDotProductAttentionConfig, ) (backends.Value, error) { qNode, err := f.resolveNode(query) if err != nil { diff --git a/mpsgraph/gomlx/gather.go b/mpsgraph/gomlx/gather.go index d7bc556..97db26b 100644 --- a/mpsgraph/gomlx/gather.go +++ b/mpsgraph/gomlx/gather.go @@ -7,7 +7,7 @@ package mpsgraph import ( "slices" - "github.com/gomlx/go-coreml/mpsgraph/gomlx/internal/bridge" + "github.com/gomlx/go-darwinml/mpsgraph/gomlx/internal/bridge" "github.com/gomlx/gomlx/backends" "github.com/gomlx/gomlx/backends/shapeinference" "github.com/gomlx/gomlx/pkg/core/shapes" diff --git a/mpsgraph/gomlx/go.mod b/mpsgraph/gomlx/go.mod index 78ac2fe..b665376 100644 --- a/mpsgraph/gomlx/go.mod +++ b/mpsgraph/gomlx/go.mod @@ -1,4 +1,4 @@ -module github.com/gomlx/go-coreml/mpsgraph/gomlx +module github.com/gomlx/go-darwinml/mpsgraph/gomlx go 1.25 diff --git a/mpsgraph/gomlx/mpsgraph_test.go b/mpsgraph/gomlx/mpsgraph_test.go index 3ae324b..f33fc69 100644 --- a/mpsgraph/gomlx/mpsgraph_test.go +++ b/mpsgraph/gomlx/mpsgraph_test.go @@ -2974,7 +2974,7 @@ func TestFusedScaledDotProductAttention(t *testing.T) { result := graph.MustExecOnce(backend, func(q, k, v *graph.Node) *graph.Node { return graph.BackendFusedScaledDotProductAttention( - q, k, v, nil, 1, 1, backends.AxesLayoutBHSD, scale, false) + q, k, v, nil, 1, 1, backends.AxesLayoutBHSD, scale, false, nil) }, tensors.FromFlatDataAndDimensions(data, 1, 1, 2, 2), tensors.FromFlatDataAndDimensions(data, 1, 1, 2, 2), @@ -3003,7 +3003,7 @@ func TestFusedScaledDotProductAttention(t *testing.T) { result := graph.MustExecOnce(backend, func(q, k, v *graph.Node) *graph.Node { return graph.BackendFusedScaledDotProductAttention( - q, k, v, nil, 1, 1, backends.AxesLayoutBHSD, scale, true) + q, k, v, nil, 1, 1, backends.AxesLayoutBHSD, scale, true, nil) }, tensors.FromFlatDataAndDimensions(qData, 1, 1, 3, 2), tensors.FromFlatDataAndDimensions(kData, 1, 1, 3, 2), @@ -3021,7 +3021,7 @@ func TestFusedScaledDotProductAttention(t *testing.T) { result := graph.MustExecOnce(backend, func(q, k, v *graph.Node) *graph.Node { return graph.BackendFusedScaledDotProductAttention( - q, k, v, nil, 1, 1, backends.AxesLayoutBSHD, scale, false) + q, k, v, nil, 1, 1, backends.AxesLayoutBSHD, scale, false, nil) }, tensors.FromFlatDataAndDimensions(data, 1, 2, 1, 2), tensors.FromFlatDataAndDimensions(data, 1, 2, 1, 2), @@ -3044,7 +3044,7 @@ func TestFusedScaledDotProductAttention(t *testing.T) { result := graph.MustExecOnce(backend, func(q, k, v, m *graph.Node) *graph.Node { return graph.BackendFusedScaledDotProductAttention( - q, k, v, m, 1, 1, backends.AxesLayoutBHSD, 1.0, false) + q, k, v, m, 1, 1, backends.AxesLayoutBHSD, 1.0, false, nil) }, tensors.FromFlatDataAndDimensions(qData, 1, 1, 2, 2), tensors.FromFlatDataAndDimensions(kData, 1, 1, 2, 2), diff --git a/proto/coreml/ArrayFeatureExtractor.proto b/proto/coreml/ArrayFeatureExtractor.proto index 3979c60..9806f18 100644 --- a/proto/coreml/ArrayFeatureExtractor.proto +++ b/proto/coreml/ArrayFeatureExtractor.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/AudioFeaturePrint.proto b/proto/coreml/AudioFeaturePrint.proto index e891d14..f22e6bb 100644 --- a/proto/coreml/AudioFeaturePrint.proto +++ b/proto/coreml/AudioFeaturePrint.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification.CoreMLModels; diff --git a/proto/coreml/BayesianProbitRegressor.proto b/proto/coreml/BayesianProbitRegressor.proto index 7555c65..e4c81a4 100644 --- a/proto/coreml/BayesianProbitRegressor.proto +++ b/proto/coreml/BayesianProbitRegressor.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/CategoricalMapping.proto b/proto/coreml/CategoricalMapping.proto index cbd00e2..a1f8470 100644 --- a/proto/coreml/CategoricalMapping.proto +++ b/proto/coreml/CategoricalMapping.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/ClassConfidenceThresholding.proto b/proto/coreml/ClassConfidenceThresholding.proto index 3f2d84d..72b1bdb 100644 --- a/proto/coreml/ClassConfidenceThresholding.proto +++ b/proto/coreml/ClassConfidenceThresholding.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/CustomModel.proto b/proto/coreml/CustomModel.proto index 36cf83c..1525883 100644 --- a/proto/coreml/CustomModel.proto +++ b/proto/coreml/CustomModel.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/DataStructures.proto b/proto/coreml/DataStructures.proto index b095601..dda5e2f 100644 --- a/proto/coreml/DataStructures.proto +++ b/proto/coreml/DataStructures.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "FeatureTypes.proto"; diff --git a/proto/coreml/DictVectorizer.proto b/proto/coreml/DictVectorizer.proto index a9450aa..228852a 100644 --- a/proto/coreml/DictVectorizer.proto +++ b/proto/coreml/DictVectorizer.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/FeatureTypes.proto b/proto/coreml/FeatureTypes.proto index 9b5069e..5b857de 100644 --- a/proto/coreml/FeatureTypes.proto +++ b/proto/coreml/FeatureTypes.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/FeatureVectorizer.proto b/proto/coreml/FeatureVectorizer.proto index 924c738..01f95a4 100644 --- a/proto/coreml/FeatureVectorizer.proto +++ b/proto/coreml/FeatureVectorizer.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/GLMClassifier.proto b/proto/coreml/GLMClassifier.proto index de601e5..c151511 100644 --- a/proto/coreml/GLMClassifier.proto +++ b/proto/coreml/GLMClassifier.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/GLMRegressor.proto b/proto/coreml/GLMRegressor.proto index c3617e4..166bd5c 100644 --- a/proto/coreml/GLMRegressor.proto +++ b/proto/coreml/GLMRegressor.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/Gazetteer.proto b/proto/coreml/Gazetteer.proto index ee6c2d4..4e94236 100644 --- a/proto/coreml/Gazetteer.proto +++ b/proto/coreml/Gazetteer.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/Identity.proto b/proto/coreml/Identity.proto index 3a9c66f..7f4a047 100644 --- a/proto/coreml/Identity.proto +++ b/proto/coreml/Identity.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/Imputer.proto b/proto/coreml/Imputer.proto index 397f492..5cb1416 100644 --- a/proto/coreml/Imputer.proto +++ b/proto/coreml/Imputer.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/ItemSimilarityRecommender.proto b/proto/coreml/ItemSimilarityRecommender.proto index 06b47de..9aeba38 100644 --- a/proto/coreml/ItemSimilarityRecommender.proto +++ b/proto/coreml/ItemSimilarityRecommender.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/LinkedModel.proto b/proto/coreml/LinkedModel.proto index afcd37e..1105f49 100644 --- a/proto/coreml/LinkedModel.proto +++ b/proto/coreml/LinkedModel.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "Parameters.proto"; package CoreML.Specification; diff --git a/proto/coreml/MIL.proto b/proto/coreml/MIL.proto index 7b0671c..653676f 100644 --- a/proto/coreml/MIL.proto +++ b/proto/coreml/MIL.proto @@ -26,7 +26,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/milspec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/milspec"; package CoreML.Specification.MILSpec; diff --git a/proto/coreml/Model.proto b/proto/coreml/Model.proto index 53c1863..f6caca6 100644 --- a/proto/coreml/Model.proto +++ b/proto/coreml/Model.proto @@ -59,7 +59,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "VisionFeaturePrint.proto"; import public "AudioFeaturePrint.proto"; diff --git a/proto/coreml/NearestNeighbors.proto b/proto/coreml/NearestNeighbors.proto index 6ab1e33..4695ed1 100644 --- a/proto/coreml/NearestNeighbors.proto +++ b/proto/coreml/NearestNeighbors.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/NeuralNetwork.proto b/proto/coreml/NeuralNetwork.proto index 5bc0176..bd7f741 100644 --- a/proto/coreml/NeuralNetwork.proto +++ b/proto/coreml/NeuralNetwork.proto @@ -57,7 +57,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; import public "Parameters.proto"; diff --git a/proto/coreml/NonMaximumSuppression.proto b/proto/coreml/NonMaximumSuppression.proto index 4301d79..30ee89f 100644 --- a/proto/coreml/NonMaximumSuppression.proto +++ b/proto/coreml/NonMaximumSuppression.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/Normalizer.proto b/proto/coreml/Normalizer.proto index a29d766..78aba40 100644 --- a/proto/coreml/Normalizer.proto +++ b/proto/coreml/Normalizer.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/OneHotEncoder.proto b/proto/coreml/OneHotEncoder.proto index 9f1123f..77adf9f 100644 --- a/proto/coreml/OneHotEncoder.proto +++ b/proto/coreml/OneHotEncoder.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/Parameters.proto b/proto/coreml/Parameters.proto index 8ec47fa..1d4054f 100644 --- a/proto/coreml/Parameters.proto +++ b/proto/coreml/Parameters.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/SVM.proto b/proto/coreml/SVM.proto index 69c25a8..82b92fe 100644 --- a/proto/coreml/SVM.proto +++ b/proto/coreml/SVM.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/Scaler.proto b/proto/coreml/Scaler.proto index 1e0abe1..386db92 100644 --- a/proto/coreml/Scaler.proto +++ b/proto/coreml/Scaler.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification; diff --git a/proto/coreml/SoundAnalysisPreprocessing.proto b/proto/coreml/SoundAnalysisPreprocessing.proto index f8b12c7..67df8d1 100644 --- a/proto/coreml/SoundAnalysisPreprocessing.proto +++ b/proto/coreml/SoundAnalysisPreprocessing.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification.CoreMLModels; diff --git a/proto/coreml/TextClassifier.proto b/proto/coreml/TextClassifier.proto index dcfe0e0..4f97fa8 100644 --- a/proto/coreml/TextClassifier.proto +++ b/proto/coreml/TextClassifier.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/TreeEnsemble.proto b/proto/coreml/TreeEnsemble.proto index 481996c..0bf62b3 100644 --- a/proto/coreml/TreeEnsemble.proto +++ b/proto/coreml/TreeEnsemble.proto @@ -24,7 +24,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/VisionFeaturePrint.proto b/proto/coreml/VisionFeaturePrint.proto index 0a8595f..8ce7170 100644 --- a/proto/coreml/VisionFeaturePrint.proto +++ b/proto/coreml/VisionFeaturePrint.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; package CoreML.Specification.CoreMLModels; diff --git a/proto/coreml/WordEmbedding.proto b/proto/coreml/WordEmbedding.proto index 5b4efc8..c76e96b 100644 --- a/proto/coreml/WordEmbedding.proto +++ b/proto/coreml/WordEmbedding.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/WordTagger.proto b/proto/coreml/WordTagger.proto index 1a7b5cc..d245809 100644 --- a/proto/coreml/WordTagger.proto +++ b/proto/coreml/WordTagger.proto @@ -5,7 +5,7 @@ syntax = "proto3"; option optimize_for = LITE_RUNTIME; -option go_package = "github.com/gomlx/go-coreml/proto/coreml/spec"; +option go_package = "github.com/gomlx/go-darwinml/proto/coreml/spec"; import public "DataStructures.proto"; diff --git a/proto/coreml/milspec/MIL.pb.go b/proto/coreml/milspec/MIL.pb.go index 218c66a..6c733f5 100644 --- a/proto/coreml/milspec/MIL.pb.go +++ b/proto/coreml/milspec/MIL.pb.go @@ -26,7 +26,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: MIL.proto package milspec @@ -2367,7 +2367,7 @@ const file_MIL_proto_rawDesc = "" + "\x05UINT2\x10$\x12\t\n" + "\x05UINT1\x10%\x12\t\n" + "\x05UINT6\x10&\x12\t\n" + - "\x05UINT3\x10'B3H\x03Z/github.com/gomlx/go-coreml/proto/coreml/milspecb\x06proto3" + "\x05UINT3\x10'B5H\x03Z1github.com/gomlx/go-darwinml/proto/coreml/milspecb\x06proto3" var ( file_MIL_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/ArrayFeatureExtractor.pb.go b/proto/coreml/spec/ArrayFeatureExtractor.pb.go index d115aa8..664f800 100644 --- a/proto/coreml/spec/ArrayFeatureExtractor.pb.go +++ b/proto/coreml/spec/ArrayFeatureExtractor.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: ArrayFeatureExtractor.proto package spec @@ -80,7 +80,7 @@ const file_ArrayFeatureExtractor_proto_rawDesc = "" + "\n" + "\x1bArrayFeatureExtractor.proto\x12\x14CoreML.Specification\";\n" + "\x15ArrayFeatureExtractor\x12\"\n" + - "\fextractIndex\x18\x01 \x03(\x04R\fextractIndexB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\fextractIndex\x18\x01 \x03(\x04R\fextractIndexB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_ArrayFeatureExtractor_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/AudioFeaturePrint.pb.go b/proto/coreml/spec/AudioFeaturePrint.pb.go index 3908a0b..a4ace8a 100644 --- a/proto/coreml/spec/AudioFeaturePrint.pb.go +++ b/proto/coreml/spec/AudioFeaturePrint.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: AudioFeaturePrint.proto package spec @@ -202,7 +202,7 @@ const file_AudioFeaturePrint_proto_rawDesc = "" + "\fSoundVersion\x12\x19\n" + "\x15SOUND_VERSION_INVALID\x10\x00\x12\x13\n" + "\x0fSOUND_VERSION_1\x10\x01B\x17\n" + - "\x15AudioFeaturePrintTypeB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x15AudioFeaturePrintTypeB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_AudioFeaturePrint_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/BayesianProbitRegressor.pb.go b/proto/coreml/spec/BayesianProbitRegressor.pb.go index 24e2ff4..c33b8b1 100644 --- a/proto/coreml/spec/BayesianProbitRegressor.pb.go +++ b/proto/coreml/spec/BayesianProbitRegressor.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: BayesianProbitRegressor.proto package spec @@ -353,7 +353,7 @@ const file_BayesianProbitRegressor_proto_rawDesc = "" + "\rfeatureWeight\x18\x02 \x01(\v26.CoreML.Specification.BayesianProbitRegressor.GaussianR\rfeatureWeight\x1a\x89\x01\n" + "\rFeatureWeight\x12\x1c\n" + "\tfeatureId\x18\x01 \x01(\rR\tfeatureId\x12Z\n" + - "\aweights\x18\x02 \x03(\v2@.CoreML.Specification.BayesianProbitRegressor.FeatureValueWeightR\aweightsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\aweights\x18\x02 \x03(\v2@.CoreML.Specification.BayesianProbitRegressor.FeatureValueWeightR\aweightsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_BayesianProbitRegressor_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/CategoricalMapping.pb.go b/proto/coreml/spec/CategoricalMapping.pb.go index 09b322a..6941e7f 100644 --- a/proto/coreml/spec/CategoricalMapping.pb.go +++ b/proto/coreml/spec/CategoricalMapping.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: CategoricalMapping.proto package spec @@ -177,7 +177,7 @@ const file_CategoricalMapping_proto_rawDesc = "" + "int64Value\x18f \x01(\x03H\x01R\n" + "int64ValueB\r\n" + "\vMappingTypeB\x10\n" + - "\x0eValueOnUnknownB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x0eValueOnUnknownB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_CategoricalMapping_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/ClassConfidenceThresholding.pb.go b/proto/coreml/spec/ClassConfidenceThresholding.pb.go index 654a679..660be1d 100644 --- a/proto/coreml/spec/ClassConfidenceThresholding.pb.go +++ b/proto/coreml/spec/ClassConfidenceThresholding.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: ClassConfidenceThresholding.proto package spec @@ -81,7 +81,7 @@ const file_ClassConfidenceThresholding_proto_rawDesc = "" + "\n" + "!ClassConfidenceThresholding.proto\x12\x14CoreML.Specification\x1a\x14DataStructures.proto\"\x7f\n" + "\x1bClassConfidenceThresholding\x12`\n" + - "\x15precisionRecallCurves\x18d \x03(\v2*.CoreML.Specification.PrecisionRecallCurveR\x15precisionRecallCurvesB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x15precisionRecallCurves\x18d \x03(\v2*.CoreML.Specification.PrecisionRecallCurveR\x15precisionRecallCurvesB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_ClassConfidenceThresholding_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/CustomModel.pb.go b/proto/coreml/spec/CustomModel.pb.go index 6b7646a..31c1a2b 100644 --- a/proto/coreml/spec/CustomModel.pb.go +++ b/proto/coreml/spec/CustomModel.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: CustomModel.proto package spec @@ -258,7 +258,7 @@ const file_CustomModel_proto_rawDesc = "" + "\x05value\x1av\n" + "\x0fParametersEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12M\n" + - "\x05value\x18\x02 \x01(\v27.CoreML.Specification.CustomModel.CustomModelParamValueR\x05value:\x028\x01B0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x05value\x18\x02 \x01(\v27.CoreML.Specification.CustomModel.CustomModelParamValueR\x05value:\x028\x01B2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_CustomModel_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/DataStructures.pb.go b/proto/coreml/spec/DataStructures.pb.go index 949a744..c1506a2 100644 --- a/proto/coreml/spec/DataStructures.pb.go +++ b/proto/coreml/spec/DataStructures.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: DataStructures.proto package spec @@ -679,7 +679,7 @@ const file_DataStructures_proto_rawDesc = "" + "\x0fprecisionValues\x18\x01 \x01(\v2!.CoreML.Specification.FloatVectorR\x0fprecisionValues\x12g\n" + "\x1dprecisionConfidenceThresholds\x18\x02 \x01(\v2!.CoreML.Specification.FloatVectorR\x1dprecisionConfidenceThresholds\x12E\n" + "\frecallValues\x18\x03 \x01(\v2!.CoreML.Specification.FloatVectorR\frecallValues\x12a\n" + - "\x1arecallConfidenceThresholds\x18\x04 \x01(\v2!.CoreML.Specification.FloatVectorR\x1arecallConfidenceThresholdsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x1arecallConfidenceThresholds\x18\x04 \x01(\v2!.CoreML.Specification.FloatVectorR\x1arecallConfidenceThresholdsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_DataStructures_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/DictVectorizer.pb.go b/proto/coreml/spec/DictVectorizer.pb.go index 27d445d..64e67b7 100644 --- a/proto/coreml/spec/DictVectorizer.pb.go +++ b/proto/coreml/spec/DictVectorizer.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: DictVectorizer.proto package spec @@ -130,7 +130,7 @@ const file_DictVectorizer_proto_rawDesc = "" + "\x0eDictVectorizer\x12J\n" + "\rstringToIndex\x18\x01 \x01(\v2\".CoreML.Specification.StringVectorH\x00R\rstringToIndex\x12G\n" + "\fint64ToIndex\x18\x02 \x01(\v2!.CoreML.Specification.Int64VectorH\x00R\fint64ToIndexB\x05\n" + - "\x03MapB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x03MapB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_DictVectorizer_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/FeatureTypes.pb.go b/proto/coreml/spec/FeatureTypes.pb.go index 3df2bd6..466ccbb 100644 --- a/proto/coreml/spec/FeatureTypes.pb.go +++ b/proto/coreml/spec/FeatureTypes.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: FeatureTypes.proto package spec @@ -1388,7 +1388,7 @@ const file_FeatureTypes_proto_rawDesc = "" + "\n" + "isOptional\x18\xe8\a \x01(\bR\n" + "isOptionalB\x06\n" + - "\x04TypeB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x04TypeB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_FeatureTypes_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/FeatureVectorizer.pb.go b/proto/coreml/spec/FeatureVectorizer.pb.go index f5217a4..76cf2d6 100644 --- a/proto/coreml/spec/FeatureVectorizer.pb.go +++ b/proto/coreml/spec/FeatureVectorizer.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: FeatureVectorizer.proto package spec @@ -137,7 +137,7 @@ const file_FeatureVectorizer_proto_rawDesc = "" + "\tinputList\x18\x01 \x03(\v23.CoreML.Specification.FeatureVectorizer.InputColumnR\tinputList\x1aY\n" + "\vInputColumn\x12 \n" + "\vinputColumn\x18\x01 \x01(\tR\vinputColumn\x12(\n" + - "\x0finputDimensions\x18\x02 \x01(\x04R\x0finputDimensionsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x0finputDimensions\x18\x02 \x01(\x04R\x0finputDimensionsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_FeatureVectorizer_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/GLMClassifier.pb.go b/proto/coreml/spec/GLMClassifier.pb.go index fd28e71..2f8f6dc 100644 --- a/proto/coreml/spec/GLMClassifier.pb.go +++ b/proto/coreml/spec/GLMClassifier.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: GLMClassifier.proto package spec @@ -300,7 +300,7 @@ const file_GLMClassifier_proto_rawDesc = "" + "\rClassEncoding\x12\x12\n" + "\x0eReferenceClass\x10\x00\x12\r\n" + "\tOneVsRest\x10\x01B\r\n" + - "\vClassLabelsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\vClassLabelsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_GLMClassifier_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/GLMRegressor.pb.go b/proto/coreml/spec/GLMRegressor.pb.go index fb60d35..436a64c 100644 --- a/proto/coreml/spec/GLMRegressor.pb.go +++ b/proto/coreml/spec/GLMRegressor.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: GLMRegressor.proto package spec @@ -195,7 +195,7 @@ const file_GLMRegressor_proto_rawDesc = "" + "\vNoTransform\x10\x00\x12\t\n" + "\x05Logit\x10\x01\x12\n" + "\n" + - "\x06Probit\x10\x02B0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x06Probit\x10\x02B2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_GLMRegressor_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Gazetteer.pb.go b/proto/coreml/spec/Gazetteer.pb.go index 37c2750..bf3559c 100644 --- a/proto/coreml/spec/Gazetteer.pb.go +++ b/proto/coreml/spec/Gazetteer.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Gazetteer.proto package spec @@ -136,7 +136,7 @@ const file_Gazetteer_proto_rawDesc = "" + " \x01(\tR\blanguage\x12.\n" + "\x12modelParameterData\x18d \x01(\fR\x12modelParameterData\x12S\n" + "\x11stringClassLabels\x18\xc8\x01 \x01(\v2\".CoreML.Specification.StringVectorH\x00R\x11stringClassLabelsB\r\n" + - "\vClassLabelsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\vClassLabelsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_Gazetteer_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Identity.pb.go b/proto/coreml/spec/Identity.pb.go index bc49ad2..5bb1b0f 100644 --- a/proto/coreml/spec/Identity.pb.go +++ b/proto/coreml/spec/Identity.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Identity.proto package spec @@ -72,7 +72,7 @@ const file_Identity_proto_rawDesc = "" + "\n" + "\x0eIdentity.proto\x12\x14CoreML.Specification\"\n" + "\n" + - "\bIdentityB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\bIdentityB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_Identity_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Imputer.pb.go b/proto/coreml/spec/Imputer.pb.go index c3a1bb1..6ae1ba4 100644 --- a/proto/coreml/spec/Imputer.pb.go +++ b/proto/coreml/spec/Imputer.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Imputer.proto package spec @@ -279,7 +279,7 @@ const file_Imputer_proto_rawDesc = "" + "\x11replaceInt64Value\x18\f \x01(\x03H\x01R\x11replaceInt64Value\x120\n" + "\x12replaceStringValue\x18\r \x01(\tH\x01R\x12replaceStringValueB\x0e\n" + "\fImputedValueB\x0e\n" + - "\fReplaceValueB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\fReplaceValueB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_Imputer_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/ItemSimilarityRecommender.pb.go b/proto/coreml/spec/ItemSimilarityRecommender.pb.go index 4ff9b01..3e1f2d9 100644 --- a/proto/coreml/spec/ItemSimilarityRecommender.pb.go +++ b/proto/coreml/spec/ItemSimilarityRecommender.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: ItemSimilarityRecommender.proto package spec @@ -297,7 +297,7 @@ const file_ItemSimilarityRecommender_proto_rawDesc = "" + "\fSimilarItems\x12\x16\n" + "\x06itemId\x18\x01 \x01(\x04R\x06itemId\x12g\n" + "\x0fsimilarItemList\x18\x02 \x03(\v2=.CoreML.Specification.ItemSimilarityRecommender.ConnectedItemR\x0fsimilarItemList\x120\n" + - "\x13itemScoreAdjustment\x18\x03 \x01(\x01R\x13itemScoreAdjustmentB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x13itemScoreAdjustment\x18\x03 \x01(\x01R\x13itemScoreAdjustmentB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_ItemSimilarityRecommender_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/LinkedModel.pb.go b/proto/coreml/spec/LinkedModel.pb.go index 7b25c48..c753020 100644 --- a/proto/coreml/spec/LinkedModel.pb.go +++ b/proto/coreml/spec/LinkedModel.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: LinkedModel.proto package spec @@ -169,7 +169,7 @@ const file_LinkedModel_proto_rawDesc = "" + "\bLinkType\"\xc7\x01\n" + "\x0fLinkedModelFile\x12W\n" + "\x13linkedModelFileName\x18\x01 \x01(\v2%.CoreML.Specification.StringParameterR\x13linkedModelFileName\x12[\n" + - "\x15linkedModelSearchPath\x18\x02 \x01(\v2%.CoreML.Specification.StringParameterR\x15linkedModelSearchPathB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x15linkedModelSearchPath\x18\x02 \x01(\v2%.CoreML.Specification.StringParameterR\x15linkedModelSearchPathB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_LinkedModel_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Model.pb.go b/proto/coreml/spec/Model.pb.go index b10a387..ba4b292 100644 --- a/proto/coreml/spec/Model.pb.go +++ b/proto/coreml/spec/Model.pb.go @@ -59,13 +59,13 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Model.proto package spec import ( - milspec "github.com/gomlx/go-coreml/proto/coreml/milspec" + milspec "github.com/gomlx/go-darwinml/proto/coreml/milspec" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" @@ -1622,7 +1622,7 @@ const file_Model_proto_rawDesc = "" + "\rwordEmbedding\x18\xd5\x0f \x01(\v20.CoreML.Specification.CoreMLModels.WordEmbeddingH\x00R\rwordEmbedding\x12e\n" + "\x11audioFeaturePrint\x18\xd6\x0f \x01(\v24.CoreML.Specification.CoreMLModels.AudioFeaturePrintH\x00R\x11audioFeaturePrint\x12R\n" + "\x0fserializedModel\x18\xb8\x17 \x01(\v2%.CoreML.Specification.SerializedModelH\x00R\x0fserializedModelB\x06\n" + - "\x04TypeB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00P\x01P\x02P\x03P\x04P\x05P\x06P\aP\bP\tP\n" + + "\x04TypeB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00P\x01P\x02P\x03P\x04P\x05P\x06P\aP\bP\tP\n" + "P\vP\fP\rP\x0eP\x0fP\x10P\x11P\x12P\x13P\x14P\x15P\x16P\x17P\x18P\x19P\x1aP\x1bP\x1cP\x1dP\x1eb\x06proto3" var ( diff --git a/proto/coreml/spec/NearestNeighbors.pb.go b/proto/coreml/spec/NearestNeighbors.pb.go index fcaddd3..9c9697c 100644 --- a/proto/coreml/spec/NearestNeighbors.pb.go +++ b/proto/coreml/spec/NearestNeighbors.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: NearestNeighbors.proto package spec @@ -595,7 +595,7 @@ const file_NearestNeighbors_proto_rawDesc = "" + "\vLinearIndex\"/\n" + "\x11SingleKdTreeIndex\x12\x1a\n" + "\bleafSize\x18\x01 \x01(\x05R\bleafSize\"\x1a\n" + - "\x18SquaredEuclideanDistanceB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00P\x01b\x06proto3" + "\x18SquaredEuclideanDistanceB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00P\x01b\x06proto3" var ( file_NearestNeighbors_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/NeuralNetwork.pb.go b/proto/coreml/spec/NeuralNetwork.pb.go index 3c62c54..6b8b0ac 100644 --- a/proto/coreml/spec/NeuralNetwork.pb.go +++ b/proto/coreml/spec/NeuralNetwork.pb.go @@ -57,7 +57,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: NeuralNetwork.proto package spec @@ -20085,7 +20085,7 @@ const file_NeuralNetwork_proto_rawDesc = "" + "\vSCATTER_MUL\x10\x03\x12\x0f\n" + "\vSCATTER_DIV\x10\x04\x12\x0f\n" + "\vSCATTER_MAX\x10\x05\x12\x0f\n" + - "\vSCATTER_MIN\x10\x06B0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00P\x01b\x06proto3" + "\vSCATTER_MIN\x10\x06B2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00P\x01b\x06proto3" var ( file_NeuralNetwork_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/NonMaximumSuppression.pb.go b/proto/coreml/spec/NonMaximumSuppression.pb.go index 53d6ce8..33ce6a2 100644 --- a/proto/coreml/spec/NonMaximumSuppression.pb.go +++ b/proto/coreml/spec/NonMaximumSuppression.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: NonMaximumSuppression.proto package spec @@ -387,7 +387,7 @@ const file_NonMaximumSuppression_proto_rawDesc = "" + "\aPickTop\x12\x1a\n" + "\bperClass\x18\x01 \x01(\bR\bperClassB\x13\n" + "\x11SuppressionMethodB\r\n" + - "\vClassLabelsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\vClassLabelsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_NonMaximumSuppression_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Normalizer.pb.go b/proto/coreml/spec/Normalizer.pb.go index f516f39..f597f1f 100644 --- a/proto/coreml/spec/Normalizer.pb.go +++ b/proto/coreml/spec/Normalizer.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Normalizer.proto package spec @@ -148,7 +148,7 @@ const file_Normalizer_proto_rawDesc = "" + "\bNormType\x12\b\n" + "\x04LMax\x10\x00\x12\x06\n" + "\x02L1\x10\x01\x12\x06\n" + - "\x02L2\x10\x02B0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x02L2\x10\x02B2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_Normalizer_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/OneHotEncoder.pb.go b/proto/coreml/spec/OneHotEncoder.pb.go index 34eeb1b..293543c 100644 --- a/proto/coreml/spec/OneHotEncoder.pb.go +++ b/proto/coreml/spec/OneHotEncoder.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: OneHotEncoder.proto package spec @@ -195,7 +195,7 @@ const file_OneHotEncoder_proto_rawDesc = "" + "\rHandleUnknown\x12\x12\n" + "\x0eErrorOnUnknown\x10\x00\x12\x11\n" + "\rIgnoreUnknown\x10\x01B\x0e\n" + - "\fCategoryTypeB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\fCategoryTypeB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_OneHotEncoder_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Parameters.pb.go b/proto/coreml/spec/Parameters.pb.go index 12367f2..459305b 100644 --- a/proto/coreml/spec/Parameters.pb.go +++ b/proto/coreml/spec/Parameters.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Parameters.proto package spec @@ -307,7 +307,7 @@ const file_Parameters_proto_rawDesc = "" + "\x0fStringParameter\x12\"\n" + "\fdefaultValue\x18\x01 \x01(\tR\fdefaultValue\"3\n" + "\rBoolParameter\x12\"\n" + - "\fdefaultValue\x18\x01 \x01(\bR\fdefaultValueB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\fdefaultValue\x18\x01 \x01(\bR\fdefaultValueB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_Parameters_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/SVM.pb.go b/proto/coreml/spec/SVM.pb.go index 7d5531c..2c84ac3 100644 --- a/proto/coreml/spec/SVM.pb.go +++ b/proto/coreml/spec/SVM.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: SVM.proto package spec @@ -1000,7 +1000,7 @@ const file_SVM_proto_rawDesc = "" + "\x11stringClassLabels\x18d \x01(\v2\".CoreML.Specification.StringVectorH\x01R\x11stringClassLabels\x12O\n" + "\x10int64ClassLabels\x18e \x01(\v2!.CoreML.Specification.Int64VectorH\x01R\x10int64ClassLabelsB\x10\n" + "\x0esupportVectorsB\r\n" + - "\vClassLabelsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\vClassLabelsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_SVM_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/Scaler.pb.go b/proto/coreml/spec/Scaler.pb.go index 38932f7..4ca24cb 100644 --- a/proto/coreml/spec/Scaler.pb.go +++ b/proto/coreml/spec/Scaler.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: Scaler.proto package spec @@ -108,7 +108,7 @@ const file_Scaler_proto_rawDesc = "" + "shiftValue\x12\x1e\n" + "\n" + "scaleValue\x18\x02 \x03(\x01R\n" + - "scaleValueB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "scaleValueB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_Scaler_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/SoundAnalysisPreprocessing.pb.go b/proto/coreml/spec/SoundAnalysisPreprocessing.pb.go index 3e5faeb..c1d2ba6 100644 --- a/proto/coreml/spec/SoundAnalysisPreprocessing.pb.go +++ b/proto/coreml/spec/SoundAnalysisPreprocessing.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: SoundAnalysisPreprocessing.proto package spec @@ -141,7 +141,7 @@ const file_SoundAnalysisPreprocessing_proto_rawDesc = "" + "\x1aSoundAnalysisPreprocessing\x12^\n" + "\x06vggish\x18\x14 \x01(\v2D.CoreML.Specification.CoreMLModels.SoundAnalysisPreprocessing.VggishH\x00R\x06vggish\x1a\b\n" + "\x06VggishB \n" + - "\x1eSoundAnalysisPreprocessingTypeB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x1eSoundAnalysisPreprocessingTypeB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_SoundAnalysisPreprocessing_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/TextClassifier.pb.go b/proto/coreml/spec/TextClassifier.pb.go index b3e8c8e..88115f6 100644 --- a/proto/coreml/spec/TextClassifier.pb.go +++ b/proto/coreml/spec/TextClassifier.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: TextClassifier.proto package spec @@ -136,7 +136,7 @@ const file_TextClassifier_proto_rawDesc = "" + " \x01(\tR\blanguage\x12.\n" + "\x12modelParameterData\x18d \x01(\fR\x12modelParameterData\x12S\n" + "\x11stringClassLabels\x18\xc8\x01 \x01(\v2\".CoreML.Specification.StringVectorH\x00R\x11stringClassLabelsB\r\n" + - "\vClassLabelsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\vClassLabelsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_TextClassifier_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/TreeEnsemble.pb.go b/proto/coreml/spec/TreeEnsemble.pb.go index bfec2f5..dc91dd5 100644 --- a/proto/coreml/spec/TreeEnsemble.pb.go +++ b/proto/coreml/spec/TreeEnsemble.pb.go @@ -24,7 +24,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: TreeEnsemble.proto package spec @@ -630,7 +630,7 @@ const file_TreeEnsemble_proto_rawDesc = "" + "\vNoTransform\x10\x00\x12\x1a\n" + "\x16Classification_SoftMax\x10\x01\x12\x17\n" + "\x13Regression_Logistic\x10\x02\x120\n" + - ",Classification_SoftMaxWithZeroClassReference\x10\x03B0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + ",Classification_SoftMaxWithZeroClassReference\x10\x03B2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_TreeEnsemble_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/VisionFeaturePrint.pb.go b/proto/coreml/spec/VisionFeaturePrint.pb.go index 2d9af2d..e40ee1d 100644 --- a/proto/coreml/spec/VisionFeaturePrint.pb.go +++ b/proto/coreml/spec/VisionFeaturePrint.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: VisionFeaturePrint.proto package spec @@ -341,7 +341,7 @@ const file_VisionFeaturePrint_proto_rawDesc = "" + "\x0eObjectsVersion\x12\x1b\n" + "\x17OBJECTS_VERSION_INVALID\x10\x00\x12\x15\n" + "\x11OBJECTS_VERSION_1\x10\x01B\x18\n" + - "\x16VisionFeaturePrintTypeB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specb\x06proto3" + "\x16VisionFeaturePrintTypeB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specb\x06proto3" var ( file_VisionFeaturePrint_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/WordEmbedding.pb.go b/proto/coreml/spec/WordEmbedding.pb.go index a16a9f8..8aa599a 100644 --- a/proto/coreml/spec/WordEmbedding.pb.go +++ b/proto/coreml/spec/WordEmbedding.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: WordEmbedding.proto package spec @@ -101,7 +101,7 @@ const file_WordEmbedding_proto_rawDesc = "" + "\brevision\x18\x01 \x01(\rR\brevision\x12\x1a\n" + "\blanguage\x18\n" + " \x01(\tR\blanguage\x12.\n" + - "\x12modelParameterData\x18d \x01(\fR\x12modelParameterDataB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x12modelParameterData\x18d \x01(\fR\x12modelParameterDataB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_WordEmbedding_proto_rawDescOnce sync.Once diff --git a/proto/coreml/spec/WordTagger.pb.go b/proto/coreml/spec/WordTagger.pb.go index 28627e8..3756a02 100644 --- a/proto/coreml/spec/WordTagger.pb.go +++ b/proto/coreml/spec/WordTagger.pb.go @@ -6,7 +6,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v7.34.0 // source: WordTagger.proto package spec @@ -188,7 +188,7 @@ const file_WordTagger_proto_rawDesc = "" + "\n" + "stringTags\x18\xc8\x01 \x01(\v2\".CoreML.Specification.StringVectorH\x00R\n" + "stringTagsB\x06\n" + - "\x04TagsB0H\x03Z,github.com/gomlx/go-coreml/proto/coreml/specP\x00b\x06proto3" + "\x04TagsB2H\x03Z.github.com/gomlx/go-darwinml/proto/coreml/specP\x00b\x06proto3" var ( file_WordTagger_proto_rawDescOnce sync.Once diff --git a/specs/001-initial-plan.md b/specs/001-initial-plan.md deleted file mode 100644 index 4979957..0000000 --- a/specs/001-initial-plan.md +++ /dev/null @@ -1,234 +0,0 @@ -# Plan: go-coreml Backend for GoMLX - -## Summary - -Create a CoreML backend for GoMLX, analogous to how go-xla provides the XLA backend. This would enable GoMLX models to run on Apple's Neural Engine (ANE), Metal GPU, and CPU with Apple-optimized performance. - -**Scope**: -- Inference-only, macOS -- Run GoMLX computation graphs on CoreML (no .mlmodel import needed) -- Goal is to beat XLA CPU on Apple Silicon - -## Architecture Overview - -GoMLX has a clean, pluggable backend architecture with four core interfaces: -- `Backend` - main entry point, device management -- `Builder` - graph construction (symbolic operations) -- `Executable` - compiled computation -- `DataInterface` - buffer/tensor management - -A CoreML backend would implement these interfaces using CoreML's APIs. - -## Two-Component Approach - -### Component 1: go-coreml (Low-Level Bindings) - -**Location**: New repository `github.com/gomlx/go-coreml` - -**Purpose**: Go bindings to CoreML, similar to how go-xla wraps PJRT/StableHLO - -**Implementation Strategy**: - -1. **Protobuf Generation for MLModel** - - Generate Go types from CoreML's `.proto` files (Model.proto, MIL.proto) - - Enables programmatic model generation in pure Go - - No cgo required for model construction - -2. **Objective-C++ Bridge for Runtime** - - Create thin Objective-C++ shim for CoreML runtime operations - - Wrap `MLModel`, `MLMultiArray`, `MLShapedArray` types - - Expose C-compatible functions callable via cgo - -**Key Files**: -``` -go-coreml/ -├── proto/ -│ └── coreml/ # Generated Go types from Apple's .proto files -├── internal/ -│ └── bridge/ -│ ├── bridge.h # C-compatible function declarations -│ ├── bridge.mm # Objective-C++ implementation -│ └── bridge.go # cgo wrapper -├── model/ -│ └── builder.go # Programmatic model construction (MIL) -├── runtime/ -│ ├── model.go # MLModel wrapper -│ └── tensor.go # MLShapedArray/MLMultiArray wrapper -└── ops/ - └── mil_ops.go # MIL operation builders -``` - -### Component 2: GoMLX CoreML Backend - -**Location**: `github.com/gomlx/gomlx/backends/coreml` - -**Purpose**: Implement GoMLX's backend interfaces using go-coreml - -**Key Types**: - -```go -// Backend implementation -type Backend struct { - device DeviceType // ANE, GPU, CPU, or All - computeUnits ComputeUnits -} - -// Builder implementation (wraps MIL program builder) -type Builder struct { - backend *Backend - program *gocoreml.MILProgram - ops []*gocoreml.MILOperation -} - -// Executable implementation (wraps compiled MLModel) -type Executable struct { - model *gocoreml.Model - inputNames []string - outputNames []string -} - -// Node wrapping MIL values -type Node struct { - value *gocoreml.MILValue - shape shapes.Shape - builder *Builder -} -``` - -## Operation Mapping - -### Supported Operations (CoreML has 500+ ops) - -| GoMLX Operation | MIL Operation | Notes | -|-----------------|---------------|-------| -| Add, Sub, Mul, Div | add, sub, mul, real_div | Direct mapping | -| MatMul | matmul | Direct mapping | -| Conv2D | conv | Requires weight format conversion | -| MaxPool, AvgPool | max_pool, avg_pool | Direct mapping | -| Relu, Sigmoid, Tanh | relu, sigmoid, tanh | Direct mapping | -| Softmax | softmax | Direct mapping | -| BatchNorm | batch_norm | Direct mapping | -| Reshape, Transpose | reshape, transpose | Direct mapping | -| Concat | concat | Direct mapping | -| Reduce* | reduce_* | Direct mapping | - -### Operations Requiring Composite Implementation - -| GoMLX Operation | Implementation Strategy | -|-----------------|------------------------| -| Custom einsum | Decompose to matmul/transpose | -| Scatter/Gather | Multiple MIL ops | -| Complex dtypes | Split real/imag channels | - -### Unsupported Operations - -- Distributed/SPMD operations (CoreML is single-device) -- Some advanced XLA-specific operations - -## Execution Flow - -``` -GoMLX User Code - ↓ -graph.Exec() or context.Exec() - ↓ -CoreMLBuilder.Parameter(), Add(), MatMul(), etc. - ↓ (builds MIL program) -CoreMLBuilder.Compile() - ↓ (generates .mlmodel protobuf, loads via CoreML runtime) -CoreMLExecutable - ↓ -CoreMLExecutable.Execute(inputs) - ↓ (MLModel.prediction()) -Returns GoMLX Buffer (wrapping MLShapedArray) -``` - -## Implementation Phases - -### Phase 1: go-coreml Foundation -- Create Objective-C++ bridge for CoreML runtime operations -- Implement MLShapedArray wrapper for tensor I/O -- Test: create simple model in Obj-C++, call from Go -- Validate we can execute basic ops via CoreML from Go - -### Phase 2: MIL Program Builder -- Generate Go protobuf types from CoreML MIL.proto -- Implement MIL operation builders in pure Go -- Support basic ops: arithmetic, matmul, activations -- Generate valid MIL program, load via CoreML runtime - -### Phase 3: GoMLX Backend Integration -- Implement `Backend`, `Builder`, `Executable` interfaces -- Map GoMLX operations to MIL operations -- Handle shape inference and type conversion -- Register backend with GoMLX - -### Phase 4: Operation Completeness -- Implement remaining standard operations -- Add convolution, pooling, normalization -- Implement composite operations for unsupported ops -- Add comprehensive tests - -### Phase 5: Optimization & Benchmarking -- Optimize memory management (buffer reuse) -- Add compute unit selection (ANE/GPU/CPU) -- Performance benchmarking vs XLA CPU backend -- Ensure we beat XLA CPU on common workloads - -## Technical Challenges - -1. **No Official C/C++ API**: CoreML only has Objective-C/Swift APIs. Requires Objective-C++ bridge with cgo. - -2. **Graph Compilation Model**: CoreML compiles to .mlmodel files, not JIT like XLA. May need to cache compiled models. - -3. **Tensor Format Differences**: CoreML uses NHWC or NCHW depending on operation. Need careful format handling. - -4. **Dynamic Shapes**: CoreML supports dynamic shapes but requires explicit range specification. May limit flexibility. - -5. **Platform Lock-in**: CoreML only runs on Apple platforms (macOS, iOS, tvOS, watchOS). - -## Alternatives Considered - -1. **ONNX Runtime with CoreML EP**: Could use ONNX as intermediate format. More mature but adds complexity. - -2. **Metal Performance Shaders**: Lower-level but doesn't leverage ANE. Would be similar to writing a custom GPU backend. - -3. **coremltools Python**: Generate models via Python subprocess. Works but adds Python dependency. - -## Success Criteria - -- [ ] GoMLX computation graphs compile to CoreML and execute -- [ ] 80%+ of GoMLX standard operations supported -- [ ] Performance beats XLA CPU backend on Apple Silicon -- [ ] ANE/Metal acceleration works (measurable via Instruments) - -## Confirmed Requirements - -- **Use case**: Inference only (no training/gradients needed) -- **Platform**: macOS only -- **Model source**: GoMLX graphs only (no .mlmodel import needed) -- **Performance goal**: Beat XLA CPU backend on Apple Silicon - -## Critical Files to Reference - -**GoMLX Backend Interfaces** (implement these): -- `gomlx/backends/backends.go` - Backend interface -- `gomlx/backends/builder.go` - Builder interface -- `gomlx/backends/executable.go` - Executable interface -- `gomlx/backends/data.go` - DataInterface -- `gomlx/backends/standard_ops.go` - Operations to implement - -**Reference Implementations**: -- `gomlx/backends/xla/` - Full XLA backend (production reference) -- `gomlx/backends/simplego/` - Pure Go interpreter (simpler reference) - -**CoreML Specification** (for protobuf generation): -- https://github.com/apple/coremltools/tree/main/mlmodel/format - -## Next Steps - -1. Create minimal Objective-C++ bridge proof-of-concept -2. Validate we can execute CoreML operations from Go (e.g., matmul) -3. Generate Go protobuf types from CoreML MIL.proto -4. Implement MIL program builder for graph construction -5. Create GoMLX backend wrapper implementing Backend/Builder/Executable diff --git a/specs/002-gomlx-backend.md b/specs/002-gomlx-backend.md deleted file mode 100644 index b62f576..0000000 --- a/specs/002-gomlx-backend.md +++ /dev/null @@ -1,747 +0,0 @@ -# GoMLX CoreML Backend Implementation Plan - -## Status - -**Phases 1-2: Complete** -- Bridge package with cgo bindings to CoreML -- MIL program builder with fluent API -- Model serialization (.mlpackage format) -- Runtime compilation and execution -- Basic operations (Add, Sub, Mul, Div, MatMul, Relu, Sigmoid, etc.) - -**Phase 3: Complete** ✅ -- GoMLX Backend interface implementation -- Builder, Executable, Buffer/DataInterface -- 16 operations mapped to CoreML MIL -- Integration tests passing -- See [Implementation Notes](#phase-3-implementation-notes) below - -**Phases 4-5: Pending** - ---- - -## Phase 3: GoMLX Backend Integration - -### Overview - -Implement GoMLX's backend interfaces using go-coreml. This creates the bridge between GoMLX computation graphs and CoreML execution. - -### 3.1 Study GoMLX Backend Interfaces - -**Files to Read**: -``` -gomlx/backends/backends.go - Backend interface -gomlx/backends/builder.go - Builder interface -gomlx/backends/executable.go - Executable interface -gomlx/backends/data.go - DataInterface (buffers) -gomlx/backends/standard_ops.go - StandardOps enum -gomlx/backends/simplego/ - Reference implementation (pure Go) -gomlx/backends/xla/ - Production reference (XLA) -``` - -**Key Questions to Answer**: -1. What methods does `Backend` interface require? -2. What is the lifecycle of `Builder` → `Executable`? -3. How does `DataInterface` handle buffer management? -4. What is `Node` and how does it relate to operations? -5. How does shape inference work? - -### 3.2 Implement Backend Interface - -**Location**: `github.com/gomlx/gomlx/backends/coreml/backend.go` - -```go -type Backend struct { - computeUnits bridge.ComputeUnits - cacheDir string -} - -// Required methods (study interface to confirm): -func (b *Backend) Name() string -func (b *Backend) NewBuilder(name string) backends.Builder -func (b *Backend) NewBuffer(shape shapes.Shape) backends.Buffer -func (b *Backend) Platform() string -func (b *Backend) Close() error -``` - -**Tasks**: -- [ ] Read GoMLX Backend interface definition -- [ ] Implement all required methods -- [ ] Add compute unit configuration (ANE, GPU, CPU, All) -- [ ] Handle platform detection (macOS only) -- [ ] Register backend with GoMLX registry - -### 3.3 Implement Builder Interface - -**Location**: `github.com/gomlx/gomlx/backends/coreml/builder.go` - -```go -type Builder struct { - backend *Backend - milBuilder *model.Builder - nodeMap map[backends.NodeID]*Node - nextNodeID backends.NodeID -} - -type Node struct { - id backends.NodeID - value *model.Value - shape shapes.Shape - dtype dtypes.DType -} - -// Key methods: -func (b *Builder) Parameter(name string, shape shapes.Shape) backends.Node -func (b *Builder) Constant(value interface{}, shape shapes.Shape) backends.Node -func (b *Builder) Op(op backends.StandardOp, inputs ...backends.Node) backends.Node -func (b *Builder) Compile(outputs []backends.Node) backends.Executable -``` - -**Tasks**: -- [ ] Study Builder interface requirements -- [ ] Implement Parameter (maps to model.Input) -- [ ] Implement Constant (maps to model.Const) -- [ ] Create operation dispatch table -- [ ] Implement shape inference helpers -- [ ] Implement dtype conversion (GoMLX dtypes → CoreML dtypes) - -### 3.4 Implement Executable Interface - -**Location**: `github.com/gomlx/gomlx/backends/coreml/executable.go` - -```go -type Executable struct { - backend *Backend - runtime *runtime.Executable - inputNames []string - outputNames []string - inputShapes []shapes.Shape - outputShapes []shapes.Shape -} - -// Key methods: -func (e *Executable) Execute(inputs []backends.Buffer) ([]backends.Buffer, error) -func (e *Executable) Close() error -func (e *Executable) InputShapes() []shapes.Shape -func (e *Executable) OutputShapes() []shapes.Shape -``` - -**Tasks**: -- [ ] Study Executable interface requirements -- [ ] Implement Execute with buffer conversion -- [ ] Handle input/output shape validation -- [ ] Implement proper cleanup in Close() - -### 3.5 Implement Buffer/DataInterface - -**Location**: `github.com/gomlx/gomlx/backends/coreml/buffer.go` - -```go -type Buffer struct { - data []byte - shape shapes.Shape - dtype dtypes.DType -} - -// Methods: -func (buf *Buffer) Shape() shapes.Shape -func (buf *Buffer) DType() dtypes.DType -func (buf *Buffer) Bytes() []byte -func (buf *Buffer) CopyFrom(data interface{}) -func (buf *Buffer) CopyTo(dst interface{}) -``` - -**Tasks**: -- [ ] Study DataInterface requirements -- [ ] Implement buffer creation and management -- [ ] Handle type conversions (Go types ↔ CoreML types) -- [ ] Implement efficient memory copying - -### 3.6 Operation Mapping Layer - -**Location**: `github.com/gomlx/gomlx/backends/coreml/ops.go` - -Create dispatch table mapping GoMLX StandardOps to MIL operations: - -```go -type opHandler func(b *Builder, inputs []*Node) (*Node, error) - -var opTable = map[backends.StandardOp]opHandler{ - backends.OpAdd: handleAdd, - backends.OpSub: handleSub, - backends.OpMul: handleMul, - backends.OpDiv: handleDiv, - backends.OpMatMul: handleMatMul, - backends.OpRelu: handleRelu, - // ... etc -} - -func (b *Builder) dispatchOp(op backends.StandardOp, inputs []*Node) (*Node, error) { - handler, ok := opTable[op] - if !ok { - return nil, fmt.Errorf("unsupported operation: %v", op) - } - return handler(b, inputs) -} -``` - -**Tasks**: -- [ ] List all StandardOps from GoMLX -- [ ] Create handler for each supported op -- [ ] Implement shape inference per operation -- [ ] Handle broadcasting rules -- [ ] Return clear errors for unsupported ops - -### 3.7 Backend Registration - -Use build tags to conditionally compile the backend only on macOS. - -**Location**: `github.com/gomlx/gomlx/backends/coreml/register_darwin.go` - -```go -//go:build darwin - -package coreml - -func init() { - backends.Register("coreml", func() backends.Backend { - return New() - }) -} -``` - -**Location**: `github.com/gomlx/gomlx/backends/coreml/register_other.go` - -```go -//go:build !darwin - -package coreml - -// No-op on non-macOS platforms - backend not registered -``` - -This approach: -- Keeps the package importable on all platforms -- Avoids runtime checks -- Lets the compiler exclude CoreML code entirely on non-Apple platforms -- Prevents linker errors from missing Objective-C frameworks - -### 3.8 Integration Testing - -Create tests that use GoMLX API with CoreML backend: - -```go -func TestGoMLXIntegration(t *testing.T) { - backend := coreml.New() - - // Use GoMLX graph API - g := graph.New(backend) - x := g.Parameter("x", shapes.Make(dtypes.Float32, 2, 3)) - y := g.Relu(x) - - exec := g.Compile(y) - defer exec.Close() - - input := []float32{-1, 2, -3, 4, -5, 6} - output, err := exec.Execute(input) - // ... verify output -} -``` - -**Tasks**: -- [x] Test basic operations through GoMLX API -- [x] Test multi-operation graphs -- [x] Test parameter passing -- [x] Test shape inference -- [ ] Verify numerical correctness vs simplego backend - ---- - -## Phase 3 Implementation Notes - -### Files Created - -Location: `github.com/gomlx/gomlx/backends/coreml/` - -| File | Lines | Description | -|------|-------|-------------| -| `backend.go` | 134 | Backend interface, runtime integration, configuration | -| `builder.go` | 284 | Builder interface wrapping go-coreml's model.Builder | -| `executable.go` | 302 | Executable interface for model execution | -| `buffer.go` | 271 | Buffer/DataInterface with sync.Pool buffer pooling | -| `ops.go` | 401 | Operation mapping layer (16 ops implemented) | -| `capabilities.go` | 57 | Supported operations and dtypes | -| `register_darwin.go` | 11 | Backend registration on macOS | -| `register_other.go` | 8 | No-op stub for other platforms | -| `doc.go` | 37 | Package documentation | -| `coreml_test.go` | 453 | Integration tests | - -### Implemented Operations - -| Category | Operations | -|----------|------------| -| **Unary** | Abs, Neg*, Exp, Log, Sqrt, Tanh, Logistic (Sigmoid) | -| **Binary** | Add, Sub, Mul, Div | -| **Shape** | Reshape, Transpose | -| **Reduction** | ReduceSum, ReduceMax | -| **Matrix** | DotGeneral (simple MatMul case) | - -*Neg is implemented as `mul(x, -1)` since CoreML lacks a native neg operator. - -### Key Implementation Discoveries - -1. **CoreML requires "main" function name** - - CoreML MIL programs must have a function named "main" - - Fixed by always using `model.NewBuilder("main")` regardless of user-provided name - - The user's name is preserved in the GoMLX Builder for identification - -2. **Missing CoreML operators** - - CoreML MIL doesn't have a native `neg` operator - - Implemented negation as `mul(x, -1)` with a scalar constant - - This pattern may apply to other missing operators - -3. **Buffer pooling** - - Implemented efficient memory reuse using `sync.Pool` keyed by (dtype, length) - - Follows the same pattern as simplego backend - - Reduces GC pressure for repeated executions - -4. **Build tags** - - `//go:build darwin` on all implementation files - - Stub `register_other.go` allows importing on non-macOS platforms without errors - -5. **Shape inference** - - Uses `github.com/gomlx/gomlx/backends/shapeinference` package - - Consistent with other backends - -### Test Results - -``` -=== RUN TestBackendCreation --- PASS -=== RUN TestBufferOperations --- PASS -=== RUN TestSharedBuffer --- PASS -=== RUN TestBuilderParameterAndConstant --- PASS -=== RUN TestAddOperation --- PASS -=== RUN TestUnaryOperations --- PASS (Abs, Neg, Exp, Sqrt) -=== RUN TestBinaryOperations --- PASS (Add, Sub, Mul, Div) -=== RUN TestReshape --- PASS -=== RUN TestReduceSum --- PASS -=== RUN TestChainedOperations --- PASS -PASS ok github.com/gomlx/gomlx/backends/coreml 0.765s -``` - -### Usage - -```go -import _ "github.com/gomlx/gomlx/backends/coreml" - -// Or set environment variable: -// export GOMLX_BACKEND=coreml - -// Or create directly: -backend, err := coreml.New("") -``` - ---- - -## Phase 4: Operation Completeness - -### 4.1 Core Operations (Priority 1) - -These operations are essential for most models: - -| GoMLX Op | MIL Op | Status | Notes | -|----------|--------|--------|-------| -| Add | add | Done | go-coreml | -| Sub | sub | Done | go-coreml | -| Mul | mul | Done | go-coreml | -| Div | real_div | Done | go-coreml | -| MatMul | matmul | Done | go-coreml | -| Relu | relu | Done | go-coreml | -| Sigmoid | sigmoid | Done | go-coreml | -| Tanh | tanh | Done | go-coreml | -| Exp | exp | Done | go-coreml | -| Log | log | Done | go-coreml | -| Sqrt | sqrt | Done | go-coreml | -| Neg | neg | Done | go-coreml | -| Abs | abs | Done | go-coreml | -| Softmax | softmax | Done | go-coreml | -| Reshape | reshape | Done | go-coreml | -| Transpose | transpose | Done | go-coreml | -| ReduceSum | reduce_sum | Done | go-coreml | -| ReduceMean | reduce_mean | Done | go-coreml | -| ReduceMax | reduce_max | Done | go-coreml | - -### 4.2 Convolution Operations (Priority 2) - -Essential for vision models: - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Conv2D | conv | Need to handle padding, strides, dilation | -| ConvTranspose2D | conv_transpose | Deconvolution | -| MaxPool2D | max_pool | Need to handle padding, strides | -| AvgPool2D | avg_pool | Need to handle padding, strides | - -**Tasks**: -- [ ] Implement Conv2D with all parameters -- [ ] Handle NHWC vs NCHW format conversion -- [ ] Implement pooling operations -- [ ] Test with simple CNN architectures - -### 4.3 Normalization Operations (Priority 2) - -Essential for deep networks: - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| BatchNorm | batch_norm | Need running mean/var | -| LayerNorm | layer_norm | | -| InstanceNorm | instance_norm | | - -**Tasks**: -- [ ] Implement BatchNorm with training/inference modes -- [ ] Implement LayerNorm -- [ ] Test normalization numerical accuracy - -### 4.4 Tensor Manipulation (Priority 2) - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Concat | concat | Along specified axis | -| Split | split | | -| Slice | slice_by_index | | -| Gather | gather | | -| Scatter | scatter | May need decomposition | -| Stack | stack | | -| Squeeze | squeeze | | -| ExpandDims | expand_dims | | -| Tile | tile | | -| Pad | pad | | - -**Tasks**: -- [ ] Implement each operation -- [ ] Handle axis/dimension parameters correctly -- [ ] Test edge cases (empty tensors, single elements) - -### 4.5 Comparison Operations (Priority 3) - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Equal | equal | | -| NotEqual | not_equal | | -| Less | less | | -| LessEqual | less_equal | | -| Greater | greater | | -| GreaterEqual | greater_equal | | -| Where | select | Conditional selection | - -### 4.6 Additional Math Operations (Priority 3) - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Pow | pow | | -| Sin, Cos, Tan | sin, cos, tan | | -| Floor, Ceil | floor, ceil | | -| Clip | clip | | -| Gelu | gelu | | -| Erf | erf | | - -### 4.7 Attention Operations (Priority 3) - -For transformer models: - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Einsum | Decompose | Break into matmul/transpose | -| ScaledDotProductAttention | scaled_dot_product_attention | CoreML 7+ | - -**Tasks**: -- [ ] Implement einsum decomposition -- [ ] Use native attention op when available -- [ ] Test with transformer architectures - -### 4.8 Composite Operations - -Some GoMLX operations need multiple MIL ops: - -```go -// Example: Einsum decomposition -func handleEinsum(b *Builder, equation string, inputs []*Node) (*Node, error) { - // Parse einsum equation - // Decompose into sequence of: - // - transpose - // - reshape - // - matmul - // - reduce -} -``` - -**Tasks**: -- [ ] Identify operations needing decomposition -- [ ] Implement decomposition strategies -- [ ] Verify numerical equivalence - -### 4.9 Test Suite - -Create comprehensive tests for each operation: - -```go -func TestOp_Add(t *testing.T) { - testCases := []struct{ - name string - a, b []float32 - aShape, bShape shapes.Shape - want []float32 - }{ - {"simple", []float32{1,2}, []float32{3,4}, ...}, - {"broadcast", ...}, - {"scalar", ...}, - } - // Test each case -} -``` - -**Tasks**: -- [ ] Create test cases for each operation -- [ ] Test broadcasting behavior -- [ ] Test edge cases (empty, scalar, large) -- [ ] Compare results vs simplego backend -- [ ] Add fuzzing tests for numerical stability - ---- - -## Phase 5: Optimization & Benchmarking - -### 5.1 Memory Management - -**Buffer Pool**: -```go -type BufferPool struct { - pools map[int]*sync.Pool // keyed by size -} - -func (p *BufferPool) Get(size int) *Buffer -func (p *BufferPool) Put(buf *Buffer) -``` - -**Tasks**: -- [ ] Implement buffer pooling -- [ ] Reduce allocations in hot paths -- [ ] Profile memory usage -- [ ] Add memory usage metrics - -### 5.2 Model Caching - -Cache compiled CoreML models to avoid recompilation: - -```go -type ModelCache struct { - dir string - mu sync.RWMutex - entries map[string]*cacheEntry -} - -type cacheEntry struct { - path string - hash string - lastUsed time.Time - executable *runtime.Executable -} - -func (c *ModelCache) Get(programHash string) (*runtime.Executable, bool) -func (c *ModelCache) Put(programHash string, exec *runtime.Executable) -``` - -**Tasks**: -- [ ] Compute stable hash of MIL programs -- [ ] Cache compiled .mlmodelc directories -- [ ] Implement LRU eviction -- [ ] Add cache statistics - -### 5.3 Compute Unit Selection - -```go -type ComputeConfig struct { - Units ComputeUnits // All, CPUOnly, CPUAndGPU, CPUAndANE - AllowANE bool // Allow Neural Engine - AllowGPU bool // Allow Metal GPU - Fallback bool // Allow fallback to CPU -} - -func (b *Backend) WithComputeConfig(cfg ComputeConfig) *Backend -``` - -**Tasks**: -- [ ] Expose compute unit configuration -- [ ] Test each configuration -- [ ] Document performance characteristics -- [ ] Auto-select based on model characteristics - -### 5.4 Benchmarking Infrastructure - -Create comprehensive benchmarks: - -```go -func BenchmarkMatMul(b *testing.B) { - sizes := []int{64, 128, 256, 512, 1024, 2048} - backends := []string{"coreml", "xla", "simplego"} - - for _, size := range sizes { - for _, backend := range backends { - b.Run(fmt.Sprintf("%s/%d", backend, size), func(b *testing.B) { - // Benchmark matmul of size x size - }) - } - } -} -``` - -**Benchmark Categories**: -1. **Micro-benchmarks**: Individual operations -2. **Layer benchmarks**: Conv, Attention, FFN blocks -3. **Model benchmarks**: Full model inference - -### 5.5 Performance Targets - -| Workload | XLA CPU | CoreML Target | Notes | -|----------|---------|---------------|-------| -| MatMul 1024x1024 | X ms | < X ms | ANE should excel | -| Conv2D 224x224 | X ms | < X ms | Vision workload | -| Transformer layer | X ms | < X ms | Attention + FFN | -| BERT inference | X ms | < X ms | Full model | -| ResNet-50 | X ms | < X ms | Vision model | - -**Tasks**: -- [ ] Establish XLA CPU baselines on Apple Silicon -- [ ] Measure CoreML performance per workload -- [ ] Identify and fix performance regressions -- [ ] Document performance vs XLA - -### 5.6 Profiling Tools - -```go -type Profiler struct { - enabled bool - traces []TraceEvent -} - -type TraceEvent struct { - Name string - Start time.Time - Duration time.Duration - Op string - Shape shapes.Shape -} - -func (p *Profiler) Start(name string) func() -func (p *Profiler) Report() string -``` - -**Tasks**: -- [ ] Add timing instrumentation -- [ ] Integrate with Instruments.app -- [ ] Create performance reports -- [ ] Add ANE/GPU utilization tracking - -### 5.7 Optimization Techniques - -1. **Operation Fusion**: Combine sequential operations - ```go - // Fuse: MatMul + Add + Relu → fused_matmul_add_relu - func fuseOperations(ops []*Operation) []*Operation - ``` - -2. **Layout Optimization**: Choose optimal tensor layout - ```go - // Prefer NHWC for ANE, NCHW for GPU - func optimizeLayout(model *Model, target ComputeUnits) *Model - ``` - -3. **Constant Folding**: Pre-compute constant expressions - ```go - func foldConstants(ops []*Operation) []*Operation - ``` - -**Tasks**: -- [ ] Implement operation fusion for common patterns -- [ ] Add layout optimization pass -- [ ] Implement constant folding -- [ ] Measure impact of each optimization - -### 5.8 CI/CD Integration - -```yaml -# .github/workflows/benchmark.yml -name: Benchmarks -on: [push, pull_request] -jobs: - benchmark: - runs-on: macos-latest - steps: - - uses: actions/checkout@v4 - - name: Run benchmarks - run: go test -bench=. -benchmem ./... - - name: Compare to baseline - run: ./scripts/compare-benchmarks.sh -``` - -**Tasks**: -- [ ] Set up macOS CI runners -- [ ] Automate benchmark comparison -- [ ] Track performance over time -- [ ] Alert on regressions - ---- - -## Deliverables Summary - -### Phase 3 Deliverables -- [ ] Working GoMLX CoreML backend -- [ ] Backend registration with GoMLX -- [ ] Basic operations through GoMLX API -- [ ] Integration tests passing - -### Phase 4 Deliverables -- [ ] 80%+ GoMLX operations supported -- [ ] Convolution operations working -- [ ] Normalization operations working -- [ ] Comprehensive test suite - -### Phase 5 Deliverables -- [ ] Performance exceeds XLA CPU on Apple Silicon -- [ ] Memory-efficient buffer management -- [ ] Model caching implemented -- [ ] Benchmark suite and reports -- [ ] Documentation complete - ---- - -## Dependencies - -**Go Dependencies**: -- `github.com/gomlx/gomlx` - GoMLX framework -- `github.com/gomlx/go-coreml` - This package - -**System Requirements**: -- macOS 12+ (for CoreML 5+) -- Xcode Command Line Tools -- Apple Silicon or Intel Mac - ---- - -## Timeline Estimates - -| Phase | Scope | Complexity | -|-------|-------|------------| -| Phase 3 | Backend integration | Medium - requires understanding GoMLX interfaces | -| Phase 4 | Operation completeness | Medium-High - many operations, edge cases | -| Phase 5 | Optimization | High - performance tuning is iterative | - ---- - -## Open Questions - -1. **Dynamic Shapes**: How does GoMLX handle dynamic shapes? CoreML requires shape ranges. -2. **Gradient Support**: If training is ever needed, CoreML has limited gradient support. -3. **Multi-device**: Should we support running on multiple compute units simultaneously? -4. **iOS/tvOS**: Should the backend support non-macOS Apple platforms? diff --git a/specs/003-additional-tensor-manipulations.md b/specs/003-additional-tensor-manipulations.md deleted file mode 100644 index d1b07b3..0000000 --- a/specs/003-additional-tensor-manipulations.md +++ /dev/null @@ -1,946 +0,0 @@ -# Additional Tensor Manipulations for CoreML Backend - -## Overview - -This document details the implementation plan for expanding the CoreML backend's operation coverage. Phase 3 established the foundation with 16 operations. This phase adds the remaining operations needed for practical ML workloads. - -## Current State - -**Implemented (Phase 3):** -- Unary: Abs, Neg, Exp, Log, Sqrt, Tanh, Logistic -- Binary: Add, Sub, Mul, Div -- Shape: Reshape, Transpose -- Reduction: ReduceSum, ReduceMax -- Matrix: DotGeneral (simple case) - -**Target:** 80%+ of GoMLX StandardOps - ---- - -## Priority 1: Core Missing Operations - -These are frequently used and block common model patterns. - -### 1.1 Comparison Operations - -Required for control flow, masking, and conditional logic. - -| GoMLX Op | MIL Op | Implementation Notes | -|----------|--------|---------------------| -| Equal | `equal` | Returns bool tensor | -| NotEqual | `not_equal` | Returns bool tensor | -| Less | `less` | Returns bool tensor | -| LessEqual | `less_equal` | Returns bool tensor | -| Greater | `greater` | Returns bool tensor | -| GreaterEqual | `greater_equal` | Returns bool tensor | - -**Implementation:** - -```go -// In ops.go - add comparison helper -func (b *Builder) addComparisonOp( - opType backends.OpType, - milOp func(*model.Value, *model.Value) *model.Value, - lhs, rhs backends.Op, -) (*Node, error) { - inputs, err := b.checkOps(opType.String(), lhs, rhs) - if err != nil { - return nil, err - } - lhsNode, rhsNode := inputs[0], inputs[1] - - // Output shape follows broadcasting, but dtype is always Bool - outputShape, err := shapeinference.ComparisonOp(opType, lhsNode.shape, rhsNode.shape) - if err != nil { - return nil, err - } - - resultValue := milOp(lhsNode.milValue, rhsNode.milValue) - node := b.newNode(opType, outputShape, resultValue, lhsNode, rhsNode) - return node, nil -} - -func (b *Builder) Equal(lhs, rhs backends.Op) (backends.Op, error) { - return b.addComparisonOp(backends.OpTypeEqual, b.milBuilder.Equal, lhs, rhs) -} -// ... similar for other comparison ops -``` - -**go-coreml additions needed:** - -```go -// In model/ops.go -func (b *Builder) Equal(x, y *Value) *Value { - outShape := broadcastShape(x.shape, y.shape) - return b.addOp("equal", map[string]*Value{ - "x": x, - "y": y, - }, b.genName("equal"), Bool, outShape) -} - -func (b *Builder) Less(x, y *Value) *Value { - outShape := broadcastShape(x.shape, y.shape) - return b.addOp("less", map[string]*Value{ - "x": x, - "y": y, - }, b.genName("less"), Bool, outShape) -} -// ... similar for NotEqual, LessEqual, Greater, GreaterEqual -``` - -**Tasks:** -- [ ] Add comparison ops to go-coreml/model/ops.go -- [ ] Add comparison ops to gomlx/backends/coreml/ops.go -- [ ] Update capabilities.go -- [ ] Add tests for each comparison op -- [ ] Test broadcasting with comparisons - ---- - -### 1.2 Select/Where Operation - -Critical for conditional tensor operations. - -| GoMLX Op | MIL Op | Implementation Notes | -|----------|--------|---------------------| -| Where | `select` | `select(cond, a, b)` - returns a where cond else b | - -**Implementation:** - -```go -// GoMLX's Where takes (condition, onTrue, onFalse) -func (b *Builder) Where(condition, onTrue, onFalse backends.Op) (backends.Op, error) { - inputs, err := b.checkOps("Where", condition, onTrue, onFalse) - if err != nil { - return nil, err - } - cond, trueVal, falseVal := inputs[0], inputs[1], inputs[2] - - // Validate condition is bool - if cond.shape.DType != dtypes.Bool { - return nil, errors.Errorf("Where: condition must be bool, got %s", cond.shape.DType) - } - - // Output shape is broadcast of onTrue and onFalse - outputShape, err := shapeinference.WhereOp(cond.shape, trueVal.shape, falseVal.shape) - if err != nil { - return nil, err - } - - resultValue := b.milBuilder.Select(cond.milValue, trueVal.milValue, falseVal.milValue) - node := b.newNode(backends.OpTypeWhere, outputShape, resultValue, cond, trueVal, falseVal) - return node, nil -} -``` - -**go-coreml addition:** - -```go -func (b *Builder) Select(cond, a, bVal *Value) *Value { - outShape := broadcastShape(a.shape, bVal.shape) - return b.addOp("select", map[string]*Value{ - "cond": cond, - "a": a, - "b": bVal, - }, b.genName("select"), a.dtype, outShape) -} -``` - -**Tasks:** -- [ ] Add Select to go-coreml/model/ops.go -- [ ] Add Where to gomlx/backends/coreml/ops.go -- [ ] Add tests for Where with various shapes -- [ ] Test Where with broadcasting - ---- - -### 1.3 Additional Math Operations - -| GoMLX Op | MIL Op | Implementation Notes | -|----------|--------|---------------------| -| Pow | `pow` | x^y element-wise | -| Max (binary) | `maximum` | Element-wise max | -| Min (binary) | `minimum` | Element-wise min | -| Floor | `floor` | Round down | -| Ceil | `ceil` | Round up | -| Round | `round` | Round to nearest | -| Sign | `sign` | -1, 0, or 1 | -| Cos | `cos` | Cosine | -| Sin | `sin` | Sine | -| Acos | `acos` | Arc cosine | -| Asin | `asin` | Arc sine | -| Atan | `atan` | Arc tangent | -| Cosh | `cosh` | Hyperbolic cosine | -| Sinh | `sinh` | Hyperbolic sine | -| Erf | `erf` | Error function | - -**go-coreml additions:** - -```go -func (b *Builder) Pow(x, y *Value) *Value { - outShape := broadcastShape(x.shape, y.shape) - return b.addOp("pow", map[string]*Value{ - "x": x, - "y": y, - }, b.genName("pow"), x.dtype, outShape) -} - -func (b *Builder) Maximum(x, y *Value) *Value { - outShape := broadcastShape(x.shape, y.shape) - return b.addOp("maximum", map[string]*Value{ - "x": x, - "y": y, - }, b.genName("maximum"), x.dtype, outShape) -} - -func (b *Builder) Minimum(x, y *Value) *Value { - outShape := broadcastShape(x.shape, y.shape) - return b.addOp("minimum", map[string]*Value{ - "x": x, - "y": y, - }, b.genName("minimum"), x.dtype, outShape) -} - -func (b *Builder) Floor(x *Value) *Value { - return b.addOp("floor", map[string]*Value{"x": x}, b.genName("floor"), x.dtype, x.shape) -} - -func (b *Builder) Ceil(x *Value) *Value { - return b.addOp("ceil", map[string]*Value{"x": x}, b.genName("ceil"), x.dtype, x.shape) -} - -func (b *Builder) Round(x *Value) *Value { - return b.addOp("round", map[string]*Value{"x": x}, b.genName("round"), x.dtype, x.shape) -} - -func (b *Builder) Sign(x *Value) *Value { - return b.addOp("sign", map[string]*Value{"x": x}, b.genName("sign"), x.dtype, x.shape) -} - -func (b *Builder) Cos(x *Value) *Value { - return b.addOp("cos", map[string]*Value{"x": x}, b.genName("cos"), x.dtype, x.shape) -} - -func (b *Builder) Sin(x *Value) *Value { - return b.addOp("sin", map[string]*Value{"x": x}, b.genName("sin"), x.dtype, x.shape) -} - -func (b *Builder) Erf(x *Value) *Value { - return b.addOp("erf", map[string]*Value{"x": x}, b.genName("erf"), x.dtype, x.shape) -} -``` - -**Tasks:** -- [ ] Add all math ops to go-coreml/model/ops.go -- [ ] Add corresponding ops to gomlx/backends/coreml/ops.go -- [ ] Update capabilities.go -- [ ] Add tests for each operation - ---- - -## Priority 2: Tensor Manipulation Operations - -### 2.1 Concatenate - -Join tensors along an axis. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Concatenate | `concat` | Multiple inputs along axis | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) Concat(values []*Value, axis int) *Value { - if len(values) == 0 { - panic("Concat requires at least one input") - } - - // Build input argument with multiple bindings - inputs := make(map[string]*Value) - for i, v := range values { - inputs[fmt.Sprintf("values_%d", i)] = v - } - - // Compute output shape - outShape := make([]int64, len(values[0].shape)) - copy(outShape, values[0].shape) - for i := 1; i < len(values); i++ { - outShape[axis] += values[i].shape[axis] - } - - axisVal := b.Const(b.genName("axis"), Int32, []int64{}, []int32{int32(axis)}) - - // Note: concat takes a tuple of values, need special handling - return b.addOpWithTuple("concat", values, map[string]*Value{ - "axis": axisVal, - }, b.genName("concat"), values[0].dtype, outShape) -} -``` - -**Note:** MIL's `concat` takes a tuple of values, which requires special serialization. May need to add tuple support to the builder. - -**Tasks:** -- [ ] Research MIL tuple syntax for concat -- [ ] Add Concat to go-coreml/model/ops.go -- [ ] Add Concatenate to gomlx/backends/coreml/ops.go -- [ ] Test with 2, 3, and many tensors -- [ ] Test with different axes - ---- - -### 2.2 Slice Operations - -Extract sub-tensors. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Slice | `slice_by_index` | Extract range along each axis | -| DynamicSlice | `slice_by_index` | With dynamic start indices | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) SliceByIndex(x *Value, begin, end, strides []int64) *Value { - // Compute output shape - outShape := make([]int64, len(x.shape)) - for i := range outShape { - start := begin[i] - stop := end[i] - stride := strides[i] - if stride == 0 { - stride = 1 - } - outShape[i] = (stop - start + stride - 1) / stride - } - - beginVal := b.Const(b.genName("begin"), Int32, []int64{int64(len(begin))}, toInt32Slice(begin)) - endVal := b.Const(b.genName("end"), Int32, []int64{int64(len(end))}, toInt32Slice(end)) - stridesVal := b.Const(b.genName("strides"), Int32, []int64{int64(len(strides))}, toInt32Slice(strides)) - - return b.addOp("slice_by_index", map[string]*Value{ - "x": x, - "begin": beginVal, - "end": endVal, - "strides": stridesVal, - }, b.genName("slice"), x.dtype, outShape) -} -``` - -**Tasks:** -- [ ] Add SliceByIndex to go-coreml/model/ops.go -- [ ] Add Slice to gomlx/backends/coreml/ops.go -- [ ] Handle negative indices -- [ ] Test with various slice patterns -- [ ] Test with strides - ---- - -### 2.3 Gather and Scatter - -Index-based tensor operations. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Gather | `gather` | Gather along axis using indices | -| GatherNd | `gather_nd` | N-dimensional gather | -| Scatter | `scatter` | May need decomposition | -| ScatterNd | `scatter_nd` | N-dimensional scatter | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) Gather(x *Value, indices *Value, axis int) *Value { - // Output shape: x.shape with axis dimension replaced by indices shape - outShape := make([]int64, 0, len(x.shape)-1+len(indices.shape)) - outShape = append(outShape, x.shape[:axis]...) - outShape = append(outShape, indices.shape...) - outShape = append(outShape, x.shape[axis+1:]...) - - axisVal := b.Const(b.genName("axis"), Int32, []int64{}, []int32{int32(axis)}) - - return b.addOp("gather", map[string]*Value{ - "x": x, - "indices": indices, - "axis": axisVal, - }, b.genName("gather"), x.dtype, outShape) -} -``` - -**Tasks:** -- [ ] Add Gather to go-coreml/model/ops.go -- [ ] Add GatherNd to go-coreml/model/ops.go -- [ ] Add Gather to gomlx/backends/coreml/ops.go -- [ ] Research Scatter implementation in MIL -- [ ] Add comprehensive tests - ---- - -### 2.4 Shape Manipulation - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Squeeze | `squeeze` | Remove size-1 dimensions | -| ExpandDims | `expand_dims` | Add size-1 dimension | -| Tile | `tile` | Repeat tensor | -| Pad | `pad` | Add padding | -| Reverse | `reverse` | Reverse along axes | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) Squeeze(x *Value, axes []int64) *Value { - // Compute output shape by removing specified axes - outShape := make([]int64, 0) - axisSet := make(map[int64]bool) - for _, a := range axes { - if a < 0 { - a = int64(len(x.shape)) + a - } - axisSet[a] = true - } - for i, dim := range x.shape { - if !axisSet[int64(i)] { - outShape = append(outShape, dim) - } - } - - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - - return b.addOp("squeeze", map[string]*Value{ - "x": x, - "axes": axesVal, - }, b.genName("squeeze"), x.dtype, outShape) -} - -func (b *Builder) ExpandDims(x *Value, axes []int64) *Value { - // Compute output shape by inserting size-1 dimensions - outRank := len(x.shape) + len(axes) - outShape := make([]int64, outRank) - - // Normalize and sort axes - normalizedAxes := make([]int64, len(axes)) - for i, a := range axes { - if a < 0 { - a = int64(outRank) + a - } - normalizedAxes[i] = a - } - sort.Slice(normalizedAxes, func(i, j int) bool { return normalizedAxes[i] < normalizedAxes[j] }) - - // Insert dimensions - axisSet := make(map[int64]bool) - for _, a := range normalizedAxes { - axisSet[a] = true - } - - srcIdx := 0 - for i := 0; i < outRank; i++ { - if axisSet[int64(i)] { - outShape[i] = 1 - } else { - outShape[i] = x.shape[srcIdx] - srcIdx++ - } - } - - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - - return b.addOp("expand_dims", map[string]*Value{ - "x": x, - "axes": axesVal, - }, b.genName("expand_dims"), x.dtype, outShape) -} - -func (b *Builder) Tile(x *Value, reps []int64) *Value { - outShape := make([]int64, len(x.shape)) - for i := range outShape { - outShape[i] = x.shape[i] * reps[i] - } - - repsVal := b.Const(b.genName("reps"), Int32, []int64{int64(len(reps))}, toInt32Slice(reps)) - - return b.addOp("tile", map[string]*Value{ - "x": x, - "reps": repsVal, - }, b.genName("tile"), x.dtype, outShape) -} - -func (b *Builder) Pad(x *Value, padBefore, padAfter []int64, mode string, constantValue float32) *Value { - outShape := make([]int64, len(x.shape)) - for i := range outShape { - outShape[i] = x.shape[i] + padBefore[i] + padAfter[i] - } - - // Pad specification as [before0, after0, before1, after1, ...] - padSpec := make([]int32, len(padBefore)*2) - for i := range padBefore { - padSpec[i*2] = int32(padBefore[i]) - padSpec[i*2+1] = int32(padAfter[i]) - } - - padVal := b.Const(b.genName("pad"), Int32, []int64{int64(len(padSpec))}, padSpec) - modeVal := b.Const(b.genName("mode"), Int32, []int64{}, []int32{padModeToInt(mode)}) - constVal := b.Const(b.genName("const_val"), Float32, []int64{}, []float32{constantValue}) - - return b.addOp("pad", map[string]*Value{ - "x": x, - "pad": padVal, - "mode": modeVal, - "constant_val": constVal, - }, b.genName("pad"), x.dtype, outShape) -} - -func padModeToInt(mode string) int32 { - switch mode { - case "constant": - return 0 - case "reflect": - return 1 - case "replicate": - return 2 - default: - return 0 - } -} -``` - -**Tasks:** -- [ ] Add Squeeze, ExpandDims, Tile, Pad, Reverse to go-coreml -- [ ] Add corresponding ops to gomlx/backends/coreml -- [ ] Test edge cases (empty axes, negative axes) -- [ ] Test Pad with different modes - ---- - -## Priority 3: Reduction Operations - -### 3.1 Additional Reductions - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| ReduceMin | `reduce_min` | Minimum reduction | -| ReduceProd | `reduce_prod` | Product reduction | -| ReduceMean | `reduce_mean` | Already in go-coreml | -| ReduceAnd | `reduce_and` | Logical AND (bool) | -| ReduceOr | `reduce_or` | Logical OR (bool) | -| ArgMax | `reduce_argmax` | Index of maximum | -| ArgMin | `reduce_argmin` | Index of minimum | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) ReduceMin(x *Value, axes []int64, keepDims bool) *Value { - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - keepVal := b.Const(b.genName("keep"), Bool, []int64{}, []bool{keepDims}) - outShape := computeReduceShape(x.shape, axes, keepDims) - - return b.addOp("reduce_min", map[string]*Value{ - "x": x, - "axes": axesVal, - "keep_dims": keepVal, - }, b.genName("reduce_min"), x.dtype, outShape) -} - -func (b *Builder) ReduceProd(x *Value, axes []int64, keepDims bool) *Value { - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - keepVal := b.Const(b.genName("keep"), Bool, []int64{}, []bool{keepDims}) - outShape := computeReduceShape(x.shape, axes, keepDims) - - return b.addOp("reduce_prod", map[string]*Value{ - "x": x, - "axes": axesVal, - "keep_dims": keepVal, - }, b.genName("reduce_prod"), x.dtype, outShape) -} - -func (b *Builder) ArgMax(x *Value, axis int, keepDims bool) *Value { - axisVal := b.Const(b.genName("axis"), Int32, []int64{}, []int32{int32(axis)}) - keepVal := b.Const(b.genName("keep"), Bool, []int64{}, []bool{keepDims}) - - outShape := computeReduceShape(x.shape, []int64{int64(axis)}, keepDims) - - return b.addOp("reduce_argmax", map[string]*Value{ - "x": x, - "axis": axisVal, - "keep_dims": keepVal, - }, b.genName("argmax"), Int32, outShape) // ArgMax returns indices -} -``` - -**Tasks:** -- [ ] Add ReduceMin, ReduceProd to go-coreml -- [ ] Add ArgMax, ArgMin to go-coreml -- [ ] Add corresponding ops to gomlx/backends/coreml -- [ ] Handle keepDims properly -- [ ] Test with various axes combinations - ---- - -## Priority 4: Activation Functions - -### 4.1 Common Activations - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Relu | `relu` | Already implemented via custom | -| Gelu | `gelu` | Gaussian Error Linear Unit | -| Silu/Swish | `silu` | x * sigmoid(x) | -| LeakyRelu | `leaky_relu` | With negative slope | -| Elu | `elu` | Exponential Linear Unit | -| Selu | `selu` | Scaled ELU | -| Softplus | `softplus` | log(1 + exp(x)) | -| Softsign | `softsign` | x / (1 + |x|) | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) Gelu(x *Value) *Value { - // GELU mode: "TANH_APPROXIMATION" or "EXACT" - modeVal := b.Const(b.genName("mode"), Int32, []int64{}, []int32{0}) // 0 = EXACT - return b.addOp("gelu", map[string]*Value{ - "x": x, - "mode": modeVal, - }, b.genName("gelu"), x.dtype, x.shape) -} - -func (b *Builder) Silu(x *Value) *Value { - return b.addOp("silu", map[string]*Value{ - "x": x, - }, b.genName("silu"), x.dtype, x.shape) -} - -func (b *Builder) LeakyRelu(x *Value, alpha float32) *Value { - alphaVal := b.Const(b.genName("alpha"), Float32, []int64{}, []float32{alpha}) - return b.addOp("leaky_relu", map[string]*Value{ - "x": x, - "alpha": alphaVal, - }, b.genName("leaky_relu"), x.dtype, x.shape) -} - -func (b *Builder) Elu(x *Value, alpha float32) *Value { - alphaVal := b.Const(b.genName("alpha"), Float32, []int64{}, []float32{alpha}) - return b.addOp("elu", map[string]*Value{ - "x": x, - "alpha": alphaVal, - }, b.genName("elu"), x.dtype, x.shape) -} - -func (b *Builder) Softplus(x *Value) *Value { - return b.addOp("softplus", map[string]*Value{ - "x": x, - }, b.genName("softplus"), x.dtype, x.shape) -} -``` - -**Tasks:** -- [ ] Add activation functions to go-coreml -- [ ] Add corresponding ops to gomlx/backends/coreml -- [ ] Test numerical accuracy against known values - ---- - -## Priority 5: Convolution and Pooling - -### 5.1 Convolution Operations - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| ConvGeneral | `conv` | General N-D convolution | -| ConvTranspose | `conv_transpose` | Transposed/deconvolution | - -**Implementation:** - -This is complex due to the many parameters (padding, strides, dilation, groups). - -```go -// go-coreml/model/ops.go -func (b *Builder) Conv( - x, weight *Value, - strides, dilations, padBefore, padAfter []int64, - groups int, -) *Value { - // Compute output shape - // H_out = (H_in + pad_before + pad_after - dilation * (kernel - 1) - 1) / stride + 1 - - stridesVal := b.Const(b.genName("strides"), Int32, []int64{int64(len(strides))}, toInt32Slice(strides)) - dilationsVal := b.Const(b.genName("dilations"), Int32, []int64{int64(len(dilations))}, toInt32Slice(dilations)) - padBeforeVal := b.Const(b.genName("pad_before"), Int32, []int64{int64(len(padBefore))}, toInt32Slice(padBefore)) - padAfterVal := b.Const(b.genName("pad_after"), Int32, []int64{int64(len(padAfter))}, toInt32Slice(padAfter)) - groupsVal := b.Const(b.genName("groups"), Int32, []int64{}, []int32{int32(groups)}) - - // ... compute outShape based on conv formula - - return b.addOp("conv", map[string]*Value{ - "x": x, - "weight": weight, - "strides": stridesVal, - "dilations": dilationsVal, - "pad": padBeforeVal, // MIL uses specific padding format - "groups": groupsVal, - }, b.genName("conv"), x.dtype, outShape) -} -``` - -**Tasks:** -- [ ] Research MIL conv parameter format (NCHW vs NHWC) -- [ ] Implement Conv in go-coreml -- [ ] Implement ConvTranspose in go-coreml -- [ ] Add to gomlx/backends/coreml with shape inference -- [ ] Test with simple CNN patterns - -### 5.2 Pooling Operations - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| MaxPool | `max_pool` | Max pooling | -| AvgPool | `avg_pool` | Average pooling | -| GlobalAvgPool | `reduce_mean` | Can use reduce_mean with spatial axes | - -**Tasks:** -- [ ] Implement MaxPool, AvgPool in go-coreml -- [ ] Add to gomlx/backends/coreml -- [ ] Handle padding modes - ---- - -## Priority 6: Normalization - -### 6.1 Normalization Layers - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| BatchNorm | `batch_norm` | Batch normalization | -| LayerNorm | `layer_norm` | Layer normalization | -| InstanceNorm | `instance_norm` | Instance normalization | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) BatchNorm( - x, mean, variance, gamma, beta *Value, - epsilon float32, -) *Value { - epsVal := b.Const(b.genName("epsilon"), Float32, []int64{}, []float32{epsilon}) - - return b.addOp("batch_norm", map[string]*Value{ - "x": x, - "mean": mean, - "variance": variance, - "gamma": gamma, - "beta": beta, - "epsilon": epsVal, - }, b.genName("batch_norm"), x.dtype, x.shape) -} - -func (b *Builder) LayerNorm( - x, gamma, beta *Value, - axes []int64, - epsilon float32, -) *Value { - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - epsVal := b.Const(b.genName("epsilon"), Float32, []int64{}, []float32{epsilon}) - - return b.addOp("layer_norm", map[string]*Value{ - "x": x, - "gamma": gamma, - "beta": beta, - "axes": axesVal, - "epsilon": epsVal, - }, b.genName("layer_norm"), x.dtype, x.shape) -} -``` - -**Tasks:** -- [ ] Implement BatchNorm, LayerNorm, InstanceNorm in go-coreml -- [ ] Add to gomlx/backends/coreml -- [ ] Test numerical accuracy - ---- - -## Implementation Order - -### Sprint 1: Comparisons and Select (High Impact) -1. Add comparison ops to go-coreml -2. Add Select to go-coreml -3. Integrate into gomlx/backends/coreml -4. Tests - -### Sprint 2: Math Operations -1. Add Pow, Maximum, Minimum -2. Add Floor, Ceil, Round, Sign -3. Add trig functions -4. Add Erf -5. Integrate and test - -### Sprint 3: Shape Manipulations -1. Add Squeeze, ExpandDims -2. Add Slice operations -3. Add Gather -4. Add Tile, Pad -5. Integrate and test - -### Sprint 4: Reductions and Activations -1. Add ReduceMin, ReduceProd -2. Add ArgMax, ArgMin -3. Add Gelu, Silu, LeakyRelu, etc. -4. Integrate and test - -### Sprint 5: Conv and Normalization -1. Research MIL conv format -2. Implement Conv, ConvTranspose -3. Implement pooling -4. Implement normalization -5. Integrate and test - ---- - -## Testing Strategy - -1. **Unit tests per operation**: Test basic functionality -2. **Broadcasting tests**: Verify correct broadcasting behavior -3. **Edge case tests**: Empty tensors, scalars, large tensors -4. **Numerical accuracy**: Compare against simplego backend -5. **Integration tests**: Multi-operation graphs - ---- - -## Success Criteria - -- [x] 60+ operations implemented (40+ new operations added) -- [x] All tests passing (17 test cases) -- [x] Numerical accuracy within 1e-5 of simplego -- [ ] Common model patterns work (MLP, CNN, Transformer attention) - ---- - -## Implementation Notes (Phase 4) - -**Completed: 2026-01-03** - -### Summary - -Successfully implemented 40+ additional operations across 4 sprints, bringing the CoreML backend to significantly improved operation coverage. - -### Operations Implemented - -#### Sprint 1: Comparison & Select Operations -**go-coreml/model/ops.go:** -- `Equal(x, y)` - Element-wise equality comparison -- `NotEqual(x, y)` - Element-wise inequality comparison -- `Less(x, y)` - Element-wise less-than comparison -- `LessEqual(x, y)` - Element-wise less-or-equal comparison -- `Greater(x, y)` - Element-wise greater-than comparison -- `GreaterEqual(x, y)` - Element-wise greater-or-equal comparison -- `Select(cond, a, b)` - Element-wise conditional selection - -**gomlx/backends/coreml/ops.go:** -- `Equal`, `NotEqual`, `LessThan`, `LessOrEqual`, `GreaterThan`, `GreaterOrEqual` - All using `addComparisonOp` helper -- `Where(condition, onTrue, onFalse)` - Conditional selection with Bool validation - -**Note:** Comparison operations produce Bool outputs which cannot be used directly as CoreML model outputs. They must be consumed by other operations like Where. - -#### Sprint 2: Math Operations -**go-coreml/model/ops.go:** -- `Pow(x, y)` - Element-wise power -- `Maximum(x, y)` - Element-wise maximum -- `Minimum(x, y)` - Element-wise minimum -- `Floor(x)` - Round down -- `Ceil(x)` - Round up -- `Round(x)` - Round to nearest -- `Sign(x)` - Sign function (-1, 0, 1) -- `Cos(x)`, `Sin(x)`, `Acos(x)`, `Asin(x)`, `Atan(x)` - Trigonometric functions -- `Cosh(x)`, `Sinh(x)` - Hyperbolic functions -- `Erf(x)` - Error function - -**gomlx/backends/coreml/ops.go:** -- `Pow`, `Max`, `Min` - Binary operations using `addBinaryOp` -- `Floor`, `Ceil`, `Round`, `Sign` - Unary operations using `addUnaryOp` -- `Cos`, `Sin`, `Erf` - Trig operations (only those with OpType constants defined) - -#### Sprint 3: Shape Manipulation Operations -**go-coreml/model/ops.go:** -- `Squeeze(x, axes)` - Remove size-1 dimensions -- `ExpandDims(x, axes)` - Add size-1 dimensions -- `SliceByIndex(x, begin, end, strides)` - Extract sub-tensors -- `Gather(x, indices, axis)` - Gather along axis - -**gomlx/backends/coreml/ops.go:** -- `Slice(x, starts, limits, strides)` - Full slice support with stride handling -- `Gather` - Partial implementation for simple single-axis gather cases - -**Note:** Squeeze and ExpandDims added to go-coreml but not exposed in gomlx backend (no OpType constants defined). - -#### Sprint 4: Reductions & Activations -**go-coreml/model/ops.go:** -- `ReduceMin(x, axes, keepDims)` - Minimum reduction -- `ReduceProd(x, axes, keepDims)` - Product reduction -- `ArgMax(x, axis, keepDims)` - Index of maximum (returns Int32) -- `ArgMin(x, axis, keepDims)` - Index of minimum (returns Int32) -- `Gelu(x, mode)` - Gaussian Error Linear Unit with EXACT or TANH_APPROXIMATION modes -- `Silu(x)` - Sigmoid Linear Unit (Swish) -- `LeakyRelu(x, alpha)` - Leaky ReLU -- `Elu(x, alpha)` - Exponential Linear Unit -- `Softplus(x)` - Smooth ReLU approximation - -**gomlx/backends/coreml/ops.go:** -- `ReduceMin`, `ReduceProduct` - Following existing reduce pattern -- `ArgMinMax(isMin)` - Combined argmin/argmax implementation - -**Note:** Activation functions added to go-coreml but not exposed in gomlx backend (no OpType constants defined). - -### Key Discoveries - -1. **Bool Output Limitation**: CoreML does not allow Bool-typed outputs directly from models. Comparison operations must be consumed by other operations (like Where/Select) before being used as outputs. - -2. **String Constants**: Added String dtype support to go-coreml builder for operations like Gelu that take string mode parameters. - -3. **Gather Complexity**: GoMLX's Gather interface (XLA-style) is significantly more complex than CoreML's simple gather. Implementation supports common single-axis cases but not all XLA Gather semantics. - -4. **OpType Coverage**: Several operations (Squeeze, ExpandDims, inverse trig, hyperbolic trig, activations) were implemented in go-coreml but not exposed in gomlx backend due to missing OpType constants. These can be used directly via go-coreml. - -### Files Modified - -**go-coreml:** -- `model/ops.go` - Added 35+ new MIL operations -- `model/builder.go` - Added String dtype support - -**gomlx/backends/coreml:** -- `ops.go` - Added 20+ new backend wrapper functions -- `capabilities.go` - Updated to advertise new operation support -- `coreml_test.go` - Added 7 new test functions (17 total test cases) - -### Test Results - -All 17 tests passing: -- TestBackendCreation -- TestBufferOperations -- TestSharedBuffer -- TestBuilderParameterAndConstant -- TestAddOperation -- TestUnaryOperations (Abs, Neg, Exp, Sqrt) -- TestBinaryOperations (Add, Sub, Mul, Div) -- TestReshape -- TestReduceSum -- TestComparisonOperationsViaWhere (Equal, LessThan, GreaterThan) -- TestWhereOperation -- TestMathOperations (Pow, Max, Floor, Ceil) -- TestTrigOperations (Cos, Sin) -- TestReduceMin -- TestSlice -- TestChainedOperations - -### Remaining Work - -1. **Priority 5: Convolution and Pooling** - Not yet implemented -2. **Priority 6: Normalization Layers** - Not yet implemented -3. **Concat Operation** - Requires tuple support in MIL serialization -4. **Full DotGeneral** - Currently only simple matmul cases supported -5. **Tile, Pad, Reverse** - Not yet implemented diff --git a/specs/004-more-tensor-ops.md b/specs/004-more-tensor-ops.md deleted file mode 100644 index bf50764..0000000 --- a/specs/004-more-tensor-ops.md +++ /dev/null @@ -1,1034 +0,0 @@ -# Additional Tensor Operations for CoreML Backend - Phase 5 - -## Overview - -This document details the implementation plan for the remaining operations needed to achieve comprehensive CoreML backend coverage. Phase 4 added 40+ operations. This phase focuses on: - -1. Missing tensor manipulation ops (Concat, Tile, Pad, Reverse) -2. Convolution and Pooling operations -3. Normalization layers -4. Enhanced DotGeneral support -5. Broadcast operations - -## Current State - -**Implemented (Phases 3-4):** -- Unary: Abs, Neg, Exp, Log, Sqrt, Tanh, Logistic, Floor, Ceil, Round, Sign, Cos, Sin, Erf -- Binary: Add, Sub, Mul, Div, Pow, Max, Min -- Comparison: Equal, NotEqual, Less, LessEqual, Greater, GreaterEqual -- Select: Where/Select -- Shape: Reshape, Transpose, Slice, Gather (partial), Squeeze*, ExpandDims* -- Reduction: ReduceSum, ReduceMax, ReduceMin, ReduceProd*, ArgMax*, ArgMin* -- Matrix: DotGeneral (simple matmul only) -- Activation: Gelu*, Silu*, LeakyRelu*, Elu*, Softplus* - -*Only in go-coreml, not exposed in gomlx backend - -**Target:** Full coverage for common ML model patterns (MLP, CNN, Transformer) - ---- - -## Priority 1: Tensor Manipulation Ops - -### 1.1 Concatenate - -Join multiple tensors along an axis. Critical for residual connections and multi-head attention. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Concatenate | `concat` | Takes tuple of values + axis | - -**Challenge:** MIL's `concat` takes a tuple/list of values, not individual named arguments. This requires special handling in the serialization layer. - -**Research Required:** -1. Investigate how MIL represents value tuples in the protobuf -2. Check if there's an alternative representation (e.g., variadic inputs) -3. Look at coremltools Python implementation for reference - -**Implementation Plan:** - -```go -// Option 1: Add tuple support to builder -// In model/builder.go - add new method for tuple arguments - -type TupleArg struct { - Values []*Value -} - -func (b *Builder) addOpWithTuple(opType string, tupleArg []*Value, namedArgs map[string]*Value, name string, dtype DataType, shape []int64) *Value { - // Create a special "tuple" argument binding - // Serialize as a list of value references -} - -// In model/ops.go -func (b *Builder) Concat(values []*Value, axis int64) *Value { - if len(values) == 0 { - panic("Concat requires at least one input") - } - if len(values) == 1 { - return values[0] // No-op for single input - } - - // Compute output shape - outShape := make([]int64, len(values[0].shape)) - copy(outShape, values[0].shape) - for i := 1; i < len(values); i++ { - outShape[axis] += values[i].shape[axis] - } - - axisVal := b.Const(b.genName("axis"), Int32, []int64{}, []int32{int32(axis)}) - - return b.addOpWithTuple("concat", values, map[string]*Value{ - "axis": axisVal, - }, b.genName("concat"), values[0].dtype, outShape) -} -``` - -**Protobuf Investigation:** -```protobuf -// From MIL spec - need to verify exact format -message Operation { - string type = 1; - repeated Argument inputs = 2; - repeated Output outputs = 3; -} - -message Argument { - oneof argument { - ListValue list_value = 1; // For tuple arguments - NamedValue named_value = 2; - } -} -``` - -**gomlx/backends/coreml/ops.go:** -```go -func (b *Builder) Concatenate(operands []backends.Op, axis int) (backends.Op, error) { - opType := backends.OpTypeConcatenate - - // Validate and convert operands - nodes := make([]*Node, len(operands)) - milValues := make([]*model.Value, len(operands)) - for i, op := range operands { - node, err := b.checkOp(opType.String(), op) - if err != nil { - return nil, err - } - nodes[i] = node - milValues[i] = node.milValue - } - - // Compute output shape - outputShape, err := shapeinference.ConcatenateOp(axis, shapes...) - if err != nil { - return nil, err - } - - resultValue := b.milBuilder.Concat(milValues, int64(axis)) - node := b.newNodeMultiInput(opType, outputShape, resultValue, nodes) - - return node, nil -} -``` - -**Tasks:** -- [ ] Research MIL protobuf format for tuple/list arguments -- [ ] Add tuple argument support to go-coreml/model/builder.go -- [ ] Implement Concat in go-coreml/model/ops.go -- [ ] Add Concatenate to gomlx/backends/coreml/ops.go -- [ ] Add tests for 2, 3, and many tensor concatenation -- [ ] Test with different axes - ---- - -### 1.2 Tile - -Repeat a tensor along each axis. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Tile | `tile` | Takes reps parameter | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) Tile(x *Value, reps []int64) *Value { - // Compute output shape - outShape := make([]int64, len(x.shape)) - for i := range outShape { - outShape[i] = x.shape[i] * reps[i] - } - - repsVal := b.Const(b.genName("reps"), Int32, []int64{int64(len(reps))}, toInt32Slice(reps)) - - return b.addOp("tile", map[string]*Value{ - "x": x, - "reps": repsVal, - }, b.genName("tile"), x.dtype, outShape) -} -``` - -**gomlx/backends/coreml/ops.go:** -```go -func (b *Builder) Tile(operandOp backends.Op, multiples []int) (backends.Op, error) { - opType := backends.OpTypeTile - inputs, err := b.checkOps(opType.String(), operandOp) - if err != nil { - return nil, err - } - operand := inputs[0] - - // Convert to int64 - multiplesInt64 := make([]int64, len(multiples)) - for i, m := range multiples { - multiplesInt64[i] = int64(m) - } - - outputShape, err := shapeinference.TileOp(operand.shape, multiples) - if err != nil { - return nil, err - } - - resultValue := b.milBuilder.Tile(operand.milValue, multiplesInt64) - node := b.newNode(opType, outputShape, resultValue, operand) - - return node, nil -} -``` - -**Tasks:** -- [ ] Implement Tile in go-coreml/model/ops.go -- [ ] Check if OpTypeTile exists in gomlx backends -- [ ] Add Tile to gomlx/backends/coreml/ops.go (if OpType exists) -- [ ] Add tests - ---- - -### 1.3 Pad - -Add padding to a tensor. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Pad | `pad` | Supports constant, reflect, replicate modes | - -**Implementation:** - -```go -// go-coreml/model/ops.go - -// PadMode represents padding mode -type PadMode int - -const ( - PadConstant PadMode = iota - PadReflect - PadReplicate -) - -func (b *Builder) Pad(x *Value, padBefore, padAfter []int64, mode PadMode, constantValue float32) *Value { - // Compute output shape - outShape := make([]int64, len(x.shape)) - for i := range outShape { - outShape[i] = x.shape[i] + padBefore[i] + padAfter[i] - } - - // MIL pad format: [before_0, after_0, before_1, after_1, ...] - padSpec := make([]int32, len(padBefore)*2) - for i := range padBefore { - padSpec[i*2] = int32(padBefore[i]) - padSpec[i*2+1] = int32(padAfter[i]) - } - - padVal := b.Const(b.genName("pad"), Int32, []int64{int64(len(padSpec))}, padSpec) - - inputs := map[string]*Value{ - "x": x, - "pad": padVal, - } - - // Add mode-specific parameters - switch mode { - case PadConstant: - modeVal := b.Const(b.genName("mode"), String, []int64{}, "constant") - constVal := b.Const(b.genName("constant_val"), x.dtype, []int64{}, []float32{constantValue}) - inputs["mode"] = modeVal - inputs["constant_val"] = constVal - case PadReflect: - modeVal := b.Const(b.genName("mode"), String, []int64{}, "reflect") - inputs["mode"] = modeVal - case PadReplicate: - modeVal := b.Const(b.genName("mode"), String, []int64{}, "replicate") - inputs["mode"] = modeVal - } - - return b.addOp("pad", inputs, b.genName("pad"), x.dtype, outShape) -} -``` - -**gomlx/backends/coreml/ops.go:** -```go -func (b *Builder) Pad(operandOp backends.Op, low, high, interior []int) (backends.Op, error) { - opType := backends.OpTypePad - inputs, err := b.checkOps(opType.String(), operandOp) - if err != nil { - return nil, err - } - operand := inputs[0] - - // GoMLX Pad also has interior padding (between elements) - // CoreML doesn't support interior padding directly - for _, i := range interior { - if i != 0 { - return nil, errors.Errorf("Pad: CoreML backend does not support interior padding") - } - } - - // Convert to int64 - lowInt64 := make([]int64, len(low)) - highInt64 := make([]int64, len(high)) - for i := range low { - lowInt64[i] = int64(low[i]) - highInt64[i] = int64(high[i]) - } - - outputShape, err := shapeinference.PadOp(operand.shape, low, high, interior) - if err != nil { - return nil, err - } - - resultValue := b.milBuilder.Pad(operand.milValue, lowInt64, highInt64, model.PadConstant, 0.0) - node := b.newNode(opType, outputShape, resultValue, operand) - - return node, nil -} -``` - -**Tasks:** -- [ ] Implement Pad in go-coreml/model/ops.go with all modes -- [ ] Check if OpTypePad exists in gomlx backends -- [ ] Add Pad to gomlx/backends/coreml/ops.go (handle interior padding limitation) -- [ ] Add tests for each padding mode - ---- - -### 1.4 Reverse - -Reverse tensor along specified axes. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| Reverse | `reverse` | Reverses along specified axes | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) Reverse(x *Value, axes []int64) *Value { - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - - return b.addOp("reverse", map[string]*Value{ - "x": x, - "axes": axesVal, - }, b.genName("reverse"), x.dtype, x.shape) // Shape unchanged -} -``` - -**gomlx/backends/coreml/ops.go:** -```go -func (b *Builder) Reverse(operandOp backends.Op, axes []int) (backends.Op, error) { - opType := backends.OpTypeReverse - inputs, err := b.checkOps(opType.String(), operandOp) - if err != nil { - return nil, err - } - operand := inputs[0] - - axesInt64 := make([]int64, len(axes)) - for i, a := range axes { - axesInt64[i] = int64(a) - } - - // Output shape is same as input - resultValue := b.milBuilder.Reverse(operand.milValue, axesInt64) - node := b.newNode(opType, operand.shape, resultValue, operand) - - return node, nil -} -``` - -**Tasks:** -- [ ] Implement Reverse in go-coreml/model/ops.go -- [ ] Check if OpTypeReverse exists in gomlx backends -- [ ] Add Reverse to gomlx/backends/coreml/ops.go -- [ ] Add tests - ---- - -## Priority 2: Convolution Operations - -Critical for CNN models. - -### 2.1 Conv2D - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| ConvGeneral | `conv` | General N-D convolution | - -**MIL Conv Parameters:** -- `x`: Input tensor [N, C_in, H, W] (NCHW format) -- `weight`: Filter tensor [C_out, C_in/groups, kH, kW] -- `strides`: [stride_h, stride_w] -- `pad_type`: "same", "valid", or "custom" -- `pad`: Custom padding [pad_h_before, pad_h_after, pad_w_before, pad_w_after] -- `dilations`: [dilation_h, dilation_w] -- `groups`: Number of groups for grouped convolution - -**Implementation:** - -```go -// go-coreml/model/ops.go - -// ConvPadType represents convolution padding type -type ConvPadType int - -const ( - ConvPadValid ConvPadType = iota - ConvPadSame - ConvPadCustom -) - -func (b *Builder) Conv( - x, weight *Value, - strides, dilations []int64, - padType ConvPadType, - padBefore, padAfter []int64, // Only used if padType == ConvPadCustom - groups int64, -) *Value { - // Validate dimensions - // x: [N, C_in, H, W] for 2D conv - // weight: [C_out, C_in/groups, kH, kW] - - xShape := x.shape - wShape := weight.shape - - N := xShape[0] - C_out := wShape[0] - - // Compute output spatial dimensions - // H_out = (H_in + pad_h_before + pad_h_after - dilation_h * (kH - 1) - 1) / stride_h + 1 - - var H_out, W_out int64 - switch padType { - case ConvPadSame: - H_out = (xShape[2] + strides[0] - 1) / strides[0] - W_out = (xShape[3] + strides[1] - 1) / strides[1] - case ConvPadValid: - kH := wShape[2] - kW := wShape[3] - H_out = (xShape[2] - dilations[0]*(kH-1) - 1) / strides[0] + 1 - W_out = (xShape[3] - dilations[1]*(kW-1) - 1) / strides[1] + 1 - case ConvPadCustom: - kH := wShape[2] - kW := wShape[3] - H_out = (xShape[2] + padBefore[0] + padAfter[0] - dilations[0]*(kH-1) - 1) / strides[0] + 1 - W_out = (xShape[3] + padBefore[1] + padAfter[1] - dilations[1]*(kW-1) - 1) / strides[1] + 1 - } - - outShape := []int64{N, C_out, H_out, W_out} - - // Build arguments - stridesVal := b.Const(b.genName("strides"), Int32, []int64{int64(len(strides))}, toInt32Slice(strides)) - dilationsVal := b.Const(b.genName("dilations"), Int32, []int64{int64(len(dilations))}, toInt32Slice(dilations)) - groupsVal := b.Const(b.genName("groups"), Int32, []int64{}, []int32{int32(groups)}) - - inputs := map[string]*Value{ - "x": x, - "weight": weight, - "strides": stridesVal, - "dilations": dilationsVal, - "groups": groupsVal, - } - - switch padType { - case ConvPadSame: - padTypeVal := b.Const(b.genName("pad_type"), String, []int64{}, "same") - inputs["pad_type"] = padTypeVal - case ConvPadValid: - padTypeVal := b.Const(b.genName("pad_type"), String, []int64{}, "valid") - inputs["pad_type"] = padTypeVal - case ConvPadCustom: - padTypeVal := b.Const(b.genName("pad_type"), String, []int64{}, "custom") - // Combine before/after: [h_before, h_after, w_before, w_after] - padSpec := make([]int32, len(padBefore)*2) - for i := range padBefore { - padSpec[i*2] = int32(padBefore[i]) - padSpec[i*2+1] = int32(padAfter[i]) - } - padVal := b.Const(b.genName("pad"), Int32, []int64{int64(len(padSpec))}, padSpec) - inputs["pad_type"] = padTypeVal - inputs["pad"] = padVal - } - - return b.addOp("conv", inputs, b.genName("conv"), x.dtype, outShape) -} - -// ConvWithBias adds bias after convolution -func (b *Builder) ConvWithBias( - x, weight, bias *Value, - strides, dilations []int64, - padType ConvPadType, - padBefore, padAfter []int64, - groups int64, -) *Value { - conv := b.Conv(x, weight, strides, dilations, padType, padBefore, padAfter, groups) - // Reshape bias for broadcasting: [C_out] -> [1, C_out, 1, 1] - biasReshaped := b.Reshape(bias, []int64{1, bias.shape[0], 1, 1}) - return b.Add(conv, biasReshaped) -} -``` - -**gomlx/backends/coreml/ops.go:** - -GoMLX uses `ConvGeneralDilated` with dimension numbers and feature group count. - -```go -func (b *Builder) ConvGeneralDilated( - operandOp, kernelOp backends.Op, - strides, dilations, paddingLow, paddingHigh []int, - inputBatchDim, inputFeatureDim int, - inputSpatialDims []int, - kernelInputFeatureDim, kernelOutputFeatureDim int, - kernelSpatialDims []int, - outputBatchDim, outputFeatureDim int, - outputSpatialDims []int, - featureGroupCount, batchGroupCount int, -) (backends.Op, error) { - opType := backends.OpTypeConvGeneralDilated - - inputs, err := b.checkOps(opType.String(), operandOp, kernelOp) - if err != nil { - return nil, err - } - operand, kernel := inputs[0], inputs[1] - - // Validate dimension numbers - // CoreML expects NCHW format for both input and output - // Need to transpose if format differs - - if inputBatchDim != 0 || inputFeatureDim != 1 { - return nil, errors.Errorf("ConvGeneralDilated: CoreML requires NCHW format (batch=0, feature=1)") - } - - // Convert parameters - stridesInt64 := toInt64Slice(strides) - dilationsInt64 := toInt64Slice(dilations) - - // Determine padding type - var padType model.ConvPadType - var padBefore, padAfter []int64 - - allZero := true - for i := range paddingLow { - if paddingLow[i] != 0 || paddingHigh[i] != 0 { - allZero = false - break - } - } - - if allZero { - padType = model.ConvPadValid - } else { - padType = model.ConvPadCustom - padBefore = toInt64Slice(paddingLow) - padAfter = toInt64Slice(paddingHigh) - } - - // Compute output shape using shapeinference - outputShape, err := shapeinference.ConvGeneralDilatedOp(...) - if err != nil { - return nil, err - } - - resultValue := b.milBuilder.Conv( - operand.milValue, kernel.milValue, - stridesInt64, dilationsInt64, - padType, padBefore, padAfter, - int64(featureGroupCount), - ) - - node := b.newNode(opType, outputShape, resultValue, operand, kernel) - return node, nil -} -``` - -**Tasks:** -- [ ] Research MIL conv parameter format in detail -- [ ] Implement Conv in go-coreml/model/ops.go -- [ ] Add ConvTranspose for deconvolution -- [ ] Add ConvGeneralDilated to gomlx/backends/coreml/ops.go -- [ ] Handle NHWC to NCHW format conversion if needed -- [ ] Add tests for various conv configurations - ---- - -### 2.2 Conv Transpose (Deconvolution) - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| ConvTranspose | `conv_transpose` | Transposed convolution | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) ConvTranspose( - x, weight *Value, - strides, dilations []int64, - padType ConvPadType, - padBefore, padAfter []int64, - outputPadding []int64, // Additional padding for output shape - groups int64, -) *Value { - // Similar to Conv but output shape calculation is different - // H_out = (H_in - 1) * stride - 2*pad + dilation*(kH - 1) + output_padding + 1 - - // ... implementation -} -``` - -**Tasks:** -- [ ] Research MIL conv_transpose parameters -- [ ] Implement ConvTranspose in go-coreml -- [ ] Add to gomlx backend if ConvTranspose op exists - ---- - -## Priority 3: Pooling Operations - -### 3.1 MaxPool and AvgPool - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| MaxPool | `max_pool` | Max pooling | -| AvgPool | `avg_pool` | Average pooling | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) MaxPool( - x *Value, - kernelSize, strides []int64, - padType ConvPadType, - padBefore, padAfter []int64, -) *Value { - // Compute output shape similar to conv - // ... - - kernelVal := b.Const(b.genName("kernel_sizes"), Int32, []int64{int64(len(kernelSize))}, toInt32Slice(kernelSize)) - stridesVal := b.Const(b.genName("strides"), Int32, []int64{int64(len(strides))}, toInt32Slice(strides)) - - inputs := map[string]*Value{ - "x": x, - "kernel_sizes": kernelVal, - "strides": stridesVal, - } - - // Add padding parameters... - - return b.addOp("max_pool", inputs, b.genName("max_pool"), x.dtype, outShape) -} - -func (b *Builder) AvgPool( - x *Value, - kernelSize, strides []int64, - padType ConvPadType, - padBefore, padAfter []int64, - excludePaddingFromAverage bool, -) *Value { - // Similar to MaxPool - // ... - - excludeVal := b.Const(b.genName("exclude_padding"), Bool, []int64{}, []bool{excludePaddingFromAverage}) - inputs["exclude_padding_from_average"] = excludeVal - - return b.addOp("avg_pool", inputs, b.genName("avg_pool"), x.dtype, outShape) -} -``` - -**Tasks:** -- [ ] Implement MaxPool in go-coreml -- [ ] Implement AvgPool in go-coreml -- [ ] Add to gomlx backend (check for OpType) -- [ ] Add tests - -### 3.2 Global Pooling - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| GlobalAvgPool | `reduce_mean` | Use reduce_mean on spatial dims | -| GlobalMaxPool | `reduce_max` | Use reduce_max on spatial dims | - -Global pooling can be implemented using existing reduce operations: - -```go -// Helper functions in ops.go -func (b *Builder) GlobalAvgPool2D(x *Value) *Value { - // Reduce over H, W dimensions (axes 2, 3 for NCHW) - return b.ReduceMean(x, []int64{2, 3}, true) // keepDims=true -} - -func (b *Builder) GlobalMaxPool2D(x *Value) *Value { - return b.ReduceMax(x, []int64{2, 3}, true) -} -``` - -**Tasks:** -- [ ] Add GlobalAvgPool2D, GlobalMaxPool2D convenience functions -- [ ] Add tests - ---- - -## Priority 4: Normalization Layers - -### 4.1 Batch Normalization - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| BatchNorm | `batch_norm` | Batch normalization | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) BatchNorm( - x, mean, variance, gamma, beta *Value, - epsilon float32, -) *Value { - epsVal := b.Const(b.genName("epsilon"), Float32, []int64{}, []float32{epsilon}) - - return b.addOp("batch_norm", map[string]*Value{ - "x": x, - "mean": mean, - "variance": variance, - "gamma": gamma, - "beta": beta, - "epsilon": epsVal, - }, b.genName("batch_norm"), x.dtype, x.shape) -} -``` - -**Tasks:** -- [ ] Implement BatchNorm in go-coreml -- [ ] Check gomlx BatchNorm interface -- [ ] Add tests - -### 4.2 Layer Normalization - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| LayerNorm | `layer_norm` | Layer normalization | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) LayerNorm( - x, gamma, beta *Value, - axes []int64, - epsilon float32, -) *Value { - axesVal := b.Const(b.genName("axes"), Int32, []int64{int64(len(axes))}, toInt32Slice(axes)) - epsVal := b.Const(b.genName("epsilon"), Float32, []int64{}, []float32{epsilon}) - - return b.addOp("layer_norm", map[string]*Value{ - "x": x, - "gamma": gamma, - "beta": beta, - "axes": axesVal, - "epsilon": epsVal, - }, b.genName("layer_norm"), x.dtype, x.shape) -} -``` - -**Tasks:** -- [ ] Implement LayerNorm in go-coreml -- [ ] Check gomlx LayerNorm interface -- [ ] Add tests - -### 4.3 Instance Normalization - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| InstanceNorm | `instance_norm` | Instance normalization | - -**Implementation:** - -```go -// go-coreml/model/ops.go -func (b *Builder) InstanceNorm( - x, gamma, beta *Value, - epsilon float32, -) *Value { - epsVal := b.Const(b.genName("epsilon"), Float32, []int64{}, []float32{epsilon}) - - return b.addOp("instance_norm", map[string]*Value{ - "x": x, - "gamma": gamma, - "beta": beta, - "epsilon": epsVal, - }, b.genName("instance_norm"), x.dtype, x.shape) -} -``` - -**Tasks:** -- [ ] Implement InstanceNorm in go-coreml -- [ ] Add tests - ---- - -## Priority 5: Enhanced DotGeneral - -The current DotGeneral only supports simple matrix multiplication. We need to support: -- Batch dimensions -- Arbitrary contracting axes -- Transposed inputs - -**Implementation Strategy:** - -Use Transpose and Reshape to normalize inputs to standard matmul format. - -```go -// gomlx/backends/coreml/ops.go - -func (b *Builder) DotGeneral( - lhsOp backends.Op, - lhsContractingAxes, lhsBatchAxes []int, - rhsOp backends.Op, - rhsContractingAxes, rhsBatchAxes []int, -) (backends.Op, error) { - // ... existing validation ... - - // For complex cases, normalize to standard batched matmul format: - // lhs: [B1, B2, ..., M, K] - // rhs: [B1, B2, ..., K, N] - // out: [B1, B2, ..., M, N] - - // Step 1: Transpose to move batch dims first, then M/K - lhsTransposed := transposeToBatchedMatmul(lhs, lhsBatchAxes, lhsContractingAxes) - rhsTransposed := transposeToBatchedMatmul(rhs, rhsBatchAxes, rhsContractingAxes) - - // Step 2: Reshape to merge batch dimensions if needed - // ... - - // Step 3: Call batched matmul - // ... - - // Step 4: Reshape/transpose output to expected shape - // ... -} -``` - -**Tasks:** -- [ ] Implement helper functions for normalizing DotGeneral inputs -- [ ] Support batch dimensions -- [ ] Support arbitrary contracting axes via transpose -- [ ] Add comprehensive tests for various DotGeneral configurations - ---- - -## Priority 6: Broadcast Operations - -### 6.1 BroadcastTo - -Explicitly broadcast a tensor to a larger shape. - -| GoMLX Op | MIL Op | Notes | -|----------|--------|-------| -| BroadcastTo | Could use `tile` or native broadcast | Explicit broadcasting | - -**Implementation:** - -CoreML handles broadcasting implicitly in most operations. For explicit broadcast, we can use Tile or a combination of ExpandDims and Tile. - -```go -// go-coreml/model/ops.go -func (b *Builder) BroadcastTo(x *Value, targetShape []int64) *Value { - // Option 1: Use tile to broadcast - // Calculate reps needed for each dimension - - xRank := len(x.shape) - targetRank := len(targetShape) - - // First expand dims if needed - if xRank < targetRank { - newAxes := make([]int64, targetRank-xRank) - for i := range newAxes { - newAxes[i] = int64(i) - } - x = b.ExpandDims(x, newAxes) - } - - // Then tile to match target shape - reps := make([]int64, len(targetShape)) - for i := range reps { - if x.shape[i] == 1 && targetShape[i] > 1 { - reps[i] = targetShape[i] - } else { - reps[i] = 1 - } - } - - return b.Tile(x, reps) -} -``` - -**Tasks:** -- [ ] Implement BroadcastTo in go-coreml -- [ ] Check if gomlx has explicit broadcast op -- [ ] Add tests - ---- - -## Implementation Order - -### Sprint 1: Tensor Manipulation (1 week) -1. Research MIL tuple format for Concat -2. Implement Tile, Pad, Reverse in go-coreml -3. Add gomlx wrappers for ops with defined OpTypes -4. Tests for each operation - -### Sprint 2: Convolution (1 week) -1. Research MIL conv parameters thoroughly -2. Implement Conv2D with all padding modes -3. Implement ConvTranspose -4. Add gomlx ConvGeneralDilated wrapper -5. Tests - -### Sprint 3: Pooling (0.5 weeks) -1. Implement MaxPool, AvgPool -2. Add global pooling helpers -3. Add gomlx wrappers -4. Tests - -### Sprint 4: Normalization (0.5 weeks) -1. Implement BatchNorm, LayerNorm, InstanceNorm -2. Add gomlx wrappers -3. Tests - -### Sprint 5: Enhanced DotGeneral (1 week) -1. Implement transpose/reshape normalization -2. Support batch dimensions -3. Comprehensive tests - -### Sprint 6: Concat & Cleanup (0.5 weeks) -1. Implement Concat with tuple support -2. Code cleanup and documentation -3. Integration tests - ---- - -## Testing Strategy - -1. **Unit tests per operation**: Test basic functionality with known inputs/outputs -2. **Shape inference tests**: Verify output shapes are computed correctly -3. **Edge cases**: Empty tensors, scalars, single-element tensors -4. **Numerical accuracy**: Compare against simplego or XLA backend -5. **Integration tests**: Build and run simple models: - - MLP with Tile for batch broadcast - - Simple CNN (Conv + Pool + Norm) - - Attention mechanism (DotGeneral + Concatenate) - ---- - -## Success Criteria - -- [x] All tensor manipulation ops implemented (Concat, Tile, Pad, Reverse) -- [x] Convolution operations working for 2D case -- [x] Pooling operations working -- [x] Normalization layers working -- [ ] DotGeneral supports batch dimensions (future work) -- [x] All tests passing -- [ ] Can run simple CNN model end-to-end (integration test pending) -- [ ] Can run transformer attention block (integration test pending) - ---- - -## Phase 5 Implementation Notes (January 2026) - -### Completed Operations - -**go-coreml/model/ops.go:** -1. **Tile** - Repeat tensor along each axis -2. **Pad** - Add padding with constant/reflect/replicate modes -3. **Reverse** - Reverse along specified axes -4. **Concat** - Concatenate multiple tensors (required adding list argument support to builder) -5. **Conv** - 2D convolution with all padding modes -6. **ConvTranspose** - Transposed/deconvolution -7. **ConvWithBias** - Convenience helper -8. **MaxPool** - Max pooling with all padding modes -9. **AvgPool** - Average pooling with exclude_padding option -10. **GlobalAvgPool2D** - Global average pooling (via ReduceMean) -11. **GlobalMaxPool2D** - Global max pooling (via ReduceMax) -12. **BatchNorm** - Batch normalization -13. **LayerNorm** - Layer normalization -14. **InstanceNorm** - Instance normalization - -**gomlx/backends/coreml/ops.go:** -1. **Pad** (OpTypePad) - Constant padding only, no interior padding -2. **Reverse** (OpTypeReverse) - Full support -3. **ConvGeneral** (OpTypeConvGeneral) - NCHW layout, standard convolutions -4. **BatchNormForInference** (OpTypeBatchNormForInference) - Feature axis=1 - -### Key Technical Decisions - -1. **Concat Tuple Support**: MIL's concat requires tuple/list arguments. Added `addOpWithListArg()` to builder.go to handle this pattern. - -2. **Padding Type Constants**: Shared `ConvPadType` enum (Valid/Same/Custom) between convolution and pooling operations. - -3. **CoreML Limitations**: - - Pad: No interior padding support (use XLA for this) - - ConvGeneral: NCHW layout required, no batch group count - - BatchNorm: Feature axis must be 1 - -### Test Coverage - -- go-coreml: All model tests pass (35+ tests) -- gomlx/backends/coreml: All tests pass (23 tests including new Pad, Reverse, ConvGeneral, BatchNorm tests) - -### Remaining Work - -- Enhanced DotGeneral with batch dimensions (Priority 5) -- Integration tests for full CNN and transformer models -- BroadcastTo operation (can use Tile + ExpandDims as workaround) -- L2 normalization -- Linear (fused matmul + bias) -- Einsum - ---- - -## Appendix: MIL Operation Reference - -Useful MIL documentation: -- https://apple.github.io/coremltools/docs-guides/source/ops-reference.html -- https://github.com/apple/coremltools/tree/main/coremltools/converters/mil/mil/ops - -Key MIL ops implemented in this phase: -- `concat` - Concatenate tensors ✓ -- `tile` - Tile/repeat tensor ✓ -- `pad` - Pad tensor ✓ -- `reverse` - Reverse along axes ✓ -- `conv` - Convolution ✓ -- `conv_transpose` - Transposed convolution ✓ -- `max_pool` - Max pooling ✓ -- `avg_pool` - Average pooling ✓ -- `batch_norm` - Batch normalization ✓ -- `layer_norm` - Layer normalization ✓ -- `instance_norm` - Instance normalization ✓ - -Remaining MIL ops: -- `l2_norm` - L2 normalization -- `linear` - Linear/dense layer (fused matmul + bias) -- `einsum` - Einstein summation diff --git a/specs/005-remaining-ops-and-enhancements.md b/specs/005-remaining-ops-and-enhancements.md deleted file mode 100644 index eeea517..0000000 --- a/specs/005-remaining-ops-and-enhancements.md +++ /dev/null @@ -1,482 +0,0 @@ -# Remaining Operations and Enhancements - Phase 6 - -## Overview - -This document outlines the remaining work to achieve comprehensive CoreML backend coverage for GoMLX. Phases 3-5 implemented 60+ operations. This phase focuses on: - -1. Enhanced DotGeneral (batch dimensions, arbitrary axes) -2. Missing gomlx backend wrappers -3. Additional MIL operations -4. Integration testing with real models - -## Current Implementation Status - -### Implemented in gomlx/backends/coreml (40+ ops) - -| Category | Operations | -|----------|-----------| -| Unary Math | Abs, Neg, Exp, Log, Sqrt, Floor, Ceil, Round, Sign, Tanh, Logistic, Cos, Sin, Erf | -| Binary Math | Add, Sub, Mul, Div, Pow, Max, Min | -| Comparison | Equal, NotEqual, LessThan, LessOrEqual, GreaterThan, GreaterOrEqual | -| Shape | Reshape, Transpose, Pad, Reverse | -| Reduction | ReduceSum, ReduceMax, ReduceMin, ReduceProduct, ArgMinMax | -| Matrix | DotGeneral (simple matmul only) | -| Conditional | Where | -| Indexing | Slice, Gather (partial) | -| Convolution | ConvGeneral | -| Normalization | BatchNormForInference | - -### Implemented in go-coreml only (not exposed to gomlx) - -| Category | Operations | -|----------|-----------| -| Tensor Manipulation | Tile, Concat | -| Convolution | Conv, ConvTranspose, ConvWithBias | -| Pooling | MaxPool, AvgPool, GlobalAvgPool2D, GlobalMaxPool2D | -| Normalization | LayerNorm, InstanceNorm | -| Activations | Gelu, Silu, LeakyRelu, Elu, Softplus, Relu, Sigmoid, Softmax | -| Trig/Hyperbolic | Acos, Asin, Atan, Cosh, Sinh | - ---- - -## Priority 1: Enhanced DotGeneral - -The current DotGeneral only supports simple 2D matrix multiplication. Full support requires: - -### 1.1 Batch Dimensions - -Support batch matmul: `[B, M, K] @ [B, K, N] -> [B, M, N]` - -**Implementation Strategy:** - -```go -// gomlx/backends/coreml/ops.go -func (b *Builder) DotGeneral( - lhsOp backends.Op, - lhsContractingAxes, lhsBatchAxes []int, - rhsOp backends.Op, - rhsContractingAxes, rhsBatchAxes []int, -) (backends.Op, error) { - // Case 1: Simple matmul (existing code) - // Case 2: Batched matmul - CoreML's matmul handles batch dims natively - // Case 3: Complex cases - use transpose/reshape to normalize -} -``` - -**CoreML MatMul Batch Support:** -CoreML's `matmul` already supports batch dimensions: -- `[..., M, K] @ [..., K, N] -> [..., M, N]` -- Batch dimensions must match or broadcast - -**Tasks:** -- [ ] Add batch dimension detection in DotGeneral -- [ ] Handle broadcasting of batch dimensions -- [ ] Add tests for batched matmul -- [ ] Test attention mechanism (Q @ K^T @ V) - -### 1.2 Arbitrary Contracting Axes - -Support contracting on any axis, not just the last axis of LHS. - -**Implementation Strategy:** -1. Transpose inputs to move contracting axes to standard positions -2. Call CoreML matmul -3. Transpose output if needed - -```go -func normalizeDotGeneralInputs( - lhs *Node, lhsContractingAxes, lhsBatchAxes []int, - rhs *Node, rhsContractingAxes, rhsBatchAxes []int, -) (*Node, *Node, []int) { - // Step 1: Transpose lhs to [batch..., M, K] - // Step 2: Transpose rhs to [batch..., K, N] - // Step 3: Return transposed nodes and output permutation -} -``` - -**Tasks:** -- [ ] Implement axis normalization helper -- [ ] Handle output shape reconstruction -- [ ] Add tests for non-standard contracting axes - ---- - -## Priority 2: Missing GoMLX Backend Wrappers - -These operations exist in go-coreml but need gomlx wrappers (where OpTypes exist): - -### 2.1 Concatenate (OpTypeConcatenate) - -```go -func (b *Builder) Concatenate(operands []backends.Op, axis int) (backends.Op, error) { - opType := backends.OpTypeConcatenate - - // Validate operands - nodes := make([]*Node, len(operands)) - milValues := make([]*model.Value, len(operands)) - for i, op := range operands { - node, err := b.checkOp(opType.String(), op) - if err != nil { - return nil, err - } - nodes[i] = node - milValues[i] = node.milValue - } - - // Use shapeinference.ConcatenateOp - outputShape, err := shapeinference.ConcatenateOp(axis, shapes...) - if err != nil { - return nil, err - } - - resultValue := b.milBuilder.Concat(milValues, int64(axis)) - node := b.newNodeMultiInput(opType, outputShape, resultValue, nodes) - return node, nil -} -``` - -**Tasks:** -- [ ] Add `newNodeMultiInput` helper for multi-input operations -- [ ] Implement Concatenate wrapper -- [ ] Add tests - -### 2.2 Broadcast Operations (OpTypeBroadcast, OpTypeBroadcastInDim) - -```go -func (b *Builder) Broadcast(operandOp backends.Op, shape shapes.Shape) (backends.Op, error) { - // Use Tile + ExpandDims to implement broadcast -} - -func (b *Builder) BroadcastInDim(operandOp backends.Op, shape shapes.Shape, broadcastDims []int) (backends.Op, error) { - // More flexible broadcast with dimension mapping -} -``` - -**Tasks:** -- [ ] Implement Broadcast using Tile/ExpandDims -- [ ] Implement BroadcastInDim -- [ ] Add tests - -### 2.3 Iota (OpTypeIota) - -Generate a tensor with values [0, 1, 2, ..., N-1] along an axis. - -```go -// go-coreml/model/ops.go -func (b *Builder) Range(start, end, step *Value) *Value { - // MIL operation "range_1d" -} - -// gomlx/backends/coreml/ops.go -func (b *Builder) Iota(shape shapes.Shape, iotaDim int) (backends.Op, error) { - // Use range_1d + broadcast to create iota tensor -} -``` - -**Tasks:** -- [ ] Implement Range in go-coreml -- [ ] Implement Iota wrapper -- [ ] Add tests - -### 2.4 Additional Unary Operations - -| OpType | MIL Op | Notes | -|--------|--------|-------| -| OpTypeRsqrt | `rsqrt` | 1/sqrt(x) | -| OpTypeExpm1 | N/A | exp(x) - 1, implement as Exp(x) - 1 | -| OpTypeLog1p | N/A | log(1 + x), implement as Log(Add(1, x)) | -| OpTypeIsFinite | `isfinite` | Check for finite values | -| OpTypeIsNaN | `isnan` | Check for NaN values | - -**Tasks:** -- [ ] Add Rsqrt to go-coreml (MIL has `rsqrt`) -- [ ] Implement Rsqrt, Expm1, Log1p, IsFinite, IsNaN wrappers -- [ ] Add tests - -### 2.5 Logical Operations - -| OpType | MIL Op | Notes | -|--------|--------|-------| -| OpTypeLogicalAnd | `logical_and` | Boolean AND | -| OpTypeLogicalOr | `logical_or` | Boolean OR | -| OpTypeLogicalNot | `logical_not` | Boolean NOT | -| OpTypeLogicalXor | `logical_xor` | Boolean XOR | - -**Tasks:** -- [ ] Add logical ops to go-coreml -- [ ] Add gomlx wrappers -- [ ] Add tests - -### 2.6 Clamp (OpTypeClamp) - -```go -func (b *Builder) Clamp(x, min, max backends.Op) (backends.Op, error) { - // Implement as: Max(min, Min(max, x)) - // Or use MIL's clip operation -} -``` - -**Tasks:** -- [ ] Add Clip to go-coreml (MIL has `clip`) -- [ ] Implement Clamp wrapper -- [ ] Add tests - -### 2.7 ConvertDType (OpTypeConvertDType) - -```go -func (b *Builder) ConvertDType(x backends.Op, dtype dtypes.DType) (backends.Op, error) { - // MIL operation "cast" -} -``` - -**Tasks:** -- [ ] Add Cast to go-coreml -- [ ] Implement ConvertDType wrapper -- [ ] Add tests for supported conversions - ---- - -## Priority 3: Additional go-coreml Operations - -### 3.1 L2 Normalization - -```go -// go-coreml/model/ops.go -func (b *Builder) L2Norm(x *Value, axes []int64, epsilon float32) *Value { - // MIL operation "l2_norm" -} -``` - -### 3.2 Linear (Fused MatMul + Bias) - -```go -// go-coreml/model/ops.go -func (b *Builder) Linear(x, weight, bias *Value) *Value { - // MIL operation "linear" - // More efficient than separate matmul + add -} -``` - -### 3.3 Einsum - -```go -// go-coreml/model/ops.go -func (b *Builder) Einsum(equation string, inputs []*Value) *Value { - // MIL operation "einsum" - // Powerful for attention and tensor contractions -} -``` - -**Tasks:** -- [ ] Implement L2Norm -- [ ] Implement Linear -- [ ] Research Einsum MIL support and implement -- [ ] Add tests - ---- - -## Priority 4: ReduceWindow (Pooling via GoMLX) - -GoMLX uses `ReduceWindow` for pooling operations rather than dedicated pool ops. - -```go -func (b *Builder) ReduceWindow( - operand backends.Op, - reduceOpType ReduceOpType, // Sum, Max, Min, Product - windowDims []int, - strides []int, - paddingLow, paddingHigh []int, -) (backends.Op, error) { - // Map to MaxPool/AvgPool for supported cases - // Return error for unsupported configurations -} -``` - -**Tasks:** -- [ ] Implement ReduceWindow mapping to MaxPool/AvgPool -- [ ] Handle padding conversion -- [ ] Add tests - ---- - -## Priority 5: Dynamic Operations - -### 5.1 DynamicSlice - -```go -func (b *Builder) DynamicSlice( - operand backends.Op, - startIndices []backends.Op, // Runtime values - sliceSizes []int, -) (backends.Op, error) { - // CoreML's slice_by_index supports dynamic indices -} -``` - -### 5.2 DynamicUpdateSlice - -```go -func (b *Builder) DynamicUpdateSlice( - operand, update backends.Op, - startIndices []backends.Op, -) (backends.Op, error) { - // MIL operation "scatter" or "slice_update" -} -``` - -**Tasks:** -- [ ] Research MIL dynamic slice support -- [ ] Implement DynamicSlice -- [ ] Implement DynamicUpdateSlice -- [ ] Add tests - ---- - -## Priority 6: Integration Testing - -### 6.1 Simple CNN Model - -Test end-to-end: -- Conv2D + BatchNorm + ReLU -- MaxPool -- Flatten + Dense - -```go -func TestSimpleCNN(t *testing.T) { - // Build model - conv1 := Conv(input, weights1, ...) - bn1 := BatchNorm(conv1, ...) - relu1 := Max(bn1, Constant(0)) - pool1 := ReduceWindow(relu1, Max, [2,2], ...) - // ... more layers - // Verify output shape and reasonable values -} -``` - -### 6.2 Transformer Attention Block - -Test: -- Q, K, V projections (DotGeneral) -- Attention scores (batched matmul) -- Softmax -- Weighted sum (batched matmul) -- Concatenation (multi-head) - -```go -func TestAttentionBlock(t *testing.T) { - // Multi-head self-attention - // Q, K, V = Linear(input) - // scores = Softmax(Q @ K^T / sqrt(d)) - // output = scores @ V - // Verify shapes and numerics -} -``` - -### 6.3 Performance Benchmarks - -Compare CoreML backend performance against: -- simplego backend (baseline) -- XLA CPU backend (if available) - -**Metrics:** -- Compilation time -- Execution time -- Memory usage - -**Tasks:** -- [ ] Implement CNN integration test -- [ ] Implement attention integration test -- [ ] Add performance benchmarks -- [ ] Document performance characteristics - ---- - -## Implementation Order - -### Sprint 1: Enhanced DotGeneral (1 week) -1. Add batch dimension support -2. Add arbitrary axis support via transpose -3. Comprehensive tests - -### Sprint 2: Core Missing Wrappers (1 week) -1. Concatenate -2. Iota -3. Rsqrt, Expm1, Log1p -4. Clamp -5. ConvertDType - -### Sprint 3: Logical and Broadcast Ops (0.5 weeks) -1. Logical ops (And, Or, Not, Xor) -2. Broadcast, BroadcastInDim -3. IsFinite, IsNaN - -### Sprint 4: Additional MIL Ops (0.5 weeks) -1. L2Norm -2. Linear -3. Einsum (if MIL supports it) - -### Sprint 5: ReduceWindow and Dynamic Ops (1 week) -1. ReduceWindow -> Pool mapping -2. DynamicSlice -3. DynamicUpdateSlice - -### Sprint 6: Integration Testing (1 week) -1. CNN model test -2. Attention block test -3. Performance benchmarks -4. Documentation - ---- - -## Success Criteria - -- [ ] DotGeneral supports batch dimensions -- [ ] DotGeneral supports arbitrary contracting axes -- [ ] Concatenate working in gomlx backend -- [ ] All logical operations implemented -- [ ] Broadcast operations working -- [ ] ReduceWindow maps to pooling -- [ ] CNN model runs end-to-end -- [ ] Attention block runs end-to-end -- [ ] Performance benchmarks documented -- [ ] 90%+ of common ML operations covered - ---- - -## Operations NOT Planned - -These operations are low priority or not suitable for CoreML: - -| Operation | Reason | -|-----------|--------| -| Bitwise ops | Rarely used in ML inference | -| Complex number ops | Limited CoreML support | -| Collective ops | CoreML is single-device | -| Sort | Limited use in inference | -| While loops | CoreML prefers unrolled graphs | -| Scatter ops | Complex, limited MIL support | -| FFT | Specialized, can add later if needed | - ---- - -## Appendix: MIL Operations Reference - -### Not Yet Implemented MIL Ops - -| MIL Operation | Use Case | -|---------------|----------| -| `range_1d` | Iota implementation | -| `rsqrt` | Rsqrt | -| `clip` | Clamp | -| `cast` | ConvertDType | -| `logical_and/or/not/xor` | Logical ops | -| `isfinite`, `isnan` | Numeric checks | -| `l2_norm` | L2 normalization | -| `linear` | Fused dense layer | -| `einsum` | Einstein summation | -| `scatter` | Dynamic update | - -### MIL Documentation - -- https://apple.github.io/coremltools/docs-guides/source/ops-reference.html -- https://github.com/apple/coremltools/tree/main/coremltools/converters/mil/mil/ops diff --git a/specs/006-weight-blob-support.md b/specs/006-weight-blob-support.md deleted file mode 100644 index 199c14e..0000000 --- a/specs/006-weight-blob-support.md +++ /dev/null @@ -1,233 +0,0 @@ -# Spec: Weight Blob Support for Large Models - -## Summary - -Implement external weight storage using CoreML's blob file format. Currently all tensor data (weights, constants) is embedded inline in the protobuf using `ImmediateValue`. For large models, this is inefficient - the blob format stores weights in a separate binary file that can be memory-mapped. - -## Problem - -- Embedding gigabytes of weights inline in protobuf is slow and memory-intensive -- Protobuf parsing loads everything into memory at once -- Large models (LLMs, vision transformers) become impractical - -## Solution - -CoreML's MIL format supports `BlobFileValue` which references external binary files: - -```protobuf -message Value { - oneof value { - ImmediateValue immediateValue = 3; - BlobFileValue blobFileValue = 5; // External storage - } -} - -message BlobFileValue { - string fileName = 1; // "@model_path/weights/weight.bin" - uint64 offset = 2; // Byte offset to blob_metadata -} -``` - -## Blob File Format - -Based on coremltools' MILBlob implementation. - -### File Layout - -``` -[storage_header (64B)] -[blob_metadata_0 (64B)] [data_0 (padded to 64B alignment)] -[blob_metadata_1 (64B)] [data_1 (padded to 64B alignment)] -... -``` - -### storage_header (64 bytes) - -```go -type StorageHeader struct { - Count uint32 // Number of blob entries - Version uint32 // Always 2 - Reserved [48]byte // Zero padding -} -``` - -### blob_metadata (64 bytes) - -```go -type BlobMetadata struct { - Sentinel uint32 // 0xDEADBEEF - MilDType uint32 // BlobDataType enum - SizeInBytes uint64 // Size of raw data - Offset uint64 // Absolute file offset to data - PaddingSizeInBits uint64 // For sub-byte types (0 for normal types) - Reserved [32]byte // Zero padding -} -``` - -### BlobDataType Enum - -| Type | Code | -|----------|------| -| Float16 | 1 | -| Float32 | 2 | -| UInt8 | 3 | -| Int8 | 4 | -| BFloat16 | 5 | -| Int16 | 6 | -| UInt16 | 7 | -| Int32 | 14 | -| UInt32 | 15 | - -### Constants - -- `DefaultAlignment = 64` -- `BlobMetadataSentinel = 0xDEADBEEF` -- `BlobVersion = 2` - -## Implementation Plan - -### 1. New Package: `blob/` - -**blob/format.go** - Data structures -```go -package blob - -const ( - DefaultAlignment = 64 - BlobMetadataSentinel = 0xDEADBEEF - BlobVersion = 2 -) - -type DataType uint32 - -const ( - DataTypeFloat16 DataType = 1 - DataTypeFloat32 DataType = 2 - // ... -) - -type StorageHeader struct { - Count uint32 - Version uint32 - Reserved [48]byte -} - -type BlobMetadata struct { - Sentinel uint32 - MilDType uint32 - SizeInBytes uint64 - Offset uint64 - PaddingSizeInBits uint64 - Reserved [32]byte -} -``` - -**blob/writer.go** - Write weight.bin files -```go -type Writer struct { - file *os.File - offset uint64 - entries []BlobMetadata -} - -func NewWriter(path string) (*Writer, error) -func (w *Writer) AddBlob(dtype DataType, data []byte) (offset uint64, err error) -func (w *Writer) Close() error -``` - -### 2. Update model/serialize.go - -Add blob-aware serialization: - -```go -type SerializeOptions struct { - // ... existing fields ... - - // UseBlobStorage enables external weight storage - UseBlobStorage bool - // BlobThreshold is the minimum tensor size (bytes) to use blob storage - BlobThreshold int64 // default: 1024 -} - -func SaveMLPackageWithBlobs(model *spec.Model, path string, opts SerializeOptions) error -``` - -### 3. Update model/builder.go - -Track constants for potential blob export: - -```go -type Builder struct { - // ... existing fields ... - - // weights tracks constant tensors that may be exported to blobs - weights []*weightEntry -} - -type weightEntry struct { - name string - dtype DType - data []byte - milVal *milspec.Value -} -``` - -### 4. Integration - -When `UseBlobStorage` is enabled: -1. During `SaveMLPackage`, scan for large constants -2. Write them to `weights/weight.bin` using blob format -3. Replace `ImmediateValue` with `BlobFileValue` referencing the blob -4. Update manifest to include weight file - -## API Changes - -### Option 1: Builder Option (Recommended) - -```go -// Enable blob storage during model construction -b := model.NewBuilder("main", model.WithBlobStorage(true)) - -// Threshold can be configured -b := model.NewBuilder("main", model.WithBlobStorage(true), model.WithBlobThreshold(4096)) -``` - -### Option 2: Serialize Option - -```go -// Enable blob storage at serialization time -opts := model.SerializeOptions{ - UseBlobStorage: true, - BlobThreshold: 1024, -} -model.SaveMLPackage(m, path, opts) -``` - -Recommend Option 1 since weights need to be tracked during construction. - -## Testing - -1. **Unit tests** for blob writer (format correctness) -2. **Round-trip test**: create model with large constant, save with blobs, load and verify -3. **Threshold test**: verify small constants stay inline -4. **Integration test**: verify CoreML can load blob-backed models - -## File Changes - -``` -blob/ -├── format.go # NEW: Data structures and constants -├── writer.go # NEW: Blob file writer -└── writer_test.go # NEW: Tests - -model/ -├── builder.go # MODIFY: Add weight tracking, blob options -├── serialize.go # MODIFY: Add blob-aware serialization -└── serialize_test.go # MODIFY: Add blob tests -``` - -## References - -- [coremltools MIL.proto](https://github.com/apple/coremltools/blob/main/mlmodel/format/MIL.proto) -- [coremltools MILBlob StorageFormat.hpp](https://github.com/apple/coremltools/blob/main/mlmodel/src/MILBlob/Blob/StorageFormat.hpp) -- [coremltools MILBlob BlobDataType.hpp](https://github.com/apple/coremltools/blob/main/mlmodel/src/MILBlob/Blob/BlobDataType.hpp) diff --git a/specs/007-full-backend-inference.md b/specs/007-full-backend-inference.md deleted file mode 100644 index ffa01c8..0000000 --- a/specs/007-full-backend-inference.md +++ /dev/null @@ -1,199 +0,0 @@ -# go-coreml Full Backend Implementation Plan - -## Goal -Make go-coreml a fully supported GoMLX backend capable of full model inference, including transformers like DeBERTa. - ---- - -## Current State Analysis - -### What go-coreml Already Has (Strong Foundation) -- **79/81 MIL operations** implemented in `model/ops.go` -- **Control flow**: While/If loops with nested block support -- **Core transformer ops**: MatMul, Softmax, LayerNorm, Gather, Transpose, Reshape -- **Activations**: GELU, ReLU, Tanh, Sigmoid, Silu -- **Reductions**: Sum, Mean, Max, Min, ArgMax -- **Convolution**: Conv2D, ConvTranspose, pooling operations - -### What's Missing for Full Backend Parity - -| Category | Missing Operations | MIL Support | Effort | -|----------|-------------------|-------------|--------| -| **Logical** | LogicalAnd/Or/Not/Xor | MIL has them | Low | -| **BatchNorm** | BatchNormForInference/Training | MIL has batch_norm | Medium | -| **Scatter** | ScatterSum/Max/Min | MIL ScatterND | Medium | -| **Misc** | Identity, BroadcastInDim, Clamp, DynamicSlice | MIL support | Low | -| **Bitwise** | BitwiseAnd/Or/Xor, Shifts | NO MIL support | High | -| **Complex** | Complex numbers, FFT | NO MIL support | Very High | -| **Control** | Call, Sort(custom comparator) | NO MIL equivalent | High | - ---- - -## Implementation Phases - -### Phase 1: Transformer-Critical Operations -**Goal**: Enable DeBERTa/BERT inference - -**Operations to implement in `gomlx/function.go`:** - -```go -// 1. Identity - trivial no-op -func (f *Function) Identity(x backends.Value) (backends.Value, error) - -// 2. BroadcastInDim - map to MIL broadcast_to -func (f *Function) BroadcastInDim(x backends.Value, shape []int, broadcastDims []int) (backends.Value, error) - -// 3. Clamp - map to MIL Clip -func (f *Function) Clamp(x, min, max backends.Value) (backends.Value, error) - -// 4. Logical operations - wire existing MIL ops -func (f *Function) LogicalAnd(lhs, rhs backends.Value) (backends.Value, error) -func (f *Function) LogicalOr(lhs, rhs backends.Value) (backends.Value, error) -func (f *Function) LogicalNot(x backends.Value) (backends.Value, error) -func (f *Function) LogicalXor(lhs, rhs backends.Value) (backends.Value, error) - -// 5. Special value checks - wire existing MIL ops -func (f *Function) IsFinite(x backends.Value) (backends.Value, error) -func (f *Function) IsNaN(x backends.Value) (backends.Value, error) - -// 6. DynamicSlice - map to MIL SliceBySize -func (f *Function) DynamicSlice(operand backends.Value, startIndices []backends.Value, sliceSizes []int) (backends.Value, error) -``` - -**Files to modify:** -- `gomlx/function.go` - Add operation implementations -- `gomlx/capabilities.go` - Update supported operations -- `gomlx/function_test.go` - Add unit tests - -### Phase 2: BatchNorm and Advanced Reductions -**Goal**: Support training and more model architectures - -```go -// BatchNorm - map to MIL batch_norm -func (f *Function) BatchNormForInference(operand, scale, offset, mean, variance backends.Value, epsilon float32, featureAxis int) (backends.Value, error) - -// Compose for training (returns mean, variance, normalized) -func (f *Function) BatchNormForTraining(operand, scale, offset backends.Value, epsilon float32, featureAxis int) (normalized, batchMean, batchVar backends.Value, err error) - -// Reduce logical - compose from type conversion + reduce -func (f *Function) ReduceLogicalAnd(x backends.Value, axes ...int) (backends.Value, error) -func (f *Function) ReduceLogicalOr(x backends.Value, axes ...int) (backends.Value, error) - -// Remainder - use MIL floor_mod or compose -func (f *Function) Rem(lhs, rhs backends.Value) (backends.Value, error) -``` - -### Phase 3: Scatter Operations -**Goal**: Support gradient operations and advanced indexing - -```go -// Scatter ops - map to MIL ScatterND with different modes -func (f *Function) ScatterSum(operand, indices, updates backends.Value, ...) (backends.Value, error) // mode="add" -func (f *Function) ScatterMax(operand, indices, updates backends.Value, ...) (backends.Value, error) // mode="max" -func (f *Function) ScatterMin(operand, indices, updates backends.Value, ...) (backends.Value, error) // mode="min" - -// SelectAndScatter - complex composition from ArgMax + Scatter -func (f *Function) SelectAndScatterMax(...) (backends.Value, error) -``` - -### Phase 4: TotalOrder Comparisons -**Goal**: Full numerical correctness - -```go -// Compose from IsNaN/IsFinite + regular comparisons -func (f *Function) EqualTotalOrder(lhs, rhs backends.Value) (backends.Value, error) -func (f *Function) LessThanTotalOrder(lhs, rhs backends.Value) (backends.Value, error) -// etc. -``` - -### Phase 5: Documentation and Limitations -**Goal**: Document what's NOT supported - -**Mark as unsupported (no MIL equivalent):** -- Bitwise operations (BitwiseAnd/Or/Xor/Not) -- Shift operations (ShiftLeft/Right) -- Complex numbers (Complex, Real, Imag, Conj) -- FFT operations -- Call (function calls) - must inline -- Sort with custom comparators -- Bitcast, BitCount, Clz - ---- - -## Critical Files - -| File | Purpose | Lines | -|------|---------|-------| -| `gomlx/function.go` | Main operation implementations | 2717 | -| `gomlx/capabilities.go` | Supported ops reporting | 103 | -| `model/ops.go` | MIL operation layer | 1865 | -| `gomlx/backends/standard_ops.go` | Reference interface | 635 | - ---- - -## Testing Strategy - -### 1. Unit Tests (per operation) -```go -// In gomlx/function_test.go -func TestLogicalAnd(t *testing.T) { /* test cases */ } -func TestBatchNormForInference(t *testing.T) { /* test cases */ } -``` - -### 2. Integration Test (DeBERTa inference) -```go -func TestDeBERTaInference(t *testing.T) { - // Use gopeft/e2e/reranker model as test case - // Compare outputs with XLA backend -} -``` - -### 3. Numerical Accuracy Tests -```go -func TestNumericalAccuracyVsXLA(t *testing.T) { - // Run same operations on CoreML and XLA - // Verify results match within tolerance -} -``` - ---- - -## Verification - -1. **Build all packages:** - ```bash - cd ~/Documents/af/go-coreml && go build ./... - ``` - -2. **Run unit tests:** - ```bash - cd ~/Documents/af/go-coreml && go test -v ./gomlx/... - ``` - -3. **Run integration test with DeBERTa:** - ```bash - cd ~/Documents/af/gopeft && go test -v ./e2e/reranker/... -run TestDeBERTaCrossEncoderForward - ``` - -4. **Benchmark CoreML vs XLA:** - ```bash - cd ~/Documents/af/gopeft && go test -bench=BenchmarkBackendComparison ./e2e/reranker/ - ``` - ---- - -## Risk Mitigation - -| Risk | Mitigation | -|------|------------| -| MIL op behavior differs from XLA | Comprehensive numerical tests | -| Complex Gather patterns fail | Document supported patterns | -| Performance regression | Benchmark against XLA | -| Missing ops block models | Clear capability reporting | - ---- - -## Summary - -**Quick win:** Phase 1 enables transformer inference -**Key insight:** go-coreml already has 95% of needed MIL ops, just need to wire them to GoMLX interface diff --git a/specs/007-gaps.md b/specs/007-gaps.md deleted file mode 100644 index 4ccf64f..0000000 --- a/specs/007-gaps.md +++ /dev/null @@ -1,44 +0,0 @@ -High Priority Issues - -1. Error Handling (1 panic that should be error): -- gomlx/buffer.go:113 - Buffer finalization -- (FIXED) model/ops.go:1325 - Concat with no inputs - now returns nil and sets error -- (FIXED) model/ops.go:1543 - Einsum with wrong input count - now returns nil and sets error - -2. Missing Operations for GoMLX backend: -- Call - calling sub-functions -- Sort - sorting with comparator -- While - loop control flow (MIL support implemented in model layer, GoMLX integration pending) -- If - conditional branching (MIL support implemented in model layer, GoMLX integration pending) -- Complex Gather - multi-axis gather - -3. Documented TODOs that block use cases: -- MinPool not implemented (could do -MaxPool(-x)) -- Input dilation > 1 not supported -- Batch group count > 1 not supported in Conv - -Medium Priority - -4. Capability declaration mismatches: -- Some ops declared as "supported" but only partially work (Gather, Pad interior, ReduceWindow with min/product) - -5. Documentation gaps: -- No limitations document -- No troubleshooting guide -- No performance tuning guide -- (ADDED) Control flow operations documented in docs/control-flow-operations.md - -Recent Progress (Control Flow) - -The model layer now supports nested blocks for control flow operations: - -- `model.BlockBuilder` - Build nested blocks within operations -- `model.Cond` - Conditional execution (if/else) with true/false branch blocks -- `model.WhileLoop` - Loop execution with condition and body blocks -- `serialize_blob.go` - Updated to handle nested blocks in operations - -The GoMLX backend still returns NotImplementedError for While/If since full integration -requires closure compilation. Users can: -- Use `Where()` for element-wise conditionals -- Unroll loops at graph construction time -- Use the model layer directly for control flow diff --git a/specs/008-high-dim-tensors.md b/specs/008-high-dim-tensors.md deleted file mode 100644 index 6051b67..0000000 --- a/specs/008-high-dim-tensors.md +++ /dev/null @@ -1,161 +0,0 @@ -# Plan: Decompose High-Rank Reshapes for CoreML - -## Context - -Florence-2's DaViT vision encoder uses window attention that creates rank-6 intermediate tensors via reshape: -`[batch, h, w, c]` → `[batch, h_win, win_h, w_win, win_w, c]` (rank 6). -CoreML's runtime limits reshape to rank ≤ 5, causing compilation failure: -``` -Rank of the shape parameter must be between 0 and 5 (inclusive) in reshape -``` - -There are 24 such reshapes in the model. The standard pattern is always -`reshape(rank≥6) → transpose → reshape(rank≤5)` — the high-rank intermediate -is immediately transposed and collapsed. Apple's coremltools has a built-in graph -pass `expand_high_rank_reshape_and_transpose` that decomposes these into rank-4 -operations. We need to port this pass to go-coreml's `model.Builder.Build()`. - -## Algorithm (from Apple's coremltools) - -### Pattern Match -Find three consecutive operations: `reshape₁ → transpose → reshape₂` where: -- `reshape₁` output rank ≥ 6 -- `reshape₂` output rank ≤ 5 -- Intermediate values each have exactly one consumer -- Intermediate values are not model outputs - -### Decomposition -1. **Group consecutive axes** in the transpose permutation. E.g., perm `[0,1,3,4,2,5]` groups as `[[0,1],[3,4],[2],[5]]` -2. **Compute merged shape**: product of original shape dims within each group -3. **Compute merged perm**: permutation of the groups -4. If merged rank < 6: emit one reshape (to merged shape) + one transpose (merged perm) -5. If merged rank ≥ 6: emit iterative rank-4 `reshape + transpose([0,2,1,3])` pairs that bubble each axis into place -6. Final reshape to the original `reshape₂` output shape, **reusing the original output name** so downstream references remain valid - -### `_get_prod` Helper -Product of shape elements from `start` to `end`, skipping indices in a memo set: -``` -_get_prod(start, end, shape, skip) = ∏ shape[i] for i in [start,end) if i not in skip -``` - -### Iterative Rank-4 Decomposition (rank ≥ 6 path) -``` -leading_dim = 1 -memo = {} -for i in range(rank): - axis = perm[i] - dim = shape[axis] - memo.add(axis) - reshape_shape = [leading_dim, _get_prod(0, axis, shape, memo), dim, _get_prod(axis+1, rank, shape, memo)] - x = reshape(x, reshape_shape) - x = transpose(x, [0, 2, 1, 3]) - leading_dim *= dim -``` - -## Changes - -### 1. Create `go-coreml/model/optimize.go` - -New file containing the graph pass with these functions: - -**`(b *Builder) expandHighRankReshapes()`** — Main entry point called from `Build()`: -- Build a consumer map: `map[string][]int` mapping value name → indices of consuming operations -- Scan operations for "reshape" ops with output rank ≥ 6 -- For each match, check if the single consumer is "transpose", and its single consumer is "reshape" with rank ≤ 5 -- Check intermediates are not in `b.outputs` -- Call decomposition function, collect replacement operations -- Rebuild `b.operations` slice, replacing each matched triple with its replacement sequence -- Update `b.values` map for new intermediate values - -**`decomposeReshapeTranspose(...) []*milspec.Operation`** — Decomposition logic: -- Extract: input name, high-rank shape, perm, final shape, output name, dtype -- Group consecutive axes in perm -- Compute merged shape and merged perm -- If merged rank < 6: emit reshape + transpose -- If merged rank ≥ 6: emit iterative rank-4 pairs -- Emit final reshape with the **original output name** of reshape₂ - -**Helper functions:** -- `getOpOutputShape(op) []int64` — extract shape from operation's output TensorType dimensions -- `getOpOutputDType(op) DType` — extract dtype from operation's output TensorType -- `getOpInputName(op, paramName) string` — extract the name reference from an input argument -- `getInlineInt32s(op, paramName) []int32` — extract inline Int32 constant values from an argument -- `makeReshapeOp(inputName, outputName string, shape []int64, dtype DType) *milspec.Operation` — create a reshape operation in protobuf -- `makeTransposeOp(inputName, outputName string, perm []int64, inputShape []int64, dtype DType) *milspec.Operation` — create a transpose operation in protobuf -- `getProd(start, end int, shape []int64, skip map[int]bool) int64` — product of shape elements - -### 2. Modify `go-coreml/model/builder.go` — Call the pass in `Build()` - -Add `b.expandHighRankReshapes()` at the start of `Build()` (line ~430), before the operations are packaged into the Block: -```go -func (b *Builder) Build() *Program { - // Optimize: decompose high-rank reshape patterns for CoreML compatibility. - b.expandHighRankReshapes() - - // Build function inputs - inputs := make([]*milspec.NamedValueType, len(b.inputs)) - ... -``` - -### 3. Create `go-coreml/model/optimize_test.go` - -Unit tests for the graph pass: -- **TestExpandRank6ReshapeTranspose**: Build a graph with `reshape([2,3] → [1,1,2,1,3,1]) → transpose → reshape(...)`, verify the resulting operations are all rank ≤ 5 -- **TestExpandRank5Passthrough**: Build a graph with rank-5 reshape, verify no transformation occurs -- **TestExpandWindowAttentionPattern**: Simulate DaViT's exact pattern: `[1,14,14,192] → reshape [1,2,7,2,7,192] → transpose [0,1,3,2,4,5] → reshape [1,4,7,7,192]` - -## Key Implementation Details - -### Protobuf Access Patterns - -Extract input name: -```go -op.Inputs["x"].Arguments[0].GetName() -``` - -Extract inline Int32 constant: -```go -op.Inputs["perm"].Arguments[0].GetValue().GetImmediateValue().GetTensor().GetInts().Values -``` - -Extract output shape: -```go -for _, dim := range op.Outputs[0].Type.GetTensorType().Dimensions { - size := int64(dim.GetConstant().Size) -} -``` - -### Name Preservation - -The final reshape in the replacement sequence **must** use the same output name as the original `reshape₂`. This ensures all downstream operations that reference that name continue to work. Intermediate operations use `b.genName()` for unique names. - -### Existing Utilities to Reuse - -- `toInt32Slice([]int64) []int32` — `model/ops.go:1763` -- `createValue(dtype, shape, data) *milspec.Value` — `model/builder.go:186` -- `b.genName(prefix) string` — `model/builder.go:113` - -## Files Modified - -| File | Change | -|------|--------| -| `go-coreml/model/optimize.go` | **New** — graph pass implementation | -| `go-coreml/model/optimize_test.go` | **New** — unit tests | -| `go-coreml/model/builder.go` | Add `b.expandHighRankReshapes()` call in `Build()` | - -## Verification - -1. Run go-coreml model package tests: - ```bash - cd /Users/ajroetker/go/src/github.com/gomlx/go-coreml && go test ./model/ -v -count=1 - ``` - -2. Run full go-coreml test suite: - ```bash - cd /Users/ajroetker/go/src/github.com/gomlx/go-coreml && go test ./... -v -count=1 - ``` - -3. Run Florence-2 CoreML E2E test: - ```bash - cd /Users/ajroetker/go/src/github.com/antflydb/antfly/termite && GOEXPERIMENT=simd go test -v -tags coreml ./e2e/ -run TestFlorence2CoreMLSingleStep -timeout 10m -count=1 - ```