Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,38 @@ jobs:
MMTEST_FAST_TEST: 1


wasm_test:
runs-on: ubuntu-latest
strategy:
matrix:
include:
- rust: stable
target: wasm32-wasip1
rustflags: ""
name: fallback
- rust: stable
target: wasm32-wasip1
rustflags: "-C target-feature=+simd128"
name: simd128

name: wasm_test/${{ matrix.target }}/${{ matrix.name }}
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
toolchain: ${{ matrix.rust }}
targets: ${{ matrix.target }}
- name: Install wasmtime
uses: bytecodealliance/actions/wasmtime/setup@v1
- name: Tests
run: |
cargo test -v --tests --lib --release --no-fail-fast --target "${{ matrix.target }}"
env:
RUSTFLAGS: ${{ matrix.rustflags }}
CARGO_TARGET_WASM32_WASIP1_RUNNER: "wasmtime --dir=."
MMTEST_FAST_TEST: 1


cargo-careful:
runs-on: ubuntu-latest
name: cargo-careful
Expand Down
175 changes: 175 additions & 0 deletions src/sgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ struct KernelSse2;
#[cfg(target_arch="aarch64")]
#[cfg(has_aarch64_simd)]
struct KernelNeon;
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
struct KernelWasmSimd;
struct KernelFallback;

type T = f32;
Expand Down Expand Up @@ -62,6 +64,11 @@ pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
return selector.select(KernelNeon);
}
}
#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
{
return selector.select(KernelWasmSimd);
}
#[allow(unreachable_code)]
return selector.select(KernelFallback);
}

Expand Down Expand Up @@ -279,6 +286,38 @@ impl GemmKernel for KernelFallback {
}
}

#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
impl GemmKernel for KernelWasmSimd {
type Elem = T;

type MRTy = U8;
type NRTy = U8;

#[inline(always)]
fn align_to() -> usize { 16 }

#[inline(always)]
fn always_masked() -> bool { false }

#[inline(always)]
fn nc() -> usize { archparam::S_NC }
#[inline(always)]
fn kc() -> usize { archparam::S_KC }
#[inline(always)]
fn mc() -> usize { archparam::S_MC }

#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T, rsc: isize, csc: isize) {
kernel_target_wasm_simd(k, alpha, a, b, beta, c, rsc, csc)
}
}

// no inline for unmasked kernels
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
#[target_feature(enable="fma")]
Expand Down Expand Up @@ -692,6 +731,131 @@ unsafe fn kernel_target_neon(k: usize, alpha: T, a: *const T, b: *const T,
}
}

#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
unsafe fn kernel_target_wasm_simd(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
use core::arch::wasm32::*;
const MR: usize = KernelWasmSimd::MR;
const NR: usize = KernelWasmSimd::NR;

let (mut a, mut b, rsc, csc) = if rsc == 1 { (b, a, csc, rsc) } else { (a, b, rsc, csc) };

// Kernel 8 x 8 (a x b)
// Four quadrants of 4 x 4
let zero = f32x4_splat(0.);
let mut ab11 = [zero; 4];
let mut ab12 = [zero; 4];
let mut ab21 = [zero; 4];
let mut ab22 = [zero; 4];

// ab_ij = a_i * b_j for all i, j
// (wasm SIMD has no lane-FMA; extract+splat into mul+add)
macro_rules! ab_ij_equals_ai_bj {
($dest:ident, $av:expr, $bv:expr) => {
$dest[0] = f32x4_add($dest[0], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<0>($av))));
$dest[1] = f32x4_add($dest[1], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<1>($av))));
$dest[2] = f32x4_add($dest[2], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<2>($av))));
$dest[3] = f32x4_add($dest[3], f32x4_mul($bv, f32x4_splat(f32x4_extract_lane::<3>($av))));
}
}

for _ in 0..k {
let a1 = v128_load(a as *const v128);
let b1 = v128_load(b as *const v128);
let a2 = v128_load(a.add(4) as *const v128);
let b2 = v128_load(b.add(4) as *const v128);

ab_ij_equals_ai_bj!(ab11, a1, b1);
ab_ij_equals_ai_bj!(ab12, a1, b2);
ab_ij_equals_ai_bj!(ab21, a2, b1);
ab_ij_equals_ai_bj!(ab22, a2, b2);

a = a.add(MR);
b = b.add(NR);
}

macro_rules! c {
($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize));
}

// ab *= alpha
let alphav = f32x4_splat(alpha);
loop4!(i, ab11[i] = f32x4_mul(ab11[i], alphav));
loop4!(i, ab12[i] = f32x4_mul(ab12[i], alphav));
loop4!(i, ab21[i] = f32x4_mul(ab21[i], alphav));
loop4!(i, ab22[i] = f32x4_mul(ab22[i], alphav));

// load one v128 from four pointers
macro_rules! loadq_from_pointers {
($p0:expr, $p1:expr, $p2:expr, $p3:expr) => ({
let v = f32x4_splat(0.);
let v = v128_load32_lane::<0>(v, $p0 as *const u32);
let v = v128_load32_lane::<1>(v, $p1 as *const u32);
let v = v128_load32_lane::<2>(v, $p2 as *const u32);
let v = v128_load32_lane::<3>(v, $p3 as *const u32);
v
});
}

if beta != 0. {
// load existing value in C
let mut c11 = [zero; 4];
let mut c12 = [zero; 4];
let mut c21 = [zero; 4];
let mut c22 = [zero; 4];

if csc == 1 {
loop4!(i, c11[i] = v128_load(c![i + 0, 0] as *const v128));
loop4!(i, c12[i] = v128_load(c![i + 0, 4] as *const v128));
loop4!(i, c21[i] = v128_load(c![i + 4, 0] as *const v128));
loop4!(i, c22[i] = v128_load(c![i + 4, 4] as *const v128));
} else {
loop4!(i, c11[i] = loadq_from_pointers!(c![i + 0, 0], c![i + 0, 1], c![i + 0, 2], c![i + 0, 3]));
loop4!(i, c12[i] = loadq_from_pointers!(c![i + 0, 4], c![i + 0, 5], c![i + 0, 6], c![i + 0, 7]));
loop4!(i, c21[i] = loadq_from_pointers!(c![i + 4, 0], c![i + 4, 1], c![i + 4, 2], c![i + 4, 3]));
loop4!(i, c22[i] = loadq_from_pointers!(c![i + 4, 4], c![i + 4, 5], c![i + 4, 6], c![i + 4, 7]));
}

let betav = f32x4_splat(beta);
// ab += β C
loop4!(i, ab11[i] = f32x4_add(ab11[i], f32x4_mul(c11[i], betav)));
loop4!(i, ab12[i] = f32x4_add(ab12[i], f32x4_mul(c12[i], betav)));
loop4!(i, ab21[i] = f32x4_add(ab21[i], f32x4_mul(c21[i], betav)));
loop4!(i, ab22[i] = f32x4_add(ab22[i], f32x4_mul(c22[i], betav)));
}

// c <- ab
// which is in full
// C <- α A B (+ β C)
if csc == 1 {
loop4!(i, v128_store(c![i + 0, 0] as *mut v128, ab11[i]));
loop4!(i, v128_store(c![i + 0, 4] as *mut v128, ab12[i]));
loop4!(i, v128_store(c![i + 4, 0] as *mut v128, ab21[i]));
loop4!(i, v128_store(c![i + 4, 4] as *mut v128, ab22[i]));
} else {
loop4!(i, v128_store32_lane::<0>(ab11[i], c![i + 0, 0] as *mut u32));
loop4!(i, v128_store32_lane::<1>(ab11[i], c![i + 0, 1] as *mut u32));
loop4!(i, v128_store32_lane::<2>(ab11[i], c![i + 0, 2] as *mut u32));
loop4!(i, v128_store32_lane::<3>(ab11[i], c![i + 0, 3] as *mut u32));

loop4!(i, v128_store32_lane::<0>(ab12[i], c![i + 0, 4] as *mut u32));
loop4!(i, v128_store32_lane::<1>(ab12[i], c![i + 0, 5] as *mut u32));
loop4!(i, v128_store32_lane::<2>(ab12[i], c![i + 0, 6] as *mut u32));
loop4!(i, v128_store32_lane::<3>(ab12[i], c![i + 0, 7] as *mut u32));

loop4!(i, v128_store32_lane::<0>(ab21[i], c![i + 4, 0] as *mut u32));
loop4!(i, v128_store32_lane::<1>(ab21[i], c![i + 4, 1] as *mut u32));
loop4!(i, v128_store32_lane::<2>(ab21[i], c![i + 4, 2] as *mut u32));
loop4!(i, v128_store32_lane::<3>(ab21[i], c![i + 4, 3] as *mut u32));

loop4!(i, v128_store32_lane::<0>(ab22[i], c![i + 4, 4] as *mut u32));
loop4!(i, v128_store32_lane::<1>(ab22[i], c![i + 4, 5] as *mut u32));
loop4!(i, v128_store32_lane::<2>(ab22[i], c![i + 4, 6] as *mut u32));
loop4!(i, v128_store32_lane::<3>(ab22[i], c![i + 4, 7] as *mut u32));
}
}

#[inline]
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
Expand Down Expand Up @@ -775,6 +939,17 @@ mod tests {
}
}

#[cfg(all(target_arch="wasm32", target_feature="simd128"))]
mod test_kernel_wasm {
use super::test_a_kernel;
use super::super::*;

#[test]
fn wasm_simd_8x8() {
test_a_kernel::<KernelWasmSimd, _>("wasm_simd_8x8");
}
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
mod test_kernel_x86 {
use super::test_a_kernel;
Expand Down
Loading