diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java index 39a7cd61..a98f9860 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Phi3FP16FFNLayers.java @@ -234,6 +234,12 @@ TaskGraph setupSinglePhi3FFNLayer(Phi3TornadoWeights weights, int layerIndex) { phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + unifiedLayer.task("attn_rms_qkv_projection", Phi3Kernels::fusedRmsNormQKVMatmulDirect, context, phi3State.wrapX, // input phi3State.wrapQ, // output Q diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java index aded0c81..a6f1c95c 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen2FP16FFNLayers.java @@ -271,6 +271,15 @@ TaskGraph setupSingleQwen2FFNLayer(Qwen2TornadoWeights weights, int layerIndex) config.rmsNormEps(), // epsilon qwen2State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused RMS Apply + QKV Projection unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java index 453a4d7c..b99a1ab3 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Qwen3FP16FFNLayers.java @@ -264,6 +264,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) qwen3Config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused RMS Apply + QKV Projection unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmul, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java index 204f8f48..46b0737d 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Phi3Q8_0FFNLayers.java @@ -236,6 +236,15 @@ TaskGraph setupSinglePhi3Q8_0FFNLayer(Phi3TornadoWeights weights, int layerIndex phi3Config.rmsNormEps(), // epsilon phi3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused: RMS apply + Q8 QKV matmul + direct Q/K/V split unifiedLayer.task("attn_rms_qkv_projection_q8", TransformerComputeKernelsLayered::fusedRmsNormQKVMatmulQ8, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java index ae45ea2f..6aea5559 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java @@ -190,6 +190,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex) config.rmsNormEps(), // epsilon qwen3State.localSize); // local memory size + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.temp, + config.dim(), + config.rmsNormEps()); + } + // Fused RMS Apply + QKV Projection unifiedLayer.task("attn_rms_qkv_projection", Qwen3Kernels::fusedRmsNormQKVMatmulQ8_0,