From 6fbf3904da597dbe1faf41da246108fdc553fb39 Mon Sep 17 00:00:00 2001 From: ajz34 Date: Thu, 25 Jun 2026 11:53:42 +0800 Subject: [PATCH] fix thread oversubscription of broadcast matmul (BLAS devices) --- crates-device/rstsr-aocl/src/matmul.rs | 60 +++++++++++----------- crates-device/rstsr-blis/src/matmul.rs | 60 +++++++++++----------- crates-device/rstsr-kml/src/matmul.rs | 60 +++++++++++----------- crates-device/rstsr-mkl/src/matmul.rs | 60 +++++++++++----------- crates-device/rstsr-openblas/src/matmul.rs | 60 +++++++++++----------- 5 files changed, 150 insertions(+), 150 deletions(-) diff --git a/crates-device/rstsr-aocl/src/matmul.rs b/crates-device/rstsr-aocl/src/matmul.rs index fbe7d58..53cebf1 100644 --- a/crates-device/rstsr-aocl/src/matmul.rs +++ b/crates-device/rstsr-aocl/src/matmul.rs @@ -155,37 +155,37 @@ where let itc_rest = IterLayoutColMajor::new(&lc_rest)?; if n_task >= 4 * nthreads { // parallel outer, sequential matmul - with_num_threads(1, || { - let task = || { - ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( - |((ia_rest, ib_rest), ic_rest)| -> Result<()> { - // prepare layout - let mut la_m = la_matmul.clone(); - let mut lb_m = lb_matmul.clone(); - let mut lc_m = lc_matmul.clone(); - unsafe { - la_m.set_offset(ia_rest); - lb_m.set_offset(ib_rest); - lc_m.set_offset(ic_rest); - } - // move mutable reference into parallel closure - let c = unsafe { - let c_ptr = c.as_ptr() as *mut TC; - let c_len = c.len(); - from_raw_parts_mut(c_ptr, c_len) - }; - // clone alpha and beta - let alpha = alpha.clone(); - let beta = beta.clone(); + let task = || { + ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( + |((ia_rest, ib_rest), ic_rest)| -> Result<()> { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // move mutable reference into parallel closure + let c = unsafe { + let c_ptr = c.as_ptr() as *mut TC; + let c_len = c.len(); + from_raw_parts_mut(c_ptr, c_len) + }; + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + with_num_threads(1, || { gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None) - }, - ) - }; - match pool { - Some(pool) => pool.install(task), - None => task(), - } - }) + }) + }, + ) + }; + match pool { + Some(pool) => pool.install(task), + None => task(), + } } else { // sequential outer, parallel matmul with_num_threads(nthreads, || -> Result<()> { diff --git a/crates-device/rstsr-blis/src/matmul.rs b/crates-device/rstsr-blis/src/matmul.rs index fbe7d58..53cebf1 100644 --- a/crates-device/rstsr-blis/src/matmul.rs +++ b/crates-device/rstsr-blis/src/matmul.rs @@ -155,37 +155,37 @@ where let itc_rest = IterLayoutColMajor::new(&lc_rest)?; if n_task >= 4 * nthreads { // parallel outer, sequential matmul - with_num_threads(1, || { - let task = || { - ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( - |((ia_rest, ib_rest), ic_rest)| -> Result<()> { - // prepare layout - let mut la_m = la_matmul.clone(); - let mut lb_m = lb_matmul.clone(); - let mut lc_m = lc_matmul.clone(); - unsafe { - la_m.set_offset(ia_rest); - lb_m.set_offset(ib_rest); - lc_m.set_offset(ic_rest); - } - // move mutable reference into parallel closure - let c = unsafe { - let c_ptr = c.as_ptr() as *mut TC; - let c_len = c.len(); - from_raw_parts_mut(c_ptr, c_len) - }; - // clone alpha and beta - let alpha = alpha.clone(); - let beta = beta.clone(); + let task = || { + ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( + |((ia_rest, ib_rest), ic_rest)| -> Result<()> { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // move mutable reference into parallel closure + let c = unsafe { + let c_ptr = c.as_ptr() as *mut TC; + let c_len = c.len(); + from_raw_parts_mut(c_ptr, c_len) + }; + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + with_num_threads(1, || { gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None) - }, - ) - }; - match pool { - Some(pool) => pool.install(task), - None => task(), - } - }) + }) + }, + ) + }; + match pool { + Some(pool) => pool.install(task), + None => task(), + } } else { // sequential outer, parallel matmul with_num_threads(nthreads, || -> Result<()> { diff --git a/crates-device/rstsr-kml/src/matmul.rs b/crates-device/rstsr-kml/src/matmul.rs index fbe7d58..53cebf1 100644 --- a/crates-device/rstsr-kml/src/matmul.rs +++ b/crates-device/rstsr-kml/src/matmul.rs @@ -155,37 +155,37 @@ where let itc_rest = IterLayoutColMajor::new(&lc_rest)?; if n_task >= 4 * nthreads { // parallel outer, sequential matmul - with_num_threads(1, || { - let task = || { - ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( - |((ia_rest, ib_rest), ic_rest)| -> Result<()> { - // prepare layout - let mut la_m = la_matmul.clone(); - let mut lb_m = lb_matmul.clone(); - let mut lc_m = lc_matmul.clone(); - unsafe { - la_m.set_offset(ia_rest); - lb_m.set_offset(ib_rest); - lc_m.set_offset(ic_rest); - } - // move mutable reference into parallel closure - let c = unsafe { - let c_ptr = c.as_ptr() as *mut TC; - let c_len = c.len(); - from_raw_parts_mut(c_ptr, c_len) - }; - // clone alpha and beta - let alpha = alpha.clone(); - let beta = beta.clone(); + let task = || { + ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( + |((ia_rest, ib_rest), ic_rest)| -> Result<()> { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // move mutable reference into parallel closure + let c = unsafe { + let c_ptr = c.as_ptr() as *mut TC; + let c_len = c.len(); + from_raw_parts_mut(c_ptr, c_len) + }; + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + with_num_threads(1, || { gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None) - }, - ) - }; - match pool { - Some(pool) => pool.install(task), - None => task(), - } - }) + }) + }, + ) + }; + match pool { + Some(pool) => pool.install(task), + None => task(), + } } else { // sequential outer, parallel matmul with_num_threads(nthreads, || -> Result<()> { diff --git a/crates-device/rstsr-mkl/src/matmul.rs b/crates-device/rstsr-mkl/src/matmul.rs index fbe7d58..53cebf1 100644 --- a/crates-device/rstsr-mkl/src/matmul.rs +++ b/crates-device/rstsr-mkl/src/matmul.rs @@ -155,37 +155,37 @@ where let itc_rest = IterLayoutColMajor::new(&lc_rest)?; if n_task >= 4 * nthreads { // parallel outer, sequential matmul - with_num_threads(1, || { - let task = || { - ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( - |((ia_rest, ib_rest), ic_rest)| -> Result<()> { - // prepare layout - let mut la_m = la_matmul.clone(); - let mut lb_m = lb_matmul.clone(); - let mut lc_m = lc_matmul.clone(); - unsafe { - la_m.set_offset(ia_rest); - lb_m.set_offset(ib_rest); - lc_m.set_offset(ic_rest); - } - // move mutable reference into parallel closure - let c = unsafe { - let c_ptr = c.as_ptr() as *mut TC; - let c_len = c.len(); - from_raw_parts_mut(c_ptr, c_len) - }; - // clone alpha and beta - let alpha = alpha.clone(); - let beta = beta.clone(); + let task = || { + ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( + |((ia_rest, ib_rest), ic_rest)| -> Result<()> { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // move mutable reference into parallel closure + let c = unsafe { + let c_ptr = c.as_ptr() as *mut TC; + let c_len = c.len(); + from_raw_parts_mut(c_ptr, c_len) + }; + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + with_num_threads(1, || { gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None) - }, - ) - }; - match pool { - Some(pool) => pool.install(task), - None => task(), - } - }) + }) + }, + ) + }; + match pool { + Some(pool) => pool.install(task), + None => task(), + } } else { // sequential outer, parallel matmul with_num_threads(nthreads, || -> Result<()> { diff --git a/crates-device/rstsr-openblas/src/matmul.rs b/crates-device/rstsr-openblas/src/matmul.rs index 785478b..8d03db2 100644 --- a/crates-device/rstsr-openblas/src/matmul.rs +++ b/crates-device/rstsr-openblas/src/matmul.rs @@ -155,37 +155,37 @@ where let itc_rest = IterLayoutColMajor::new(&lc_rest)?; if n_task >= 4 * nthreads { // parallel outer, sequential matmul - with_num_threads(1, || { - let task = || { - ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( - |((ia_rest, ib_rest), ic_rest)| -> Result<()> { - // prepare layout - let mut la_m = la_matmul.clone(); - let mut lb_m = lb_matmul.clone(); - let mut lc_m = lc_matmul.clone(); - unsafe { - la_m.set_offset(ia_rest); - lb_m.set_offset(ib_rest); - lc_m.set_offset(ic_rest); - } - // move mutable reference into parallel closure - let c = unsafe { - let c_ptr = c.as_ptr() as *mut TC; - let c_len = c.len(); - from_raw_parts_mut(c_ptr, c_len) - }; - // clone alpha and beta - let alpha = alpha.clone(); - let beta = beta.clone(); + let task = || { + ita_rest.into_par_iter().zip(itb_rest).zip(itc_rest).try_for_each( + |((ia_rest, ib_rest), ic_rest)| -> Result<()> { + // prepare layout + let mut la_m = la_matmul.clone(); + let mut lb_m = lb_matmul.clone(); + let mut lc_m = lc_matmul.clone(); + unsafe { + la_m.set_offset(ia_rest); + lb_m.set_offset(ib_rest); + lc_m.set_offset(ic_rest); + } + // move mutable reference into parallel closure + let c = unsafe { + let c_ptr = c.as_ptr() as *mut TC; + let c_len = c.len(); + from_raw_parts_mut(c_ptr, c_len) + }; + // clone alpha and beta + let alpha = alpha.clone(); + let beta = beta.clone(); + with_num_threads(1, || { gemm_blas_ix2_no_conj_dispatch(c, &lc_m, a, &la_m, b, &lb_m, alpha, beta, None) - }, - ) - }; - match pool { - Some(pool) => pool.install(task), - None => task(), - } - }) + }) + }, + ) + }; + match pool { + Some(pool) => pool.install(task), + None => task(), + } } else { // sequential outer, parallel matmul with_num_threads(nthreads, || -> Result<()> {