diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ee75d10546e..786e341512f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -812,6 +812,21 @@ struct ggml_backend_cuda_context { ~ggml_backend_cuda_context(); + int get_max_supported_preempt_level(int device_id) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); + int arch = prop.major * 10 + prop.minor; + switch (arch) { + case 35: // Kepler: K20, K40, GTX TITAN + return 2; // kPreemptLevelDeactivate + case 70: // Volta: V100 + case 86: // Ampere: RTX 30系列 + return 3; // kPreemptLevelInterrupt + default: // 其他架构(包括 A100 arch=80) + return 1; // kPreemptLevelBlock + } + } + cudaStream_t stream(int device, int stream) { if (streams[device][stream] == nullptr) { ggml_cuda_set_device(device); @@ -819,7 +834,7 @@ struct ggml_backend_cuda_context { HwQueueHandle hwqueue; CudaQueueCreate(&hwqueue,streams[device][stream]); XQueueHandle xqueue; - XQueueCreate(&xqueue, hwqueue, kPreemptLevelDeactivate, kQueueCreateFlagNone); + XQueueCreate(&xqueue, hwqueue, get_max_supported_preempt_level(device), kQueueCreateFlagNone); XHintPriority(xqueue, priority); // In XSched, lower number means lower priority } return streams[device][stream]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0b939ca4854..a6c926e67c0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2864,7 +2864,7 @@ static void ggml_backend_cuda_set_priority(ggml_backend_t backend, int prio) { HwQueueHandle hwqueue; CudaQueueCreate(&hwqueue,stream); XQueueHandle xqueue; - XQueueCreate(&xqueue, hwqueue, kPreemptLevelDeactivate, kQueueCreateFlagNone); + XQueueCreate(&xqueue, hwqueue, cuda_ctx->get_max_supported_preempt_level(device), kQueueCreateFlagNone); XHintPriority(xqueue, prio); // In XSched, lower number means lower priority } }