From b7bf4848323fd5b07a433ea578a90b47f92d2599 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 11 Dec 2025 13:30:08 +0200 Subject: [PATCH 1/2] [fix] Norm step for non-nvidia hardware --- .../tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java | 6 ++++++ .../tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java | 9 +++++++++ .../tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java | 9 +++++++++ .../tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java | 9 +++++++++ .../tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java | 9 +++++++++ 5 files changed, 42 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 39a7cd61..a98f9860 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -234,6 +234,12 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, phi3State.wrapX, // input phi3State.wrapQ, // output Q diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index aded0c81..a6f1c95c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -271,6 +271,15 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) config.rmsNormEps(), // epsilon qwen2State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused RMS Apply + QKV Projection unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 453a4d7c..b99a1ab3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -264,6 +264,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3Config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused RMS Apply + QKV Projection unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 204f8f48..e3ce6efa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -236,6 +236,15 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused: RMS apply + Q8 QKV matmul + direct Q/K/V split unifiedLayer.task("attn_rms_qkv_projection_q8", TransformerComputeKernelsLayered::fusedRmsNormQKVMatmulQ8, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index ae45ea2f..6aea5559 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -190,6 +190,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused RMS Apply + QKV Projection unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmulQ8_0, From 7b830ea305e7da5803e4bad123cb536833da29a7 Mon Sep 17 00:00:00 2001 From: Michalis Papadimitriou Date: Thu, 11 Dec 2025 13:45:12 +0200 Subject: [PATCH 2/2] Update src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../layers/type/q8_0/Phi3Q8_0FFNLayers.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index e3ce6efa..46b0737d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -236,14 +236,14 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size - if (shouldUseFinalNormalization()) { - unifiedLayer.task("attn_rms_finalize", - TransformerComputeKernelsLayered::reductionFinalNormalization, - context, - state.temp, - config.dim(), - config.rmsNormEps()); - } + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } // Fused: RMS apply + Q8 QKV matmul + direct Q/K/V split unifiedLayer.task("attn_rms_qkv_projection_q8",