-
Notifications
You must be signed in to change notification settings - Fork 184
AIR CLI Integration: air run Command Pt. 1 - Add GPU accelerator type and compute config model
#5602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AIR CLI Integration: air run Command Pt. 1 - Add GPU accelerator type and compute config model
#5602
Changes from all commits
6c80f23
059bd61
2ccd069
89042d0
c99239c
235f1bd
8a97e0f
e04b698
31121ad
3883791
472a1fe
0ab1008
2615cd7
d3bb64b
3deceab
9072c29
fddbdfd
a69e0d3
fa3a1a2
62be1a1
9efd3d1
6850d23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| package aircmd | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "strings" | ||
| ) | ||
|
|
||
| // gpuType is a wire-facing accelerator type submitted to the training service. | ||
| // The number in the name is the partition count (e.g. GPU_8xH100 is 8 GPUs). | ||
| type gpuType string | ||
|
|
||
| const ( | ||
| gpuType1xA10 gpuType = "GPU_1xA10" | ||
| gpuType8xH100 gpuType = "GPU_8xH100" | ||
| gpuType1xH100 gpuType = "GPU_1xH100" | ||
| ) | ||
|
|
||
| // gpuTypes lists every valid type. Used for validation error messages. | ||
| var gpuTypes = []gpuType{gpuType1xA10, gpuType1xH100, gpuType8xH100} | ||
|
|
||
| func validGPUTypesHint() string { | ||
| names := make([]string, len(gpuTypes)) | ||
| for i, g := range gpuTypes { | ||
| names[i] = string(g) | ||
| } | ||
| return "valid types are: " + strings.Join(names, ", ") | ||
| } | ||
|
|
||
| // parseGPUType resolves a YAML accelerator_type string to a gpuType. The match is | ||
| // exact: the server's lookup is case-sensitive. | ||
| func parseGPUType(value string) (gpuType, error) { | ||
| switch gpuType(value) { | ||
| case gpuType1xA10, gpuType8xH100, gpuType1xH100: | ||
| return gpuType(value), nil | ||
| } | ||
| return "", fmt.Errorf("invalid GPU type %q: %s", value, validGPUTypesHint()) | ||
| } | ||
|
|
||
| // gpusPerNode returns the per-node GPU count, which is the partition count from | ||
| // the name (GPU_1xH100 -> 1, GPU_8xH100 -> 8). num_accelerators must be a | ||
| // round multiple of this since accelerators are allocated in whole nodes. | ||
| func gpusPerNode(g gpuType) (int, error) { | ||
| switch g { | ||
| case gpuType1xA10, gpuType1xH100: | ||
| return 1, nil | ||
| case gpuType8xH100: | ||
| return 8, nil | ||
| } | ||
| // Unreachable: callers resolve g through parseGPUType first, which rejects | ||
| // unknown types. Kept as a defensive guard. | ||
| return 0, fmt.Errorf("invalid GPU type %q", string(g)) | ||
| } | ||
|
|
||
| // computeConfig is the `compute` block of the run YAML: which accelerators to | ||
| // use and how many. | ||
| type computeConfig struct { | ||
| NumAccelerators int `yaml:"num_accelerators"` | ||
| AcceleratorType string `yaml:"accelerator_type"` | ||
| } | ||
|
|
||
| // validate checks the compute block against the backend's constraints. | ||
| func (c computeConfig) validate() error { | ||
| g, err := parseGPUType(c.AcceleratorType) | ||
| if err != nil { | ||
| return fmt.Errorf("compute.accelerator_type: %w", err) | ||
| } | ||
|
|
||
| if c.NumAccelerators <= 0 { | ||
| return fmt.Errorf("compute.num_accelerators must be positive, got %d", c.NumAccelerators) | ||
| } | ||
|
|
||
| perNode, err := gpusPerNode(g) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| if c.NumAccelerators%perNode != 0 { | ||
| return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators) | ||
| } | ||
|
Comment on lines
+72
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm off the opinion this kind of check should be done in the backend. @maggiewang-db @ben-hansen-db @vinchenzo-db wdyt? can we do that easily using Training Service logic?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that based on the project milestones and as I discussed with Maggie yesterday, we want to port this in phases. As written in the project doc, we want to first port the run functionality directly as is (including the validation) and then move the validation & add handlers to the backend in milestone 3. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. But my plan is to do that later in Milestone 3.2 after the initial lift and shift. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good |
||
|
|
||
| return nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| package aircmd | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestParseGPUType(t *testing.T) { | ||
| tests := []struct { | ||
| in string | ||
| want gpuType | ||
| }{ | ||
| {"GPU_1xA10", gpuType1xA10}, | ||
| {"GPU_8xH100", gpuType8xH100}, | ||
| {"GPU_1xH100", gpuType1xH100}, | ||
| } | ||
| for _, tt := range tests { | ||
| t.Run(tt.in, func(t *testing.T) { | ||
| got, err := parseGPUType(tt.in) | ||
| require.NoError(t, err) | ||
| assert.Equal(t, tt.want, got) | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| func TestParseGPUTypeInvalid(t *testing.T) { | ||
| // Wrong casing is rejected rather than fixed up; legacy types (h100_80gb, a10) | ||
| // can no longer be submitted; unknown types are rejected. | ||
| for _, in := range []string{"gpu_1xa10", "GPU_1XA10", "GPU_2xH100", "h100_80gb", "a10", "b200", ""} { | ||
| t.Run(in, func(t *testing.T) { | ||
| _, err := parseGPUType(in) | ||
| require.Error(t, err) | ||
| assert.Contains(t, err.Error(), "valid types are") | ||
| }) | ||
| } | ||
| } | ||
|
|
||
| func TestGPUsPerNode(t *testing.T) { | ||
| tests := []struct { | ||
| in gpuType | ||
| want int | ||
| }{ | ||
| {gpuType1xA10, 1}, | ||
| {gpuType1xH100, 1}, | ||
| {gpuType8xH100, 8}, | ||
| } | ||
| for _, tt := range tests { | ||
| t.Run(string(tt.in), func(t *testing.T) { | ||
| got, err := gpusPerNode(tt.in) | ||
| require.NoError(t, err) | ||
| assert.Equal(t, tt.want, got) | ||
| }) | ||
| } | ||
|
|
||
| _, err := gpusPerNode(gpuType("nonsense")) | ||
| require.Error(t, err) | ||
| } | ||
|
|
||
| func TestComputeConfigValidate(t *testing.T) { | ||
| tests := []struct { | ||
| name string | ||
| cfg computeConfig | ||
| wantErr string // substring; empty means the config is valid | ||
| }{ | ||
| {"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""}, | ||
| {"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""}, | ||
| {"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""}, | ||
| {"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"}, | ||
| {"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"}, | ||
| {"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"}, | ||
| {"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"}, | ||
| } | ||
| for _, tt := range tests { | ||
| t.Run(tt.name, func(t *testing.T) { | ||
| err := tt.cfg.validate() | ||
| if tt.wantErr == "" { | ||
| require.NoError(t, err) | ||
| return | ||
| } | ||
| require.Error(t, err) | ||
| assert.Contains(t, err.Error(), tt.wantErr) | ||
| }) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: By the time validate() reaches gpusPerNode(), parseGPUType() has already guaranteed g is valid.
It's ok to leave the code as is to be defensive. Just add a comment this shouldn't be reachable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment, thanks