diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 249f6257..1551f320 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -128,6 +128,20 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model /$MODELS_DIR/Phi-3-mini-4k-instruct-fp16.gguf \ --prompt "Say hello" + - name: FP16 - Run Granite-3.2-2b-instruct-f16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-3.2-2b-instruct-f16.gguf \ + --prompt "Say hello" + - name: FP16 - Run Granite-4.0-1b-F16.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-4.0-1b-F16.gguf \ + --prompt "Say hello" - name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf run: | cd ${{ github.workspace }} @@ -163,3 +177,18 @@ jobs: ./llama-tornado --gpu --${{ matrix.backend.name }} \ --model $MODELS_DIR/Mistral-7B-Instruct-v0.3.Q8_0.gguf \ --prompt "Say hello" + - name: Q8 - Run Granite-3.2-2b-instruct-Q8.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-3.2-2b-instruct-Q8_0.gguf \ + --prompt "Say hello" + - name: Q8 - Run Granite-4.0-1b-Q8_0.gguf + run: | + cd ${{ github.workspace }} + export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH" + ./llama-tornado --gpu --${{ matrix.backend.name }} \ + --model /$MODELS_DIR/granite-4.0-1b-Q8_0.gguf \ + --prompt "Say hello" + diff --git a/Makefile b/Makefile index 3f44bac9..9e41301c 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ clean: $(MVN) clean install: - $(MVN) install -DskipTests + $(MVN) install -DskipTests # Package the project without running tests package: diff --git a/README.md b/README.md index 5a096fa4..e1591ea8 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Llama3 models written in native Java automatically accelerated on GPUs with TornadoVM. Runs Llama3 inference efficiently using TornadoVM's GPU acceleration.

-Currently, supports Llama3, Mistral, Qwen2.5, Qwen3 and Phi3 models in the GGUF format. +Currently, supports Llama3, Mistral, Qwen2.5, Qwen3, Phi-3, IBM Granite 3.2+ and IBM Granite 4.0 models in the GGUF format. Also, it is used as GPU inference engine in Quarkus and @@ -89,7 +89,7 @@ All pre-built SDKs are available on the TornadoVM [Releases Page](https://github wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-linux-amd64.zip unzip tornadovm-2.1.0-opencl-linux-amd64.zip # Replace manually with the absolute path of the extracted folder -export TORNADO_SDK="/tornadovm-2.1.0-opencl" +export TORNADOVM_HOME="/tornadovm-2.1.0-opencl" export PATH=$TORNADO_SDK/bin:$PATH tornado --devices @@ -102,7 +102,7 @@ tornado --version wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-mac-aarch64.zip unzip tornadovm-2.1.0-opencl-mac-aarch64.zip # Replace manually with the absolute path of the extracted folder -export TORNADO_SDK="/tornadovm-2.1.0-opencl" +export TORNADOVM_HOME="/tornadovm-2.1.0-opencl" export PATH=$TORNADO_SDK/bin:$PATH tornado --devices @@ -251,7 +251,7 @@ You can run llama-tornado as a pure Java script using [JBang](https://www.jbang. ### Prerequisites for JBang 1. **Install JBang**: Follow the [JBang installation guide](https://www.jbang.dev/download/) -2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADO_SDK` environment variable set (see Setup section above) +2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADOVM_HOME` environment variable set (see Setup section above) ### Quick Start with JBang @@ -295,6 +295,13 @@ jbang LlamaTornadoCli.java -m beehive-llama-3.2-1b-instruct-fp16.gguf \ ### Llama3.2 Collection [https://huggingface.co/collections/beehive-lab/llama3-gpullama3java](https://huggingface.co/collections/beehive-lab/llama3-gpullama3java) +### IBM Granite 4.0 Collection +[https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java) + + +### IBM Granite 3.3 Collection +[https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java) + ### Qwen 2.5 Collection [https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java](https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java) diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 475f711e..b679c25b 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -11,6 +11,7 @@ import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; @@ -546,6 +547,127 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke return state.logits; } + /** + * Forward pass for Granite models with µP scaling factors applied. + *

+ * Granite uses the same transformer architecture as Llama but with maximal update parameterization (µP) + * scaling factors applied at specific points: + *

    + *
  • Embedding scaling: multiply embeddings after lookup
  • + *
  • Attention scaling: use custom multiplier instead of 1/sqrt(headDim)
  • + *
  • Residual scaling: multiply residual connections
  • + *
  • Logit scaling: divide logits by the scaling factor
  • + *
+ */ + public static FloatTensor forwardGranite(Model model, State state, int token, int position) { + final GraniteConfiguration config = (GraniteConfiguration) model.configuration(); + final StandardWeights weights = (StandardWeights) model.weights(); + int dim = config.dim(); + int headSize = config.headSize(); + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); + float attentionScale = config.attentionScale(); + float residualScale = config.residualScale(); + float embeddingScale = config.embeddingScale(); + float logitScale = config.logitScale(); + + // copy the token embedding into x + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); + // Apply Granite embedding scaling + state.x.mapInPlace(v -> v * embeddingScale); + + // forward all the layers + for (int l = 0; l < config.numberOfLayers(); l++) { + // attention rmsnorm + rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps()); + + // qkv matmuls for this position + weights.wq[l].matmul(state.xb, state.q, dim, dim); + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + + // RoPE relative positional encoding + for (int i = 0; i < dim; i += 2) { + int head_dim = i % headSize; + float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2)); + float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2)); + int rotn = i < kvDim ? 2 : 1; + for (int v = 0; v < rotn; v++) { + FloatTensor vec = v == 0 ? state.q : state.k; + float v0 = vec.getFloat(i); + float v1 = vec.getFloat(i + 1); + vec.setFloat(i, v0 * fcr - v1 * fci); + vec.setFloat(i + 1, v0 * fci + v1 * fcr); + } + } + + // save key,value at this time step to kv cache + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + + int curLayer = l; + + // multihead attention with Granite attention scaling + Parallel.parallelFor(0, config.numberOfHeads(), h -> { + int qOffset = h * headSize; + int attOffset = h * config.contextLength(); + + for (int t = 0; t <= position; t++) { + int keyCacheOffset = t * kvDim + (h / kvMul) * headSize; + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); + // Granite uses custom attention multiplier instead of 1/sqrt(headSize) + score *= attentionScale; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset, position + 1); + + int xbOffset = h * headSize; + state.xb.fillInPlace(xbOffset, headSize, 0f); + + for (int t = 0; t <= position; t++) { + int vOffset = t * kvDim + (h / kvMul) * headSize; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); + } + }); + + // final matmul to get the output of the attention + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); + + // residual connection with Granite scaling + state.xb2.mapInPlace(v -> v * residualScale); + state.x.addInPlace(state.xb2); + + // ffn rmsnorm + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps()); + + // FFN: self.w2(F.silu(self.w1(x)) * self.w3(x)) + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); + + // SwiGLU non-linearity + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); + state.hb.multiplyInPlace(state.hb2); + + // final matmul to get the output of the ffn + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); + + // residual connection with Granite scaling + state.xb.mapInPlace(v -> v * residualScale); + state.x.addInPlace(state.xb); + } + + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); + + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); + + // Apply Granite logit scaling (divide by the scaling factor) + state.logits.mapInPlace(v -> v * logitScale); + + return state.logits; + } + static void copyChunk(FloatTensor in, FloatTensor out, int dim1In, int dim1Out, int nChunks, int chunkNo) { assert (dim1In == dim1Out * nChunks); final int startOffsetInDim1 = chunkNo * dim1Out; diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java b/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java index 7599f1b0..a9c65223 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java @@ -531,4 +531,138 @@ public static List generateTokensGPUPhi3(Model model, State state, int return generatedTokens; } + + /** + * Generates tokens using the Granite model with CPU inference. + * Identical pattern to generateTokensLlama but calls forwardGranite. + */ + public static List generateTokensGranite(Model model, State state, int startPosition, + List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + long startNanos = System.nanoTime(); + long inferenceStartNanos = 0; + + Object logits; + if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) { + maxTokens = model.configuration().contextLength(); + } + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; + int nextToken; + int promptIndex = 0; + int pos = startPosition; + + while (pos < maxTokens) { + // Call Granite-specific forward pass + logits = InferenceCore.forwardGranite(model, state, currentToken, pos); + + if (promptIndex < promptTokens.size()) { + nextToken = promptTokens.get(promptIndex++); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); + + return generatedTokens; + } + + /** + * Generates tokens using the Granite model with GPU (TornadoVM) inference. + * Identical pattern to generateTokensGPULlama. + */ + public static List generateTokensGPUGranite(Model model, State state, int startPosition, + List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMMasterPlan) { + long startNanos = System.nanoTime(); + long inferenceStartNanos = 0; + + Object logits; + if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) { + maxTokens = model.configuration().contextLength(); + } + + List generatedTokens = new ArrayList<>(); + + int currentToken = state.latestToken; + int nextToken; + int promptIndex = 0; + int pos = startPosition; + + while (pos < maxTokens) { + // Call TornadoVM forward pass (same as Llama for now) + logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMMasterPlan); + + if (promptIndex < promptTokens.size()) { + nextToken = promptTokens.get(promptIndex++); + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + } else { + if (inferenceStartNanos == 0) { + inferenceStartNanos = System.nanoTime(); + } + + nextToken = sampler.sampleToken(logits); + + if (echo) { + System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); + } + + generatedTokens.add(nextToken); + + if (onTokenGenerated != null) { + onTokenGenerated.accept(nextToken); + } + + if (stopTokens.contains(nextToken)) { + break; + } + } + + currentToken = nextToken; + state.latestToken = currentToken; + pos++; + } + + long endNanos = System.nanoTime(); + double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; + int totalTokens = promptIndex + generatedTokens.size(); + + LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); + + return generatedTokens; + } } \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java b/src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java new file mode 100644 index 00000000..531fa951 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/state/GraniteState.java @@ -0,0 +1,83 @@ +package org.beehive.gpullama3.inference.state; + +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.model.Configuration; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import java.util.stream.Stream; + +/** + * Represents the state of the Granite model during inference. + * This class extends {@link State} to include model-specific functionalities + * and configurations tailored for the Granite model. + * + *

Note: GraniteState tensor shapes are identical to LlamaState + * since Granite uses the same transformer architecture as Llama, + * with differences only in the scaling factors applied.

+ */ +public final class GraniteState extends State { + + public GraniteState(Configuration config, int batchsize) { + super(config, batchsize); + } + + @Override + protected StateFields createStateFields(Configuration config) { + StateFields fields = new StateFields(); + + // Allocation with Granite dimensions (identical to Llama) + fields.x = ArrayFloatTensor.allocate(config.dim()); + fields.xb = ArrayFloatTensor.allocate(config.dim()); + fields.xb2 = ArrayFloatTensor.allocate(config.dim()); + fields.hb = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim()); + fields.q = ArrayFloatTensor.allocate(config.dim()); + fields.k = ArrayFloatTensor.allocate(config.dim()); + fields.v = ArrayFloatTensor.allocate(config.dim()); + fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength()); + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + + // Key-value cache with Granite dimensions + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); + fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), kvDim)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); + + // TornadoVM wrappers with Granite dimensions + fields.wrapX = new FloatArray(config.dim()); + fields.wrapXb = new FloatArray(config.dim()); + fields.wrapXb2 = new FloatArray(config.dim()); + fields.wrapHb = new FloatArray(config.hiddenDim()); + fields.wrapHb2 = new FloatArray(config.hiddenDim()); + + switch (config.quantization()) { + case "FP16" -> fields.createActivationFP16(config.dim()); + case "Q8_0" -> fields.createActivationQ8_0(config.dim()); + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); + } + fields.wrapLogits = new FloatArray(config.vocabularySize()); + fields.wrapQ = new FloatArray(config.dim()); + fields.wrapK = new FloatArray(config.dim()); + fields.wrapV = new FloatArray(config.dim()); + + fields.wrapXFP16 = new HalfFloatArray(config.dim()); + fields.wrapXbFP16 = new HalfFloatArray(config.dim()); + // dim vs kvdim + fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers()); + fields.wrapValueCache.init(0.f); + fields.wrapKeyCache.init(0.f); + fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength()); + fields.positionHolder = new IntArray(1); + + // Temporary arrays + fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize)); + fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize)); + fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize)); + + return fields; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java new file mode 100644 index 00000000..0666a1f1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/GraniteStandardWeights.java @@ -0,0 +1,75 @@ +package org.beehive.gpullama3.inference.weights.standard; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; + +/** + * A model-specific implementation of {@link StandardWeights} for the Granite model. + * This class encapsulates the weights required for performing inference + * using the Granite model in the standard CPU-based format. + * + *

Note: Granite uses the same weight structure as Llama, + * with differences only in the scaling factors applied during inference.

+ */ +public class GraniteStandardWeights extends StandardWeights { + + // @formatter:off + /** + * Constructor for GraniteStandardWeights. + * + * @param token_embedding_table The token embedding table tensor. + * @param rms_att_weight Array of RMS attention weights tensors. + * @param wq Array of query weight tensors. + * @param wk Array of key weight tensors. + * @param wv Array of value weight tensors. + * @param wo Array of output weight tensors. + * @param rms_ffn_weight Array of RMS feed-forward network weights. + * @param w1 Array of first feed-forward layer weights (gate). + * @param w2 Array of second feed-forward layer weights (down). + * @param w3 Array of third feed-forward layer weights (up). + * @param rms_final_weight Final RMS weight tensor. + * @param freq_cis_real Real part of frequency cis tensor (RoPE). + * @param freq_cis_imag Imaginary part of frequency cis tensor (RoPE). + * @param wcls Output/classification weight tensor (or shared embedding). + * @param weightType The GGML weight type (FP16 or Q8_0). + */ + public GraniteStandardWeights( + FloatTensor token_embedding_table, + FloatTensor[] rms_att_weight, + FloatTensor[] wq, + FloatTensor[] wk, + FloatTensor[] wv, + FloatTensor[] wo, + FloatTensor[] rms_ffn_weight, + FloatTensor[] w1, + FloatTensor[] w2, + FloatTensor[] w3, + FloatTensor rms_final_weight, + FloatTensor freq_cis_real, + FloatTensor freq_cis_imag, + FloatTensor wcls, + GGMLType weightType) { + // call to StandardWeights constructor + super(token_embedding_table, + rms_att_weight, + wq, + wk, + wv, + wo, + rms_ffn_weight, + w1, + w2, + w3, + rms_final_weight, + freq_cis_real, + freq_cis_imag, + wcls, + weightType); + } + // @formatter:on + + @Override + public GGMLType getWeightType() { + return weightType; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java new file mode 100644 index 00000000..5257a966 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/GraniteTornadoWeights.java @@ -0,0 +1,61 @@ +package org.beehive.gpullama3.inference.weights.tornado; + +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; + +/** + * A model-specific implementation of {@link TornadoWeights} for the Granite model. + * This class encapsulates the weights required for performing inference + * using the Granite model in the TornadoVM GPU-accelerated format. + * + *

Note: Granite uses the same weight structure as Llama in TornadoVM format, + * with differences only in the scaling factors applied during inference.

+ */ +public class GraniteTornadoWeights extends TornadoWeights { + // @formatter:off + /** + * Constructor for GraniteTornadoWeights. + * + * @param tokenEmbeddingTable The token embedding table tensor. + * @param rms_att_weightLayered Array of RMS attention weights tensors. + * @param wqLayered Array of query weight tensors. + * @param wkLayered Array of key weight tensors. + * @param wvLayered Array of value weight tensors. + * @param woLayered Array of output weight tensors. + * @param rms_ffn_weightLayered Array of RMS feed-forward network weights. + * @param w1Layered Array of first feed-forward layer weights (gate). + * @param w2Layered Array of second feed-forward layer weights (down). + * @param w3Layered Array of third feed-forward layer weights (up). + * @param rms_final_weight_as_floatArray Final RMS weight tensor. + * @param freq_cis_realFlat Real part of frequency cis tensor (RoPE). + * @param freq_cis_imagFlat Imaginary part of frequency cis tensor (RoPE). + * @param wclsByteArray Output/classification weight tensor (or shared embedding). + * @param weightType The GGML weight type (FP16 or Q8_0). + */ + public GraniteTornadoWeights( + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] rms_att_weightLayered, + TornadoTensor[] wqLayered, + TornadoTensor[] wkLayered, + TornadoTensor[] wvLayered, + TornadoTensor[] woLayered, + TornadoTensor[] rms_ffn_weightLayered, + TornadoTensor[] w1Layered, + TornadoTensor[] w2Layered, + TornadoTensor[] w3Layered, + TornadoTensor rms_final_weight_as_floatArray, + TornadoTensor freq_cis_realFlat, + TornadoTensor freq_cis_imagFlat, + TornadoTensor wclsByteArray, + GGMLType weightType) { + super(tokenEmbeddingTable, rms_att_weightLayered, + wqLayered, wkLayered, wvLayered, woLayered, + rms_ffn_weightLayered, + w1Layered, w2Layered, w3Layered, + rms_final_weight_as_floatArray, + freq_cis_realFlat, freq_cis_imagFlat, + wclsByteArray, + weightType); + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java index ce88a69b..fb46ff6e 100644 --- a/src/main/java/org/beehive/gpullama3/model/ModelType.java +++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model; +import org.beehive.gpullama3.model.loader.GraniteLoader; import org.beehive.gpullama3.tensor.GGUF; import org.beehive.gpullama3.model.loader.LlamaModelLoader; import org.beehive.gpullama3.model.loader.MistralModelLoader; @@ -64,6 +65,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo } }, + GRANITE { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new GraniteLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); + } + }, + UNKNOWN { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index e2a166b0..d3466a8e 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.model.format; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; import org.beehive.gpullama3.tokenizer.LlamaTokenizer; import org.beehive.gpullama3.tokenizer.MistralTokenizer; import org.beehive.gpullama3.tokenizer.Phi3Tokenizer; @@ -12,6 +13,7 @@ public interface ChatFormat { static ChatFormat create(Object tokenizer, ChatTokens chatTokens) { return switch (tokenizer) { + case GraniteTokenizer graniteTokenizer -> new GraniteChatFormat(graniteTokenizer); case LlamaTokenizer llamaTokenizer -> new LlamaChatFormat(llamaTokenizer); case MistralTokenizer mistralTokenizer -> new MistralChatFormat(mistralTokenizer); case Qwen3Tokenizer qwen3Tokenizer -> new Qwen3ChatFormat(qwen3Tokenizer, chatTokens); diff --git a/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java new file mode 100644 index 00000000..a5ea7947 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/GraniteChatFormat.java @@ -0,0 +1,85 @@ +package org.beehive.gpullama3.model.format; + +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Chat format for Granite models. + * + * Granite uses a different chat template than Llama: + * <|start_of_role|>system<|end_of_role|>...<|end_of_text|> + * <|start_of_role|>user<|end_of_role|>...<|end_of_text|> + * <|start_of_role|>assistant<|end_of_role|>... + */ +public class GraniteChatFormat implements ChatFormat { + + protected final Tokenizer tokenizer; + protected final int startRole; + protected final int endRole; + protected final int endOfText; + protected final Set stopTokens; + + public GraniteChatFormat(Tokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = tokenizer.getSpecialTokens(); + + this.startRole = specialTokens.getOrDefault("<|start_of_role|>", -1); + this.endRole = specialTokens.getOrDefault("<|end_of_role|>", -1); + + // Use tokenizer's EOS token instead of hardcoding + if (tokenizer instanceof GraniteTokenizer graniteTokenizer) { + this.endOfText = graniteTokenizer.getEosTokenId(); + } else { + this.endOfText = specialTokens.getOrDefault("<|end_of_text|>", 0); + } + + this.stopTokens = Set.of(endOfText); + } + + @Override + public int getBeginOfText() { + return endOfText; // For Granite, token 0 is both BOS and EOS + } + + @Override + public Set getStopTokens() { + return stopTokens; + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + if (startRole >= 0) { + tokens.add(startRole); + } + tokens.addAll(tokenizer.encodeAsList(message.role().name())); + if (endRole >= 0) { + tokens.add(endRole); + } + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = encodeHeader(message); + tokens.addAll(tokenizer.encodeAsList(message.content().strip())); + tokens.add(endOfText); + return tokens; + } + + public List encodeDialogPrompt(boolean appendAssistantTurn, List dialog) { + List tokens = new ArrayList<>(); + for (Message message : dialog) { + tokens.addAll(encodeMessage(message)); + } + if (appendAssistantTurn) { + tokens.addAll(encodeHeader(new Message(ChatFormat.Role.ASSISTANT, ""))); + } + return tokens; + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java index 80987a06..c98a72c9 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/LlamaChatFormat.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.model.format; import org.beehive.gpullama3.tokenizer.LlamaTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; import java.util.ArrayList; import java.util.List; @@ -9,7 +10,7 @@ public class LlamaChatFormat implements ChatFormat { - protected final LlamaTokenizer tokenizer; + protected final Tokenizer tokenizer; protected final int beginOfText; protected final int endHeader; protected final int startHeader; @@ -18,7 +19,7 @@ public class LlamaChatFormat implements ChatFormat { protected final int endOfMessage; protected final Set stopTokens; - public LlamaChatFormat(LlamaTokenizer tokenizer) { + public LlamaChatFormat(Tokenizer tokenizer) { this.tokenizer = tokenizer; Map specialTokens = tokenizer.getSpecialTokens(); this.beginOfText = specialTokens.get("<|begin_of_text|>"); diff --git a/src/main/java/org/beehive/gpullama3/model/granite/Granite.java b/src/main/java/org/beehive/gpullama3/model/granite/Granite.java new file mode 100644 index 00000000..a4743e38 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/granite/Granite.java @@ -0,0 +1,99 @@ +package org.beehive.gpullama3.model.granite; + +import org.beehive.gpullama3.inference.InferenceCore; +import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.AbstractModel; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +public class Granite extends AbstractModel { + + private final GraniteConfiguration configuration; + + public Granite(GraniteConfiguration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(tokenizer, weights, chatFormat, null); + this.configuration = configuration; + } + + @Override + public GraniteConfiguration configuration() { + return configuration; + } + + @Override + public GraniteTokenizer tokenizer() { + return (GraniteTokenizer) tokenizer; + } + + @Override + public ModelType getModelType() { + return ModelType.GRANITE; + } + + @Override + public State createNewState() { + State state = new GraniteState(configuration(), -1); + // Granite uses token 0 (<|end_of_text|>) as BOS - it's multi-purpose + // Token 0 is the default BOS for Granite + state.latestToken = 0; + return state; + } + + @Override + public State createNewState(int batchsize) { + State state = new GraniteState(configuration(), batchsize); + // Token 0 is the default BOS for Granite + state.latestToken = 0; + return state; + } + + @Override + public void forward(State state, int token, int position) { + // Uses Granite-specific forward with scaling factors + InferenceCore.forwardGranite(this, state, token, position); + } + + @Override + public List generateTokens(State state, int startPosition, List promptTokens, + Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensGranite(this, state, startPosition, promptTokens, + stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + + @Override + public List generateTokensGPU(State state, int startPosition, List promptTokens, + Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + return InferenceEngine.generateTokensGPUGranite(this, state, startPosition, promptTokens, + stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } + + // Convenience accessors for scaling factors (used in forward pass) + public float embeddingScale() { + return configuration.embeddingScale(); + } + + public float residualScale() { + return configuration.residualScale(); + } + + public float attentionScale() { + return configuration.attentionScale(); + } + + public float logitScale() { + return configuration.logitScale(); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java b/src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java new file mode 100644 index 00000000..db03fc5a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/granite/GraniteConfiguration.java @@ -0,0 +1,152 @@ +package org.beehive.gpullama3.model.granite; + +import org.beehive.gpullama3.model.Configuration; + +// @formatter:off +public record GraniteConfiguration( + String quantization, + int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int vocabularySize, + int contextLength, + float rmsNormEps, + float ropeTheta, + // Granite-specific scaling factors (µP parameterization) + float embeddingMultiplier, // multiply embeddings after lookup + float residualMultiplier, // multiply residual additions + float attentionMultiplier, // replaces 1/sqrt(headDim) + float logitsScaling, // DIVIDE logits by this value + boolean tieWordEmbeddings // share input/output embeddings +) implements Configuration { + + @Override + public String quantization() { + return quantization; + } + + @Override + public int numberOfHeadsKey() { + // Granite uses standard GQA, same as Llama + return numberOfKeyValueHeads; + } + + @Override + public int contextLengthModel() { + return contextLength; + } + + /** Size of each attention head (derived from dim / numberOfHeads) */ + @Override + public int headSize() { + return dim / numberOfHeads; + } + + /** Key/value dimension (derived from dim * numberOfKeyValueHeads / numberOfHeads) */ + @Override + public int kvDim() { + return dim * numberOfKeyValueHeads / numberOfHeads; + } + + /** Multiplier for key/value sharing in grouped-query attention */ + @Override + public int kvMul() { + return numberOfHeads / numberOfKeyValueHeads; + } + + /** + * Creates a new Configuration with a different context length. + * + * @param newContextLength The new context length to use + * @return A new Configuration instance with updated context length, + * or the current instance if newContextLength is negative + */ + // @formatter:off + public GraniteConfiguration withContextLength(int newContextLength) { + if (newContextLength < 0) { + return this; // no change + } + return new GraniteConfiguration( + this.quantization, + this.dim, + this.hiddenDim, + this.numberOfLayers, + this.numberOfHeads, + this.numberOfKeyValueHeads, + this.vocabularySize, + newContextLength, + this.rmsNormEps, + this.ropeTheta, + this.embeddingMultiplier, + this.residualMultiplier, + this.attentionMultiplier, + this.logitsScaling, + this.tieWordEmbeddings + ); + } + // @formatter:on + + /** + * Accessor for embedding scale (alias for embeddingMultiplier) + */ + public float embeddingScale() { + return embeddingMultiplier; + } + + /** + * Accessor for residual scale (alias for residualMultiplier) + */ + public float residualScale() { + return residualMultiplier; + } + + /** + * Accessor for attention scale (alias for attentionMultiplier) + */ + public float attentionScale() { + return attentionMultiplier; + } + + /** + * Accessor for logit scale (alias for logitsScaling) + */ + public float logitScale() { + return logitsScaling; + } + + /** + * Factory method to create GraniteConfiguration with default scaling values + * for Granite 3.x models. + */ + public static GraniteConfiguration createDefault( + String quantization, + int dim, + int hiddenDim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int vocabularySize, + int contextLength, + float rmsNormEps, + float ropeTheta) { + return new GraniteConfiguration( + quantization, + dim, + hiddenDim, + numberOfLayers, + numberOfHeads, + numberOfKeyValueHeads, + vocabularySize, + contextLength, + rmsNormEps, + ropeTheta, + 12.0f, // embeddingMultiplier + 0.22f, // residualMultiplier + 0.0078125f, // attentionMultiplier (1/128) + 16.0f, // logitsScaling (divisor) + true // tieWordEmbeddings + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java new file mode 100644 index 00000000..c22491e7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/GraniteLoader.java @@ -0,0 +1,170 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.GraniteStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.granite.Granite; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tokenizer.GraniteTokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; + +import java.nio.channels.FileChannel; +import java.util.Map; + +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTensors; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTornadoTensors; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensor; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTornadoTensor; + +public class GraniteLoader extends AbstractModelLoader { + + public GraniteLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); + } + + @Override + protected Vocabulary loadVocabulary(Map metadata) { + // Granite uses the same token format as Llama + return Vocabulary.loadLlamaVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new GraniteTokenizer(metadata, vocabulary); + } + + // @formatter:off + @Override + protected GraniteConfiguration createConfiguration(Map metadata) { + int vocabSize = metadata.containsKey("granite.vocab_size") + ? (int) metadata.get("granite.vocab_size") + : (int) metadata.get("tokenizer.ggml.tokens.length"); + + // Extract Granite-specific metadata keys + float embeddingScale = (float) metadata.getOrDefault("granite.embedding_scale", 12.0f); + float residualScale = (float) metadata.getOrDefault("granite.residual_scale", 0.22f); + float attentionScale = (float) metadata.getOrDefault("granite.attention.scale", 0.0078125f); + float logitScale = (float) metadata.getOrDefault("granite.logit_scale", 16.0f); + + int kvHeads; + Object kvHeadsObj = metadata.get("granite.attention.head_count_kv"); + if (kvHeadsObj instanceof int[] kvHeadsArray) { + // Granite 4.0: per-layer array - take first value (assuming uniform for now) + kvHeads = kvHeadsArray[0]; + } else if (kvHeadsObj instanceof Integer) { + // Granite 3.3: scalar value + kvHeads = (Integer) kvHeadsObj; + } else { + // Fallback to head count (no GQA) + kvHeads = (int) metadata.get("granite.attention.head_count"); + } + + return new GraniteConfiguration( + getModelQuantization(metadata), + (int) metadata.get("granite.embedding_length"), + (int) metadata.get("granite.feed_forward_length"), + (int) metadata.get("granite.block_count"), + (int) metadata.get("granite.attention.head_count"), + kvHeads, + vocabSize, + (int) metadata.get("granite.context_length"), + (float) metadata.getOrDefault("granite.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("granite.rope.freq_base", 10000f), + embeddingScale, + residualScale, + attentionScale, + logitScale, + true // Granite ties word embeddings + ).withContextLength(contextLength); + } + + @Override + protected Pair precomputeRopeFrequencies(GraniteConfiguration config) { + return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), + false, 1.0f, 1.0f, 1.0f, config.contextLength()); + } + // @formatter:on + + @Override + protected Granite createModel(GraniteConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Granite(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } + + // @formatter:off + @Override + protected Weights createStandardWeights(Map tensorEntries, + GraniteConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + final int nl = config.numberOfLayers(); + + return new GraniteStandardWeights( + loadTensor(tokenEmbeddings), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTensor(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadTensor(outputWeight), + outputWeight.ggmlType()); + } + // @formatter:on + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, + GraniteConfiguration config, + Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + GGMLType ggmlType = outputWeight.ggmlType(); + + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + ggmlType + ")"); + } + + // Validate supported types + if (ggmlType != GGMLType.F16 && ggmlType != GGMLType.Q8_0) { + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + } + + final int nl = config.numberOfLayers(); + return new GraniteTornadoWeights( + loadTornadoTensor(tokenEmbeddings), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.first())), + new FP32TornadoTensor(FloatArray.fromArray(ropeFreqs.second())), + loadTornadoTensor(outputWeight), + ggmlType + ); + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index a5e5e59c..392113be 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -48,7 +48,9 @@ private static ModelType detectModelType(Map metadata) { // Check by name first if (name != null) { String lowerName = name.toLowerCase(); - if (lowerName.contains("mistral")) { + if (lowerName.contains("granite")) { + return ModelType.GRANITE; + } else if (lowerName.contains("mistral")) { return ModelType.MISTRAL; } else if (lowerName.contains("llama")) { return ModelType.LLAMA_3; @@ -63,6 +65,11 @@ private static ModelType detectModelType(Map metadata) { } } + // Alternative: check by metadata keys if name-based detection fails + if (metadata.containsKey("granite.block_count")) { + return ModelType.GRANITE; + } + return ModelType.UNKNOWN; } diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java new file mode 100644 index 00000000..0c7750fd --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tokenizer/GraniteTokenizer.java @@ -0,0 +1,276 @@ +package org.beehive.gpullama3.tokenizer; + +import org.beehive.gpullama3.auxiliary.Pair; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * GPT-2-style BPE tokenizer for Granite models. + *

+ * Supports both Granite 3.3 (refact pretokenizer, 49K vocab) and Granite 4.0 (dbrx pretokenizer, 100K vocab). + */ +public class GraniteTokenizer implements Tokenizer { + static final Map BYTE_ENCODER = bytesToUnicode(); + static final Map BYTE_DECODER = BYTE_ENCODER.entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + + // Pretokenizer patterns + private static final String REFACT_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + private static final String DBRX_PATTERN = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; + + // Instance fields + private final Pattern compiledPattern; + private final Vocabulary vocabulary; + private final Map, Integer> merges; + private final Map specialTokens; + + // Token IDs (version-dependent) + private final int bosTokenId; + private final int eosTokenId; + private final int padTokenId; + private final String pretokenizerType; + + public GraniteTokenizer(Map metadata, Vocabulary vocabulary) { + this.vocabulary = vocabulary; + + // Detect pretokenizer type and select pattern + this.pretokenizerType = (String) metadata.getOrDefault("tokenizer.ggml.pre", "refact"); + String pattern = switch (pretokenizerType) { + case "dbrx" -> DBRX_PATTERN; + default -> REFACT_PATTERN; + }; + this.compiledPattern = Pattern.compile(pattern); + + // Read token IDs from metadata + this.bosTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.bos_token_id", 0); + this.eosTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.eos_token_id", 0); + this.padTokenId = getIntFromMetadata(metadata, "tokenizer.ggml.padding_token_id", 0); + + // Load merges + String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); + List> mergeList = Arrays.stream(mergeLines).map(line -> line.split(" ")) + .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); + + // Collect special tokens + Map specialTokens = new HashMap<>(); + int allTokens = vocabulary.size(); + for (int i = 0; i < allTokens; i++) { + String token = vocabulary.get(i); + if (token.startsWith("<|") && token.endsWith("|>")) { + specialTokens.put(token, i); + } + // Also catch style tokens used in some Granite models + if (token.startsWith("<") && token.endsWith(">") && !token.contains(" ")) { + specialTokens.putIfAbsent(token, i); + } + } + this.specialTokens = Map.copyOf(specialTokens); + + // Build merge map + this.merges = new HashMap<>(); + for (Pair pair : mergeList) { + String merged = vocabulary.get(pair.first()) + vocabulary.get(pair.second()); + int mergeIndex = vocabulary.getIndex(merged).orElseThrow(); + this.merges.put(pair, mergeIndex); + } + } + + private static int getIntFromMetadata(Map metadata, String key, int defaultValue) { + Object value = metadata.get(key); + if (value instanceof Number num) { + return num.intValue(); + } + return defaultValue; + } + + // === Token ID accessors === + private static List findAll(Pattern pattern, String text) { + List matches = new ArrayList<>(); + Matcher matcher = pattern.matcher(text); + while (matcher.find()) { + matches.add(matcher.group()); + } + return matches; + } + + private static List merge(List ids, Pair pair, int idx) { + List newIds = new ArrayList<>(); + int i = 0; + while (i < ids.size()) { + if (i < ids.size() - 1 && ids.get(i).equals(pair.first()) && ids.get(i + 1).equals(pair.second())) { + newIds.add(idx); + i += 2; + } else { + newIds.add(ids.get(i)); + i++; + } + } + return newIds; + } + + private static Map bytesToUnicode() { + List bs = new ArrayList<>(); + IntStream.rangeClosed('!', '~').forEach(bs::add); + IntStream.rangeClosed('¡', '¬').forEach(bs::add); + IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); + + List cs = new ArrayList<>(bs); + int n = 0; + for (int b = 0; b < 256; b++) { + if (!bs.contains(b)) { + bs.add(b); + cs.add(256 + n); + n++; + } + } + return IntStream.range(0, bs.size()).boxed().collect(Collectors.toMap(bs::get, cs::get)); + } + + public int getBosTokenId() { + return bosTokenId; + } + + // === Tokenizer interface === + + public int getEosTokenId() { + return eosTokenId; + } + + public int getPadTokenId() { + return padTokenId; + } + + public String getPretokenizerType() { + return pretokenizerType; + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + // === Encoding === + + @Override + public boolean isSpecialToken(int tokenIndex) { + return specialTokens.containsValue(tokenIndex); + } + + @Override + public boolean shouldDisplayToken(int token) { + return !isSpecialToken(token); + } + + public String regexPattern() { + return compiledPattern != null ? compiledPattern.pattern() : null; + } + + // @Override + public int[] encode(String text) { + StringBuilder sb = new StringBuilder(); + byte[] bytes = text.getBytes(StandardCharsets.UTF_8); + for (byte b : bytes) { + sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); + } + return encodeImpl(sb.toString()); + } + + @Override + public List encodeAsList(String text) { + return Arrays.stream(encode(text)).boxed().toList(); + } + + private int[] encodeImpl(String text) { + return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); + } + + // === Decoding === + + public List encode(String text, Set allowedSpecial) { + if (allowedSpecial.isEmpty()) { + return encodeOrdinary(text); + } + + assert specialTokens.keySet().containsAll(allowedSpecial); + String specialPattern = allowedSpecial.stream().map(Pattern::quote).collect(Collectors.joining("|", "(", ")")); + String[] specialChunks = text.split(specialPattern); + + List ids = new ArrayList<>(); + for (String part : specialChunks) { + if (allowedSpecial.contains(part)) { + ids.add(specialTokens.get(part)); + } else { + ids.addAll(encodeOrdinary(part)); + } + } + return ids; + } + + public List encodeOrdinary(String text) { + List textChunks = findAll(compiledPattern, text); + List ids = new ArrayList<>(); + for (String chunk : textChunks) { + ids.addAll(encodeChunk(chunk)); + } + return ids; + } + + // === Helpers === + + private List encodeChunk(String chunk) { + List ids = new ArrayList<>(); + for (char c : chunk.toCharArray()) { + int tokenIndex = vocabulary.getIndex(String.valueOf(c)).orElseThrow(); + ids.add(tokenIndex); + } + + while (ids.size() >= 2) { + Map, Integer> stats = getStats(ids); + Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); + if (!merges.containsKey(pair)) { + break; + } + ids = merge(ids, pair, merges.get(pair)); + } + return ids; + } + + @Override + public String decode(List tokens) { + String decoded = decodeImpl(tokens); + int[] decodedBytesAsInts = decoded.codePoints().map(cp -> BYTE_DECODER.getOrDefault(cp, cp)).toArray(); + byte[] rawBytes = new byte[decodedBytesAsInts.length]; + for (int i = 0; i < decodedBytesAsInts.length; i++) { + rawBytes[i] = (byte) decodedBytesAsInts[i]; + } + return new String(rawBytes, StandardCharsets.UTF_8); + } + + private String decodeImpl(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + sb.append(vocabulary.get(token)); + } + return sb.toString(); + } + + private Map, Integer> getStats(List ids) { + Map, Integer> map = new HashMap<>(); + for (int i = 0; i + 1 < ids.size(); i++) { + Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); + map.merge(key, 1, Integer::sum); + } + return map; + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java new file mode 100644 index 00000000..36b9803e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/GraniteKernels.java @@ -0,0 +1,396 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimized; +import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedQ8_0Byte; +import static org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered.matrixVectorRowMajorOptimizedSingle; + +public class GraniteKernels { + + public static void convertFP16toFP32withGraniteScale(KernelContext context, HalfFloatArray x, FloatArray wrapX, float embeddingScale) { + int i = context.globalIdx; + wrapX.set(i, embeddingScale * x.get(i).getFloat32()); + } + + public static void convertQ8_0toFP32withGraniteScale(KernelContext context, ByteArray x, FloatArray wrapX, float embeddingScale) { + int globalId = context.globalIdx; + int totalElements = wrapX.getSize(); + + if (globalId >= totalElements) { + return; + } + + // Q8_0 block structure constants + int blockSize = 32; + int Q8_0_BLOCK_BYTES = 34; // 2 bytes scale + 32 bytes quants + + // Calculate which block and position within block + int blockIdx = globalId / blockSize; + int withinBlockIdx = globalId % blockSize; + + // Calculate byte offset for this Q8_0 block + int blockByteOffset = blockIdx * Q8_0_BLOCK_BYTES; + + // Load scale (first 2 bytes of block as HalfFloat) + HalfFloat scale = x.getHalfFloat(blockByteOffset); + float scaleFloat = scale.getFloat32(); + + // Load quantized value (skip 2-byte scale, then index within block) + byte quantValue = x.get(blockByteOffset + 2 + withinBlockIdx); + + // Dequantize: float_value = quantized_value * scale + float dequantizedValue = ((float) quantValue) * scaleFloat; + + // Store result in output FloatArray + wrapX.set(globalId, embeddingScale * dequantizedValue); + } + + // @formatter:off + public static void matrixVectorGenericWithGraniteScale( + KernelContext context, + HalfFloatArray x, + FloatArray hb, // output + HalfFloatArray w, + int dim1, // inner loop + int dim0, // outer loop + int localWorkGroupSize, + float logitsScale + ) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= dim0) { + return; + } + float sum = matrixVectorRowMajorOptimizedSingle(context, localSize, x, w, dim1); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + hb.set(rowId, sum * logitsScale); + } + } + + + public static void processHeadsFlashAttentionWithGraniteScale(KernelContext context, + FloatArray q, FloatArray key_cache, FloatArray value_cache, + FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, + IntArray positionHolder, int layer, int contextLength, float attentionScale) { + + // Thread and workgroup information + int tid = context.localIdx; + int h = context.groupIdx; // Each workgroup processes one head + int localSize = context.localGroupSizeX; + + // Early exit if this workgroup is beyond our head count + // This relies on the kernel being launched with nHeads workgroups. + if (h >= nHeads) { + return; + } + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + int kvHeadIdx = h / kvMul; + int BLOCK_SIZE_C = 16; + + // Allocate shared memory for tiled computation + float[] q_shared = context.allocateFloatLocalArray(headSize); + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); + float[] shared_tile_max_holder = context.allocateFloatLocalArray(1); // FIX: For broadcasting tile max + + // Thread-local accumulators for online softmax + float maxScore = Float.NEGATIVE_INFINITY; + float sumExp = 0.0f; + + // Thread-local output accumulation + float[] output = new float[headSize]; + for (int i = 0; i < headSize; i++) { + output[i] = 0.0f; + } + + // Load query vector into shared memory + for (int i = tid; i < headSize; i += localSize) { + q_shared[i] = q.get(h * headSize + i); + } + + context.localBarrier(); + + // Process sequence in tiles + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); + + // Load key and value vectors for this tile + // Each thread loads a portion of the K and V vectors for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int k_v_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + int tileMemOffset = k_v_idx_in_tile * headSize; + for (int d = 0; d < headSize; d++) { + int kvCacheAbsolutePos = tIdxInSeq; + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d; + k_tile[tileMemOffset + d] = key_cache.get(kvOffset); + v_tile[tileMemOffset + d] = value_cache.get(kvOffset); + } + } + + context.localBarrier(); + + // Compute attention scores for this tile + // Each thread computes one score for the tile + for (int tIdxInSeq = tileC + tid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { + int score_idx_in_tile = tIdxInSeq - tileC; // 0, 1, 2, or 3 for this tile + + float score = 0.0f; + for (int d = 0; d < headSize; d++) { + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; + } + score *= attentionScale; + s_tile[score_idx_in_tile] = score; + } + + context.localBarrier(); + + // Find max score in this tile (all threads compute it redundantly over the small s_tile) + float tileLocalMax = Float.NEGATIVE_INFINITY; + for (int i = 0; i <= tileEnd - tileC; i++) { // Iterate over valid scores in s_tile + if (s_tile[i] > tileLocalMax) { + tileLocalMax = s_tile[i]; + } + } + + // Broadcast max to all threads via shared memory + if (tid == 0) { + shared_tile_max_holder[0] = tileLocalMax; // FIX: Use dedicated holder + } + context.localBarrier(); + float currentTileMax = shared_tile_max_holder[0]; // FIX: Read from dedicated holder + + // Determine if we need to rescale previous results + float newMax = Math.max(maxScore, currentTileMax); + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { + float scale = TornadoMath.exp(maxScore - newMax); + sumExp *= scale; + for (int d = 0; d < headSize; d++) { + output[d] *= scale; + } + } + maxScore = newMax; + + // Process each key-value pair using original scores from s_tile + // All threads iterate over all scores in the current tile + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { + // s_tile[t_idx_in_s_tile] now correctly refers to the original score + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); + sumExp += expScore; + + for (int d = 0; d < headSize; d++) { + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; + } + } + context.localBarrier(); // Ensure all threads finish with s_tile, k_tile, v_tile before next tile load + } + + // Normalize and write final results + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; // Avoid division by zero, return 0 if sumExp is 0 + for (int d = tid; d < headSize; d += localSize) { + xb.set(h * headSize + d, output[d] * normFactor); + } + } + // @formatter:on + + public static void matrixVectorGenericWithResidualGranite(KernelContext context, FloatArray x, FloatArray hb, + HalfFloatArray w, int n, int d, int localWorkGroupSize, float residualScale) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimized(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float residual = residualScale * sum; + float result = hb.get(rowId) + residual; + hb.set(rowId, result); + } + } + + public static void processHeadsParallelGranite(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen, + IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength, float attentionScale) { + + int pos = positionHolder.get(0); + int loff = layer * contextLength * kvDim; + + // Parallelize computation across attention heads + for (@Parallel int h = 0; h < nHeads; h++) { + // Process each head in parallel + processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt, attentionScale); + } + } + + private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, + FloatArray wrapAtt, float attentionScale) { + + // Base index for this head's attention weights + int headOffset = h * (pos + 1); + + // STEP 1: Calculate attention scores for all timesteps + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + + float score = 0.0f; + for (int i = 0; i < headSize; i++) { + score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i); + } + score *= attentionScale; //TODO: might need score = score * attentionScale; +// score = score / TornadoMath.sqrt(headSize); + + // Store in attention buffer + wrapAtt.set(headOffset + t, score); + } + + // STEP 2: Find max score for softmax stability + float maxScore = wrapAtt.get(headOffset); + for (int t = 1; t <= pos; t++) { + float val = wrapAtt.get(headOffset + t); + if (val > maxScore) { + maxScore = val; + } + } + + // STEP 3: Compute exponentials and sum + float sum = 0.0f; + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore); + wrapAtt.set(idx, expScore); + sum += expScore; + } + + // STEP 4: Normalize + float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos + 1)); + for (int t = 0; t <= pos; t++) { + int idx = headOffset + t; + wrapAtt.set(idx, wrapAtt.get(idx) * normFactor); + } + + // STEP 5: Compute weighted sum of values for each dimension + for (int i = 0; i < headSize; i++) { + float weightedSum = 0.0f; + for (int t = 0; t <= pos; t++) { + int kvHeadIdx = h / kvMul; + int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize); + weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i); + } + allXb.set(h * headSize + i, weightedSum); + } + } + + public static void matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale(KernelContext context, FloatArray x, + FloatArray hb, ByteArray w, int n, int d, int localWorkGroupSize, float residualScale) { + // One row per workgroup (not per thread) + int rowId = context.groupIdx; + int localId = context.localIdx; + int localSize = localWorkGroupSize; + + // Early exit if this workgroup is beyond our output dimension + if (rowId >= d) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localSize, x, w, n); + + // Thread 0 in each workgroup writes the final result + if (localId == 0) { + float residualScaledSum = residualScale * sum; + float result = hb.get(rowId) + residualScaledSum; + hb.set(rowId, result); + } + } + + public static void matrixVectorGenericQ8ByteWithGraniteScale(KernelContext context, FloatArray x, FloatArray output, ByteArray q, + int dim1, int dim0, int localWorkGroupSize, float logitsScale) { + int rowId = context.groupIdx; + int localId = context.localIdx; + + if (rowId >= dim0) { + return; + } + + float sum = matrixVectorRowMajorOptimizedQ8_0Byte(context, localWorkGroupSize, x, q, dim1); + + // Thread 0 writes the result + if (localId == 0) { + output.set(rowId, sum * logitsScale); + } + } + + public static void ropeRotationWithCacheCopy(KernelContext context, IntArray positionHolder, FloatArray sq, // Q vector (in/out) + FloatArray sk, // K vector (in/out) + FloatArray sv, // V vector (in only) + FloatArray keyCache, // Key cache (out) + FloatArray valueCache, // Value cache (out) + int kvDim, int headSize, + float ropeTheta, + int layer, int contextLength) { + + int i = context.globalIdx * 2; + int pos = positionHolder.get(0); + + // Bounds check for Q rotation (Q has dim elements, processed in pairs) + if (i + 1 < sq.getSize()) { + // RoPE frequency calculation + int head_dim = i % headSize; + // TornadoMath.pow(ropeTheta, head_dim / (float) headSize); + float freq = 1.0f / TornadoMath.pow(ropeTheta, head_dim / (float) headSize); + float val = pos * freq; + float fcr = TornadoMath.cos(val); + float fci = TornadoMath.sin(val); + + // Rotate Q + float v0q = sq.get(i); + float v1q = sq.get(i + 1); + sq.set(i, v0q * fcr - v1q * fci); + sq.set(i + 1, v0q * fci + v1q * fcr); + + // Rotate K AND write to cache (only for kvDim elements) + if (i + 1 < kvDim) { + float v0k = sk.get(i); + float v1k = sk.get(i + 1); + float rotated0 = v0k * fcr - v1k * fci; + float rotated1 = v0k * fci + v1k * fcr; + + // Write rotated K back to sk + sk.set(i, rotated0); + sk.set(i + 1, rotated1); + + // Direct cache write (fused - no separate copy kernel!) + int cacheOffset = layer * contextLength * kvDim + pos * kvDim; + keyCache.set(cacheOffset + i, rotated0); + keyCache.set(cacheOffset + i + 1, rotated1); + + // Copy V to cache (V doesn't need rotation) + valueCache.set(cacheOffset + i, sv.get(i)); + valueCache.set(cacheOffset + i + 1, sv.get(i + 1)); + } + } + + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java index 1684a5b8..ca844e51 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizationPlannerFactory.java @@ -1,5 +1,6 @@ package org.beehive.gpullama3.tornadovm.layerplanner.base; +import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.LlamaState; import org.beehive.gpullama3.inference.state.Phi3State; @@ -8,10 +9,12 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.GenericLayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Phi3FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen2FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Qwen3FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.GraniteQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.LlamaQ8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Phi3Q8_0LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0.Qwen2Q8_0LayerPlanner; @@ -55,6 +58,7 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) { case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); + case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model); case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); default -> throw new UnsupportedOperationException("FP16 not supported for model: " + model.getModelType()); }; @@ -67,6 +71,7 @@ private static GenericLayerPlanner createQ8_0Planner(State state, Model model) { case QWEN_2 -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3Q8_0LayerPlanner((Qwen3State) state, model); case PHI_3 -> new Phi3Q8_0LayerPlanner((Phi3State) state, model); + case GRANITE -> new GraniteQ8_0LayerPlanner((GraniteState) state, model); case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2Q8_0LayerPlanner((Qwen2State) state, model); default -> throw new UnsupportedOperationException("Q8_0 not supported for model: " + model.getModelType()); }; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java new file mode 100644 index 00000000..7cc97d64 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/GraniteFP16LayerPlanner.java @@ -0,0 +1,26 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.FP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.GraniteFP16FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.LogitsGraniteFP16Layer; + +public class GraniteFP16LayerPlanner extends FP16LayerPlanner { + public GraniteFP16LayerPlanner(GraniteState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new GraniteFP16FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsGraniteFP16Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java new file mode 100644 index 00000000..ee818080 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/GraniteQ8_0LayerPlanner.java @@ -0,0 +1,26 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.layerplanner.quantization.Q8_0LayerPlanner; +import org.beehive.gpullama3.tornadovm.layers.ActivationGranite; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.GraniteQ8_0FFNLayers; +import org.beehive.gpullama3.tornadovm.layers.type.q8_0.LogitsGraniteQ8_0Layer; + +public class GraniteQ8_0LayerPlanner extends Q8_0LayerPlanner { + + public GraniteQ8_0LayerPlanner(GraniteState state, Model model) { + super(state, model); + validateQuantizationType(); + setupTornadoForwardPlan(); + } + + @Override + protected void initializeLayerComponents() { + this.activationLayer = new ActivationGranite("activationUpdate", this.state, this.weights, this.config); + this.ffnLayers = new GraniteQ8_0FFNLayers("graniteFFN", this.state, this.weights, this.config, this.schedulerType); + this.logitsLayer = new LogitsGraniteQ8_0Layer("graniteLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java new file mode 100644 index 00000000..20002dac --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/ActivationGranite.java @@ -0,0 +1,68 @@ +package org.beehive.gpullama3.tornadovm.layers; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; +import uk.ac.manchester.tornado.api.types.arrays.ByteArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; + +public class ActivationGranite extends Activation { + private final TaskGraph activationUpdate; + + // Granite is a special case where activation X is scaled by embedding scale float value that inside model. + public ActivationGranite(String taskGraphHandle, State state, Weights weights, GraniteConfiguration config) { + super(taskGraphHandle, state, weights, config); + + KernelContext kernelContext = new KernelContext(); + + // @formatter:off + switch (config.quantization()) { + case "FP16" -> { + this.activationUpdate = new TaskGraph(taskGraphHandle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertFP16toFP32withGraniteScale, kernelContext, (HalfFloatArray) state.embeddingX, state.wrapX, config.embeddingScale()) + .persistOnDevice(state.wrapX); + } + case "Q8_0" -> { + this.activationUpdate = new TaskGraph(taskGraphHandle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", GraniteKernels::convertQ8_0toFP32withGraniteScale, kernelContext, (ByteArray) state.embeddingX, state.wrapX, config.embeddingScale()) + .persistOnDevice(state.wrapX); + } + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); + } + // @formatter:on + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler scheduler) { + WorkerGrid worker = new WorkerGrid1D(config.dim()); + worker.setLocalWork(128, 1, 1); + scheduler.addWorkerGrid("activationUpdate.updateX", worker); + return scheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return null; + } + + @Override + public TaskGraph getTaskGraph() { + return activationUpdate; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return activationUpdate.snapshot(); + } + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java new file mode 100644 index 00000000..0387ae54 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/GraniteFP16FFNLayers.java @@ -0,0 +1,375 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class GraniteFP16FFNLayers extends AbstractFFNLayers { + + TaskGraph ffnTaskGraphs; + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public GraniteFP16FFNLayers(String taskGraph, State state, Weights weights, GraniteConfiguration config, SchedulerType schedulerType) { + super(taskGraph, state, weights, config, schedulerType); + this.ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + int fusedQKVRows = config.dim() + 2 * config.kvDim(); + int fusedQKVGlobal = fusedQKVRows * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + + // Map workers to tasks + for (int i = 0; i < config.numberOfLayers(); i++) { + // === Attention Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // === FFN Block === + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return ffnTaskGraphs; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + List setupFFNLayered() { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); + if (i == config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); + } + + // @formatter:off + /** + * Transformer Layer Task Flow (LlamaFP16FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌─────────────────────┐ + * │ attn_rms_apply_fp16 │──▶ wrapXbFP16 (normalized, FP16) + * └──────────┬──────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fused: RMS apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fused RMS apply + W1/W3 matmuls + SiLU + GLU (4→1 kernel) + * + */ + TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantize, + context, state.wrapXbFP16, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulX, + context, + state.wrapXbFP16, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asHalfFloatArray(), // Wq + weights.wkLayered[layerIndex].asHalfFloatArray(), // Wk + weights.wvLayered[layerIndex].asHalfFloatArray(), // Wv + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + GraniteKernels::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + config.ropeTheta(), // needs to load it from model + layerIndex, + config.contextLength()); + // Attention + configureAttention(unifiedLayer, layerIndex, config ); + // Output Projection (Wo) with residual + unifiedLayer.task("attn_output_proj", + GraniteKernels::matrixVectorGenericWithResidualGranite, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asHalfFloatArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, + config.residualScale() + ); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fusedRmsNormFFNGateUp, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + state.tempFFN, // RMS scale factor + weights.w1Layered[layerIndex].asHalfFloatArray(), // W1 + weights.w3Layered[layerIndex].asHalfFloatArray(), // W3 + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual + unifiedLayer.task("ffn_down_proj", + GraniteKernels::matrixVectorGenericWithResidualGranite, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asHalfFloatArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale()); + + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + // First layer: Transfer initial data to device (one-time transfer) + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN + ); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, state.wrapXbFP16); + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice( + // Kernel context + context, + // Intermediate buffers + state.wrapXb, state.wrapXb2, + // QKV vectors + state.wrapQ, state.wrapK, state.wrapV, + // KV cache + state.wrapKeyCache, state.wrapValueCache, + // Attention & FFN buffers + state.wrapAtt, state.wrapHb, + // Position & misc + state.positionHolder, state.wrapXbFP16); + } + return unifiedLayer; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) { + if (schedulerType == SchedulerType.NVIDIA) { + // Flash Attention (optimized for NVIDIA GPUs) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsFlashAttentionWithGraniteScale, + context, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + state.positionHolder, + layerIndex, + config.contextLength(), + config.attentionScale() + ); + } else { + // Standard parallel attention (for non-NVIDIA backends) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsParallelGranite, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + config.contextLength(), // seqLen parameter + state.positionHolder, + state.wrapAtt, // Attention weights buffer + layerIndex, + config.contextLength(), + config.attentionScale()); + } + } + // @formatter:on + +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java new file mode 100644 index 00000000..d55d707e --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LogitsGraniteFP16Layer.java @@ -0,0 +1,125 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class LogitsGraniteFP16Layer extends LogitsFP16Layer { + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + public LogitsGraniteFP16Layer(String name, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + this.lastTaskGraphID = lastTaskGraphID; + this.schedulerType = schedulerType; + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsFP16Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); + } + + // @formatter:off + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Kernel context + context, + // Output buffer + state.wrapLogits, + // Intermediate FP16 buffer + state.wrapXbFP16, + // Weights + weights.wclsByteArray.asHalfFloatArray(), + weights.rms_final_weight_as_floatArray.asFloatArray()); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, // in/out: combines partial sums + config.dim(), // dimension + config.rmsNormEps()); // epsilon + } + + logits.task("rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantizeLogits, + context, + state.wrapXbFP16, // output: normalized (FP16) + state.wrapX, // input: hidden state + weights.rms_final_weight_as_floatArray.asFloatArray(), // RMS weights + state.tempLogits); // scale factor from reduction + + // === Vocabulary Projection === + logits.task("vocab_proj", + GraniteKernels::matrixVectorGenericWithGraniteScale, + context, + state.wrapXbFP16, // input (FP16) + state.wrapLogits, // output + weights.wclsByteArray.asHalfFloatArray(), // vocabulary weights + config.dim(), // input dimension + config.vocabularySize(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, + config.logitScale()); // granite logit scaling + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + // @formatter:on + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + WorkerGrid logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.rms_apply_fp16", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + return tornadoForwardScheduler; + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } +} + diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java new file mode 100644 index 00000000..b7e036d4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/GraniteQ8_0FFNLayers.java @@ -0,0 +1,376 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.GraniteState; +import org.beehive.gpullama3.inference.state.LlamaState; +import org.beehive.gpullama3.inference.weights.tornado.GraniteTornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.Granite; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +import java.util.List; +import java.util.stream.IntStream; + +public class GraniteQ8_0FFNLayers extends AbstractFFNLayers { + + GridScheduler scheduler; + List ffnLayerTaskGraphs; + + public GraniteQ8_0FFNLayers(String taskGraphName, GraniteState state, GraniteTornadoWeights weights, GraniteConfiguration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + ffnLayerTaskGraphs = setupFFNLayered(); + } + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return null; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return null; + } + + List setupFFNLayered() { + return IntStream.range(0, config.numberOfLayers()).mapToObj(i -> { + var ffnLayer = setupSingleFFNLayer((GraniteTornadoWeights) weights, (GraniteConfiguration) config, i); + if (i == config.numberOfLayers() - 1) { + setupLastID(ffnLayer.getTaskGraphName()); + } + return ffnLayer.snapshot(); + }).toList(); + } + + /** + * Transformer Layer Task Flow (LlamaQ8FFNLayers) + * + * ══════════════════════════════════════════════════════════════════════════════ + * ATTENTION BLOCK + * ══════════════════════════════════════════════════════════════════════════════ + * + * wrapX (FP32) + * │ + * ▼ + * ┌─────────────────┐ + * │ attn_rms_reduce │──▶ temp (partial sums) + * └────────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌──────────────────┐ + * │ attn_rms_finalize│──▶ temp (final scale) + * └────────┬─────────┘ + * │ + * ▼ + * ┌────────────────┐ + * │ attn_rms_apply │──▶ wrapXb (normalized, FP32) + * └───────┬────────┘ + * │ + * ▼ + * ┌────────────────┐ ┌─────────────────────────────┐ + * │ qkv_projection │──────▶│ wrapQ, wrapK, wrapV (FP32) │ + * └───────┬────────┘ └─────────────────────────────┘ + * │ + * ▼ + * ┌───────────────────┐ ┌─────────────────────────────────────┐ + * │ rope_and_kv_cache │───▶│ Q,K rotated + KeyCache, ValueCache │ + * └─────────┬─────────┘ └─────────────────────────────────────┘ + * │ + * ▼ + * ┌───────────┐ + * │ attention │──▶ wrapXb (attention output) + * └─────┬─────┘ + * │ + * ▼ + * ┌──────────────────┐ + * │ attn_output_proj │──▶ wrapX += Wo · wrapXb (residual connection) + * └────────┬─────────┘ + * │ + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ FFN BLOCK + * ══════════╪═══════════════════════════════════════════════════════════════════ + * │ + * ▼ + * ┌────────────────┐ + * │ ffn_rms_reduce │──▶ tempFFN (partial sums) + * └───────┬────────┘ + * │ + * ▼ (optional: NON_NVIDIA only) + * ┌─────────────────┐ + * │ ffn_rms_finalize│──▶ tempFFN (final scale) + * └────────┬────────┘ + * │ + * ▼ + * ┌─────────────────┐ + * │ rms_ffn_gate_up │──▶ wrapHb = SiLU(RMSNorm(x)·W1) ⊙ (RMSNorm(x)·W3) + * └────────┬────────┘ (fully fused: RMS reduce/apply + W1/W3 matmuls + SiLU + GLU) + * │ + * ▼ + * ┌──────────────┐ + * │ ffn_down_proj│──▶ wrapX += W2 · wrapHb (residual connection) + * └──────┬───────┘ + * │ + * ▼ + * wrapX (FP32) ──▶ [next layer or logits] + * + * ══════════════════════════════════════════════════════════════════════════════ + * + * Task Count: 9 tasks (7 if NVIDIA, skipping rms_finalize steps) + * + * Data Flow Summary: + * Input: wrapX (FP32) - hidden state from previous layer + * Output: wrapX (FP32) - updated hidden state with residual connections + * + * Key Fusion Points: + * • qkv_projection: Fused Q/K/V matmuls with Q8 dequantization (3→1 kernel) + * • rope_and_kv_cache: Fused RoPE rotation + cache write (2→1 kernel) + * • rms_ffn_gate_up: Fully fused RMS norm + W1/W3 matmuls + SiLU + GLU (5→1 kernel) + * + * Quantization: Q8_0 format (8-bit weights with block-wise scaling) + * + */ + TaskGraph setupSingleFFNLayer(GraniteTornadoWeights weights, GraniteConfiguration config, int layerIndex) { + var layerTaskGraphName = "layer_" + layerIndex; + TaskGraph unifiedLayer = new TaskGraph(layerTaskGraphName); + + // === Data Setup === + unifiedLayer.consumeFromDevice(state.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + // Copy-in weights per layer for batched-layered layout (Q8 format) + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asByteArray(), + weights.wkLayered[layerIndex].asByteArray(), + weights.wvLayered[layerIndex].asByteArray(), + weights.woLayered[layerIndex].asByteArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asByteArray(), + weights.w2Layered[layerIndex].asByteArray(), + weights.w3Layered[layerIndex].asByteArray()); + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + // === Attention Block === + // RMS Normalization + unifiedLayer.task("attn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.temp, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.temp, config.dim(), config.rmsNormEps()); + } + + unifiedLayer.task("attn_rms_apply", + TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, + context, state.wrapXb, state.wrapX, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); + + // QKV Projection (fused with Q8 dequantization) + unifiedLayer.task("qkv_projection", + TransformerComputeKernelsLayered::fusedQKVMatmulQ8, + context, + state.wrapXb, // input (FP32) + state.wrapQ, // output Q + state.wrapK, // output K + state.wrapV, // output V + weights.wqLayered[layerIndex].asByteArray(), // Wq (Q8) + weights.wkLayered[layerIndex].asByteArray(), // Wk (Q8) + weights.wvLayered[layerIndex].asByteArray(), // Wv (Q8) + config.dim(), // dim + config.kvDim(), // kvDim + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // RoPE + KV Cache + unifiedLayer.task("rope_and_kv_cache", + GraniteKernels::ropeRotationWithCacheCopy, + context, + state.positionHolder, + state.wrapQ, // Q (in/out) + state.wrapK, // K (in/out) + state.wrapV, // V (in only) + state.wrapKeyCache, // Key cache (out) + state.wrapValueCache, // Value cache (out) + config.kvDim(), + config.headSize(), + config.ropeTheta(), + layerIndex, + config.contextLength()); + + // Attention + configureAttention(unifiedLayer, layerIndex, config); + + // Output Projection (Wo) with residual (Q8 dequantization) + unifiedLayer.task("attn_output_proj", + GraniteKernels::matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale, + context, state.wrapXb, state.wrapX, + weights.woLayered[layerIndex].asByteArray(), + config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale()); + + // === FFN Block === + // RMS Normalization + unifiedLayer.task("ffn_rms_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, state.tempFFN, state.wrapX, + config.dim(), config.rmsNormEps(), state.localSize); + + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, state.tempFFN, config.dim(), config.rmsNormEps()); + } + + // Fully fused: RMS apply + Gate/Up projections + SiLU + GLU (Q8 dequantization) + unifiedLayer.task("rms_ffn_gate_up", + TransformerComputeKernelsLayered::fullyFusedRmsNormFFNGateUpQ8, + context, + state.wrapX, // raw input (FP32) + state.wrapHb, // output + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // RMS weights + weights.w1Layered[layerIndex].asByteArray(), // W1 (Q8) + weights.w3Layered[layerIndex].asByteArray(), // W3 (Q8) + config.dim(), // input dimension + config.hiddenDim(), // output dimension + LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Down projection (W2) with residual (Q8 dequantization) + unifiedLayer.task("ffn_down_proj", + GraniteKernels::matrixVectorGenericWithResidualQ8_0ByteWithGraniteScale, + context, state.wrapHb, state.wrapX, + weights.w2Layered[layerIndex].asByteArray(), + config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC, config.residualScale()); + + // Keep activation X on device for next layer + unifiedLayer.persistOnDevice(state.wrapX); + + return unifiedLayer; + } + + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + // First layer: Transfer initial data to device (one-time transfer) + if (layerIndex == 0) { + // Transfer all attention-related data: query, key, value matrices and their caches + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + state.positionHolder, + state.temp, state.tempFFN); // + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, // + context, + state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb); // + } else { + // Subsequent layers: Consume data already on device from previous layer + unifiedLayer.consumeFromDevice( + context, + state.wrapXb, state.wrapXb2, // + state.wrapQ, state.wrapK, state.wrapV, // + state.wrapKeyCache, state.wrapValueCache, // + state.wrapAtt, state.wrapHb, // + state.positionHolder // + ); + } + return unifiedLayer; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + // === Worker Grid Definitions === + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256); + + int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + int configHiddenDimRowMajor = config.hiddenDim() * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid configHiddenDimRowMajorWorker = WorkerGridFactory.genericWorker(configHiddenDimRowMajor, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // Fused QKV: dim rows for Q + kvDim rows for K + kvDim rows for V + int fusedQkvGlobal = (config.dim() + 2 * config.kvDim()) * LOCAL_WORK_GROUP_SIZE_ALLOC; + WorkerGrid fusedQkvWorker = WorkerGridFactory.genericWorker(fusedQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC); + + WorkerGrid ropeWithCacheWorker = WorkerGridFactory.genericWorker(config.dim() / 2, 512); + + WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), config.headSize()); + + // === Per-Layer Grid Assignments (ordered by task graph flow) === + for (int i = 0; i < config.numberOfLayers(); i++) { + // --- Attention Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker); + // --- FFN Block --- + // RMS Normalization + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker); + // Fused RMS + Gate/Up Projections + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker); + // Down Projection + tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker); + } + + return tornadoForwardScheduler; + } + + public List getFfnLayerTaskGraphs() { + return ffnLayerTaskGraphs; + } + + private TaskGraph configureAttention(TaskGraph unifiedLayer, int layerIndex, GraniteConfiguration config) { + if (schedulerType == SchedulerType.NVIDIA) { + // Flash Attention (optimized for NVIDIA GPUs) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsFlashAttentionWithGraniteScale, + context, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + state.positionHolder, + layerIndex, + config.contextLength(), + config.attentionScale() + ); + } else { + // Standard parallel attention (for non-NVIDIA backends) + return unifiedLayer.task("attention", + GraniteKernels::processHeadsParallelGranite, + state.wrapQ, // Query + state.wrapKeyCache, // Key cache + state.wrapValueCache, // Value cache + state.wrapXb, // Output + config.numberOfHeads(), + config.headSize(), + config.kvDim(), + config.kvMul(), + config.contextLength(), // seqLen parameter + state.positionHolder, + state.wrapAtt, // Attention weights buffer + layerIndex, + config.contextLength(), + config.attentionScale()); + } + } + // @formatter:on +} + diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java new file mode 100644 index 00000000..d1fd4f0c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LogitsGraniteQ8_0Layer.java @@ -0,0 +1,118 @@ +package org.beehive.gpullama3.tornadovm.layers.type.q8_0; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.granite.GraniteConfiguration; +import org.beehive.gpullama3.tornadovm.kernels.GraniteKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid1D; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +public class LogitsGraniteQ8_0Layer extends LogitsQ8_0Layer{ + private String lastTaskGraphID; + private TaskGraph logitsTaskGraph; + private ImmutableTaskGraph immutableLogitsGraph; + private GridScheduler scheduler; + private SchedulerType schedulerType; + + public LogitsGraniteQ8_0Layer(String taskGraphName, State state, Weights weights, Configuration config, String lastTaskGraphID, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, lastTaskGraphID, schedulerType); + this.lastTaskGraphID = lastTaskGraphID; + var tornadoWeights = requireWeightsType(weights, TornadoWeights.class, "LogitsQ8_0Layer", "TornadoTensor"); + this.logitsTaskGraph = setupLogitsTaskGraph(tornadoWeights, (GraniteConfiguration) config); + this.schedulerType = schedulerType; + } + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + var logitsRMS = WorkerGridFactory.createRmsNormWorker(config.dim(), weights instanceof Qwen2TornadoWeights ? 32 : 256); + var vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS; + var vocabWorker = new WorkerGrid1D(vocabSizeRowMajor); + vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1); + tornadoForwardScheduler.addWorkerGrid("logits.vocab_proj", vocabWorker); + tornadoForwardScheduler.addWorkerGrid("logits.rms_reduce", logitsRMS); + tornadoForwardScheduler.addWorkerGrid("logits.mapContextLogits", logitsRMS); + return tornadoForwardScheduler; + } + + // @formatter:off + private TaskGraph setupLogitsTaskGraph(TornadoWeights weights, GraniteConfiguration config) { + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, // + state.wrapLogits, // + weights.wclsByteArray.asByteArray(), // + weights.rms_final_weight_as_floatArray); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, // output: partial sums + final scale factor + state.wrapX, // input: hidden state + config.dim(), // dimension + config.rmsNormEps(), // epsilon for numerical stability + state.localSize); // local workgroup size + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, + config.dim(), + config.rmsNormEps()); + } + logits.task("mapContextLogits", + TransformerComputeKernels::reductionOneBlock2WithLogits, + context, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); + + // === Vocabulary vocab_proj === + logits.task("vocab_proj", GraniteKernels::matrixVectorGenericQ8ByteWithGraniteScale, // + context, + state.wrapX, + state.wrapLogits, + weights.wclsByteArray.asByteArray(), + config.dim(), + config.vocabularySize(), + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, + config.logitScale() + + ); + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + // @formatter:on + + @Override + public GridScheduler getGridScheduler() { + return scheduler; + } + + @Override + public TaskGraph getTaskGraph() { + return logitsTaskGraph; + } + + @Override + public ImmutableTaskGraph getImmutableTaskGraph() { + return immutableLogitsGraph; + } + +}