metadata) {
}
}
+ // Alternative: check by metadata keys if name-based detection fails
+ if (metadata.containsKey("granite.block_count")) {
+ return ModelType.GRANITE;
+ }
+
return ModelType.UNKNOWN;
}
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java
new file mode 100644
index 00000000..0c7750fd
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java
@@ -0,0 +1,276 @@
+package org.beehive.gpullama3.tokenizer;
+
+import org.beehive.gpullama3.auxiliary.Pair;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * GPT-2-style BPE tokenizer for Granite models.
+ *
+ * Supports both Granite 3.3 (refact pretokenizer, 49K vocab) and Granite 4.0 (dbrx pretokenizer, 100K vocab).
+ */
+public class GraniteTokenizer implements Tokenizer {
+ static final Map BYTE_ENCODER = bytesToUnicode();
+ static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey));
+
+ // Pretokenizer patterns
+ private static final String REFACT_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
+
+ private static final String DBRX_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
+
+ // Instance fields
+ private final Pattern compiledPattern;
+ private final Vocabulary vocabulary;
+ private final Map, Integer> merges;
+ private final Map specialTokens;
+
+ // Token IDs (version-dependent)
+ private final int bosTokenId;
+ private final int eosTokenId;
+ private final int padTokenId;
+ private final String pretokenizerType;
+
+ public GraniteTokenizer(Map metadata, Vocabulary vocabulary) {
+ this.vocabulary = vocabulary;
+
+ // Detect pretokenizer type and select pattern
+ this.pretokenizerType = (String) metadata.getOrDefault("tokenizer.ggml.pre", "refact");
+ String pattern = switch (pretokenizerType) {
+ case "dbrx" -> DBRX_PATTERN;
+ default -> REFACT_PATTERN;
+ };
+ this.compiledPattern = Pattern.compile(pattern);
+
+ // Read token IDs from metadata
+ this.bosTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.bos_token_id", 0);
+ this.eosTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.eos_token_id", 0);
+ this.padTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.padding_token_id", 0);
+
+ // Load merges
+ String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges");
+ List> mergeList = Arrays.stream(mergeLines).map(line -> line.split(" "))
+ .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList();
+
+ // Collect special tokens
+ Map specialTokens = new HashMap<>();
+ int allTokens = vocabulary.size();
+ for (int i = 0; i < allTokens; i++) {
+ String token = vocabulary.get(i);
+ if (token.startsWith("<|") && token.endsWith("|>")) {
+ specialTokens.put(token, i);
+ }
+ // Also catch style tokens used in some Granite models
+ if (token.startsWith("<") && token.endsWith(">") && !token.contains(" ")) {
+ specialTokens.putIfAbsent(token, i);
+ }
+ }
+ this.specialTokens = Map.copyOf(specialTokens);
+
+ // Build merge map
+ this.merges = new HashMap<>();
+ for (Pair pair : mergeList) {
+ String merged = vocabulary.get(pair.first()) + vocabulary.get(pair.second());
+ int mergeIndex = vocabulary.getIndex(merged).orElseThrow();
+ this.merges.put(pair, mergeIndex);
+ }
+ }
+
+ private static int getIntFromMetadata(Map metadata, String key, int defaultValue) {
+ Object value = metadata.get(key);
+ if (value instanceof Number num) {
+ return num.intValue();
+ }
+ return defaultValue;
+ }
+
+ // === Token ID accessors ===
+ private static List findAll(Pattern pattern, String text) {
+ List matches = new ArrayList<>();
+ Matcher matcher = pattern.matcher(text);
+ while (matcher.find()) {
+ matches.add(matcher.group());
+ }
+ return matches;
+ }
+
+ private static List merge(List ids, Pair pair, int idx) {
+ List newIds = new ArrayList<>();
+ int i = 0;
+ while (i < ids.size()) {
+ if (i < ids.size() - 1 && ids.get(i).equals(pair.first()) && ids.get(i + 1).equals(pair.second())) {
+ newIds.add(idx);
+ i += 2;
+ } else {
+ newIds.add(ids.get(i));
+ i++;
+ }
+ }
+ return newIds;
+ }
+
+ private static Map bytesToUnicode() {
+ List bs = new ArrayList<>();
+ IntStream.rangeClosed('!', '~').forEach(bs::add);
+ IntStream.rangeClosed('¡', '¬').forEach(bs::add);
+ IntStream.rangeClosed('®', 'ÿ').forEach(bs::add);
+
+ List cs = new ArrayList<>(bs);
+ int n = 0;
+ for (int b = 0; b < 256; b++) {
+ if (!bs.contains(b)) {
+ bs.add(b);
+ cs.add(256 + n);
+ n++;
+ }
+ }
+ return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get));
+ }
+
+ public int getBosTokenId() {
+ return bosTokenId;
+ }
+
+ // === Tokenizer interface ===
+
+ public int getEosTokenId() {
+ return eosTokenId;
+ }
+
+ public int getPadTokenId() {
+ return padTokenId;
+ }
+
+ public String getPretokenizerType() {
+ return pretokenizerType;
+ }
+
+ @Override
+ public Map getSpecialTokens() {
+ return specialTokens;
+ }
+
+ // === Encoding ===
+
+ @Override
+ public boolean isSpecialToken(int tokenIndex) {
+ return specialTokens.containsValue(tokenIndex);
+ }
+
+ @Override
+ public boolean shouldDisplayToken(int token) {
+ return !isSpecialToken(token);
+ }
+
+ public String regexPattern() {
+ return compiledPattern != null ? compiledPattern.pattern() : null;
+ }
+
+ // @Override
+ public int[] encode(String text) {
+ StringBuilder sb = new StringBuilder();
+ byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
+ for (byte b : bytes) {
+ sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
+ }
+ return encodeImpl(sb.toString());
+ }
+
+ @Override
+ public List encodeAsList(String text) {
+ return Arrays.stream(encode(text)).boxed().toList();
+ }
+
+ private int[] encodeImpl(String text) {
+ return encode(text, Set.of()).stream().mapToInt(i -> i).toArray();
+ }
+
+ // === Decoding ===
+
+ public List encode(String text, Set allowedSpecial) {
+ if (allowedSpecial.isEmpty()) {
+ return encodeOrdinary(text);
+ }
+
+ assert specialTokens.keySet().containsAll(allowedSpecial);
+ String specialPattern = allowedSpecial.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")"));
+ String[] specialChunks = text.split(specialPattern);
+
+ List ids = new ArrayList<>();
+ for (String part : specialChunks) {
+ if (allowedSpecial.contains(part)) {
+ ids.add(specialTokens.get(part));
+ } else {
+ ids.addAll(encodeOrdinary(part));
+ }
+ }
+ return ids;
+ }
+
+ public List encodeOrdinary(String text) {
+ List textChunks = findAll(compiledPattern, text);
+ List ids = new ArrayList<>();
+ for (String chunk : textChunks) {
+ ids.addAll(encodeChunk(chunk));
+ }
+ return ids;
+ }
+
+ // === Helpers ===
+
+ private List encodeChunk(String chunk) {
+ List ids = new ArrayList<>();
+ for (char c : chunk.toCharArray()) {
+ int tokenIndex = vocabulary.getIndex(String.valueOf(c)).orElseThrow();
+ ids.add(tokenIndex);
+ }
+
+ while (ids.size() >= 2) {
+ Map, Integer> stats = getStats(ids);
+ Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow();
+ if (!merges.containsKey(pair)) {
+ break;
+ }
+ ids = merge(ids, pair, merges.get(pair));
+ }
+ return ids;
+ }
+
+ @Override
+ public String decode(List tokens) {
+ String decoded = decodeImpl(tokens);
+ int[] decodedBytesAsInts = decoded.codePoints().map(cp -> BYTE_DECODER.getOrDefault(cp, cp)).toArray();
+ byte[] rawBytes = new byte[decodedBytesAsInts.length];
+ for (int i = 0; i < decodedBytesAsInts.length; i++) {
+ rawBytes[i] = (byte) decodedBytesAsInts[i];
+ }
+ return new String(rawBytes, StandardCharsets.UTF_8);
+ }
+
+ private String decodeImpl(List tokens) {
+ StringBuilder sb = new StringBuilder();
+ for (int token : tokens) {
+ sb.append(vocabulary.get(token));
+ }
+ return sb.toString();
+ }
+
+ private Map, Integer> getStats(List ids) {
+ Map, Integer> map = new HashMap<>();
+ for (int i = 0; i + 1 < ids.size(); i++) {
+ Pair key = new Pair<>(ids.get(i), ids.get(i + 1));
+ map.merge(key, 1, Integer::sum);
+ }
+ return map;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java
new file mode 100644
index 00000000..36b9803e
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java
@@ -0,0 +1,396 @@
+package org.beehive.gpullama3.tornadovm.kernels;
+
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.annotations.Parallel;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.api.types.HalfFloat;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimized;
+import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedQ8_0Byte;
+import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedSingle;
+
+public class GraniteKernels {
+
+ public static void convertFP16toFP32withGraniteScale(KernelContext context, HalfFloatArray x, FloatArray wrapX, float embeddingScale) {
+ int i = context.globalIdx;
+ wrapX.set(i, embeddingScale * x.get(i).getFloat32());
+ }
+
+ public static void convertQ8_0toFP32withGraniteScale(KernelContext context, ByteArray x, FloatArray wrapX, float embeddingScale) {
+ int globalId = context.globalIdx;
+ int totalElements = wrapX.getSize();
+
+ if (globalId >= totalElements) {
+ return;
+ }
+
+ // Q8_0 block structure constants
+ int blockSize = 32;
+ int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants
+
+ // Calculate which block and position within block
+ int blockIdx = globalId / blockSize;
+ int withinBlockIdx = globalId % blockSize;
+
+ // Calculate byte offset for this Q8_0 block
+ int blockByteOffset = blockIdx * Q8_0_BLOCK_BYTES;
+
+ // Load scale (first 2 bytes of block as HalfFloat)
+ HalfFloat scale = x.getHalfFloat(blockByteOffset);
+ float scaleFloat = scale.getFloat32();
+
+ // Load quantized value (skip 2-byte scale, then index within block)
+ byte quantValue = x.get(blockByteOffset + 2 + withinBlockIdx);
+
+ // Dequantize: float_value = quantized_value * scale
+ float dequantizedValue = ((float) quantValue) * scaleFloat;
+
+ // Store result in output FloatArray
+ wrapX.set(globalId, embeddingScale * dequantizedValue);
+ }
+
+ // @formatter:off
+ public static void matrixVectorGenericWithGraniteScale(
+ KernelContext context,
+ HalfFloatArray x,
+ FloatArray hb, // output
+ HalfFloatArray w,
+ int dim1, // inner loop
+ int dim0, // outer loop
+ int localWorkGroupSize,
+ float logitsScale
+ ) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = localWorkGroupSize;
+
+ // Early exit if this workgroup is beyond our output dimension
+ if (rowId >= dim0) {
+ return;
+ }
+ float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ hb.set(rowId, sum * logitsScale);
+ }
+ }
+
+
+ public static void processHeadsFlashAttentionWithGraniteScale(KernelContext context,
+ FloatArray q, FloatArray key_cache, FloatArray value_cache,
+ FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul,
+ IntArray positionHolder, int layer, int contextLength, float attentionScale) {
+
+ // Thread and workgroup information
+ int tid = context.localIdx;
+ int h = context.groupIdx; // Each workgroup processes one head
+ int localSize = context.localGroupSizeX;
+
+ // Early exit if this workgroup is beyond our head count
+ // This relies on the kernel being launched with nHeads workgroups.
+ if (h >= nHeads) {
+ return;
+ }
+
+ int pos = positionHolder.get(0);
+ int loff = layer * contextLength * kvDim;
+ int kvHeadIdx = h / kvMul;
+ int BLOCK_SIZE_C = 16;
+
+ // Allocate shared memory for tiled computation
+ float[] q_shared = context.allocateFloatLocalArray(headSize);
+ float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
+ float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
+ float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C);
+ float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max
+
+ // Thread-local accumulators for online softmax
+ float maxScore = Float.NEGATIVE_INFINITY;
+ float sumExp = 0.0f;
+
+ // Thread-local output accumulation
+ float[] output = new float[headSize];
+ for (int i = 0; i < headSize; i++) {
+ output[i] = 0.0f;
+ }
+
+ // Load query vector into shared memory
+ for (int i = tid; i < headSize; i += localSize) {
+ q_shared[i] = q.get(h * headSize + i);
+ }
+
+ context.localBarrier();
+
+ // Process sequence in tiles
+ for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) {
+ int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos);
+
+ // Load key and value vectors for this tile
+ // Each thread loads a portion of the K and V vectors for the tile
+ for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
+ int k_v_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile
+ int tileMemOffset = k_v_idx_in_tile * headSize;
+ for (int d = 0; d < headSize; d++) {
+ int kvCacheAbsolutePos = tIdxInSeq;
+ int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d;
+ k_tile[tileMemOffset + d] = key_cache.get(kvOffset);
+ v_tile[tileMemOffset + d] = value_cache.get(kvOffset);
+ }
+ }
+
+ context.localBarrier();
+
+ // Compute attention scores for this tile
+ // Each thread computes one score for the tile
+ for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
+ int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile
+
+ float score = 0.0f;
+ for (int d = 0; d < headSize; d++) {
+ score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d];
+ }
+ score *= attentionScale;
+ s_tile[score_idx_in_tile] = score;
+ }
+
+ context.localBarrier();
+
+ // Find max score in this tile (all threads compute it redundantly over the small s_tile)
+ float tileLocalMax = Float.NEGATIVE_INFINITY;
+ for (int i = 0; i <= tileEnd - tileC; i++) { // Iterate over valid scores in s_tile
+ if (s_tile[i] > tileLocalMax) {
+ tileLocalMax = s_tile[i];
+ }
+ }
+
+ // Broadcast max to all threads via shared memory
+ if (tid == 0) {
+ shared_tile_max_holder[0] = tileLocalMax; // FIX: Use dedicated holder
+ }
+ context.localBarrier();
+ float currentTileMax = shared_tile_max_holder[0]; // FIX: Read from dedicated holder
+
+ // Determine if we need to rescale previous results
+ float newMax = Math.max(maxScore, currentTileMax);
+ if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) {
+ float scale = TornadoMath.exp(maxScore - newMax);
+ sumExp *= scale;
+ for (int d = 0; d < headSize; d++) {
+ output[d] *= scale;
+ }
+ }
+ maxScore = newMax;
+
+ // Process each key-value pair using original scores from s_tile
+ // All threads iterate over all scores in the current tile
+ for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) {
+ // s_tile[t_idx_in_s_tile] now correctly refers to the original score
+ float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore);
+ sumExp += expScore;
+
+ for (int d = 0; d < headSize; d++) {
+ output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d];
+ }
+ }
+ context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load
+ }
+
+ // Normalize and write final results
+ float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; // Avoid division by zero, return 0 if sumExp is 0
+ for (int d = tid; d < headSize; d += localSize) {
+ xb.set(h * headSize + d, output[d] * normFactor);
+ }
+ }
+ // @formatter:on
+
+ public static void matrixVectorGenericWithResidualGranite(KernelContext context, FloatArray x, FloatArray hb,
+ HalfFloatArray w, int n, int d, int localWorkGroupSize, float residualScale) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = localWorkGroupSize;
+
+ // Early exit if this workgroup is beyond our output dimension
+ if (rowId >= d) {
+ return;
+ }
+
+ float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ float residual = residualScale * sum;
+ float result = hb.get(rowId) + residual;
+ hb.set(rowId, result);
+ }
+ }
+
+ public static void processHeadsParallelGranite(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
+ IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength, float attentionScale) {
+
+ int pos = positionHolder.get(0);
+ int loff = layer * contextLength * kvDim;
+
+ // Parallelize computation across attention heads
+ for (@Parallel int h = 0; h < nHeads; h++) {
+ // Process each head in parallel
+ processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt, attentionScale);
+ }
+ }
+
+ private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos,
+ FloatArray wrapAtt, float attentionScale) {
+
+ // Base index for this head's attention weights
+ int headOffset = h * (pos + 1);
+
+ // STEP 1: Calculate attention scores for all timesteps
+ for (int t = 0; t <= pos; t++) {
+ int kvHeadIdx = h / kvMul;
+ int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
+
+ float score = 0.0f;
+ for (int i = 0; i < headSize; i++) {
+ score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
+ }
+ score *= attentionScale; //TODO: might need score = score * attentionScale;
+// score = score / TornadoMath.sqrt(headSize);
+
+ // Store in attention buffer
+ wrapAtt.set(headOffset + t, score);
+ }
+
+ // STEP 2: Find max score for softmax stability
+ float maxScore = wrapAtt.get(headOffset);
+ for (int t = 1; t <= pos; t++) {
+ float val = wrapAtt.get(headOffset + t);
+ if (val > maxScore) {
+ maxScore = val;
+ }
+ }
+
+ // STEP 3: Compute exponentials and sum
+ float sum = 0.0f;
+ for (int t = 0; t <= pos; t++) {
+ int idx = headOffset + t;
+ float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore);
+ wrapAtt.set(idx, expScore);
+ sum += expScore;
+ }
+
+ // STEP 4: Normalize
+ float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos + 1));
+ for (int t = 0; t <= pos; t++) {
+ int idx = headOffset + t;
+ wrapAtt.set(idx, wrapAtt.get(idx) * normFactor);
+ }
+
+ // STEP 5: Compute weighted sum of values for each dimension
+ for (int i = 0; i < headSize; i++) {
+ float weightedSum = 0.0f;
+ for (int t = 0; t <= pos; t++) {
+ int kvHeadIdx = h / kvMul;
+ int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
+ weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i);
+ }
+ allXb.set(h * headSize + i, weightedSum);
+ }
+ }
+
+ public static void matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale(KernelContext context, FloatArray x,
+ FloatArray hb, ByteArray w, int n, int d, int localWorkGroupSize, float residualScale) {
+ // One row per workgroup (not per thread)
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = localWorkGroupSize;
+
+ // Early exit if this workgroup is beyond our output dimension
+ if (rowId >= d) {
+ return;
+ }
+
+ float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localSize, x, w, n);
+
+ // Thread 0 in each workgroup writes the final result
+ if (localId == 0) {
+ float residualScaledSum = residualScale * sum;
+ float result = hb.get(rowId) + residualScaledSum;
+ hb.set(rowId, result);
+ }
+ }
+
+ public static void matrixVectorGenericQ8ByteWithGraniteScale(KernelContext context, FloatArray x, FloatArray output, ByteArray q,
+ int dim1, int dim0, int localWorkGroupSize, float logitsScale) {
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+
+ if (rowId >= dim0) {
+ return;
+ }
+
+ float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, q, dim1);
+
+ // Thread 0 writes the result
+ if (localId == 0) {
+ output.set(rowId, sum * logitsScale);
+ }
+ }
+
+ public static void ropeRotationWithCacheCopy(KernelContext context, IntArray positionHolder, FloatArray sq, // Q vector (in/out)
+ FloatArray sk, // K vector (in/out)
+ FloatArray sv, // V vector (in only)
+ FloatArray keyCache, // Key cache (out)
+ FloatArray valueCache, // Value cache (out)
+ int kvDim, int headSize,
+ float ropeTheta,
+ int layer, int contextLength) {
+
+ int i = context.globalIdx * 2;
+ int pos = positionHolder.get(0);
+
+ // Bounds check for Q rotation (Q has dim elements, processed in pairs)
+ if (i + 1 < sq.getSize()) {
+ // RoPE frequency calculation
+ int head_dim = i % headSize;
+ // TornadoMath.pow(ropeTheta, head_dim / (float) headSize);
+ float freq = 1.0f / TornadoMath.pow(ropeTheta, head_dim / (float) headSize);
+ float val = pos * freq;
+ float fcr = TornadoMath.cos(val);
+ float fci = TornadoMath.sin(val);
+
+ // Rotate Q
+ float v0q = sq.get(i);
+ float v1q = sq.get(i + 1);
+ sq.set(i, v0q * fcr - v1q * fci);
+ sq.set(i + 1, v0q * fci + v1q * fcr);
+
+ // Rotate K AND write to cache (only for kvDim elements)
+ if (i + 1 < kvDim) {
+ float v0k = sk.get(i);
+ float v1k = sk.get(i + 1);
+ float rotated0 = v0k * fcr - v1k * fci;
+ float rotated1 = v0k * fci + v1k * fcr;
+
+ // Write rotated K back to sk
+ sk.set(i, rotated0);
+ sk.set(i + 1, rotated1);
+
+ // Direct cache write (fused - no separate copy kernel!)
+ int cacheOffset = layer * contextLength * kvDim + pos * kvDim;
+ keyCache.set(cacheOffset + i, rotated0);
+ keyCache.set(cacheOffset + i + 1, rotated1);
+
+ // Copy V to cache (V doesn't need rotation)
+ valueCache.set(cacheOffset + i, sv.get(i));
+ valueCache.set(cacheOffset + i + 1, sv.get(i + 1));
+ }
+ }
+
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java
index 1684a5b8..ca844e51 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java
@@ -1,5 +1,6 @@
package org.beehive.gpullama3.tornadovm.layerplanner.base;
+import org.beehive.gpullama3.inference.state.GraniteState;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.inference.state.LlamaState;
import org.beehive.gpullama3.inference.state.Phi3State;
@@ -8,10 +9,12 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
+import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner;
+import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner;
@@ -55,6 +58,7 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) {
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
+ case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model);
case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType());
};
@@ -67,6 +71,7 @@ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) {
case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model);
case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model);
+ case GRANITE -> new GraniteQ8_0LayerPlanner((GraniteState) state, model);
case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model);
default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType());
};
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java
new file mode 100644
index 00000000..7cc97d64
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java
@@ -0,0 +1,26 @@
+package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
+
+import org.beehive.gpullama3.inference.state.GraniteState;
+import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer;
+
+public class GraniteFP16LayerPlanner extends FP16LayerPlanner {
+ public GraniteFP16LayerPlanner(GraniteState state, Model model) {
+ super(state, model);
+ validateQuantizationType();
+ setupTornadoForwardPlan();
+ }
+
+ @Override
+ protected void initializeLayerComponents() {
+ this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config);
+ this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType);
+ this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
+ }
+
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java
new file mode 100644
index 00000000..ee818080
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java
@@ -0,0 +1,26 @@
+package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0;
+
+import org.beehive.gpullama3.inference.state.GraniteState;
+import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner;
+import org.beehive.gpullama3.tornadovm.layers.ActivationGranite;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers;
+import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer;
+
+public class GraniteQ8_0LayerPlanner extends Q8_0LayerPlanner {
+
+ public GraniteQ8_0LayerPlanner(GraniteState state, Model model) {
+ super(state, model);
+ validateQuantizationType();
+ setupTornadoForwardPlan();
+ }
+
+ @Override
+ protected void initializeLayerComponents() {
+ this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config);
+ this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType);
+ this.logitsLayer = new LogitsGraniteQ8_0Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java
new file mode 100644
index 00000000..20002dac
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java
@@ -0,0 +1,68 @@
+package org.beehive.gpullama3.tornadovm.layers;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.WorkerGrid1D;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+
+public class ActivationGranite extends Activation {
+ private final TaskGraph activationUpdate;
+
+ // Granite is a special case where activation X is scaled by embedding scale float value that inside model.
+ public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) {
+ super(taskGraphHandle, state, weights, config);
+
+ KernelContext kernelContext = new KernelContext();
+
+ // @formatter:off
+ switch (config.quantization()) {
+ case "FP16" -> {
+ this.activationUpdate = new TaskGraph(taskGraphHandle)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
+ .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX, config.embeddingScale())
+ .persistOnDevice(state.wrapX);
+ }
+ case "Q8_0" -> {
+ this.activationUpdate = new TaskGraph(taskGraphHandle)
+ .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX)
+ .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, kernelContext, (ByteArray) state.embeddingX, state.wrapX, config.embeddingScale())
+ .persistOnDevice(state.wrapX);
+ }
+ default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
+ }
+ // @formatter:on
+ }
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler scheduler) {
+ WorkerGrid worker = new WorkerGrid1D(config.dim());
+ worker.setLocalWork(128, 1, 1);
+ scheduler.addWorkerGrid("activationUpdate.updateX", worker);
+ return scheduler;
+ }
+
+ @Override
+ public GridScheduler getGridScheduler() {
+ return null;
+ }
+
+ @Override
+ public TaskGraph getTaskGraph() {
+ return activationUpdate;
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return activationUpdate.snapshot();
+ }
+
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
new file mode 100644
index 00000000..0387ae54
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java
@@ -0,0 +1,375 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+public class GraniteFP16FFNLayers extends AbstractFFNLayers {
+
+ TaskGraph ffnTaskGraphs;
+ GridScheduler scheduler;
+ List ffnLayerTaskGraphs;
+
+ public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) {
+ super(taskGraph, state, weights, config, schedulerType);
+ this.ffnLayerTaskGraphs = setupFFNLayered();
+ }
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
+ WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
+
+ int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize());
+
+ int fusedQKVRows = config.dim() + 2 * config.kvDim();
+ int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512);
+
+ // Map workers to tasks
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ // === Attention Block ===
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
+ // === FFN Block ===
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
+ }
+ return tornadoForwardScheduler;
+ }
+
+ @Override
+ public GridScheduler getGridScheduler() {
+ return scheduler;
+ }
+
+ @Override
+ public TaskGraph getTaskGraph() {
+ return ffnTaskGraphs;
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return null;
+ }
+
+ public List getFfnLayerTaskGraphs() {
+ return ffnLayerTaskGraphs;
+ }
+
+ List setupFFNLayered() {
+ return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
+ var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i);
+ if (i == config.numberOfLayers() - 1) {
+ setupLastID(ffnLayer.getTaskGraphName());
+ }
+ return ffnLayer.snapshot();
+ }).toList();
+ }
+
+ // @formatter:off
+ /**
+ * Transformer Layer Task Flow (LlamaFP16FFNLayers)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (partial sums)
+ * └────────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ attn_rms_finalize│──▶ temp (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌─────────────────────┐
+ * │ attn_rms_apply_fp16 │──▶ wrapXbFP16 (normalized, FP16)
+ * └──────────┬──────────┘
+ * │
+ * ▼
+ * ┌────────────────┐ ┌─────────────────────────────┐
+ * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │
+ * └───────┬────────┘ └─────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (partial sums)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌─────────────────┐
+ * │ ffn_rms_finalize│──▶ tempFFN (final scale)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
+ * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points:
+ * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel)
+ * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
+ * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel)
+ *
+ */
+ TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) {
+ var layerTaskGraphName = "layer_" + layerIndex;
+ TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
+
+ // === Data Setup ===
+ unifiedLayer.consumeFromDevice(state.wrapX);
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ weights.wqLayered[layerIndex].asHalfFloatArray(),
+ weights.wkLayered[layerIndex].asHalfFloatArray(),
+ weights.wvLayered[layerIndex].asHalfFloatArray(),
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ weights.w1Layered[layerIndex].asHalfFloatArray(),
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ weights.w3Layered[layerIndex].asHalfFloatArray());
+ unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
+
+ // === Attention Block ===
+ // RMS Normalization
+ unifiedLayer.task("attn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.temp, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("attn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.temp, config.dim(), config.rmsNormEps());
+ }
+
+ unifiedLayer.task("attn_rms_apply_fp16",
+ TransformerComputeKernels::mapContextWithQuantize,
+ context, state.wrapXbFP16, state.wrapX,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);
+
+ // QKV Projection (fused)
+ unifiedLayer.task("qkv_projection",
+ TransformerComputeKernelsLayered::fusedQKVMatmulX,
+ context,
+ state.wrapXbFP16, // input (FP32)
+ state.wrapQ, // output Q
+ state.wrapK, // output K
+ state.wrapV, // output V
+ weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq
+ weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk
+ weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv
+ config.dim(), // dim
+ config.kvDim(), // kvDim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // RoPE + KV Cache
+ unifiedLayer.task("rope_and_kv_cache",
+ GraniteKernels::ropeRotationWithCacheCopy,
+ context,
+ state.positionHolder,
+ state.wrapQ, // Q (in/out)
+ state.wrapK, // K (in/out)
+ state.wrapV, // V (in only)
+ state.wrapKeyCache, // Key cache (out)
+ state.wrapValueCache, // Value cache (out)
+ config.kvDim(),
+ config.headSize(),
+ config.ropeTheta(), // needs to load it from model
+ layerIndex,
+ config.contextLength());
+ // Attention
+ configureAttention(unifiedLayer, layerIndex, config );
+ // Output Projection (Wo) with residual
+ unifiedLayer.task("attn_output_proj",
+ GraniteKernels::matrixVectorGenericWithResidualGranite,
+ context, state.wrapXb, state.wrapX,
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC,
+ config.residualScale()
+ );
+
+ // === FFN Block ===
+ // RMS Normalization
+ unifiedLayer.task("ffn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.tempFFN, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.tempFFN, config.dim(), config.rmsNormEps());
+ }
+
+ unifiedLayer.task("rms_ffn_gate_up",
+ TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp,
+ context,
+ state.wrapX, // raw input (FP32)
+ state.wrapHb, // output
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ state.tempFFN, // RMS scale factor
+ weights.w1Layered[layerIndex].asHalfFloatArray(), // W1
+ weights.w3Layered[layerIndex].asHalfFloatArray(), // W3
+ config.dim(), // input dimension
+ config.hiddenDim(), // output dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Down projection (W2) with residual
+ unifiedLayer.task("ffn_down_proj",
+ GraniteKernels::matrixVectorGenericWithResidualGranite,
+ context, state.wrapHb, state.wrapX,
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale());
+
+ unifiedLayer.persistOnDevice(state.wrapX);
+
+ return unifiedLayer;
+ }
+
+ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
+ if (layerIndex == 0) {
+ // First layer: Transfer initial data to device (one-time transfer)
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder,
+ state.temp, state.tempFFN
+ );
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Kernel context
+ context,
+ // Intermediate buffers
+ state.wrapXb, state.wrapXb2,
+ // QKV vectors
+ state.wrapQ, state.wrapK, state.wrapV,
+ // KV cache
+ state.wrapKeyCache, state.wrapValueCache,
+ // Attention & FFN buffers
+ state.wrapAtt, state.wrapHb, state.wrapXbFP16);
+ } else {
+ // Subsequent layers: Consume data already on device from previous layer
+ unifiedLayer.consumeFromDevice(
+ // Kernel context
+ context,
+ // Intermediate buffers
+ state.wrapXb, state.wrapXb2,
+ // QKV vectors
+ state.wrapQ, state.wrapK, state.wrapV,
+ // KV cache
+ state.wrapKeyCache, state.wrapValueCache,
+ // Attention & FFN buffers
+ state.wrapAtt, state.wrapHb,
+ // Position & misc
+ state.positionHolder, state.wrapXbFP16);
+ }
+ return unifiedLayer;
+ }
+
+ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) {
+ if (schedulerType == SchedulerType.NVIDIA) {
+ // Flash Attention (optimized for NVIDIA GPUs)
+ return unifiedLayer.task("attention",
+ GraniteKernels::processHeadsFlashAttentionWithGraniteScale,
+ context,
+ state.wrapQ, // Query
+ state.wrapKeyCache, // Key cache
+ state.wrapValueCache, // Value cache
+ state.wrapXb, // Output
+ config.numberOfHeads(),
+ config.headSize(),
+ config.kvDim(),
+ config.kvMul(),
+ state.positionHolder,
+ layerIndex,
+ config.contextLength(),
+ config.attentionScale()
+ );
+ } else {
+ // Standard parallel attention (for non-NVIDIA backends)
+ return unifiedLayer.task("attention",
+ GraniteKernels::processHeadsParallelGranite,
+ state.wrapQ, // Query
+ state.wrapKeyCache, // Key cache
+ state.wrapValueCache, // Value cache
+ state.wrapXb, // Output
+ config.numberOfHeads(),
+ config.headSize(),
+ config.kvDim(),
+ config.kvMul(),
+ config.contextLength(), // seqLen parameter
+ state.positionHolder,
+ state.wrapAtt, // Attention weights buffer
+ layerIndex,
+ config.contextLength(),
+ config.attentionScale());
+ }
+ }
+ // @formatter:on
+
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
new file mode 100644
index 00000000..d55d707e
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java
@@ -0,0 +1,125 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.WorkerGrid1D;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+public class LogitsGraniteFP16Layer extends LogitsFP16Layer {
+ private String lastTaskGraphID;
+ private TaskGraph logitsTaskGraph;
+ private ImmutableTaskGraph immutableLogitsGraph;
+ private GridScheduler scheduler;
+ private SchedulerType schedulerType;
+
+ public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
+ super(name, state, weights, config, lastTaskGraphID, schedulerType);
+ this.lastTaskGraphID = lastTaskGraphID;
+ this.schedulerType = schedulerType;
+ var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor");
+ this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config);
+ }
+
+ // @formatter:off
+ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) {
+ var logits = new TaskGraph("logits");
+ // === Data Setup ===
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
+ logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
+ logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Kernel context
+ context,
+ // Output buffer
+ state.wrapLogits,
+ // Intermediate FP16 buffer
+ state.wrapXbFP16,
+ // Weights
+ weights.wclsByteArray.asHalfFloatArray(),
+ weights.rms_final_weight_as_floatArray.asFloatArray());
+
+ // === Final RMS Normalization ===
+ logits.task("rms_reduce",
+ TransformerComputeKernels::reductionOneBlockWithLayer,
+ context,
+ state.tempLogits, // output: partial sums + final scale factor
+ state.wrapX, // input: hidden state
+ config.dim(), // dimension
+ config.rmsNormEps(), // epsilon for numerical stability
+ state.localSize); // local workgroup size
+
+ if (schedulerType == SchedulerType.NON_NVIDIA) {
+ logits.task("rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ state.tempLogits, // in/out: combines partial sums
+ config.dim(), // dimension
+ config.rmsNormEps()); // epsilon
+ }
+
+ logits.task("rms_apply_fp16",
+ TransformerComputeKernels::mapContextWithQuantizeLogits,
+ context,
+ state.wrapXbFP16, // output: normalized (FP16)
+ state.wrapX, // input: hidden state
+ weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights
+ state.tempLogits); // scale factor from reduction
+
+ // === Vocabulary Projection ===
+ logits.task("vocab_proj",
+ GraniteKernels::matrixVectorGenericWithGraniteScale,
+ context,
+ state.wrapXbFP16, // input (FP16)
+ state.wrapLogits, // output
+ weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights
+ config.dim(), // input dimension
+ config.vocabularySize(), // output dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS,
+ config.logitScale()); // granite logit scaling
+
+ // === Transfer Results to Host ===
+ logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
+ return logits;
+ }
+ // @formatter:on
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
+ WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
+ var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
+ var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
+ vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
+ tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
+ return tornadoForwardScheduler;
+ }
+
+ @Override
+ public GridScheduler getGridScheduler() {
+ return scheduler;
+ }
+
+ @Override
+ public TaskGraph getTaskGraph() {
+ return logitsTaskGraph;
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return immutableLogitsGraph;
+ }
+}
+
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
new file mode 100644
index 00000000..b7e036d4
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java
@@ -0,0 +1,376 @@
+package org.beehive.gpullama3.tornadovm.layers.type.q8_0;
+
+import org.beehive.gpullama3.inference.state.GraniteState;
+import org.beehive.gpullama3.inference.state.LlamaState;
+import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.granite.Granite;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+import java.util.List;
+import java.util.stream.IntStream;
+
+public class GraniteQ8_0FFNLayers extends AbstractFFNLayers {
+
+ GridScheduler scheduler;
+ List ffnLayerTaskGraphs;
+
+ public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) {
+ super(taskGraphName, state, weights, config, schedulerType);
+ ffnLayerTaskGraphs = setupFFNLayered();
+ }
+
+ @Override
+ public GridScheduler getGridScheduler() {
+ return scheduler;
+ }
+
+ @Override
+ public TaskGraph getTaskGraph() {
+ return null;
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return null;
+ }
+
+ List setupFFNLayered() {
+ return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> {
+ var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i);
+ if (i == config.numberOfLayers() - 1) {
+ setupLastID(ffnLayer.getTaskGraphName());
+ }
+ return ffnLayer.snapshot();
+ }).toList();
+ }
+
+ /**
+ * Transformer Layer Task Flow (LlamaQ8FFNLayers)
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ * ATTENTION BLOCK
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * wrapX (FP32)
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ attn_rms_reduce │──▶ temp (partial sums)
+ * └────────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌──────────────────┐
+ * │ attn_rms_finalize│──▶ temp (final scale)
+ * └────────┬─────────┘
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ attn_rms_apply │──▶ wrapXb (normalized, FP32)
+ * └───────┬────────┘
+ * │
+ * ▼
+ * ┌────────────────┐ ┌─────────────────────────────┐
+ * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │
+ * └───────┬────────┘ └─────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────────────┐ ┌─────────────────────────────────────┐
+ * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │
+ * └─────────┬─────────┘ └─────────────────────────────────────┘
+ * │
+ * ▼
+ * ┌───────────┐
+ * │ attention │──▶ wrapXb (attention output)
+ * └─────┬─────┘
+ * │
+ * ▼
+ * ┌──────────────────┐
+ * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection)
+ * └────────┬─────────┘
+ * │
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │ FFN BLOCK
+ * ══════════╪═══════════════════════════════════════════════════════════════════
+ * │
+ * ▼
+ * ┌────────────────┐
+ * │ ffn_rms_reduce │──▶ tempFFN (partial sums)
+ * └───────┬────────┘
+ * │
+ * ▼ (optional: NON_NVIDIA only)
+ * ┌─────────────────┐
+ * │ ffn_rms_finalize│──▶ tempFFN (final scale)
+ * └────────┬────────┘
+ * │
+ * ▼
+ * ┌─────────────────┐
+ * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3)
+ * └────────┬────────┘ (fully fused: RMS reduce/apply + W1/W3 matmuls + SiLU + GLU)
+ * │
+ * ▼
+ * ┌──────────────┐
+ * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection)
+ * └──────┬───────┘
+ * │
+ * ▼
+ * wrapX (FP32) ──▶ [next layer or logits]
+ *
+ * ══════════════════════════════════════════════════════════════════════════════
+ *
+ * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps)
+ *
+ * Data Flow Summary:
+ * Input: wrapX (FP32) - hidden state from previous layer
+ * Output: wrapX (FP32) - updated hidden state with residual connections
+ *
+ * Key Fusion Points:
+ * • qkv_projection: Fused Q/K/V matmuls with Q8 dequantization (3→1 kernel)
+ * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel)
+ * • rms_ffn_gate_up: Fully fused RMS norm + W1/W3 matmuls + SiLU + GLU (5→1 kernel)
+ *
+ * Quantization: Q8_0 format (8-bit weights with block-wise scaling)
+ *
+ */
+ TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) {
+ var layerTaskGraphName = "layer_" + layerIndex;
+ TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName);
+
+ // === Data Setup ===
+ unifiedLayer.consumeFromDevice(state.wrapX);
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ // Copy-in weights per layer for batched-layered layout (Q8 format)
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ weights.wqLayered[layerIndex].asByteArray(),
+ weights.wkLayered[layerIndex].asByteArray(),
+ weights.wvLayered[layerIndex].asByteArray(),
+ weights.woLayered[layerIndex].asByteArray(),
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ weights.w1Layered[layerIndex].asByteArray(),
+ weights.w2Layered[layerIndex].asByteArray(),
+ weights.w3Layered[layerIndex].asByteArray());
+ unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
+
+ // === Attention Block ===
+ // RMS Normalization
+ unifiedLayer.task("attn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.temp, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("attn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.temp, config.dim(), config.rmsNormEps());
+ }
+
+ unifiedLayer.task("attn_rms_apply",
+ TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
+ context, state.wrapXb, state.wrapX,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp);
+
+ // QKV Projection (fused with Q8 dequantization)
+ unifiedLayer.task("qkv_projection",
+ TransformerComputeKernelsLayered::fusedQKVMatmulQ8,
+ context,
+ state.wrapXb, // input (FP32)
+ state.wrapQ, // output Q
+ state.wrapK, // output K
+ state.wrapV, // output V
+ weights.wqLayered[layerIndex].asByteArray(), // Wq (Q8)
+ weights.wkLayered[layerIndex].asByteArray(), // Wk (Q8)
+ weights.wvLayered[layerIndex].asByteArray(), // Wv (Q8)
+ config.dim(), // dim
+ config.kvDim(), // kvDim
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // RoPE + KV Cache
+ unifiedLayer.task("rope_and_kv_cache",
+ GraniteKernels::ropeRotationWithCacheCopy,
+ context,
+ state.positionHolder,
+ state.wrapQ, // Q (in/out)
+ state.wrapK, // K (in/out)
+ state.wrapV, // V (in only)
+ state.wrapKeyCache, // Key cache (out)
+ state.wrapValueCache, // Value cache (out)
+ config.kvDim(),
+ config.headSize(),
+ config.ropeTheta(),
+ layerIndex,
+ config.contextLength());
+
+ // Attention
+ configureAttention(unifiedLayer, layerIndex, config);
+
+ // Output Projection (Wo) with residual (Q8 dequantization)
+ unifiedLayer.task("attn_output_proj",
+ GraniteKernels::matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale,
+ context, state.wrapXb, state.wrapX,
+ weights.woLayered[layerIndex].asByteArray(),
+ config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale());
+
+ // === FFN Block ===
+ // RMS Normalization
+ unifiedLayer.task("ffn_rms_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, state.tempFFN, state.wrapX,
+ config.dim(), config.rmsNormEps(), state.localSize);
+
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, state.tempFFN, config.dim(), config.rmsNormEps());
+ }
+
+ // Fully fused: RMS apply + Gate/Up projections + SiLU + GLU (Q8 dequantization)
+ unifiedLayer.task("rms_ffn_gate_up",
+ TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8,
+ context,
+ state.wrapX, // raw input (FP32)
+ state.wrapHb, // output
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights
+ weights.w1Layered[layerIndex].asByteArray(), // W1 (Q8)
+ weights.w3Layered[layerIndex].asByteArray(), // W3 (Q8)
+ config.dim(), // input dimension
+ config.hiddenDim(), // output dimension
+ LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Down projection (W2) with residual (Q8 dequantization)
+ unifiedLayer.task("ffn_down_proj",
+ GraniteKernels::matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale,
+ context, state.wrapHb, state.wrapX,
+ weights.w2Layered[layerIndex].asByteArray(),
+ config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale());
+
+ // Keep activation X on device for next layer
+ unifiedLayer.persistOnDevice(state.wrapX);
+
+ return unifiedLayer;
+ }
+
+ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
+ // First layer: Transfer initial data to device (one-time transfer)
+ if (layerIndex == 0) {
+ // Transfer all attention-related data: query, key, value matrices and their caches
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ state.positionHolder,
+ state.temp, state.tempFFN); //
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, //
+ context,
+ state.wrapXb, state.wrapXb2, //
+ state.wrapQ, state.wrapK, state.wrapV, //
+ state.wrapKeyCache, state.wrapValueCache, //
+ state.wrapAtt, state.wrapHb); //
+ } else {
+ // Subsequent layers: Consume data already on device from previous layer
+ unifiedLayer.consumeFromDevice(
+ context,
+ state.wrapXb, state.wrapXb2, //
+ state.wrapQ, state.wrapK, state.wrapV, //
+ state.wrapKeyCache, state.wrapValueCache, //
+ state.wrapAtt, state.wrapHb, //
+ state.positionHolder //
+ );
+ }
+ return unifiedLayer;
+ }
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
+ // === Worker Grid Definitions ===
+ WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
+
+ int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // Fused QKV: dim rows for Q + kvDim rows for K + kvDim rows for V
+ int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC;
+ WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512);
+
+ WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize());
+
+ // === Per-Layer Grid Assignments (ordered by task graph flow) ===
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ // --- Attention Block ---
+ // RMS Normalization
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
+ // --- FFN Block ---
+ // RMS Normalization
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
+ // Fused RMS + Gate/Up Projections
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
+ // Down Projection
+ tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
+ }
+
+ return tornadoForwardScheduler;
+ }
+
+ public List getFfnLayerTaskGraphs() {
+ return ffnLayerTaskGraphs;
+ }
+
+ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) {
+ if (schedulerType == SchedulerType.NVIDIA) {
+ // Flash Attention (optimized for NVIDIA GPUs)
+ return unifiedLayer.task("attention",
+ GraniteKernels::processHeadsFlashAttentionWithGraniteScale,
+ context,
+ state.wrapQ, // Query
+ state.wrapKeyCache, // Key cache
+ state.wrapValueCache, // Value cache
+ state.wrapXb, // Output
+ config.numberOfHeads(),
+ config.headSize(),
+ config.kvDim(),
+ config.kvMul(),
+ state.positionHolder,
+ layerIndex,
+ config.contextLength(),
+ config.attentionScale()
+ );
+ } else {
+ // Standard parallel attention (for non-NVIDIA backends)
+ return unifiedLayer.task("attention",
+ GraniteKernels::processHeadsParallelGranite,
+ state.wrapQ, // Query
+ state.wrapKeyCache, // Key cache
+ state.wrapValueCache, // Value cache
+ state.wrapXb, // Output
+ config.numberOfHeads(),
+ config.headSize(),
+ config.kvDim(),
+ config.kvMul(),
+ config.contextLength(), // seqLen parameter
+ state.positionHolder,
+ state.wrapAtt, // Attention weights buffer
+ layerIndex,
+ config.contextLength(),
+ config.attentionScale());
+ }
+ }
+ // @formatter:on
+}
+
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java
new file mode 100644
index 00000000..d1fd4f0c
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java
@@ -0,0 +1,118 @@
+package org.beehive.gpullama3.tornadovm.layers.type.q8_0;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.granite.GraniteConfiguration;
+import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid1D;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer{
+ private String lastTaskGraphID;
+ private TaskGraph logitsTaskGraph;
+ private ImmutableTaskGraph immutableLogitsGraph;
+ private GridScheduler scheduler;
+ private SchedulerType schedulerType;
+
+ public LogitsGraniteQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) {
+ super(taskGraphName, state, weights, config, lastTaskGraphID, schedulerType);
+ this.lastTaskGraphID = lastTaskGraphID;
+ var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor");
+ this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config);
+ this.schedulerType = schedulerType;
+ }
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
+ var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256);
+ var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
+ var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
+ vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
+ tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker);
+ tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS);
+ tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS);
+ return tornadoForwardScheduler;
+ }
+
+ // @formatter:off
+ private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) {
+ var logits = new TaskGraph("logits");
+ // === Data Setup ===
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
+ logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
+ logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context, //
+ state.wrapLogits, //
+ weights.wclsByteArray.asByteArray(), //
+ weights.rms_final_weight_as_floatArray);
+
+ // === Final RMS Normalization ===
+ logits.task("rms_reduce",
+ TransformerComputeKernels::reductionOneBlockWithLayer,
+ context,
+ state.tempLogits, // output: partial sums + final scale factor
+ state.wrapX, // input: hidden state
+ config.dim(), // dimension
+ config.rmsNormEps(), // epsilon for numerical stability
+ state.localSize); // local workgroup size
+
+ if (schedulerType == SchedulerType.NON_NVIDIA) {
+ logits.task("rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ state.tempLogits,
+ config.dim(),
+ config.rmsNormEps());
+ }
+ logits.task("mapContextLogits",
+ TransformerComputeKernels::reductionOneBlock2WithLogits,
+ context,
+ state.wrapX,
+ weights.rms_final_weight_as_floatArray.asFloatArray(),
+ state.tempLogits);
+
+ // === Vocabulary vocab_proj ===
+ logits.task("vocab_proj", GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, //
+ context,
+ state.wrapX,
+ state.wrapLogits,
+ weights.wclsByteArray.asByteArray(),
+ config.dim(),
+ config.vocabularySize(),
+ LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS,
+ config.logitScale()
+
+ );
+
+ // === Transfer Results to Host ===
+ logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
+ return logits;
+ }
+ // @formatter:on
+
+ @Override
+ public GridScheduler getGridScheduler() {
+ return scheduler;
+ }
+
+ @Override
+ public TaskGraph getTaskGraph() {
+ return logitsTaskGraph;
+ }
+
+ @Override
+ public ImmutableTaskGraph getImmutableTaskGraph() {
+ return immutableLogitsGraph;
+ }
+
+}