@@ -1734,8 +1734,10 @@ pub async fn sandbox_create(
17341734 }
17351735 None => None ,
17361736 } ;
1737- let requested_gpu =
1738- gpu || gpu_count. is_some ( ) || image. as_deref ( ) . is_some_and ( image_requests_gpu) ;
1737+ let requested_gpu = gpu
1738+ || gpu_count. is_some ( )
1739+ || gpu_device. is_some_and ( |device_id| !device_id. trim ( ) . is_empty ( ) )
1740+ || image. as_deref ( ) . is_some_and ( image_requests_gpu) ;
17391741
17401742 let providers_v2_enabled = gateway_providers_v2_enabled ( & mut client) . await ?;
17411743 let inferred_types: Vec < String > = if providers_v2_enabled {
@@ -1753,11 +1755,13 @@ pub async fn sandbox_create(
17531755
17541756 let policy = load_sandbox_policy ( policy) ?;
17551757 let resource_limits = build_sandbox_resource_limits ( cpu, memory) ?;
1758+ let driver_config = gpu_driver_config_from_cli ( gpu_device) ;
17561759
1757- let template = if image. is_some ( ) || resource_limits. is_some ( ) {
1760+ let template = if image. is_some ( ) || resource_limits. is_some ( ) || driver_config . is_some ( ) {
17581761 Some ( SandboxTemplate {
17591762 image : image. unwrap_or_default ( ) ,
17601763 resources : resource_limits,
1764+ driver_config,
17611765 ..SandboxTemplate :: default ( )
17621766 } )
17631767 } else {
@@ -1766,11 +1770,7 @@ pub async fn sandbox_create(
17661770
17671771 let request = CreateSandboxRequest {
17681772 spec : Some ( SandboxSpec {
1769- resource_requirements : resource_requirements_from_cli (
1770- requested_gpu,
1771- gpu_device,
1772- gpu_count,
1773- ) ,
1773+ resource_requirements : resource_requirements_from_cli ( requested_gpu, gpu_count) ,
17741774 policy,
17751775 providers : configured_providers,
17761776 template,
@@ -2197,17 +2197,48 @@ pub async fn sandbox_create(
21972197
21982198fn resource_requirements_from_cli (
21992199 requested_gpu : bool ,
2200- gpu_device : Option < & str > ,
22012200 gpu_count : Option < u32 > ,
22022201) -> Option < SandboxResourceRequirements > {
2203- requested_gpu. then ( || SandboxResourceRequirements {
2204- gpu : Some ( GpuResourceRequirement {
2205- device_ids : gpu_device
2206- . filter ( |device_id| !device_id. is_empty ( ) )
2207- . map ( |device_id| vec ! [ device_id. to_string( ) ] )
2208- . unwrap_or_default ( ) ,
2209- count : gpu_count,
2210- } ) ,
2202+ requested_gpu. then_some ( SandboxResourceRequirements {
2203+ gpu : Some ( GpuResourceRequirement { count : gpu_count } ) ,
2204+ } )
2205+ }
2206+
2207+ fn gpu_driver_config_from_cli ( gpu_device : Option < & str > ) -> Option < prost_types:: Struct > {
2208+ use prost_types:: { ListValue , Struct , Value , value:: Kind } ;
2209+
2210+ fn string_value ( value : & str ) -> Value {
2211+ Value {
2212+ kind : Some ( Kind :: StringValue ( value. to_string ( ) ) ) ,
2213+ }
2214+ }
2215+
2216+ fn driver_block ( gpu_device : & str ) -> Value {
2217+ Value {
2218+ kind : Some ( Kind :: StructValue ( Struct {
2219+ fields : std:: iter:: once ( (
2220+ "gpu_device_ids" . to_string ( ) ,
2221+ Value {
2222+ kind : Some ( Kind :: ListValue ( ListValue {
2223+ values : vec ! [ string_value( gpu_device) ] ,
2224+ } ) ) ,
2225+ } ,
2226+ ) )
2227+ . collect ( ) ,
2228+ } ) ) ,
2229+ }
2230+ }
2231+
2232+ let gpu_device = gpu_device. filter ( |device_id| !device_id. trim ( ) . is_empty ( ) ) ?;
2233+
2234+ Some ( Struct {
2235+ fields : [
2236+ ( "docker" . to_string ( ) , driver_block ( gpu_device) ) ,
2237+ ( "podman" . to_string ( ) , driver_block ( gpu_device) ) ,
2238+ ( "vm" . to_string ( ) , driver_block ( gpu_device) ) ,
2239+ ]
2240+ . into_iter ( )
2241+ . collect ( ) ,
22112242 } )
22122243}
22132244
@@ -7460,10 +7491,10 @@ mod tests {
74607491 dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header,
74617492 format_gateway_select_items, format_provider_attachment_table, gateway_add,
74627493 gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label,
7463- git_sync_files, http_health_check , image_requests_gpu , import_local_package_mtls_bundle ,
7464- inferred_provider_type , package_managed_tls_dirs , parse_cli_setting_value ,
7465- parse_credential_expiry_cli_value , parse_credential_expiry_pairs , parse_credential_pairs ,
7466- plaintext_gateway_is_remote, progress_step_from_metadata,
7494+ git_sync_files, gpu_driver_config_from_cli , http_health_check , image_requests_gpu ,
7495+ import_local_package_mtls_bundle , inferred_provider_type , package_managed_tls_dirs ,
7496+ parse_cli_setting_value , parse_credential_expiry_cli_value , parse_credential_expiry_pairs ,
7497+ parse_credential_pairs , plaintext_gateway_is_remote, progress_step_from_metadata,
74677498 provider_profile_allows_refresh_bootstrap, provisioning_timeout_message,
74687499 ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from,
74697500 resource_requirements_from_cli, sandbox_should_persist, sandbox_upload_plan,
@@ -7948,47 +7979,40 @@ mod tests {
79487979
79497980 #[ test]
79507981 fn resource_requirements_from_cli_uses_presence_for_default_gpu ( ) {
7951- let requirements = resource_requirements_from_cli ( true , None , None )
7982+ let requirements = resource_requirements_from_cli ( true , None )
79527983 . expect ( "resource requirements should be present" ) ;
79537984 let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
79547985
7955- assert ! ( gpu. device_ids. is_empty( ) ) ;
79567986 assert_eq ! ( gpu. count, None ) ;
79577987 }
79587988
79597989 #[ test]
7960- fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id ( ) {
7961- let requirements = resource_requirements_from_cli ( true , Some ( "0000:2d:00.0" ) , None )
7962- . expect ( "resource requirements should be present" ) ;
7963- let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
7990+ fn gpu_driver_config_from_cli_maps_gpu_device_to_driver_blocks ( ) {
7991+ let config = gpu_driver_config_from_cli ( Some ( "nvidia.com/gpu=0" ) )
7992+ . expect ( "driver config should be present" ) ;
79647993
7965- assert_eq ! ( gpu. device_ids, vec![ "0000:2d:00.0" ] ) ;
7966- assert_eq ! ( gpu. count, None ) ;
7994+ assert ! ( config. fields. contains_key( "docker" ) ) ;
7995+ assert ! ( config. fields. contains_key( "podman" ) ) ;
7996+ assert ! ( config. fields. contains_key( "vm" ) ) ;
79677997 }
79687998
79697999 #[ test]
79708000 fn resource_requirements_from_cli_maps_gpu_count ( ) {
79718001 let requirements =
7972- resource_requirements_from_cli ( true , None , Some ( 2 ) ) . expect ( "requirements should exist" ) ;
8002+ resource_requirements_from_cli ( true , Some ( 2 ) ) . expect ( "requirements should exist" ) ;
79738003 let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
79748004
7975- assert ! ( gpu. device_ids. is_empty( ) ) ;
79768005 assert_eq ! ( gpu. count, Some ( 2 ) ) ;
79778006 }
79788007
79798008 #[ test]
7980- fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation ( ) {
7981- let requirements = resource_requirements_from_cli ( true , Some ( "nvidia.com/gpu=0" ) , Some ( 2 ) )
7982- . expect ( "requirements should exist" ) ;
7983- let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
7984-
7985- assert_eq ! ( gpu. device_ids, vec![ "nvidia.com/gpu=0" ] ) ;
7986- assert_eq ! ( gpu. count, Some ( 2 ) ) ;
8009+ fn gpu_driver_config_from_cli_omits_empty_device ( ) {
8010+ assert ! ( gpu_driver_config_from_cli( Some ( "" ) ) . is_none( ) ) ;
79878011 }
79888012
79898013 #[ test]
79908014 fn resource_requirements_from_cli_omits_gpu_request_when_not_requested ( ) {
7991- assert ! ( resource_requirements_from_cli( false , Some ( "0" ) , None ) . is_none( ) ) ;
8015+ assert ! ( resource_requirements_from_cli( false , None ) . is_none( ) ) ;
79928016 }
79938017
79948018 #[ test]
0 commit comments