diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b9c9f5..a39685f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -156,6 +156,38 @@ jobs: MMTEST_FAST_TEST: 1 + wasm_test: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - rust: stable + target: wasm32-wasip1 + rustflags: "" + name: fallback + - rust: stable + target: wasm32-wasip1 + rustflags: "-C target-feature=+simd128" + name: simd128 + + name: wasm_test/${{ matrix.target }}/${{ matrix.name }} + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + - name: Install wasmtime + uses: bytecodealliance/actions/wasmtime/setup@v1 + - name: Tests + run: | + cargo test -v --tests --lib --release --no-fail-fast --target "${{ matrix.target }}" + env: + RUSTFLAGS: ${{ matrix.rustflags }} + CARGO_TARGET_WASM32_WASIP1_RUNNER: "wasmtime --dir=." + MMTEST_FAST_TEST: 1 + + cargo-careful: runs-on: ubuntu-latest name: cargo-careful diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index 28fe8ed..a2d0d52 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -30,6 +30,8 @@ struct KernelSse2; #[cfg(target_arch="aarch64")] #[cfg(has_aarch64_simd)] struct KernelNeon; +#[cfg(all(target_arch="wasm32", target_feature="simd128"))] +struct KernelWasmSimd; struct KernelFallback; type T = f32; @@ -62,6 +64,11 @@ pub(crate) fn detect(selector: G) where G: GemmSelect { return selector.select(KernelNeon); } } + #[cfg(all(target_arch="wasm32", target_feature="simd128"))] + { + return selector.select(KernelWasmSimd); + } + #[allow(unreachable_code)] return selector.select(KernelFallback); } @@ -279,6 +286,38 @@ impl GemmKernel for KernelFallback { } } +#[cfg(all(target_arch="wasm32", target_feature="simd128"))] +impl GemmKernel for KernelWasmSimd { + type Elem = T; + + type MRTy = U8; + type NRTy = U8; + + #[inline(always)] + fn align_to() -> usize { 16 } + + #[inline(always)] + fn always_masked() -> bool { false } + + #[inline(always)] + fn nc() -> usize { archparam::S_NC } + #[inline(always)] + fn kc() -> usize { archparam::S_KC } + #[inline(always)] + fn mc() -> usize { archparam::S_MC } + + #[inline(always)] + unsafe fn kernel( + k: usize, + alpha: T, + a: *const T, + b: *const T, + beta: T, + c: *mut T, rsc: isize, csc: isize) { + kernel_target_wasm_simd(k, alpha, a, b, beta, c, rsc, csc) + } +} + // no inline for unmasked kernels #[cfg(any(target_arch="x86", target_arch="x86_64"))] #[target_feature(enable="fma")] @@ -692,6 +731,131 @@ unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T, } } +#[cfg(all(target_arch="wasm32", target_feature="simd128"))] +unsafe fn kernel_target_wasm_simd(k: usize, alpha: T, a: *const T, b: *const T, + beta: T, c: *mut T, rsc: isize, csc: isize) +{ + use core::arch::wasm32::*; + const MR: usize = KernelWasmSimd::MR; + const NR: usize = KernelWasmSimd::NR; + + let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) }; + + // Kernel 8 x 8 (a x b) + // Four quadrants of 4 x 4 + let zero = f32x4_splat(0.); + let mut ab11 = [zero; 4]; + let mut ab12 = [zero; 4]; + let mut ab21 = [zero; 4]; + let mut ab22 = [zero; 4]; + + // ab_ij = a_i * b_j for all i, j + // (wasm SIMD has no lane-FMA; extract+splat into mul+add) + macro_rules! ab_ij_equals_ai_bj { + ($dest:ident, $av:expr, $bv:expr) => { + $dest[0] = f32x4_add($dest[0], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<0>($av)))); + $dest[1] = f32x4_add($dest[1], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<1>($av)))); + $dest[2] = f32x4_add($dest[2], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<2>($av)))); + $dest[3] = f32x4_add($dest[3], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<3>($av)))); + } + } + + for _ in 0..k { + let a1 = v128_load(a as *const v128); + let b1 = v128_load(b as *const v128); + let a2 = v128_load(a.add(4) as *const v128); + let b2 = v128_load(b.add(4) as *const v128); + + ab_ij_equals_ai_bj!(ab11, a1, b1); + ab_ij_equals_ai_bj!(ab12, a1, b2); + ab_ij_equals_ai_bj!(ab21, a2, b1); + ab_ij_equals_ai_bj!(ab22, a2, b2); + + a = a.add(MR); + b = b.add(NR); + } + + macro_rules! c { + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // ab *= alpha + let alphav = f32x4_splat(alpha); + loop4!(i, ab11[i] = f32x4_mul(ab11[i], alphav)); + loop4!(i, ab12[i] = f32x4_mul(ab12[i], alphav)); + loop4!(i, ab21[i] = f32x4_mul(ab21[i], alphav)); + loop4!(i, ab22[i] = f32x4_mul(ab22[i], alphav)); + + // load one v128 from four pointers + macro_rules! loadq_from_pointers { + ($p0:expr, $p1:expr, $p2:expr, $p3:expr) => ({ + let v = f32x4_splat(0.); + let v = v128_load32_lane::<0>(v, $p0 as *const u32); + let v = v128_load32_lane::<1>(v, $p1 as *const u32); + let v = v128_load32_lane::<2>(v, $p2 as *const u32); + let v = v128_load32_lane::<3>(v, $p3 as *const u32); + v + }); + } + + if beta != 0. { + // load existing value in C + let mut c11 = [zero; 4]; + let mut c12 = [zero; 4]; + let mut c21 = [zero; 4]; + let mut c22 = [zero; 4]; + + if csc == 1 { + loop4!(i, c11[i] = v128_load(c![i + 0, 0] as *const v128)); + loop4!(i, c12[i] = v128_load(c![i + 0, 4] as *const v128)); + loop4!(i, c21[i] = v128_load(c![i + 4, 0] as *const v128)); + loop4!(i, c22[i] = v128_load(c![i + 4, 4] as *const v128)); + } else { + loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1], c![i + 0, 2], c![i + 0, 3])); + loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 4], c![i + 0, 5], c![i + 0, 6], c![i + 0, 7])); + loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1], c![i + 4, 2], c![i + 4, 3])); + loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 4], c![i + 4, 5], c![i + 4, 6], c![i + 4, 7])); + } + + let betav = f32x4_splat(beta); + // ab += β C + loop4!(i, ab11[i] = f32x4_add(ab11[i], f32x4_mul(c11[i], betav))); + loop4!(i, ab12[i] = f32x4_add(ab12[i], f32x4_mul(c12[i], betav))); + loop4!(i, ab21[i] = f32x4_add(ab21[i], f32x4_mul(c21[i], betav))); + loop4!(i, ab22[i] = f32x4_add(ab22[i], f32x4_mul(c22[i], betav))); + } + + // c <- ab + // which is in full + // C <- α A B (+ β C) + if csc == 1 { + loop4!(i, v128_store(c![i + 0, 0] as *mut v128, ab11[i])); + loop4!(i, v128_store(c![i + 0, 4] as *mut v128, ab12[i])); + loop4!(i, v128_store(c![i + 4, 0] as *mut v128, ab21[i])); + loop4!(i, v128_store(c![i + 4, 4] as *mut v128, ab22[i])); + } else { + loop4!(i, v128_store32_lane::<0>(ab11[i], c![i + 0, 0] as *mut u32)); + loop4!(i, v128_store32_lane::<1>(ab11[i], c![i + 0, 1] as *mut u32)); + loop4!(i, v128_store32_lane::<2>(ab11[i], c![i + 0, 2] as *mut u32)); + loop4!(i, v128_store32_lane::<3>(ab11[i], c![i + 0, 3] as *mut u32)); + + loop4!(i, v128_store32_lane::<0>(ab12[i], c![i + 0, 4] as *mut u32)); + loop4!(i, v128_store32_lane::<1>(ab12[i], c![i + 0, 5] as *mut u32)); + loop4!(i, v128_store32_lane::<2>(ab12[i], c![i + 0, 6] as *mut u32)); + loop4!(i, v128_store32_lane::<3>(ab12[i], c![i + 0, 7] as *mut u32)); + + loop4!(i, v128_store32_lane::<0>(ab21[i], c![i + 4, 0] as *mut u32)); + loop4!(i, v128_store32_lane::<1>(ab21[i], c![i + 4, 1] as *mut u32)); + loop4!(i, v128_store32_lane::<2>(ab21[i], c![i + 4, 2] as *mut u32)); + loop4!(i, v128_store32_lane::<3>(ab21[i], c![i + 4, 3] as *mut u32)); + + loop4!(i, v128_store32_lane::<0>(ab22[i], c![i + 4, 4] as *mut u32)); + loop4!(i, v128_store32_lane::<1>(ab22[i], c![i + 4, 5] as *mut u32)); + loop4!(i, v128_store32_lane::<2>(ab22[i], c![i + 4, 6] as *mut u32)); + loop4!(i, v128_store32_lane::<3>(ab22[i], c![i + 4, 7] as *mut u32)); + } +} + #[inline] unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T, beta: T, c: *mut T, rsc: isize, csc: isize) @@ -775,6 +939,17 @@ mod tests { } } + #[cfg(all(target_arch="wasm32", target_feature="simd128"))] + mod test_kernel_wasm { + use super::test_a_kernel; + use super::super::*; + + #[test] + fn wasm_simd_8x8() { + test_a_kernel::("wasm_simd_8x8"); + } + } + #[cfg(any(target_arch="x86", target_arch="x86_64"))] mod test_kernel_x86 { use super::test_a_kernel;