Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6c80f23
experimental/air: scaffold AI runtime CLI command package
riddhibhagwat-db Jun 11, 2026
059bd61
experimental/air: rename `status` subcommand to `get`
riddhibhagwat-db Jun 12, 2026
2ccd069
experimental/air: implement the `air get` command
riddhibhagwat-db Jun 14, 2026
89042d0
experimental/air: rename stale TestBuildStatusData to TestBuildGetData
riddhibhagwat-db Jun 14, 2026
c99239c
experimental/air: apply testifylint fixes in get/format tests
riddhibhagwat-db Jun 14, 2026
235f1bd
experimental/air: disambiguate JOB_RUN_ID, hide --review, add -i alias
riddhibhagwat-db Jun 16, 2026
8a97e0f
Merge branch 'main' into air-integration-m0
riddhibhagwat-db Jun 16, 2026
e04b698
Merge branch 'main' into air-integration-m0
riddhibhagwat-db Jun 16, 2026
31121ad
experimental/air: implement the `air get` command
riddhibhagwat-db Jun 14, 2026
3883791
experimental/air: rename stale TestBuildStatusData to TestBuildGetData
riddhibhagwat-db Jun 14, 2026
472a1fe
experimental/air: apply testifylint fixes in get/format tests
riddhibhagwat-db Jun 14, 2026
0ab1008
experimental/air: print `air get` training config to stdout
riddhibhagwat-db Jun 16, 2026
2615cd7
experimental/air: report latest-attempt timing and round duration
riddhibhagwat-db Jun 16, 2026
d3bb64b
experimental/air: link MLflow output to the latest attempt
riddhibhagwat-db Jun 16, 2026
3deceab
experimental/air: render JSON error envelopes and align `air get` JSO…
riddhibhagwat-db Jun 16, 2026
9072c29
experimental/air: restore last-attempt timing for `air get`
riddhibhagwat-db Jun 17, 2026
fddbdfd
experimental/air: fix `air get` MLflow link request
riddhibhagwat-db Jun 17, 2026
a69e0d3
experimental/air: rename `air get` to `air get run` and match Python …
riddhibhagwat-db Jun 17, 2026
fa3a1a2
experimental/air: add GPU accelerator type and compute config model
riddhibhagwat-db Jun 14, 2026
62be1a1
experimental/air: drop node pool / pool name compute fields
riddhibhagwat-db Jun 16, 2026
9efd3d1
Merge branch 'air-integration-m1-1' into air-integration-m2-1
riddhibhagwat-db Jun 17, 2026
6850d23
Merge branch 'air-cli' into air-integration-m2-1
riddhibhagwat-db Jun 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions experimental/air/cmd/compute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package aircmd

import (
"fmt"
"strings"
)

// gpuType is a wire-facing accelerator type submitted to the training service.
// The number in the name is the partition count (e.g. GPU_8xH100 is 8 GPUs).
type gpuType string

const (
gpuType1xA10 gpuType = "GPU_1xA10"
gpuType8xH100 gpuType = "GPU_8xH100"
gpuType1xH100 gpuType = "GPU_1xH100"
)

// gpuTypes lists every valid type. Used for validation error messages.
var gpuTypes = []gpuType{gpuType1xA10, gpuType1xH100, gpuType8xH100}

func validGPUTypesHint() string {
names := make([]string, len(gpuTypes))
for i, g := range gpuTypes {
names[i] = string(g)
}
return "valid types are: " + strings.Join(names, ", ")
}

// parseGPUType resolves a YAML accelerator_type string to a gpuType. The match is
// exact: the server's lookup is case-sensitive.
func parseGPUType(value string) (gpuType, error) {
switch gpuType(value) {
case gpuType1xA10, gpuType8xH100, gpuType1xH100:
return gpuType(value), nil
}
return "", fmt.Errorf("invalid GPU type %q: %s", value, validGPUTypesHint())
}

// gpusPerNode returns the per-node GPU count, which is the partition count from
// the name (GPU_1xH100 -> 1, GPU_8xH100 -> 8). num_accelerators must be a
// round multiple of this since accelerators are allocated in whole nodes.
func gpusPerNode(g gpuType) (int, error) {
switch g {
case gpuType1xA10, gpuType1xH100:
return 1, nil
case gpuType8xH100:
return 8, nil
}
// Unreachable: callers resolve g through parseGPUType first, which rejects
// unknown types. Kept as a defensive guard.
return 0, fmt.Errorf("invalid GPU type %q", string(g))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: By the time validate() reaches gpusPerNode(), parseGPUType() has already guaranteed g is valid.
It's ok to leave the code as is to be defensive. Just add a comment this shouldn't be reachable.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment, thanks

}

// computeConfig is the `compute` block of the run YAML: which accelerators to
// use and how many.
type computeConfig struct {
NumAccelerators int `yaml:"num_accelerators"`
AcceleratorType string `yaml:"accelerator_type"`
}

// validate checks the compute block against the backend's constraints.
func (c computeConfig) validate() error {
g, err := parseGPUType(c.AcceleratorType)
if err != nil {
return fmt.Errorf("compute.accelerator_type: %w", err)
}

if c.NumAccelerators <= 0 {
return fmt.Errorf("compute.num_accelerators must be positive, got %d", c.NumAccelerators)
}

perNode, err := gpusPerNode(g)
if err != nil {
return err
}
if c.NumAccelerators%perNode != 0 {
return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators)
}
Comment on lines +72 to +78

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm off the opinion this kind of check should be done in the backend. @maggiewang-db @ben-hansen-db @vinchenzo-db wdyt? can we do that easily using Training Service logic?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that based on the project milestones and as I discussed with Maggie yesterday, we want to port this in phases. As written in the project doc, we want to first port the run functionality directly as is (including the validation) and then move the validation & add handlers to the backend in milestone 3.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. But my plan is to do that later in Milestone 3.2 after the initial lift and shift.
It needs some design to decide which validations to move to backend, which validations to keep in client

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good


return nil
}
86 changes: 86 additions & 0 deletions experimental/air/cmd/compute_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package aircmd

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParseGPUType(t *testing.T) {
tests := []struct {
in string
want gpuType
}{
{"GPU_1xA10", gpuType1xA10},
{"GPU_8xH100", gpuType8xH100},
{"GPU_1xH100", gpuType1xH100},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
got, err := parseGPUType(tt.in)
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestParseGPUTypeInvalid(t *testing.T) {
// Wrong casing is rejected rather than fixed up; legacy types (h100_80gb, a10)
// can no longer be submitted; unknown types are rejected.
for _, in := range []string{"gpu_1xa10", "GPU_1XA10", "GPU_2xH100", "h100_80gb", "a10", "b200", ""} {
t.Run(in, func(t *testing.T) {
_, err := parseGPUType(in)
require.Error(t, err)
assert.Contains(t, err.Error(), "valid types are")
})
}
}

func TestGPUsPerNode(t *testing.T) {
tests := []struct {
in gpuType
want int
}{
{gpuType1xA10, 1},
{gpuType1xH100, 1},
{gpuType8xH100, 8},
}
for _, tt := range tests {
t.Run(string(tt.in), func(t *testing.T) {
got, err := gpusPerNode(tt.in)
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}

_, err := gpusPerNode(gpuType("nonsense"))
require.Error(t, err)
}

func TestComputeConfigValidate(t *testing.T) {
tests := []struct {
name string
cfg computeConfig
wantErr string // substring; empty means the config is valid
}{
{"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""},
{"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""},
{"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""},
{"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"},
{"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"},
{"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"},
{"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cfg.validate()
if tt.wantErr == "" {
require.NoError(t, err)
return
}
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}
Loading