Skip to content

Commit 69be336

Browse files
committed
feat(gpu): route device selection through driver config
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 009a6ee commit 69be336

21 files changed

Lines changed: 1004 additions & 280 deletions

File tree

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

architecture/compute-runtimes.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ but currently ignores them.
4242

4343
GPU requests enter the driver layer through
4444
`SandboxSpec.resource_requirements.gpu`. The compact interim shape supports a
45-
default GPU request, GPU count, and driver-specific device IDs.
45+
default GPU request and GPU count. Exact driver-native device selection is
46+
passed through the selected runtime's `driver_config` block; the gateway
47+
selects that block but does not interpret the nested driver schema. Drivers
48+
that support exact selection validate that the unique `gpu_device_ids` entry
49+
count matches the portable GPU count.
4650

4751
VM runtime state paths are derived only from driver-validated sandbox IDs
4852
matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a
@@ -81,9 +85,7 @@ users.
8185
Custom sandbox images must include the agent runtime and any system
8286
dependencies, but they should not need to include the gateway. GPU-capable
8387
images must include the user-space libraries required by the workload. The
84-
runtime still owns GPU device injection. GPU requests can include explicit
85-
driver-native device IDs or a requested count; the gateway validates the public
86-
request shape and each runtime enforces the GPU allocation modes it supports.
88+
runtime still owns GPU device injection.
8789

8890
## Deployment Shape
8991

crates/openshell-cli/src/run.rs

Lines changed: 63 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,8 +1734,10 @@ pub async fn sandbox_create(
17341734
}
17351735
None => None,
17361736
};
1737-
let requested_gpu =
1738-
gpu || gpu_count.is_some() || image.as_deref().is_some_and(image_requests_gpu);
1737+
let requested_gpu = gpu
1738+
|| gpu_count.is_some()
1739+
|| gpu_device.is_some_and(|device_id| !device_id.trim().is_empty())
1740+
|| image.as_deref().is_some_and(image_requests_gpu);
17391741

17401742
let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?;
17411743
let inferred_types: Vec<String> = if providers_v2_enabled {
@@ -1753,11 +1755,13 @@ pub async fn sandbox_create(
17531755

17541756
let policy = load_sandbox_policy(policy)?;
17551757
let resource_limits = build_sandbox_resource_limits(cpu, memory)?;
1758+
let driver_config = gpu_driver_config_from_cli(gpu_device);
17561759

1757-
let template = if image.is_some() || resource_limits.is_some() {
1760+
let template = if image.is_some() || resource_limits.is_some() || driver_config.is_some() {
17581761
Some(SandboxTemplate {
17591762
image: image.unwrap_or_default(),
17601763
resources: resource_limits,
1764+
driver_config,
17611765
..SandboxTemplate::default()
17621766
})
17631767
} else {
@@ -1766,11 +1770,7 @@ pub async fn sandbox_create(
17661770

17671771
let request = CreateSandboxRequest {
17681772
spec: Some(SandboxSpec {
1769-
resource_requirements: resource_requirements_from_cli(
1770-
requested_gpu,
1771-
gpu_device,
1772-
gpu_count,
1773-
),
1773+
resource_requirements: resource_requirements_from_cli(requested_gpu, gpu_count),
17741774
policy,
17751775
providers: configured_providers,
17761776
template,
@@ -2197,17 +2197,48 @@ pub async fn sandbox_create(
21972197

21982198
fn resource_requirements_from_cli(
21992199
requested_gpu: bool,
2200-
gpu_device: Option<&str>,
22012200
gpu_count: Option<u32>,
22022201
) -> Option<SandboxResourceRequirements> {
2203-
requested_gpu.then(|| SandboxResourceRequirements {
2204-
gpu: Some(GpuResourceRequirement {
2205-
device_ids: gpu_device
2206-
.filter(|device_id| !device_id.is_empty())
2207-
.map(|device_id| vec![device_id.to_string()])
2208-
.unwrap_or_default(),
2209-
count: gpu_count,
2210-
}),
2202+
requested_gpu.then_some(SandboxResourceRequirements {
2203+
gpu: Some(GpuResourceRequirement { count: gpu_count }),
2204+
})
2205+
}
2206+
2207+
fn gpu_driver_config_from_cli(gpu_device: Option<&str>) -> Option<prost_types::Struct> {
2208+
use prost_types::{ListValue, Struct, Value, value::Kind};
2209+
2210+
fn string_value(value: &str) -> Value {
2211+
Value {
2212+
kind: Some(Kind::StringValue(value.to_string())),
2213+
}
2214+
}
2215+
2216+
fn driver_block(gpu_device: &str) -> Value {
2217+
Value {
2218+
kind: Some(Kind::StructValue(Struct {
2219+
fields: std::iter::once((
2220+
"gpu_device_ids".to_string(),
2221+
Value {
2222+
kind: Some(Kind::ListValue(ListValue {
2223+
values: vec![string_value(gpu_device)],
2224+
})),
2225+
},
2226+
))
2227+
.collect(),
2228+
})),
2229+
}
2230+
}
2231+
2232+
let gpu_device = gpu_device.filter(|device_id| !device_id.trim().is_empty())?;
2233+
2234+
Some(Struct {
2235+
fields: [
2236+
("docker".to_string(), driver_block(gpu_device)),
2237+
("podman".to_string(), driver_block(gpu_device)),
2238+
("vm".to_string(), driver_block(gpu_device)),
2239+
]
2240+
.into_iter()
2241+
.collect(),
22112242
})
22122243
}
22132244

@@ -7460,10 +7491,10 @@ mod tests {
74607491
dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header,
74617492
format_gateway_select_items, format_provider_attachment_table, gateway_add,
74627493
gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label,
7463-
git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle,
7464-
inferred_provider_type, package_managed_tls_dirs, parse_cli_setting_value,
7465-
parse_credential_expiry_cli_value, parse_credential_expiry_pairs, parse_credential_pairs,
7466-
plaintext_gateway_is_remote, progress_step_from_metadata,
7494+
git_sync_files, gpu_driver_config_from_cli, http_health_check, image_requests_gpu,
7495+
import_local_package_mtls_bundle, inferred_provider_type, package_managed_tls_dirs,
7496+
parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs,
7497+
parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata,
74677498
provider_profile_allows_refresh_bootstrap, provisioning_timeout_message,
74687499
ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from,
74697500
resource_requirements_from_cli, sandbox_should_persist, sandbox_upload_plan,
@@ -7948,47 +7979,40 @@ mod tests {
79487979

79497980
#[test]
79507981
fn resource_requirements_from_cli_uses_presence_for_default_gpu() {
7951-
let requirements = resource_requirements_from_cli(true, None, None)
7982+
let requirements = resource_requirements_from_cli(true, None)
79527983
.expect("resource requirements should be present");
79537984
let gpu = requirements.gpu.expect("GPU requirement should be present");
79547985

7955-
assert!(gpu.device_ids.is_empty());
79567986
assert_eq!(gpu.count, None);
79577987
}
79587988

79597989
#[test]
7960-
fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id() {
7961-
let requirements = resource_requirements_from_cli(true, Some("0000:2d:00.0"), None)
7962-
.expect("resource requirements should be present");
7963-
let gpu = requirements.gpu.expect("GPU requirement should be present");
7990+
fn gpu_driver_config_from_cli_maps_gpu_device_to_driver_blocks() {
7991+
let config = gpu_driver_config_from_cli(Some("nvidia.com/gpu=0"))
7992+
.expect("driver config should be present");
79647993

7965-
assert_eq!(gpu.device_ids, vec!["0000:2d:00.0"]);
7966-
assert_eq!(gpu.count, None);
7994+
assert!(config.fields.contains_key("docker"));
7995+
assert!(config.fields.contains_key("podman"));
7996+
assert!(config.fields.contains_key("vm"));
79677997
}
79687998

79697999
#[test]
79708000
fn resource_requirements_from_cli_maps_gpu_count() {
79718001
let requirements =
7972-
resource_requirements_from_cli(true, None, Some(2)).expect("requirements should exist");
8002+
resource_requirements_from_cli(true, Some(2)).expect("requirements should exist");
79738003
let gpu = requirements.gpu.expect("GPU requirement should be present");
79748004

7975-
assert!(gpu.device_ids.is_empty());
79768005
assert_eq!(gpu.count, Some(2));
79778006
}
79788007

79798008
#[test]
7980-
fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation() {
7981-
let requirements = resource_requirements_from_cli(true, Some("nvidia.com/gpu=0"), Some(2))
7982-
.expect("requirements should exist");
7983-
let gpu = requirements.gpu.expect("GPU requirement should be present");
7984-
7985-
assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]);
7986-
assert_eq!(gpu.count, Some(2));
8009+
fn gpu_driver_config_from_cli_omits_empty_device() {
8010+
assert!(gpu_driver_config_from_cli(Some("")).is_none());
79878011
}
79888012

79898013
#[test]
79908014
fn resource_requirements_from_cli_omits_gpu_request_when_not_requested() {
7991-
assert!(resource_requirements_from_cli(false, Some("0"), None).is_none());
8015+
assert!(resource_requirements_from_cli(false, None).is_none());
79928016
}
79938017

79948018
#[test]

crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -922,15 +922,23 @@ async fn sandbox_create_sends_gpu_count_request() {
922922
.expect("sandbox create should succeed");
923923

924924
let requests = create_requests(&server).await;
925-
let gpu = requests[0]
925+
let spec = requests[0]
926926
.spec
927927
.as_ref()
928-
.and_then(|spec| spec.resource_requirements.as_ref())
928+
.expect("sandbox spec should be sent");
929+
let gpu = spec
930+
.resource_requirements
931+
.as_ref()
929932
.and_then(|requirements| requirements.gpu.as_ref())
930933
.expect("GPU request should be sent");
931934

932-
assert!(gpu.device_ids.is_empty());
933935
assert_eq!(gpu.count, Some(2));
936+
assert!(
937+
spec.template
938+
.as_ref()
939+
.and_then(|template| template.driver_config.as_ref())
940+
.is_none()
941+
);
934942
}
935943

936944
#[tokio::test]

0 commit comments

Comments
 (0)