From ba7b94a7f828dbcf339ee95e54479d8967841235 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 15:23:56 +0200 Subject: [PATCH 01/26] Add Granite model support in inference pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduced `Granite` model type in `ModelType`. - Added `forwardGranite` method with µP scaling in `InferenceCore`. - Implemented token generation methods (`generateTokensGranite`, `generateTokensGPUGranite`) for Granite models. - Updated `ModelLoader` to detect Granite models via metadata or name. - Enhanced tokenizer and chat format compatibility with Granite. --- .../gpullama3/inference/InferenceCore.java | 122 ++++++++++++++++ .../gpullama3/inference/InferenceEngine.java | 134 ++++++++++++++++++ .../beehive/gpullama3/model/ModelType.java | 8 ++ .../gpullama3/model/format/ChatFormat.java | 2 + .../model/format/LlamaChatFormat.java | 5 +- .../gpullama3/model/loader/ModelLoader.java | 9 +- 6 files changed, 277 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 475f711e..061ff3ed 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -11,6 +11,7 @@ import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; @@ -546,6 +547,127 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke return state.logits; } + /** + * Forward pass for Granite models with µP scaling factors applied. + *

+ * Granite uses the same transformer architecture as Llama but with maximal update parameterization (µP) + * scaling factors applied at specific points: + *

+ */ + public static FloatTensor forwardGranite(Model model, State state, int token, int position) { + final GraniteConfiguration config = (GraniteConfiguration) model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float attentionScale = config.attentionScale(); + float residualScale = config.residualScale(); + float embeddingScale = config.embeddingScale(); + float logitScale = config.logitScale(); + + // copy the token embedding into x + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + // Apply Granite embedding scaling + state.x.mapInPlace(v -> v * embeddingScale); + + // forward all the layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // attention rmsnorm + rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + // qkv matmuls for this position + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE relative positional encoding + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // save key,value at this time step to kv cache + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + int curLayer = l; + + // multihead attention with Granite attention scaling + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + // Granite uses custom attention multiplier instead of 1/sqrt(headSize) + score *= attentionScale; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // final matmul to get the output of the attention + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + + // residual connection with Granite scaling + state.xb2.mapInPlace(v -> v * residualScale); + state.x.addInPlace(state.xb2); + + // ffn rmsnorm + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // FFN: self.w2(F.silu(self.w1(x)) * self.w3(x)) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + + // SwiGLU non-linearity + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + + // final matmul to get the output of the ffn + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // residual connection with Granite scaling + state.xb.mapInPlace(v -> v * residualScale); + state.x.addInPlace(state.xb); + } + + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); + + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); + + // Apply Granite logit scaling (divide by the scaling factor) + state.logits.mapInPlace(v -> v / logitScale); + + return state.logits; + } + static void copyChunk(FloatTensor in, FloatTensor out, int dim1In, int dim1Out, int nChunks, int chunkNo) { assert (dim1In == dim1Out * nChunks); final int startOffsetInDim1 = chunkNo * dim1Out; diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java index 7599f1b0..a9c65223 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java @@ -531,4 +531,138 @@ public static List generateTokensGPUPhi3(Model model, State state, int return generatedTokens; } + + /** + * Generates tokens using the Granite model with CPU inference. + * Identical pattern to generateTokensLlama but calls forwardGranite. + */ + public static List generateTokensGranite(Model model, State state, int startPosition, + List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + long startNanos = System.nanoTime(); + long inferenceStartNanos = 0; + + Object logits; + if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) { + maxTokens = model.configuration().contextLength(); + } + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; + int nextToken; + int promptIndex = 0; + int pos = startPosition; + + while (pos < maxTokens) { + // Call Granite-specific forward pass + logits = InferenceCore.forwardGranite(model, state, currentToken, pos); + + if (promptIndex < promptTokens.size()) { + nextToken = promptTokens.get(promptIndex++); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); + + return generatedTokens; + } + + /** + * Generates tokens using the Granite model with GPU (TornadoVM) inference. + * Identical pattern to generateTokensGPULlama. + */ + public static List generateTokensGPUGranite(Model model, State state, int startPosition, + List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMMasterPlan) { + long startNanos = System.nanoTime(); + long inferenceStartNanos = 0; + + Object logits; + if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) { + maxTokens = model.configuration().contextLength(); + } + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; + int nextToken; + int promptIndex = 0; + int pos = startPosition; + + while (pos < maxTokens) { + // Call TornadoVM forward pass (same as Llama for now) + logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMMasterPlan); + + if (promptIndex < promptTokens.size()) { + nextToken = promptTokens.get(promptIndex++); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); + + return generatedTokens; + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java index ce88a69b..fb46ff6e 100644 --- a/src/main/java/org/beehive/gpullama3/model/ModelType.java +++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model; +import org.beehive.gpullama3.model.loader.GraniteLoader; import org.beehive.gpullama3.tensor.GGUF; import org.beehive.gpullama3.model.loader.LlamaModelLoader; import org.beehive.gpullama3.model.loader.MistralModelLoader; @@ -64,6 +65,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo } }, + GRANITE { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new GraniteLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); + } + }, + UNKNOWN { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index e2a166b0..d3466a8e 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model.format; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; import org.beehive.gpullama3.tokenizer.LlamaTokenizer; import org.beehive.gpullama3.tokenizer.MistralTokenizer; import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; @@ -12,6 +13,7 @@ public interface ChatFormat { static ChatFormat create(Object tokenizer, ChatTokens chatTokens) { return switch (tokenizer) { + case GraniteTokenizer graniteTokenizer -> new GraniteChatFormat(graniteTokenizer); case LlamaTokenizer llamaTokenizer -> new LlamaChatFormat(llamaTokenizer); case MistralTokenizer mistralTokenizer -> new MistralChatFormat(mistralTokenizer); case Qwen3Tokenizer qwen3Tokenizer -> new Qwen3ChatFormat(qwen3Tokenizer, chatTokens); diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 80987a06..c98a72c9 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.model.format; import org.beehive.gpullama3.tokenizer.LlamaTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import java.util.ArrayList; import java.util.List; @@ -9,7 +10,7 @@ public class LlamaChatFormat implements ChatFormat { - protected final LlamaTokenizer tokenizer; + protected final Tokenizer tokenizer; protected final int beginOfText; protected final int endHeader; protected final int startHeader; @@ -18,7 +19,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfMessage; protected final Set stopTokens; - public LlamaChatFormat(LlamaTokenizer tokenizer) { + public LlamaChatFormat(Tokenizer tokenizer) { this.tokenizer = tokenizer; Map specialTokens = tokenizer.getSpecialTokens(); this.beginOfText = specialTokens.get("<|begin_of_text|>"); diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index a5e5e59c..392113be 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -48,7 +48,9 @@ private static ModelType detectModelType(Map metadata) { // Check by name first if (name != null) { String lowerName = name.toLowerCase(); - if (lowerName.contains("mistral")) { + if (lowerName.contains("granite")) { + return ModelType.GRANITE; + } else if (lowerName.contains("mistral")) { return ModelType.MISTRAL; } else if (lowerName.contains("llama")) { return ModelType.LLAMA_3; @@ -63,6 +65,11 @@ private static ModelType detectModelType(Map metadata) { } } + // Alternative: check by metadata keys if name-based detection fails + if (metadata.containsKey("granite.block_count")) { + return ModelType.GRANITE; + } + return ModelType.UNKNOWN; } From a50e0cabcbc8875f8534faaf6a2a62f925ca338b Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 15:25:59 +0200 Subject: [PATCH 02/26] Add Granite model-specific implementations for inference pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduced `GraniteStandardWeights` and `GraniteTornadoWeights` for CPU and TornadoVM GPU weight handling. - Added `GraniteState` to manage model-specific state during inference. - Implemented `GraniteTokenizer` with GPT-2-style BPE support adapted for Granite. - Added `GraniteLoader` to handle model loading, configuration, and weight initialization. - Created `GraniteConfiguration` to define model-specific parameters and scaling factors (µP parameterization). --- .../inference/state/GraniteState.java | 83 ++++++ .../standard/GraniteStandardWeights.java | 75 +++++ .../tornado/GraniteTornadoWeights.java | 61 ++++ .../model/format/GraniteChatFormat.java | 76 +++++ .../gpullama3/model/granite/Granite.java | 99 ++++++ .../model/granite/GraniteConfiguration.java | 152 ++++++++++ .../gpullama3/model/loader/GraniteLoader.java | 163 ++++++++++ .../gpullama3/tokenizer/GraniteTokenizer.java | 281 ++++++++++++++++++ 8 files changed, 990 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java create mode 100644 src/main/java/org/beehive/gpullama3/model/granite/Granite.java create mode 100644 src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java create mode 100644 src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java create mode 100644 src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java diff --git a/src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java b/src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java new file mode 100644 index 00000000..531fa951 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java @@ -0,0 +1,83 @@ +package org.beehive.gpullama3.inference.state; + +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.model.Configuration; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import java.util.stream.Stream; + +/** + * Represents the state of the Granite model during inference. + * This class extends {@link State} to include model-specific functionalities + * and configurations tailored for the Granite model. + * + *

Note: GraniteState tensor shapes are identical to LlamaState + * since Granite uses the same transformer architecture as Llama, + * with differences only in the scaling factors applied.

+ */ +public final class GraniteState extends State { + + public GraniteState(Configuration config, int batchsize) { + super(config, batchsize); + } + + @Override + protected StateFields createStateFields(Configuration config) { + StateFields fields = new StateFields(); + + // Allocation with Granite dimensions (identical to Llama) + fields.x = ArrayFloatTensor.allocate(config.dim()); + fields.xb = ArrayFloatTensor.allocate(config.dim()); + fields.xb2 = ArrayFloatTensor.allocate(config.dim()); + fields.hb = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.q = ArrayFloatTensor.allocate(config.dim()); + fields.k = ArrayFloatTensor.allocate(config.dim()); + fields.v = ArrayFloatTensor.allocate(config.dim()); + fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + + // Key-value cache with Granite dimensions + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + + // TornadoVM wrappers with Granite dimensions + fields.wrapX = new FloatArray(config.dim()); + fields.wrapXb = new FloatArray(config.dim()); + fields.wrapXb2 = new FloatArray(config.dim()); + fields.wrapHb = new FloatArray(config.hiddenDim()); + fields.wrapHb2 = new FloatArray(config.hiddenDim()); + + switch (config.quantization()) { + case "FP16" -> fields.createActivationFP16(config.dim()); + case "Q8_0" -> fields.createActivationQ8_0(config.dim()); + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); + } + fields.wrapLogits = new FloatArray(config.vocabularySize()); + fields.wrapQ = new FloatArray(config.dim()); + fields.wrapK = new FloatArray(config.dim()); + fields.wrapV = new FloatArray(config.dim()); + + fields.wrapXFP16 = new HalfFloatArray(config.dim()); + fields.wrapXbFP16 = new HalfFloatArray(config.dim()); + // dim vs kvdim + fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache.init(0.f); + fields.wrapKeyCache.init(0.f); + fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength()); + fields.positionHolder = new IntArray(1); + + // Temporary arrays + fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize)); + fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize)); + fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize)); + + return fields; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java new file mode 100644 index 00000000..0666a1f1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java @@ -0,0 +1,75 @@ +package org.beehive.gpullama3.inference.weights.standard; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; + +/** + * A model-specific implementation of {@link StandardWeights} for the Granite model. + * This class encapsulates the weights required for performing inference + * using the Granite model in the standard CPU-based format. + * + *

Note: Granite uses the same weight structure as Llama, + * with differences only in the scaling factors applied during inference.

+ */ +public class GraniteStandardWeights extends StandardWeights { + + // @formatter:off + /** + * Constructor for GraniteStandardWeights. + * + * @param token_embedding_table The token embedding table tensor. + * @param rms_att_weight Array of RMS attention weights tensors. + * @param wq Array of query weight tensors. + * @param wk Array of key weight tensors. + * @param wv Array of value weight tensors. + * @param wo Array of output weight tensors. + * @param rms_ffn_weight Array of RMS feed-forward network weights. + * @param w1 Array of first feed-forward layer weights (gate). + * @param w2 Array of second feed-forward layer weights (down). + * @param w3 Array of third feed-forward layer weights (up). + * @param rms_final_weight Final RMS weight tensor. + * @param freq_cis_real Real part of frequency cis tensor (RoPE). + * @param freq_cis_imag Imaginary part of frequency cis tensor (RoPE). + * @param wcls Output/classification weight tensor (or shared embedding). + * @param weightType The GGML weight type (FP16 or Q8_0). + */ + public GraniteStandardWeights( + FloatTensor token_embedding_table, + FloatTensor[] rms_att_weight, + FloatTensor[] wq, + FloatTensor[] wk, + FloatTensor[] wv, + FloatTensor[] wo, + FloatTensor[] rms_ffn_weight, + FloatTensor[] w1, + FloatTensor[] w2, + FloatTensor[] w3, + FloatTensor rms_final_weight, + FloatTensor freq_cis_real, + FloatTensor freq_cis_imag, + FloatTensor wcls, + GGMLType weightType) { + // call to StandardWeights constructor + super(token_embedding_table, + rms_att_weight, + wq, + wk, + wv, + wo, + rms_ffn_weight, + w1, + w2, + w3, + rms_final_weight, + freq_cis_real, + freq_cis_imag, + wcls, + weightType); + } + // @formatter:on + + @Override + public GGMLType getWeightType() { + return weightType; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java new file mode 100644 index 00000000..5257a966 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java @@ -0,0 +1,61 @@ +package org.beehive.gpullama3.inference.weights.tornado; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; + +/** + * A model-specific implementation of {@link TornadoWeights} for the Granite model. + * This class encapsulates the weights required for performing inference + * using the Granite model in the TornadoVM GPU-accelerated format. + * + *

Note: Granite uses the same weight structure as Llama in TornadoVM format, + * with differences only in the scaling factors applied during inference.

+ */ +public class GraniteTornadoWeights extends TornadoWeights { + // @formatter:off + /** + * Constructor for GraniteTornadoWeights. + * + * @param tokenEmbeddingTable The token embedding table tensor. + * @param rms_att_weightLayered Array of RMS attention weights tensors. + * @param wqLayered Array of query weight tensors. + * @param wkLayered Array of key weight tensors. + * @param wvLayered Array of value weight tensors. + * @param woLayered Array of output weight tensors. + * @param rms_ffn_weightLayered Array of RMS feed-forward network weights. + * @param w1Layered Array of first feed-forward layer weights (gate). + * @param w2Layered Array of second feed-forward layer weights (down). + * @param w3Layered Array of third feed-forward layer weights (up). + * @param rms_final_weight_as_floatArray Final RMS weight tensor. + * @param freq_cis_realFlat Real part of frequency cis tensor (RoPE). + * @param freq_cis_imagFlat Imaginary part of frequency cis tensor (RoPE). + * @param wclsByteArray Output/classification weight tensor (or shared embedding). + * @param weightType The GGML weight type (FP16 or Q8_0). + */ + public GraniteTornadoWeights( + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rms_att_weightLayered, + TornadoTensor[] wqLayered, + TornadoTensor[] wkLayered, + TornadoTensor[] wvLayered, + TornadoTensor[] woLayered, + TornadoTensor[] rms_ffn_weightLayered, + TornadoTensor[] w1Layered, + TornadoTensor[] w2Layered, + TornadoTensor[] w3Layered, + TornadoTensor rms_final_weight_as_floatArray, + TornadoTensor freq_cis_realFlat, + TornadoTensor freq_cis_imagFlat, + TornadoTensor wclsByteArray, + GGMLType weightType) { + super(tokenEmbeddingTable, rms_att_weightLayered, + wqLayered, wkLayered, wvLayered, woLayered, + rms_ffn_weightLayered, + w1Layered, w2Layered, w3Layered, + rms_final_weight_as_floatArray, + freq_cis_realFlat, freq_cis_imagFlat, + wclsByteArray, + weightType); + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java new file mode 100644 index 00000000..e5da7ad9 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java @@ -0,0 +1,76 @@ +package org.beehive.gpullama3.model.format; + +import org.beehive.gpullama3.tokenizer.Tokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Chat format for Granite models. + * + * Granite uses a different chat template than Llama: + * <|start_of_role|>system<|end_of_role|>...<|end_of_text|> + * <|start_of_role|>user<|end_of_role|>...<|end_of_text|> + * <|start_of_role|>assistant<|end_of_role|>... + */ +public class GraniteChatFormat implements ChatFormat { + + protected final Tokenizer tokenizer; + protected final int startRole; + protected final int endRole; + protected final int endOfText; + protected final Set stopTokens; + + public GraniteChatFormat(Tokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = tokenizer.getSpecialTokens(); + this.startRole = specialTokens.getOrDefault("<|start_of_role|>", -1); + this.endRole = specialTokens.getOrDefault("<|end_of_role|>", -1); + this.endOfText = specialTokens.getOrDefault("<|end_of_text|>", 0); // Token 0 is end_of_text for Granite + this.stopTokens = Set.of(endOfText); + } + + @Override + public int getBeginOfText() { + return endOfText; // For Granite, token 0 is both BOS and EOS + } + + @Override + public Set getStopTokens() { + return stopTokens; + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + if (startRole >= 0) { + tokens.add(startRole); + } + tokens.addAll(tokenizer.encodeAsList(message.role().name())); + if (endRole >= 0) { + tokens.add(endRole); + } + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = encodeHeader(message); + tokens.addAll(tokenizer.encodeAsList(message.content().strip())); + tokens.add(endOfText); + return tokens; + } + + public List encodeDialogPrompt(boolean appendAssistantTurn, List dialog) { + List tokens = new ArrayList<>(); + for (Message message : dialog) { + tokens.addAll(encodeMessage(message)); + } + if (appendAssistantTurn) { + tokens.addAll(encodeHeader(new Message(ChatFormat.Role.ASSISTANT, ""))); + } + return tokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/granite/Granite.java b/src/main/java/org/beehive/gpullama3/model/granite/Granite.java new file mode 100644 index 00000000..a4743e38 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/granite/Granite.java @@ -0,0 +1,99 @@ +package org.beehive.gpullama3.model.granite; + +import org.beehive.gpullama3.inference.InferenceCore; +import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.AbstractModel; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +public class Granite extends AbstractModel { + + private final GraniteConfiguration configuration; + + public Granite(GraniteConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(tokenizer, weights, chatFormat, null); + this.configuration = configuration; + } + + @Override + public GraniteConfiguration configuration() { + return configuration; + } + + @Override + public GraniteTokenizer tokenizer() { + return (GraniteTokenizer) tokenizer; + } + + @Override + public ModelType getModelType() { + return ModelType.GRANITE; + } + + @Override + public State createNewState() { + State state = new GraniteState(configuration(), -1); + // Granite uses token 0 (<|end_of_text|>) as BOS - it's multi-purpose + // Token 0 is the default BOS for Granite + state.latestToken = 0; + return state; + } + + @Override + public State createNewState(int batchsize) { + State state = new GraniteState(configuration(), batchsize); + // Token 0 is the default BOS for Granite + state.latestToken = 0; + return state; + } + + @Override + public void forward(State state, int token, int position) { + // Uses Granite-specific forward with scaling factors + InferenceCore.forwardGranite(this, state, token, position); + } + + @Override + public List generateTokens(State state, int startPosition, List promptTokens, + Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensGranite(this, state, startPosition, promptTokens, + stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + + @Override + public List generateTokensGPU(State state, int startPosition, List promptTokens, + Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + return InferenceEngine.generateTokensGPUGranite(this, state, startPosition, promptTokens, + stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } + + // Convenience accessors for scaling factors (used in forward pass) + public float embeddingScale() { + return configuration.embeddingScale(); + } + + public float residualScale() { + return configuration.residualScale(); + } + + public float attentionScale() { + return configuration.attentionScale(); + } + + public float logitScale() { + return configuration.logitScale(); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java b/src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java new file mode 100644 index 00000000..db03fc5a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java @@ -0,0 +1,152 @@ +package org.beehive.gpullama3.model.granite; + +import org.beehive.gpullama3.model.Configuration; + +// @formatter:off +public record GraniteConfiguration( + String quantization, + int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int vocabularySize, + int contextLength, + float rmsNormEps, + float ropeTheta, + // Granite-specific scaling factors (µP parameterization) + float embeddingMultiplier, // multiply embeddings after lookup + float residualMultiplier, // multiply residual additions + float attentionMultiplier, // replaces 1/sqrt(headDim) + float logitsScaling, // DIVIDE logits by this value + boolean tieWordEmbeddings // share input/output embeddings +) implements Configuration { + + @Override + public String quantization() { + return quantization; + } + + @Override + public int numberOfHeadsKey() { + // Granite uses standard GQA, same as Llama + return numberOfKeyValueHeads; + } + + @Override + public int contextLengthModel() { + return contextLength; + } + + /** Size of each attention head (derived from dim / numberOfHeads) */ + @Override + public int headSize() { + return dim / numberOfHeads; + } + + /** Key/value dimension (derived from dim * numberOfKeyValueHeads / numberOfHeads) */ + @Override + public int kvDim() { + return dim * numberOfKeyValueHeads / numberOfHeads; + } + + /** Multiplier for key/value sharing in grouped-query attention */ + @Override + public int kvMul() { + return numberOfHeads / numberOfKeyValueHeads; + } + + /** + * Creates a new Configuration with a different context length. + * + * @param newContextLength The new context length to use + * @return A new Configuration instance with updated context length, + * or the current instance if newContextLength is negative + */ + // @formatter:off + public GraniteConfiguration withContextLength(int newContextLength) { + if (newContextLength < 0) { + return this; // no change + } + return new GraniteConfiguration( + this.quantization, + this.dim, + this.hiddenDim, + this.numberOfLayers, + this.numberOfHeads, + this.numberOfKeyValueHeads, + this.vocabularySize, + newContextLength, + this.rmsNormEps, + this.ropeTheta, + this.embeddingMultiplier, + this.residualMultiplier, + this.attentionMultiplier, + this.logitsScaling, + this.tieWordEmbeddings + ); + } + // @formatter:on + + /** + * Accessor for embedding scale (alias for embeddingMultiplier) + */ + public float embeddingScale() { + return embeddingMultiplier; + } + + /** + * Accessor for residual scale (alias for residualMultiplier) + */ + public float residualScale() { + return residualMultiplier; + } + + /** + * Accessor for attention scale (alias for attentionMultiplier) + */ + public float attentionScale() { + return attentionMultiplier; + } + + /** + * Accessor for logit scale (alias for logitsScaling) + */ + public float logitScale() { + return logitsScaling; + } + + /** + * Factory method to create GraniteConfiguration with default scaling values + * for Granite 3.x models. + */ + public static GraniteConfiguration createDefault( + String quantization, + int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int vocabularySize, + int contextLength, + float rmsNormEps, + float ropeTheta) { + return new GraniteConfiguration( + quantization, + dim, + hiddenDim, + numberOfLayers, + numberOfHeads, + numberOfKeyValueHeads, + vocabularySize, + contextLength, + rmsNormEps, + ropeTheta, + 12.0f, // embeddingMultiplier + 0.22f, // residualMultiplier + 0.0078125f, // attentionMultiplier (1/128) + 16.0f, // logitsScaling (divisor) + true // tieWordEmbeddings + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java new file mode 100644 index 00000000..aadc15d8 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java @@ -0,0 +1,163 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.GraniteStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.granite.Granite; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; + +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.nio.channels.FileChannel; +import java.util.Map; + +import static org.beehive.gpullama3.model.loader.ModelLoader.*; + +public class GraniteLoader extends AbstractModelLoader { + + public GraniteLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); + } + + @Override + protected Vocabulary loadVocabulary(Map metadata) { + // Granite uses the same token format as Llama + return Vocabulary.loadLlamaVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new GraniteTokenizer(metadata, vocabulary); + } + + // @formatter:off + @Override + protected GraniteConfiguration createConfiguration(Map metadata) { + int vocabSize = metadata.containsKey("granite.vocab_size") + ? (int) metadata.get("granite.vocab_size") + : (int) metadata.get("tokenizer.ggml.tokens.length"); + + // Extract Granite-specific metadata keys + float embeddingScale = (float) metadata.getOrDefault("granite.embedding_scale", 12.0f); + float residualScale = (float) metadata.getOrDefault("granite.residual_scale", 0.22f); + float attentionScale = (float) metadata.getOrDefault("granite.attention.scale", 0.0078125f); + float logitScale = (float) metadata.getOrDefault("granite.logit_scale", 16.0f); + + return new GraniteConfiguration( + getModelQuantization(metadata), + (int) metadata.get("granite.embedding_length"), + (int) metadata.get("granite.feed_forward_length"), + (int) metadata.get("granite.block_count"), + (int) metadata.get("granite.attention.head_count"), + metadata.containsKey("granite.attention.head_count_kv") + ? (int) metadata.get("granite.attention.head_count_kv") + : (int) metadata.get("granite.attention.head_count"), + vocabSize, + (int) metadata.get("granite.context_length"), + (float) metadata.getOrDefault("granite.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("granite.rope.freq_base", 10000f), + embeddingScale, + residualScale, + attentionScale, + logitScale, + true // Granite ties word embeddings + ).withContextLength(contextLength); + } + // @formatter:on + + @Override + protected Pair precomputeRopeFrequencies(GraniteConfiguration config) { + return RoPE.precomputeFreqsCis( + config.contextLength(), + config.dim() / config.numberOfHeads(), + config.ropeTheta(), + false, + 1.0f, 1.0f, 1.0f, + config.contextLength()); + } + + @Override + protected Granite createModel(GraniteConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Granite(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } + + // @formatter:off + @Override + protected Weights createStandardWeights(Map tensorEntries, + GraniteConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + final int nl = config.numberOfLayers(); + + return new GraniteStandardWeights( + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTensor(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadTensor(outputWeight), + outputWeight.ggmlType()); + } + // @formatter:on + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, + GraniteConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); + } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + + return new GraniteTornadoWeights( + loadTornadoTensor(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType + ); + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java new file mode 100644 index 00000000..b5f6011e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java @@ -0,0 +1,281 @@ +package org.beehive.gpullama3.tokenizer; + +import org.beehive.gpullama3.auxiliary.Pair; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * GPT-2-style BPE tokenizer for Granite models. + *

+ * Granite uses the same refact BPE tokenization algorithm as Llama, + * but with different special tokens and token IDs. + *

+ * BOS/EOS Token: Token ID 0 (<|end_of_text|>) serves both purposes. + */ +public class GraniteTokenizer implements Tokenizer { + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + private static final String GRANITE_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + // general fields + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + // model-specific fields + private final Map, Integer> merges; + private final Map specialTokens; + + public GraniteTokenizer(Map metadata, Vocabulary vocabulary) { + // load from metadata + String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); + List> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) + .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); + int allTokens = vocabulary.size(); + + // For Granite, determine base tokens differently than Llama + // Find the first special token ID by looking at vocabulary + int baseTokens = findBaseTokensCount(vocabulary); + int reservedSpecialTokens = allTokens - baseTokens; + List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); + + assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); + + Map specialTokens = IntStream.range(0, specialTokensList.size()).boxed() + .collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); + + // init tokenizer object fields + this.vocabulary = vocabulary; + this.compiledPattern = Pattern.compile(GRANITE_PATTERN); + this.specialTokens = new HashMap<>(specialTokens); + this.merges = new HashMap<>(); + for (Pair pair : merges) { + int firstIndex = pair.first(); + int secondIndex = pair.second(); + int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); + this.merges.put(pair, mergeIndex); + } + } + + /** + * Finds the base token count by detecting where special tokens start. + * For Granite, this is typically around token ID 49000-49159. + */ + private static int findBaseTokensCount(Vocabulary vocabulary) { + int allTokens = vocabulary.size(); + // Look for special tokens in the vocabulary + // Granite special tokens typically start after regular vocabulary + // Default to 90% of vocabulary size as an estimate + for (int i = allTokens - 1; i >= 0; i--) { + String token = vocabulary.get(i); + if (token.startsWith("<|") && token.endsWith("|>")) { + // Found a special token, return this index as base + return i; + } + } + // Fallback: assume last few tokens are special + return Math.max(1000, (int) (allTokens * 0.98)); + } + + /** + * Gets the BOS (Beginning of Sequence) token ID. + * For Granite, this is token 0 (<|end_of_text|>). + */ + public int getBosTokenId() { + return 0; + } + + /** + * Gets the EOS (End of Sequence) token ID. + * For Granite, this is also token 0 (<|end_of_text|>). + */ + public int getEosTokenId() { + return 0; + } + + private static List findAll(Pattern pattern, String text) { + List allMatches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + allMatches.add(matcher.group()); + } + return allMatches; + } + + private static List merge(List ids, Pair pair, int idx) { + List newids = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { + newids.add(idx); + i += 2; + } else { + newids.add(ids.get(i)); + i += 1; + } + } + return newids; + } + + /** + * Returns list of utf-8 byte and a corresponding list of unicode strings. + */ + private static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; ++b) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n += 1; + } + } + + return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); + } + + public String regexPattern() { + if (compiledPattern == null) { + return null; + } + return compiledPattern.pattern(); + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + return !isSpecialToken(token); + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + /** + * Encode text handling special tokens. + */ + public List encode(String text, Set allowedSpecial) { + Set special = allowedSpecial; + assert getSpecialTokens().keySet().containsAll(special); + if (special.isEmpty()) { + return encodeOrdinary(text); + } + + String specialPattern = special.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")")); + String[] specialChunks = text.split(specialPattern); + List ids = new ArrayList<>(); + for (String part : specialChunks) { + if (special.contains(part)) { + ids.add(getSpecialTokens().get(part)); + } else { + ids.addAll(encodeOrdinary(part)); + } + } + return ids; + } + + /** + * Encoding that ignores any special tokens. + */ + public List encodeOrdinary(String text) { + List textChunks = findAll(compiledPattern, text); + List ids = new ArrayList<>(); + for (String chunk : textChunks) { + List chunkIds = encodeChunk(chunk); + ids.addAll(chunkIds); + } + return ids; + } + + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.put(key, map.getOrDefault(key, 0) + 1); + } + return map; + } + + private List encodeChunk(String chunk) { + List ids = new ArrayList<>(); + for (int b : chunk.toCharArray()) { + int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); + ids.add(tokenIndex); + } + + while (ids.size() >= 2) { + Map, Integer> stats = getStats(ids); + Pair pair = stats.keySet().stream() + .min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))) + .orElseThrow(); + if (!this.merges.containsKey(pair)) { + break; + } + int idx = this.merges.get(pair); + ids = merge(ids, pair, idx); + } + return ids; + } + + public String decodeImpl(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + String tokenString = vocabulary.get(token); + sb.append(tokenString); + } + return sb.toString(); + } + + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + @Override + public List encodeAsList(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return Arrays.stream(encodeImpl(sb.toString())).boxed().toList(); + } + + @Override + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length]; + for (int i = 0; i < decoded.length(); i++) { + rawBytes[i] = (byte) decodedBytesAsInts[i]; + } + return new String(rawBytes, StandardCharsets.UTF_8); + } +} From af363a86f580fc470c41d6dc90d0da294728c4a2 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 15:26:04 +0200 Subject: [PATCH 03/26] Fix Makefile indentation on 'install' target --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 3f44bac9..9e41301c 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ clean: $(MVN) clean install: - $(MVN) install -DskipTests + $(MVN) install -DskipTests # Package the project without running tests package: From c14b08616e2ac0f2f819987dd2d4e71fbfd42aa5 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 15:29:37 +0200 Subject: [PATCH 04/26] Refactor `GraniteTokenizer` to simplify special token handling and remove unused methods. --- .../gpullama3/tokenizer/GraniteTokenizer.java | 54 ++++--------------- 1 file changed, 9 insertions(+), 45 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java index b5f6011e..e5699d5a 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java @@ -42,16 +42,15 @@ public GraniteTokenizer(Map metadata, Vocabulary vocabulary) { .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); int allTokens = vocabulary.size(); - // For Granite, determine base tokens differently than Llama - // Find the first special token ID by looking at vocabulary - int baseTokens = findBaseTokensCount(vocabulary); - int reservedSpecialTokens = allTokens - baseTokens; - List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); - - assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); - - Map specialTokens = IntStream.range(0, specialTokensList.size()).boxed() - .collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); + // For Granite, collect ALL special tokens (including token 0) + Map specialTokens = new HashMap<>(); + for (int i = 0; i < allTokens; i++) { + String token = vocabulary.get(i); + // Identify special tokens by their format: start with <| and end with |> + if (token.startsWith("<|") && token.endsWith("|>")) { + specialTokens.put(token, i); + } + } // init tokenizer object fields this.vocabulary = vocabulary; @@ -66,41 +65,6 @@ public GraniteTokenizer(Map metadata, Vocabulary vocabulary) { } } - /** - * Finds the base token count by detecting where special tokens start. - * For Granite, this is typically around token ID 49000-49159. - */ - private static int findBaseTokensCount(Vocabulary vocabulary) { - int allTokens = vocabulary.size(); - // Look for special tokens in the vocabulary - // Granite special tokens typically start after regular vocabulary - // Default to 90% of vocabulary size as an estimate - for (int i = allTokens - 1; i >= 0; i--) { - String token = vocabulary.get(i); - if (token.startsWith("<|") && token.endsWith("|>")) { - // Found a special token, return this index as base - return i; - } - } - // Fallback: assume last few tokens are special - return Math.max(1000, (int) (allTokens * 0.98)); - } - - /** - * Gets the BOS (Beginning of Sequence) token ID. - * For Granite, this is token 0 (<|end_of_text|>). - */ - public int getBosTokenId() { - return 0; - } - - /** - * Gets the EOS (End of Sequence) token ID. - * For Granite, this is also token 0 (<|end_of_text|>). - */ - public int getEosTokenId() { - return 0; - } private static List findAll(Pattern pattern, String text) { List allMatches = new ArrayList<>(); From b1414cdbd377d08f0cc14fe3d4fdbfa76d73b5f9 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 18:41:35 +0200 Subject: [PATCH 05/26] Add Granite-specific TornadoVM components and FP16 inference pipelines MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduced `GraniteKernels` for optimized kernel operations with FP16 support. - Implemented `GraniteFP16FFNLayers` and `GraniteFP16LayerPlanner` for Transformer-based inference with TornadoVM. - Added `LogitsGraniteFP16Layer` to support Granite logits layer. - Enabled model-specific task graph creation and worker grid configuration tailored for Granite. - Updated `QuantizationPlannerFactory` to integrate `GraniteFP16LayerPlanner`. - Extended support for attention scaling, residual connections, and model-specific configurations (e.g., µP scaling). --- .../tornadovm/kernels/GraniteKernels.java | 197 +++++++++ .../base/QuantizationPlannerFactory.java | 3 + .../model/fp16/GraniteFP16LayerPlanner.java | 28 ++ .../tornadovm/layers/ActivationGranite.java | 70 ++++ .../type/fp16/GraniteFP16FFNLayers.java | 375 ++++++++++++++++++ .../type/fp16/LogitsGraniteFP16Layer.java | 122 ++++++ 6 files changed, 795 insertions(+) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java new file mode 100644 index 00000000..c611cc24 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java @@ -0,0 +1,197 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimized; +import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedSingle; + +public class GraniteKernels { + + public static void convertFP16toFP32withGraniteScale(KernelContext context, HalfFloatArray x, FloatArray wrapX, float embeddingScale) { + int i = context.globalIdx; + wrapX.set(i, embeddingScale * x.get(i).getFloat32()); + } + + // @formatter:off + public static void matrixVectorGenericWithGraniteScale( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize, + float logitsScale + ) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum); + } + } + + + public static void processHeadsFlashAttentionWithGraniteScale(KernelContext context, + FloatArray q, FloatArray key_cache, FloatArray value_cache, + FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, + IntArray positionHolder, int layer, int contextLength, float attentionScale) { + + // Thread and workgroup information + int tid = context.localIdx; + int h = context.groupIdx; // Each workgroup processes one head + int localSize = context.localGroupSizeX; + + // Early exit if this workgroup is beyond our head count + // This relies on the kernel being launched with nHeads workgroups. + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 16; + + // Allocate shared memory for tiled computation + float[] q_shared = context.allocateFloatLocalArray(headSize); + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); + float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max + + // Thread-local accumulators for online softmax + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + + // Thread-local output accumulation + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) { + output[i] = 0.0f; + } + + // Load query vector into shared memory + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + + // Load key and value vectors for this tile + // Each thread loads a portion of the K and V vectors for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int k_v_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + int tileMemOffset = k_v_idx_in_tile * headSize; + for (int d = 0; d < headSize; d++) { + int kvCacheAbsolutePos = tIdxInSeq; + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d; + k_tile[tileMemOffset + d] = key_cache.get(kvOffset); + v_tile[tileMemOffset + d] = value_cache.get(kvOffset); + } + } + + context.localBarrier(); + + // Compute attention scores for this tile + // Each thread computes one score for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; + } + score *= attentionScale; +// score /= TornadoMath.sqrt(headSize); + s_tile[score_idx_in_tile] = score; + } + + context.localBarrier(); + + // Find max score in this tile (all threads compute it redundantly over the small s_tile) + float tileLocalMax = Float.NEGATIVE_INFINITY; + for (int i = 0; i <= tileEnd - tileC; i++) { // Iterate over valid scores in s_tile + if (s_tile[i] > tileLocalMax) { + tileLocalMax = s_tile[i]; + } + } + + // Broadcast max to all threads via shared memory + if (tid == 0) { + shared_tile_max_holder[0] = tileLocalMax; // FIX: Use dedicated holder + } + context.localBarrier(); + float currentTileMax = shared_tile_max_holder[0]; // FIX: Read from dedicated holder + + // Determine if we need to rescale previous results + float newMax = Math.max(maxScore, currentTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) { + output[d] *= scale; + } + } + maxScore = newMax; + + // Process each key-value pair using original scores from s_tile + // All threads iterate over all scores in the current tile + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { + // s_tile[t_idx_in_s_tile] now correctly refers to the original score + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); + sumExp += expScore; + + for (int d = 0; d < headSize; d++) { + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; + } + } + context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load + } + + // Normalize and write final results + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; // Avoid division by zero, return 0 if sumExp is 0 + for (int d = tid; d < headSize; d += localSize) { + xb.set(h * headSize + d, output[d] * normFactor); + } + } + // @formatter:on + + public static void matrixVectorGenericWithResidualGranite(KernelContext context, FloatArray x, FloatArray hb, + HalfFloatArray w, int n, int d, int localWorkGroupSize, float residualScale) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float residual = residualScale * sum; + float result = hb.get(rowId) + residual; + hb.set(rowId, result); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java index 1684a5b8..68153895 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.tornadovm.layerplanner.base; +import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.Phi3State; @@ -8,6 +9,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; @@ -55,6 +57,7 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) { case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); + case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model); case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType()); }; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java new file mode 100644 index 00000000..708a8b43 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -0,0 +1,28 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; + +public class GraniteFP16LayerPlanner extends FP16LayerPlanner { + public GraniteFP16LayerPlanner(GraniteState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java new file mode 100644 index 00000000..2ac30cdf --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -0,0 +1,70 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +public class ActivationGranite extends Activation { + private final TaskGraph activationUpdate; + + public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) { + super(taskGraphHandle, state, weights, config); + + KernelContext kernelContext = new KernelContext(); + + // @formatter:off + switch (config.quantization()) { + case "FP16" -> { + this.activationUpdate = new TaskGraph(taskGraphHandle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX, config.embeddingScale()) + .persistOnDevice(state.wrapX); + } + case "Q8_0" -> { + this.activationUpdate = new TaskGraph(taskGraphHandle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, kernelContext, (ByteArray) state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + } + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); + } + // @formatter:on + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + WorkerGrid worker = new WorkerGrid1D(config.dim()); + worker.setLocalWork(128, 1, 1); + scheduler.addWorkerGrid("activationUpdate.updateX", worker); + return scheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return null; + } + + @Override + public TaskGraph getTaskGraph() { + return activationUpdate; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return activationUpdate.snapshot(); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java new file mode 100644 index 00000000..8ba4ded5 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -0,0 +1,375 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class GraniteFP16FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnTaskGraphs; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + + public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + this.ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // === FFN Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnTaskGraphs; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + List setupFFNLayered() { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); + if (i == config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); + } + + // @formatter:off + /** + * Transformer Layer Task Flow (LlamaFP16FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────────┐ + * │ attn_rms_apply_fp16 │──▶ wrapXbFP16 (normalized, FP16) + * └──────────┬──────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) + * + */ + TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + // Attention + configureAttention(unifiedLayer, layerIndex, config ); + // Output Projection (Wo) with residual + unifiedLayer.task("attn_output_proj", + GraniteKernels::matrixVectorGenericWithResidualGranite, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, + config.residualScale() + ); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + state.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual + unifiedLayer.task("ffn_down_proj", + GraniteKernels::matrixVectorGenericWithResidualGranite, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale()); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer initial data to device (one-time transfer) + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN + ); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice( + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, + // Position & misc + state.positionHolder, state.wrapXbFP16); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) { + if (schedulerType == SchedulerType.NVIDIA) { + // Flash Attention (optimized for NVIDIA GPUs) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsFlashAttentionWithGraniteScale, + context, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + state.positionHolder, + layerIndex, + config.contextLength(), + config.attentionScale() + ); + } else { + // Standard parallel attention (for non-NVIDIA backends) + return unifiedLayer.task("attention", + TransformerComputeKernelsLayered::processHeadsParallel, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + config.contextLength(), // seqLen parameter + state.positionHolder, + state.wrapAtt, // Attention weights buffer + layerIndex, + config.contextLength()); + } + } + // @formatter:on + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java new file mode 100644 index 00000000..f8a12c14 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java @@ -0,0 +1,122 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class LogitsGraniteFP16Layer extends AbstractLayer { + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + protected LogitsGraniteFP16Layer(String taskGraphName, State state, Weights weights, GraniteConfiguration config) { + super(taskGraphName, state, weights, config); + } + + // @formatter:off + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Output buffer + state.wrapLogits, + // Intermediate FP16 buffer + state.wrapXbFP16, + // Weights + weights.wclsByteArray.asHalfFloatArray(), + weights.rms_final_weight_as_floatArray.asFloatArray()); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, // in/out: combines partial sums + config.dim(), // dimension + config.rmsNormEps()); // epsilon + } + + logits.task("rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantizeLogits, + context, + state.wrapXbFP16, // output: normalized (FP16) + state.wrapX, // input: hidden state + weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights + state.tempLogits); // scale factor from reduction + + // === Vocabulary Projection === + logits.task("vocab_proj", + GraniteKernels::matrixVectorGenericWithGraniteScale, + context, + state.wrapXbFP16, // input (FP16) + state.wrapLogits, // output + weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights + config.dim(), // input dimension + config.vocabularySize(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, + config.logitScale()); // granite logit scaling + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + // @formatter:on + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } +} + From 62acbcaa1ee914b9647b311f92bab1cc7466f61a Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 20:53:48 +0200 Subject: [PATCH 06/26] Introduce parallelized attention computation in Granite kernels - Added `processHeadsParallelGranite` in `GraniteKernels` for efficient multi-head attention processing with parallelism using TornadoVM. - Updated `GraniteFP16FFNLayers` to utilize the new kernel. - Extended support for configurable attention scaling. --- .../tornadovm/kernels/GraniteKernels.java | 73 +++++++++++++++++++ .../type/fp16/GraniteFP16FFNLayers.java | 5 +- 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java index c611cc24..b362bc7a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.tornadovm.kernels; import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; @@ -194,4 +195,76 @@ public static void matrixVectorGenericWithResidualGranite(KernelContext context, hb.set(rowId, result); } } + + public static void processHeadsParallelGranite(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, + IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength, float attentionScale) { + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + + // Parallelize computation across attention heads + for (@Parallel int h = 0; h < nHeads; h++) { + // Process each head in parallel + processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt, attentionScale); + } + } + + private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, + FloatArray wrapAtt, float attentionScale) { + + // Base index for this head's attention weights + int headOffset = h * (pos + 1); + + // STEP 1: Calculate attention scores for all timesteps + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + + float score = 0.0f; + for (int i = 0; i < headSize; i++) { + score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i); + } + score *= attentionScale; //TODO: might need score = score * attentionScale; +// score = score / TornadoMath.sqrt(headSize); + + // Store in attention buffer + wrapAtt.set(headOffset + t, score); + } + + // STEP 2: Find max score for softmax stability + float maxScore = wrapAtt.get(headOffset); + for (int t = 1; t <= pos; t++) { + float val = wrapAtt.get(headOffset + t); + if (val > maxScore) { + maxScore = val; + } + } + + // STEP 3: Compute exponentials and sum + float sum = 0.0f; + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore); + wrapAtt.set(idx, expScore); + sum += expScore; + } + + // STEP 4: Normalize + float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos + 1)); + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + wrapAtt.set(idx, wrapAtt.get(idx) * normFactor); + } + + // STEP 5: Compute weighted sum of values for each dimension + for (int i = 0; i < headSize; i++) { + float weightedSum = 0.0f; + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i); + } + allXb.set(h * headSize + i, weightedSum); + } + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index 8ba4ded5..a8b955cf 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -354,7 +354,7 @@ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, Gra } else { // Standard parallel attention (for non-NVIDIA backends) return unifiedLayer.task("attention", - TransformerComputeKernelsLayered::processHeadsParallel, + GraniteKernels::processHeadsParallelGranite, state.wrapQ, // Query state.wrapKeyCache, // Key cache state.wrapValueCache, // Value cache @@ -367,7 +367,8 @@ private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, Gra state.positionHolder, state.wrapAtt, // Attention weights buffer layerIndex, - config.contextLength()); + config.contextLength(), + config.attentionScale()); } } // @formatter:on From 8a6578db44e216287b7e4d0660d503c8e645a337 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 20:54:46 +0200 Subject: [PATCH 07/26] Remove unused `LlamaTornadoWeights` import and clean up formatting in `GraniteFP16FFNLayers` --- .../tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index a8b955cf..a9c9bd78 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -3,7 +3,6 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; @@ -26,7 +25,6 @@ public class GraniteFP16FFNLayers extends AbstractFFNLayers { GridScheduler scheduler; List ffnLayerTaskGraphs; - public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { super(taskGraph, state, weights, config, schedulerType); this.ffnLayerTaskGraphs = setupFFNLayered(); From 79293b39c2506cedb8d5b36b57263611b5e8d741 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 21:28:29 +0200 Subject: [PATCH 08/26] Update README to include IBM Granite 3.1+ model support and link to Granite 3.3 collection --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a096fa4..e0267d2d 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Llama3 models written in native Java automatically accelerated on GPUs with TornadoVM. Runs Llama3 inference efficiently using TornadoVM's GPU acceleration.

-Currently, supports Llama3, Mistral, Qwen2.5, Qwen3 and Phi3 models in the GGUF format. +Currently, supports Llama3, Mistral, Qwen2.5, Qwen3 and Phi3 , IBM Granite 3.1+ models in the GGUF format. Also, it is used as GPU inference engine in Quarkus and @@ -295,6 +295,9 @@ jbang LlamaTornadoCli.java -m beehive-llama-3.2-1b-instruct-fp16.gguf \ ### Llama3.2 Collection [https://huggingface.co/collections/beehive-lab/llama3-gpullama3java](https://huggingface.co/collections/beehive-lab/llama3-gpullama3java) +### IBM Granite 3.3 Collection +[https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java) + ### Qwen 2.5 Collection [https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java](https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java) From 0808a9d360accdb645051593c796a68c1c282ca7 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 21:37:29 +0200 Subject: [PATCH 09/26] Update README to replace `TORNADO_SDK` with `TORNADOVM_HOME` for consistency in environment variable naming --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e0267d2d..aa5b3bfa 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ All pre-built SDKs are available on the TornadoVM [Releases Page](https://github wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-linux-amd64.zip unzip tornadovm-2.1.0-opencl-linux-amd64.zip # Replace manually with the absolute path of the extracted folder -export TORNADO_SDK="/tornadovm-2.1.0-opencl" +export TORNADOVM_HOME="/tornadovm-2.1.0-opencl" export PATH=$TORNADO_SDK/bin:$PATH tornado --devices @@ -102,7 +102,7 @@ tornado --version wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-mac-aarch64.zip unzip tornadovm-2.1.0-opencl-mac-aarch64.zip # Replace manually with the absolute path of the extracted folder -export TORNADO_SDK="/tornadovm-2.1.0-opencl" +export TORNADOVM_HOME="/tornadovm-2.1.0-opencl" export PATH=$TORNADO_SDK/bin:$PATH tornado --devices @@ -251,7 +251,7 @@ You can run llama-tornado as a pure Java script using [JBang](https://www.jbang. ### Prerequisites for JBang 1. **Install JBang**: Follow the [JBang installation guide](https://www.jbang.dev/download/) -2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADO_SDK` environment variable set (see Setup section above) +2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADOVM_HOME` environment variable set (see Setup section above) ### Quick Start with JBang From 7803ad26e0fae59cce76f91fa5bf67d3834031c8 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 21:50:29 +0200 Subject: [PATCH 10/26] Update `GraniteFP16LayerPlanner` to use `LogitsGraniteFP16Layer` for logits initialization --- .../layerplanner/model/fp16/GraniteFP16LayerPlanner.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java index 708a8b43..f17608dd 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -10,6 +10,7 @@ import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; public class GraniteFP16LayerPlanner extends FP16LayerPlanner { public GraniteFP16LayerPlanner(GraniteState state, Model model) { @@ -22,7 +23,7 @@ public GraniteFP16LayerPlanner(GraniteState state, Model model) { protected void initializeLayerComponents() { this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); - this.logitsLayer = new LogitsFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); } } From a0fe7da01c487cac4670aeffb714b00fa367ef94 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 21:50:52 +0200 Subject: [PATCH 11/26] Refactor `LogitsGraniteFP16Layer` to extend `LogitsFP16Layer` and update constructor for enhanced configurability --- .../layers/type/fp16/LogitsGraniteFP16Layer.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java index f8a12c14..d55d707e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java @@ -11,7 +11,6 @@ import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; -import org.beehive.gpullama3.tornadovm.layers.AbstractLayer; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.TaskGraph; @@ -19,15 +18,19 @@ import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; -public class LogitsGraniteFP16Layer extends AbstractLayer { +public class LogitsGraniteFP16Layer extends LogitsFP16Layer { private String lastTaskGraphID; private TaskGraph logitsTaskGraph; private ImmutableTaskGraph immutableLogitsGraph; private GridScheduler scheduler; private SchedulerType schedulerType; - protected LogitsGraniteFP16Layer(String taskGraphName, State state, Weights weights, GraniteConfiguration config) { - super(taskGraphName, state, weights, config); + public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); } // @formatter:off From 4cfcb59870cb740f5dc9c48ea3b45d7346831a00 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 22:15:48 +0200 Subject: [PATCH 12/26] Add TornadoVM Q8_0 support for Granite inference pipeline - Introduced `GraniteKernels` enhancements, adding Q8_0 kernel operations such as `convertQ8_0toFP32withGraniteScale` and fused matrix-vector computation. - Implemented `GraniteQ8_0FFNLayers` and `Granite8_0LayerPlanner` for layered Q8_0 inference. - Added `LogitsGraniteQ8_0Layer` to support Granite logits processing with Q8_0 quantization. - Updated `QuantizationPlannerFactory` to include Q8_0 planning for Granite models. - Enhanced `ActivationGranite` to handle Q8_0 embedding dequantization using Granite-specific scales. --- .../tornadovm/kernels/GraniteKernels.java | 75 ++++ .../base/QuantizationPlannerFactory.java | 2 + .../model/q8_0/Granite8_0LayerPlanner.java | 26 ++ .../tornadovm/layers/ActivationGranite.java | 6 +- .../type/q8_0/GraniteQ8_0FFNLayers.java | 374 ++++++++++++++++++ .../type/q8_0/LogitsGraniteQ8_0Layer.java | 118 ++++++ 6 files changed, 597 insertions(+), 4 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java index b362bc7a..c191e133 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java @@ -3,11 +3,14 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.annotations.Parallel; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimized; +import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedQ8_0Byte; import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedSingle; public class GraniteKernels { @@ -17,6 +20,39 @@ public static void convertFP16toFP32withGraniteScale(KernelContext context, Half wrapX.set(i, embeddingScale * x.get(i).getFloat32()); } + public static void convertQ8_0toFP32withGraniteScale(KernelContext context, ByteArray x, FloatArray wrapX, float embeddingScale) { + int globalId = context.globalIdx; + int totalElements = wrapX.getSize(); + + if (globalId >= totalElements) { + return; + } + + // Q8_0 block structure constants + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants + + // Calculate which block and position within block + int blockIdx = globalId / blockSize; + int withinBlockIdx = globalId % blockSize; + + // Calculate byte offset for this Q8_0 block + int blockByteOffset = blockIdx * Q8_0_BLOCK_BYTES; + + // Load scale (first 2 bytes of block as HalfFloat) + HalfFloat scale = x.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + // Load quantized value (skip 2-byte scale, then index within block) + byte quantValue = x.get(blockByteOffset + 2 + withinBlockIdx); + + // Dequantize: float_value = quantized_value * scale + float dequantizedValue = ((float) quantValue) * scaleFloat; + + // Store result in output FloatArray + wrapX.set(globalId, embeddingScale * dequantizedValue); + } + // @formatter:off public static void matrixVectorGenericWithGraniteScale( KernelContext context, @@ -267,4 +303,43 @@ private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, Fl allXb.set(h * headSize + i, weightedSum); } } + + public static void matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale(KernelContext context, FloatArray x, + FloatArray hb, ByteArray w, int n, int d, int localWorkGroupSize, float residualScale) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float residualScaledSum = residualScale * sum; + float result = hb.get(rowId) + residualScaledSum; + hb.set(rowId, result); + } + } + + public static void matrixVectorGenericQ8ByteWithGraniteScale(KernelContext context, FloatArray x, FloatArray output, ByteArray q, + int dim1, int dim0, int localWorkGroupSize, float logitsScale) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, q, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, logitsScale * sum); + } + } } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java index 68153895..a3be4266 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -14,6 +14,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Granite8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; @@ -70,6 +71,7 @@ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); + case GRANITE -> new Granite8_0LayerPlanner((GraniteState) state, model); case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType()); }; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java new file mode 100644 index 00000000..4150d52a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java @@ -0,0 +1,26 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; + +public class Granite8_0LayerPlanner extends Q8_0LayerPlanner { + + public Granite8_0LayerPlanner(GraniteState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsGraniteQ8_0Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java index 2ac30cdf..20002dac 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -1,12 +1,9 @@ package org.beehive.gpullama3.tornadovm.layers; -import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; -import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; import uk.ac.manchester.tornado.api.KernelContext; @@ -20,6 +17,7 @@ public class ActivationGranite extends Activation { private final TaskGraph activationUpdate; + // Granite is a special case where activation X is scaled by embedding scale float value that inside model. public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) { super(taskGraphHandle, state, weights, config); @@ -36,7 +34,7 @@ public ActivationGranite(String taskGraphHandle, State state, Weights weights, G case "Q8_0" -> { this.activationUpdate = new TaskGraph(taskGraphHandle) .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) - .task("updateX", TransformerComputeKernels::convertQ8_0toFP32, kernelContext, (ByteArray) state.embeddingX, state.wrapX) + .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, kernelContext, (ByteArray) state.embeddingX, state.wrapX, config.embeddingScale()) .persistOnDevice(state.wrapX); } default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java new file mode 100644 index 00000000..7907df86 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -0,0 +1,374 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { + + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return null; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + List setupFFNLayered() { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); + if (i == config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); + } + + /** + * Transformer Layer Task Flow (LlamaQ8FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌────────────────┐ + * │ attn_rms_apply │──▶ wrapXb (normalized, FP32) + * └───────┬────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fully fused: RMS reduce/apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls with Q8 dequantization (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fully fused RMS norm + W1/W3 matmuls + SiLU + GLU (5→1 kernel) + * + * Quantization: Q8_0 format (8-bit weights with block-wise scaling) + * + */ + TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Copy-in weights per layer for batched-layered layout (Q8 format) + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused with Q8 dequantization) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8, + context, + state.wrapXb, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asByteArray(), // Wq (Q8) + weights.wkLayered[layerIndex].asByteArray(), // Wk (Q8) + weights.wvLayered[layerIndex].asByteArray(), // Wv (Q8) + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + layerIndex, + config.contextLength()); + + // Attention + configureAttention(unifiedLayer, layerIndex, config); + + // Output Projection (Wo) with residual (Q8 dequantization) + unifiedLayer.task("attn_output_proj", + GraniteKernels::matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale()); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + // Fully fused: RMS apply + Gate/Up projections + SiLU + GLU (Q8 dequantization) + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + weights.w1Layered[layerIndex].asByteArray(), // W1 (Q8) + weights.w3Layered[layerIndex].asByteArray(), // W3 (Q8) + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual (Q8 dequantization) + unifiedLayer.task("ffn_down_proj", + GraniteKernels::matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale()); + + // Keep activation X on device for next layer + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, + state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder // + ); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + // === Worker Grid Definitions === + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused QKV: dim rows for Q + kvDim rows for K + kvDim rows for V + int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + // === Per-Layer Grid Assignments (ordered by task graph flow) === + for (int i = 0; i < config.numberOfLayers(); i++) { + // --- Attention Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // --- FFN Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + // Fused RMS + Gate/Up Projections + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + // Down Projection + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + + return tornadoForwardScheduler; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) { + if (schedulerType == SchedulerType.NVIDIA) { + // Flash Attention (optimized for NVIDIA GPUs) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsFlashAttentionWithGraniteScale, + context, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + state.positionHolder, + layerIndex, + config.contextLength(), + config.attentionScale() + ); + } else { + // Standard parallel attention (for non-NVIDIA backends) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsParallelGranite, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + config.contextLength(), // seqLen parameter + state.positionHolder, + state.wrapAtt, // Attention weights buffer + layerIndex, + config.contextLength(), + config.attentionScale()); + } + } + // @formatter:on +} + diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java new file mode 100644 index 00000000..d1fd4f0c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java @@ -0,0 +1,118 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer{ + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + public LogitsGraniteQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, lastTaskGraphID, schedulerType); + this.lastTaskGraphID = lastTaskGraphID; + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); + this.schedulerType = schedulerType; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; + } + + // @formatter:off + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, // + state.wrapLogits, // + weights.wclsByteArray.asByteArray(), // + weights.rms_final_weight_as_floatArray); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, + config.dim(), + config.rmsNormEps()); + } + logits.task("mapContextLogits", + TransformerComputeKernels::reductionOneBlock2WithLogits, + context, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); + + // === Vocabulary vocab_proj === + logits.task("vocab_proj", GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, // + context, + state.wrapX, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + config.dim(), + config.vocabularySize(), + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, + config.logitScale() + + ); + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + // @formatter:on + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } + +} From cc965de787c4b2fbc79adcd06e7bbb088bf6626d Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 22:58:55 +0200 Subject: [PATCH 13/26] Use `GraniteTokenizer` for EOS token retrieval in `GraniteChatFormat` to eliminate hardcoded values. --- .../gpullama3/model/format/GraniteChatFormat.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java index e5da7ad9..a5ea7947 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model.format; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; import org.beehive.gpullama3.tokenizer.Tokenizer; import java.util.ArrayList; @@ -26,9 +27,17 @@ public class GraniteChatFormat implements ChatFormat { public GraniteChatFormat(Tokenizer tokenizer) { this.tokenizer = tokenizer; Map specialTokens = tokenizer.getSpecialTokens(); + this.startRole = specialTokens.getOrDefault("<|start_of_role|>", -1); this.endRole = specialTokens.getOrDefault("<|end_of_role|>", -1); - this.endOfText = specialTokens.getOrDefault("<|end_of_text|>", 0); // Token 0 is end_of_text for Granite + + // Use tokenizer's EOS token instead of hardcoding + if (tokenizer instanceof GraniteTokenizer graniteTokenizer) { + this.endOfText = graniteTokenizer.getEosTokenId(); + } else { + this.endOfText = specialTokens.getOrDefault("<|end_of_text|>", 0); + } + this.stopTokens = Set.of(endOfText); } From dc69ae400c0c296c2454e35104b16ccc71302401 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 22:59:41 +0200 Subject: [PATCH 14/26] Add `ropeRotationWithCacheCopy` kernel to `GraniteKernels` for RoPE computation and fused cache writes --- .../tornadovm/kernels/GraniteKernels.java | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java index c191e133..4670a250 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java @@ -342,4 +342,56 @@ public static void matrixVectorGenericQ8ByteWithGraniteScale(KernelContext conte output.set(rowId, logitsScale * sum); } } + + public static void ropeRotationWithCacheCopy(KernelContext context, IntArray positionHolder, FloatArray sq, // Q vector (in/out) + FloatArray sk, // K vector (in/out) + FloatArray sv, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int kvDim, int headSize, + float ropeTheta, + int layer, int contextLength) { + + int i = context.globalIdx * 2; + int pos = positionHolder.get(0); + + // Bounds check for Q rotation (Q has dim elements, processed in pairs) + if (i + 1 < sq.getSize()) { + // RoPE frequency calculation + int head_dim = i % headSize; + // TornadoMath.pow(ropeTheta, head_dim / (float) headSize); + float freq = 1.0f / TornadoMath.pow(ropeTheta, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate K AND write to cache (only for kvDim elements) + if (i + 1 < kvDim) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + float rotated0 = v0k * fcr - v1k * fci; + float rotated1 = v0k * fci + v1k * fcr; + + // Write rotated K back to sk + sk.set(i, rotated0); + sk.set(i + 1, rotated1); + + // Direct cache write (fused - no separate copy kernel!) + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + keyCache.set(cacheOffset + i, rotated0); + keyCache.set(cacheOffset + i + 1, rotated1); + + // Copy V to cache (V doesn't need rotation) + valueCache.set(cacheOffset + i, sv.get(i)); + valueCache.set(cacheOffset + i + 1, sv.get(i + 1)); + } + } + + } } From 1aeb5e120f45c2897c80101edc55752b4b7ee23d Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:00:55 +0200 Subject: [PATCH 15/26] Refactor `GraniteLoader`: enhance metadata handling, improve formatting, and streamline tensor-loading logic. --- .../gpullama3/model/loader/GraniteLoader.java | 169 +++++++++--------- 1 file changed, 88 insertions(+), 81 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java index aadc15d8..c22491e7 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java @@ -1,10 +1,5 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.tensor.GGMLType; -import org.beehive.gpullama3.tensor.GGUF; -import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; -import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; -import org.beehive.gpullama3.tensor.GGMLTensorEntry; import org.beehive.gpullama3.auxiliary.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; @@ -13,17 +8,24 @@ import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.granite.Granite; import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; import org.beehive.gpullama3.tokenizer.GraniteTokenizer; import org.beehive.gpullama3.tokenizer.Tokenizer; import org.beehive.gpullama3.tokenizer.Vocabulary; - import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import java.nio.channels.FileChannel; import java.util.Map; -import static org.beehive.gpullama3.model.loader.ModelLoader.*; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTensors; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTornadoTensors; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensor; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTornadoTensor; public class GraniteLoader extends AbstractModelLoader { @@ -43,50 +45,56 @@ protected Tokenizer createTokenizer(Map metadata, Vocabulary voc } // @formatter:off - @Override - protected GraniteConfiguration createConfiguration(Map metadata) { - int vocabSize = metadata.containsKey("granite.vocab_size") - ? (int) metadata.get("granite.vocab_size") - : (int) metadata.get("tokenizer.ggml.tokens.length"); - - // Extract Granite-specific metadata keys - float embeddingScale = (float) metadata.getOrDefault("granite.embedding_scale", 12.0f); - float residualScale = (float) metadata.getOrDefault("granite.residual_scale", 0.22f); - float attentionScale = (float) metadata.getOrDefault("granite.attention.scale", 0.0078125f); - float logitScale = (float) metadata.getOrDefault("granite.logit_scale", 16.0f); - - return new GraniteConfiguration( - getModelQuantization(metadata), - (int) metadata.get("granite.embedding_length"), - (int) metadata.get("granite.feed_forward_length"), - (int) metadata.get("granite.block_count"), - (int) metadata.get("granite.attention.head_count"), - metadata.containsKey("granite.attention.head_count_kv") - ? (int) metadata.get("granite.attention.head_count_kv") - : (int) metadata.get("granite.attention.head_count"), - vocabSize, - (int) metadata.get("granite.context_length"), - (float) metadata.getOrDefault("granite.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("granite.rope.freq_base", 10000f), - embeddingScale, - residualScale, - attentionScale, - logitScale, - true // Granite ties word embeddings - ).withContextLength(contextLength); - } - // @formatter:on + @Override + protected GraniteConfiguration createConfiguration(Map metadata) { + int vocabSize = metadata.containsKey("granite.vocab_size") + ? (int) metadata.get("granite.vocab_size") + : (int) metadata.get("tokenizer.ggml.tokens.length"); + + // Extract Granite-specific metadata keys + float embeddingScale = (float) metadata.getOrDefault("granite.embedding_scale", 12.0f); + float residualScale = (float) metadata.getOrDefault("granite.residual_scale", 0.22f); + float attentionScale = (float) metadata.getOrDefault("granite.attention.scale", 0.0078125f); + float logitScale = (float) metadata.getOrDefault("granite.logit_scale", 16.0f); + + int kvHeads; + Object kvHeadsObj = metadata.get("granite.attention.head_count_kv"); + if (kvHeadsObj instanceof int[] kvHeadsArray) { + // Granite 4.0: per-layer array - take first value (assuming uniform for now) + kvHeads = kvHeadsArray[0]; + } else if (kvHeadsObj instanceof Integer) { + // Granite 3.3: scalar value + kvHeads = (Integer) kvHeadsObj; + } else { + // Fallback to head count (no GQA) + kvHeads = (int) metadata.get("granite.attention.head_count"); + } + + return new GraniteConfiguration( + getModelQuantization(metadata), + (int) metadata.get("granite.embedding_length"), + (int) metadata.get("granite.feed_forward_length"), + (int) metadata.get("granite.block_count"), + (int) metadata.get("granite.attention.head_count"), + kvHeads, + vocabSize, + (int) metadata.get("granite.context_length"), + (float) metadata.getOrDefault("granite.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("granite.rope.freq_base", 10000f), + embeddingScale, + residualScale, + attentionScale, + logitScale, + true // Granite ties word embeddings + ).withContextLength(contextLength); + } @Override protected Pair precomputeRopeFrequencies(GraniteConfiguration config) { - return RoPE.precomputeFreqsCis( - config.contextLength(), - config.dim() / config.numberOfHeads(), - config.ropeTheta(), - false, - 1.0f, 1.0f, 1.0f, - config.contextLength()); + return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), + false, 1.0f, 1.0f, 1.0f, config.contextLength()); } + // @formatter:on @Override protected Granite createModel(GraniteConfiguration config, Tokenizer tokenizer, Weights weights) { @@ -94,15 +102,15 @@ protected Granite createModel(GraniteConfiguration config, Tokenizer tokenizer, } // @formatter:off - @Override - protected Weights createStandardWeights(Map tensorEntries, - GraniteConfiguration config, - Pair ropeFreqs, - GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - final int nl = config.numberOfLayers(); - - return new GraniteStandardWeights( + @Override + protected Weights createStandardWeights(Map tensorEntries, + GraniteConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + final int nl = config.numberOfLayers(); + + return new GraniteStandardWeights( loadTensor(tokenEmbeddings), loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), @@ -118,30 +126,29 @@ protected Weights createStandardWeights(Map tensorEntri new ArrayFloatTensor(ropeFreqs.second()), loadTensor(outputWeight), outputWeight.ggmlType()); - } - // @formatter:on - - // @formatter:off - @Override - protected Weights createTornadoVMWeights(Map tensorEntries, - GraniteConfiguration config, - Pair ropeFreqs, - GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - GGMLType ggmlType = outputWeight.ggmlType(); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); } + // @formatter:on - // Validate supported types - if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { - throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); - } - - final int nl = config.numberOfLayers(); - - return new GraniteTornadoWeights( + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, + GraniteConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); + } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + return new GraniteTornadoWeights( loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), @@ -157,7 +164,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), loadTornadoTensor(outputWeight), ggmlType - ); - } - // @formatter:on + ); + } + // @formatter:on } From 34144a597ddb8a33d50fc0409b09153eab9e7e61 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:01:08 +0200 Subject: [PATCH 16/26] Replace `TransformerComputeKernelsLayered::ropeRotationWithCacheCopy` call with `GraniteKernels::ropeRotationWithCacheCopy` in `GraniteFP16FFNLayers` and add `ropeTheta` to arguments. --- .../tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java index a9c9bd78..0387ae54 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -232,7 +232,7 @@ TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguratio // RoPE + KV Cache unifiedLayer.task("rope_and_kv_cache", - TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + GraniteKernels::ropeRotationWithCacheCopy, context, state.positionHolder, state.wrapQ, // Q (in/out) @@ -242,6 +242,7 @@ TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguratio state.wrapValueCache, // Value cache (out) config.kvDim(), config.headSize(), + config.ropeTheta(), // needs to load it from model layerIndex, config.contextLength()); // Attention From a6b6a80b2a056fd91b9318f98e891c8c0366130f Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:02:39 +0200 Subject: [PATCH 17/26] Refactor `GraniteTokenizer`: add support for Granite 4.0, enhance metadata handling, and improve pretokenizer logic. --- .../gpullama3/tokenizer/GraniteTokenizer.java | 231 ++++++++++-------- 1 file changed, 131 insertions(+), 100 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java index e5699d5a..0c7750fd 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java @@ -18,81 +18,107 @@ /** * GPT-2-style BPE tokenizer for Granite models. *

- * Granite uses the same refact BPE tokenization algorithm as Llama, - * but with different special tokens and token IDs. - *

- * BOS/EOS Token: Token ID 0 (<|end_of_text|>) serves both purposes. + * Supports both Granite 3.3 (refact pretokenizer, 49K vocab) and Granite 4.0 (dbrx pretokenizer, 100K vocab). */ public class GraniteTokenizer implements Tokenizer { static final Map BYTE_ENCODER = bytesToUnicode(); static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); - private static final String GRANITE_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; - // general fields + // Pretokenizer patterns + private static final String REFACT_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + private static final String DBRX_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + // Instance fields private final Pattern compiledPattern; private final Vocabulary vocabulary; - // model-specific fields private final Map, Integer> merges; private final Map specialTokens; + // Token IDs (version-dependent) + private final int bosTokenId; + private final int eosTokenId; + private final int padTokenId; + private final String pretokenizerType; + public GraniteTokenizer(Map metadata, Vocabulary vocabulary) { - // load from metadata + this.vocabulary = vocabulary; + + // Detect pretokenizer type and select pattern + this.pretokenizerType = (String) metadata.getOrDefault("tokenizer.ggml.pre", "refact"); + String pattern = switch (pretokenizerType) { + case "dbrx" -> DBRX_PATTERN; + default -> REFACT_PATTERN; + }; + this.compiledPattern = Pattern.compile(pattern); + + // Read token IDs from metadata + this.bosTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.bos_token_id", 0); + this.eosTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.eos_token_id", 0); + this.padTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.padding_token_id", 0); + + // Load merges String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); - List> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) + List> mergeList = Arrays.stream(mergeLines).map(line -> line.split(" ")) .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); - int allTokens = vocabulary.size(); - // For Granite, collect ALL special tokens (including token 0) + // Collect special tokens Map specialTokens = new HashMap<>(); + int allTokens = vocabulary.size(); for (int i = 0; i < allTokens; i++) { String token = vocabulary.get(i); - // Identify special tokens by their format: start with <| and end with |> if (token.startsWith("<|") && token.endsWith("|>")) { specialTokens.put(token, i); } + // Also catch style tokens used in some Granite models + if (token.startsWith("<") && token.endsWith(">") && !token.contains(" ")) { + specialTokens.putIfAbsent(token, i); + } } + this.specialTokens = Map.copyOf(specialTokens); - // init tokenizer object fields - this.vocabulary = vocabulary; - this.compiledPattern = Pattern.compile(GRANITE_PATTERN); - this.specialTokens = new HashMap<>(specialTokens); + // Build merge map this.merges = new HashMap<>(); - for (Pair pair : merges) { - int firstIndex = pair.first(); - int secondIndex = pair.second(); - int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); + for (Pair pair : mergeList) { + String merged = vocabulary.get(pair.first()) + vocabulary.get(pair.second()); + int mergeIndex = vocabulary.getIndex(merged).orElseThrow(); this.merges.put(pair, mergeIndex); } } + private static int getIntFromMetadata(Map metadata, String key, int defaultValue) { + Object value = metadata.get(key); + if (value instanceof Number num) { + return num.intValue(); + } + return defaultValue; + } + // === Token ID accessors === private static List findAll(Pattern pattern, String text) { - List allMatches = new ArrayList<>(); + List matches = new ArrayList<>(); Matcher matcher = pattern.matcher(text); while (matcher.find()) { - allMatches.add(matcher.group()); + matches.add(matcher.group()); } - return allMatches; + return matches; } private static List merge(List ids, Pair pair, int idx) { - List newids = new ArrayList<>(); + List newIds = new ArrayList<>(); int i = 0; while (i < ids.size()) { - if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { - newids.add(idx); + if (i < ids.size() - 1 && ids.get(i).equals(pair.first()) && ids.get(i + 1).equals(pair.second())) { + newIds.add(idx); i += 2; } else { - newids.add(ids.get(i)); - i += 1; + newIds.add(ids.get(i)); + i++; } } - return newids; + return newIds; } - /** - * Returns list of utf-8 byte and a corresponding list of unicode strings. - */ private static Map bytesToUnicode() { List bs = new ArrayList<>(); IntStream.rangeClosed('!', '~').forEach(bs::add); @@ -101,22 +127,32 @@ private static Map bytesToUnicode() { List cs = new ArrayList<>(bs); int n = 0; - for (int b = 0; b < 256; ++b) { + for (int b = 0; b < 256; b++) { if (!bs.contains(b)) { bs.add(b); cs.add(256 + n); - n += 1; + n++; } } - return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); } - public String regexPattern() { - if (compiledPattern == null) { - return null; - } - return compiledPattern.pattern(); + public int getBosTokenId() { + return bosTokenId; + } + + // === Tokenizer interface === + + public int getEosTokenId() { + return eosTokenId; + } + + public int getPadTokenId() { + return padTokenId; + } + + public String getPretokenizerType() { + return pretokenizerType; } @Override @@ -124,6 +160,8 @@ public Map getSpecialTokens() { return specialTokens; } + // === Encoding === + @Override public boolean isSpecialToken(int tokenIndex) { return specialTokens.containsValue(tokenIndex); @@ -134,26 +172,44 @@ public boolean shouldDisplayToken(int token) { return !isSpecialToken(token); } + public String regexPattern() { + return compiledPattern != null ? compiledPattern.pattern() : null; + } + + // @Override + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + @Override + public List encodeAsList(String text) { + return Arrays.stream(encode(text)).boxed().toList(); + } + private int[] encodeImpl(String text) { return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); } - /** - * Encode text handling special tokens. - */ + // === Decoding === + public List encode(String text, Set allowedSpecial) { - Set special = allowedSpecial; - assert getSpecialTokens().keySet().containsAll(special); - if (special.isEmpty()) { + if (allowedSpecial.isEmpty()) { return encodeOrdinary(text); } - String specialPattern = special.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")")); + assert specialTokens.keySet().containsAll(allowedSpecial); + String specialPattern = allowedSpecial.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")")); String[] specialChunks = text.split(specialPattern); + List ids = new ArrayList<>(); for (String part : specialChunks) { - if (special.contains(part)) { - ids.add(getSpecialTokens().get(part)); + if (allowedSpecial.contains(part)) { + ids.add(specialTokens.get(part)); } else { ids.addAll(encodeOrdinary(part)); } @@ -161,85 +217,60 @@ public List encode(String text, Set allowedSpecial) { return ids; } - /** - * Encoding that ignores any special tokens. - */ public List encodeOrdinary(String text) { List textChunks = findAll(compiledPattern, text); List ids = new ArrayList<>(); for (String chunk : textChunks) { - List chunkIds = encodeChunk(chunk); - ids.addAll(chunkIds); + ids.addAll(encodeChunk(chunk)); } return ids; } - private Map, Integer> getStats(List ids) { - Map, Integer> map = new HashMap<>(); - for (int i = 0; i + 1 < ids.size(); i++) { - Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); - map.put(key, map.getOrDefault(key, 0) + 1); - } - return map; - } + // === Helpers === private List encodeChunk(String chunk) { List ids = new ArrayList<>(); - for (int b : chunk.toCharArray()) { - int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); + for (char c : chunk.toCharArray()) { + int tokenIndex = vocabulary.getIndex(String.valueOf(c)).orElseThrow(); ids.add(tokenIndex); } while (ids.size() >= 2) { Map, Integer> stats = getStats(ids); - Pair pair = stats.keySet().stream() - .min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))) - .orElseThrow(); - if (!this.merges.containsKey(pair)) { + Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); + if (!merges.containsKey(pair)) { break; } - int idx = this.merges.get(pair); - ids = merge(ids, pair, idx); + ids = merge(ids, pair, merges.get(pair)); } return ids; } - public String decodeImpl(List tokens) { - StringBuilder sb = new StringBuilder(); - for (int token : tokens) { - String tokenString = vocabulary.get(token); - sb.append(tokenString); - } - return sb.toString(); - } - - public int[] encode(String text) { - StringBuilder sb = new StringBuilder(); - byte[] bytes = text.getBytes(StandardCharsets.UTF_8); - for (byte b : bytes) { - sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + @Override + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + int[] decodedBytesAsInts = decoded.codePoints().map(cp -> BYTE_DECODER.getOrDefault(cp, cp)).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length]; + for (int i = 0; i < decodedBytesAsInts.length; i++) { + rawBytes[i] = (byte) decodedBytesAsInts[i]; } - return encodeImpl(sb.toString()); + return new String(rawBytes, StandardCharsets.UTF_8); } - @Override - public List encodeAsList(String text) { + private String decodeImpl(List tokens) { StringBuilder sb = new StringBuilder(); - byte[] bytes = text.getBytes(StandardCharsets.UTF_8); - for (byte b : bytes) { - sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + for (int token : tokens) { + sb.append(vocabulary.get(token)); } - return Arrays.stream(encodeImpl(sb.toString())).boxed().toList(); + return sb.toString(); } - @Override - public String decode(List tokens) { - String decoded = decodeImpl(tokens); - int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); - byte[] rawBytes = new byte[decodedBytesAsInts.length]; - for (int i = 0; i < decoded.length(); i++) { - rawBytes[i] = (byte) decodedBytesAsInts[i]; + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.merge(key, 1, Integer::sum); } - return new String(rawBytes, StandardCharsets.UTF_8); + return map; } -} +} \ No newline at end of file From aafa1561203b2c7013ffa60773cfff657fd6ad0e Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:06:27 +0200 Subject: [PATCH 18/26] Replace `TransformerComputeKernelsLayered::ropeRotationWithCacheCopy` with `GraniteKernels::ropeRotationWithCacheCopy` in `GraniteQ8_0FFNLayers` and add `ropeTheta` to kernel arguments. --- .../tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java index 7907df86..b7e036d4 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.Granite; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; @@ -196,7 +197,7 @@ TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguratio // RoPE + KV Cache unifiedLayer.task("rope_and_kv_cache", - TransformerComputeKernelsLayered::ropeRotationWithCacheCopy, + GraniteKernels::ropeRotationWithCacheCopy, context, state.positionHolder, state.wrapQ, // Q (in/out) @@ -206,6 +207,7 @@ TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguratio state.wrapValueCache, // Value cache (out) config.kvDim(), config.headSize(), + config.ropeTheta(), layerIndex, config.contextLength()); From a8759a2e72efa3019507af1b3f075734d88eb31a Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:10:16 +0200 Subject: [PATCH 19/26] Update README to add IBM Granite 4.0 collection link --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index aa5b3bfa..ca18fa92 100644 --- a/README.md +++ b/README.md @@ -295,6 +295,10 @@ jbang LlamaTornadoCli.java -m beehive-llama-3.2-1b-instruct-fp16.gguf \ ### Llama3.2 Collection [https://huggingface.co/collections/beehive-lab/llama3-gpullama3java](https://huggingface.co/collections/beehive-lab/llama3-gpullama3java) +### IBM Granite 4.0 Collection +[https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java) + + ### IBM Granite 3.3 Collection [https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java) From c1459d9f5cd79163f9823abb99dd85f9e95f7a87 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:14:47 +0200 Subject: [PATCH 20/26] Update README to include support for Phi-3, IBM Granite 3.2+, and IBM Granite 4.0 models --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ca18fa92..e1591ea8 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Llama3 models written in native Java automatically accelerated on GPUs with TornadoVM. Runs Llama3 inference efficiently using TornadoVM's GPU acceleration.

-Currently, supports Llama3, Mistral, Qwen2.5, Qwen3 and Phi3 , IBM Granite 3.1+ models in the GGUF format. +Currently, supports Llama3, Mistral, Qwen2.5, Qwen3, Phi-3, IBM Granite 3.2+ and IBM Granite 4.0 models in the GGUF format. Also, it is used as GPU inference engine in Quarkus and From 25716144e067f31f8a29dd7b832742af1a62ef56 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 17 Dec 2025 23:30:48 +0200 Subject: [PATCH 21/26] Add build step for running Granite 3.2 model in CI pipelineAdd workflow step to run Granite-3.2-2b-instruct-f16.gguf during CI --- .github/workflows/build-and-run.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 249f6257..d53298f7 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -128,6 +128,13 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model /$MODELS_DIR/Phi-3-mini-4k-instruct-fp16.gguf \ --prompt "Say hello" + - name: FP16 - Run Granite-3.2-2b-instruct-f16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-3.2-2b-instruct-f16.gguf \ + --prompt "Say hello" - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf run: | cd ${{ github.workspace }} From a45b532ac314ae78396becd880e00d67c0497a4b Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 18 Dec 2025 00:04:55 +0200 Subject: [PATCH 22/26] Add workflow step to run Granite-3.2-2b-instruct-Q8.gguf during CI --- .github/workflows/build-and-run.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index d53298f7..ff8a6b5b 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -170,3 +170,10 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model $MODELS_DIR/Mistral-7B-Instruct-v0.3.Q8_0.gguf \ --prompt "Say hello" + - name: Q8 - Run Granite-3.2-2b-instruct-Q8.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-3.2-2b-instruct-Q8_0.gguf \ + --prompt "Say hello" From c2b91c44af1ae8b32d969fe4d38addc28fac7061 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 18 Dec 2025 01:22:48 +0200 Subject: [PATCH 23/26] Add workflow steps to run Granite-4.0-1b-F16.gguf and Granite-4.0-1b-Q8_0.gguf during CI --- .github/workflows/build-and-run.yml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index ff8a6b5b..1551f320 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -135,6 +135,13 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model /$MODELS_DIR/granite-3.2-2b-instruct-f16.gguf \ --prompt "Say hello" + - name: FP16 - Run Granite-4.0-1b-F16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-4.0-1b-F16.gguf \ + --prompt "Say hello" - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf run: | cd ${{ github.workspace }} @@ -177,3 +184,11 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model /$MODELS_DIR/granite-3.2-2b-instruct-Q8_0.gguf \ --prompt "Say hello" + - name: Q8 - Run Granite-4.0-1b-Q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-4.0-1b-Q8_0.gguf \ + --prompt "Say hello" + From 49eb298c27a90a10099511be7a027ce055b27a90 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 18 Dec 2025 09:59:56 +0200 Subject: [PATCH 24/26] Fix logits scaling in `GraniteKernels`: correct scaling order for `hb` and `output` writes. --- .../java/org/beehive/gpullama3/inference/InferenceCore.java | 2 +- .../beehive/gpullama3/tornadovm/kernels/GraniteKernels.java | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 061ff3ed..b679c25b 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -663,7 +663,7 @@ public static FloatTensor forwardGranite(Model model, State state, int token, in weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); // Apply Granite logit scaling (divide by the scaling factor) - state.logits.mapInPlace(v -> v / logitScale); + state.logits.mapInPlace(v -> v * logitScale); return state.logits; } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java index 4670a250..36b9803e 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java @@ -77,7 +77,7 @@ public static void matrixVectorGenericWithGraniteScale( // Thread 0 in each workgroup writes the final result if (localId == 0) { - hb.set(rowId, sum); + hb.set(rowId, sum * logitsScale); } } @@ -156,7 +156,6 @@ public static void processHeadsFlashAttentionWithGraniteScale(KernelContext cont score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; } score *= attentionScale; -// score /= TornadoMath.sqrt(headSize); s_tile[score_idx_in_tile] = score; } @@ -339,7 +338,7 @@ public static void matrixVectorGenericQ8ByteWithGraniteScale(KernelContext conte // Thread 0 writes the result if (localId == 0) { - output.set(rowId, logitsScale * sum); + output.set(rowId, sum * logitsScale); } } From 150c5eefdd2b3a0ddca5317b8c795d0aa3430aae Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 18 Dec 2025 10:00:02 +0200 Subject: [PATCH 25/26] Remove unused imports from `GraniteFP16LayerPlanner`. --- .../layerplanner/model/fp16/GraniteFP16LayerPlanner.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java index f17608dd..7cc97d64 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -5,11 +5,8 @@ import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layers.Activation; import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LlamaFP16FFNLayers; -import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsFP16Layer; import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; public class GraniteFP16LayerPlanner extends FP16LayerPlanner { From 59eb425d567e40d955431b780ba02ec95370c203 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Thu, 18 Dec 2025 10:41:51 +0200 Subject: [PATCH 26/26] Rename `Granite8_0LayerPlanner` to `GraniteQ8_0LayerPlanner` for consistency with naming conventions across Q8_0 models. --- .../layerplanner/base/QuantizationPlannerFactory.java | 4 ++-- ...anite8_0LayerPlanner.java => GraniteQ8_0LayerPlanner.java} | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/{Granite8_0LayerPlanner.java => GraniteQ8_0LayerPlanner.java} (86%) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java index a3be4266..ca844e51 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -14,7 +14,7 @@ import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; -import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Granite8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; @@ -71,7 +71,7 @@ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); - case GRANITE -> new Granite8_0LayerPlanner((GraniteState) state, model); + case GRANITE -> new GraniteQ8_0LayerPlanner((GraniteState) state, model); case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType()); }; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java similarity index 86% rename from src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java index 4150d52a..ee818080 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Granite8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java @@ -9,9 +9,9 @@ import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; -public class Granite8_0LayerPlanner extends Q8_0LayerPlanner { +public class GraniteQ8_0LayerPlanner extends Q8_0LayerPlanner { - public Granite8_0LayerPlanner(GraniteState state, Model model) { + public GraniteQ8_0LayerPlanner(GraniteState state, Model model) { super(state, model); validateQuantizationType(); setupTornadoForwardPlan();