Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ba7b94a
Add Granite model support in inference pipeline
mikepapadim Dec 17, 2025
a50e0ca
Add Granite model-specific implementations for inference pipeline
mikepapadim Dec 17, 2025
af363a8
Fix Makefile indentation on 'install' target
mikepapadim Dec 17, 2025
c14b086
Refactor `GraniteTokenizer` to simplify special token handling and re…
mikepapadim Dec 17, 2025
b1414cd
Add Granite-specific TornadoVM components and FP16 inference pipelines
mikepapadim Dec 17, 2025
62acbca
Introduce parallelized attention computation in Granite kernels
mikepapadim Dec 17, 2025
8a6578d
Remove unused `LlamaTornadoWeights` import and clean up formatting in…
mikepapadim Dec 17, 2025
79293b3
Update README to include IBM Granite 3.1+ model support and link to G…
mikepapadim Dec 17, 2025
0808a9d
Update README to replace `TORNADO_SDK` with `TORNADOVM_HOME` for cons…
mikepapadim Dec 17, 2025
7803ad2
Update `GraniteFP16LayerPlanner` to use `LogitsGraniteFP16Layer` for …
mikepapadim Dec 17, 2025
a0fe7da
Refactor `LogitsGraniteFP16Layer` to extend `LogitsFP16Layer` and upd…
mikepapadim Dec 17, 2025
4cfcb59
Add TornadoVM Q8_0 support for Granite inference pipeline
mikepapadim Dec 17, 2025
cc965de
Use `GraniteTokenizer` for EOS token retrieval in `GraniteChatFormat`…
mikepapadim Dec 17, 2025
dc69ae4
Add `ropeRotationWithCacheCopy` kernel to `GraniteKernels` for RoPE c…
mikepapadim Dec 17, 2025
1aeb5e1
Refactor `GraniteLoader`: enhance metadata handling, improve formatti…
mikepapadim Dec 17, 2025
34144a5
Replace `TransformerComputeKernelsLayered::ropeRotationWithCacheCopy`…
mikepapadim Dec 17, 2025
a6b6a80
Refactor `GraniteTokenizer`: add support for Granite 4.0, enhance met…
mikepapadim Dec 17, 2025
aafa156
Replace `TransformerComputeKernelsLayered::ropeRotationWithCacheCopy`…
mikepapadim Dec 17, 2025
a8759a2
Update README to add IBM Granite 4.0 collection link
mikepapadim Dec 17, 2025
c1459d9
Update README to include support for Phi-3, IBM Granite 3.2+, and IBM…
mikepapadim Dec 17, 2025
2571614
Add build step for running Granite 3.2 model in CI pipelineAdd workfl…
mikepapadim Dec 17, 2025
a45b532
Add workflow step to run Granite-3.2-2b-instruct-Q8.gguf during CI
mikepapadim Dec 17, 2025
c2b91c4
Add workflow steps to run Granite-4.0-1b-F16.gguf and Granite-4.0-1b-…
mikepapadim Dec 17, 2025
49eb298
Fix logits scaling in `GraniteKernels`: correct scaling order for `hb…
mikepapadim Dec 18, 2025
150c5ee
Remove unused imports from `GraniteFP16LayerPlanner`.
mikepapadim Dec 18, 2025
59eb425
Rename `Granite8_0LayerPlanner` to `GraniteQ8_0LayerPlanner` for cons…
mikepapadim Dec 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions .github/workflows/build-and-run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ jobs:
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model /$MODELS_DIR/Phi-3-mini-4k-instruct-fp16.gguf \
--prompt "Say hello"
- name: FP16 - Run Granite-3.2-2b-instruct-f16.gguf
run: |
cd ${{ github.workspace }}
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model /$MODELS_DIR/granite-3.2-2b-instruct-f16.gguf \
--prompt "Say hello"
- name: FP16 - Run Granite-4.0-1b-F16.gguf
run: |
cd ${{ github.workspace }}
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model /$MODELS_DIR/granite-4.0-1b-F16.gguf \
--prompt "Say hello"
- name: Q8 - Run Llama-3.2-1B-Instruct-Q8_0.gguf
run: |
cd ${{ github.workspace }}
Expand Down Expand Up @@ -163,3 +177,18 @@ jobs:
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model $MODELS_DIR/Mistral-7B-Instruct-v0.3.Q8_0.gguf \
--prompt "Say hello"
- name: Q8 - Run Granite-3.2-2b-instruct-Q8.gguf
run: |
cd ${{ github.workspace }}
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model /$MODELS_DIR/granite-3.2-2b-instruct-Q8_0.gguf \
--prompt "Say hello"
- name: Q8 - Run Granite-4.0-1b-Q8_0.gguf
run: |
cd ${{ github.workspace }}
export PATH="$TORNADO_SDK/bin:$JAVA_HOME/bin:$PATH"
./llama-tornado --gpu --${{ matrix.backend.name }} \
--model /$MODELS_DIR/granite-4.0-1b-Q8_0.gguf \
--prompt "Say hello"

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ clean:
$(MVN) clean

install:
$(MVN) install -DskipTests
$(MVN) install -DskipTests

# Package the project without running tests
package:
Expand Down
15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<strong>Llama3</strong> models written in <strong>native Java</strong> automatically accelerated on GPUs with <a href="https://github.com/beehive-lab/TornadoVM" target="_blank"><strong>TornadoVM</strong></a>.
Runs Llama3 inference efficiently using TornadoVM's GPU acceleration.
<br><br>
Currently, supports <strong>Llama3</strong>, <strong>Mistral</strong>, <strong>Qwen2.5</strong>, <strong>Qwen3</strong> and <strong>Phi3</strong> models in the GGUF format.
Currently, supports <strong>Llama3</strong>, <strong>Mistral</strong>, <strong>Qwen2.5</strong>, <strong>Qwen3</strong>, <strong>Phi-3</strong>, <strong> IBM Granite 3.2+ </strong> and <strong> IBM Granite 4.0 </strong> models in the GGUF format.
Also, it is used as GPU inference engine in
<a href="https://docs.quarkiverse.io/quarkus-langchain4j/dev/gpullama3-chat-model.html" target="_blank">Quarkus</a>
and
Expand Down Expand Up @@ -89,7 +89,7 @@ All pre-built SDKs are available on the TornadoVM [Releases Page](https://github
wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-linux-amd64.zip
unzip tornadovm-2.1.0-opencl-linux-amd64.zip
# Replace <path-to-sdk> manually with the absolute path of the extracted folder
export TORNADO_SDK="<path-to-sdk>/tornadovm-2.1.0-opencl"
export TORNADOVM_HOME="<path-to-sdk>/tornadovm-2.1.0-opencl"
export PATH=$TORNADO_SDK/bin:$PATH

tornado --devices
Expand All @@ -102,7 +102,7 @@ tornado --version
wget https://github.com/beehive-lab/TornadoVM/releases/download/v2.1.0/tornadovm-2.1.0-opencl-mac-aarch64.zip
unzip tornadovm-2.1.0-opencl-mac-aarch64.zip
# Replace <path-to-sdk> manually with the absolute path of the extracted folder
export TORNADO_SDK="<path-to-sdk>/tornadovm-2.1.0-opencl"
export TORNADOVM_HOME="<path-to-sdk>/tornadovm-2.1.0-opencl"
export PATH=$TORNADO_SDK/bin:$PATH

tornado --devices
Expand Down Expand Up @@ -251,7 +251,7 @@ You can run llama-tornado as a pure Java script using [JBang](https://www.jbang.
### Prerequisites for JBang

1. **Install JBang**: Follow the [JBang installation guide](https://www.jbang.dev/download/)
2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADO_SDK` environment variable set (see Setup section above)
2. **TornadoVM SDK**: You still need TornadoVM installed and `TORNADOVM_HOME` environment variable set (see Setup section above)

### Quick Start with JBang

Expand Down Expand Up @@ -295,6 +295,13 @@ jbang LlamaTornadoCli.java -m beehive-llama-3.2-1b-instruct-fp16.gguf \
### Llama3.2 Collection
[https://huggingface.co/collections/beehive-lab/llama3-gpullama3java](https://huggingface.co/collections/beehive-lab/llama3-gpullama3java)

### IBM Granite 4.0 Collection
[https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-40-language-models-gpullama3java)


### IBM Granite 3.3 Collection
[https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java](https://huggingface.co/collections/beehive-lab/granite-33-language-models-gpullama3java)

### Qwen 2.5 Collection
[https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java](https://huggingface.co/collections/beehive-lab/qwen-25-gpullama3java)

Expand Down
122 changes: 122 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
Expand Down Expand Up @@ -546,6 +547,127 @@ public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int toke
return state.logits;
}

/**
* Forward pass for Granite models with µP scaling factors applied.
* <p>
* Granite uses the same transformer architecture as Llama but with maximal update parameterization (µP)
* scaling factors applied at specific points:
* <ul>
* <li>Embedding scaling: multiply embeddings after lookup</li>
* <li>Attention scaling: use custom multiplier instead of 1/sqrt(headDim)</li>
* <li>Residual scaling: multiply residual connections</li>
* <li>Logit scaling: divide logits by the scaling factor</li>
* </ul>
*/
public static FloatTensor forwardGranite(Model model, State state, int token, int position) {
final GraniteConfiguration config = (GraniteConfiguration) model.configuration();
final StandardWeights weights = (StandardWeights) model.weights();
int dim = config.dim();
int headSize = config.headSize();
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads();
float attentionScale = config.attentionScale();
float residualScale = config.residualScale();
float embeddingScale = config.embeddingScale();
float logitScale = config.logitScale();

// copy the token embedding into x
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
// Apply Granite embedding scaling
state.x.mapInPlace(v -> v * embeddingScale);

// forward all the layers
for (int l = 0; l < config.numberOfLayers(); l++) {
// attention rmsnorm
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());

// qkv matmuls for this position
weights.wq[l].matmul(state.xb, state.q, dim, dim);
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);

// RoPE relative positional encoding
for (int i = 0; i < dim; i += 2) {
int head_dim = i % headSize;
float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
int rotn = i < kvDim ? 2 : 1;
for (int v = 0; v < rotn; v++) {
FloatTensor vec = v == 0 ? state.q : state.k;
float v0 = vec.getFloat(i);
float v1 = vec.getFloat(i + 1);
vec.setFloat(i, v0 * fcr - v1 * fci);
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
}
}

// save key,value at this time step to kv cache
state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);

int curLayer = l;

// multihead attention with Granite attention scaling
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
int qOffset = h * headSize;
int attOffset = h * config.contextLength();

for (int t = 0; t <= position; t++) {
int keyCacheOffset = t * kvDim + (h / kvMul) * headSize;
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
// Granite uses custom attention multiplier instead of 1/sqrt(headSize)
score *= attentionScale;
state.att.setFloat(attOffset + t, score);
}

state.att.softmaxInPlace(attOffset, position + 1);

int xbOffset = h * headSize;
state.xb.fillInPlace(xbOffset, headSize, 0f);

for (int t = 0; t <= position; t++) {
int vOffset = t * kvDim + (h / kvMul) * headSize;
float a = state.att.getFloat(attOffset + t);
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
}
});

// final matmul to get the output of the attention
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);

// residual connection with Granite scaling
state.xb2.mapInPlace(v -> v * residualScale);
state.x.addInPlace(state.xb2);

// ffn rmsnorm
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());

// FFN: self.w2(F.silu(self.w1(x)) * self.w3(x))
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);

// SwiGLU non-linearity
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
state.hb.multiplyInPlace(state.hb2);

// final matmul to get the output of the ffn
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());

// residual connection with Granite scaling
state.xb.mapInPlace(v -> v * residualScale);
state.x.addInPlace(state.xb);
}

rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());

weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);

// Apply Granite logit scaling (divide by the scaling factor)
state.logits.mapInPlace(v -> v * logitScale);

return state.logits;
}

static void copyChunk(FloatTensor in, FloatTensor out, int dim1In, int dim1Out, int nChunks, int chunkNo) {
assert (dim1In == dim1Out * nChunks);
final int startOffsetInDim1 = chunkNo * dim1Out;
Expand Down
134 changes: 134 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/InferenceEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -531,4 +531,138 @@ public static List<Integer> generateTokensGPUPhi3(Model model, State state, int

return generatedTokens;
}

/**
* Generates tokens using the Granite model with CPU inference.
* Identical pattern to generateTokensLlama but calls forwardGranite.
*/
public static List<Integer> generateTokensGranite(Model model, State state, int startPosition,
List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated) {
long startNanos = System.nanoTime();
long inferenceStartNanos = 0;

Object logits;
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
maxTokens = model.configuration().contextLength();
}

List<Integer> generatedTokens = new ArrayList<>();

int currentToken = state.latestToken;
int nextToken;
int promptIndex = 0;
int pos = startPosition;

while (pos < maxTokens) {
// Call Granite-specific forward pass
logits = InferenceCore.forwardGranite(model, state, currentToken, pos);

if (promptIndex < promptTokens.size()) {
nextToken = promptTokens.get(promptIndex++);
if (echo) {
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
} else {
if (inferenceStartNanos == 0) {
inferenceStartNanos = System.nanoTime();
}

nextToken = sampler.sampleToken(logits);

if (echo) {
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}

generatedTokens.add(nextToken);

if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}

if (stopTokens.contains(nextToken)) {
break;
}
}

currentToken = nextToken;
state.latestToken = currentToken;
pos++;
}

long endNanos = System.nanoTime();
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
int totalTokens = promptIndex + generatedTokens.size();

LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);

return generatedTokens;
}

/**
* Generates tokens using the Granite model with GPU (TornadoVM) inference.
* Identical pattern to generateTokensGPULlama.
*/
public static List<Integer> generateTokensGPUGranite(Model model, State state, int startPosition,
List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMMasterPlan) {
long startNanos = System.nanoTime();
long inferenceStartNanos = 0;

Object logits;
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
maxTokens = model.configuration().contextLength();
}

List<Integer> generatedTokens = new ArrayList<>();

int currentToken = state.latestToken;
int nextToken;
int promptIndex = 0;
int pos = startPosition;

while (pos < maxTokens) {
// Call TornadoVM forward pass (same as Llama for now)
logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMMasterPlan);

if (promptIndex < promptTokens.size()) {
nextToken = promptTokens.get(promptIndex++);
if (echo) {
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}
} else {
if (inferenceStartNanos == 0) {
inferenceStartNanos = System.nanoTime();
}

nextToken = sampler.sampleToken(logits);

if (echo) {
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
}

generatedTokens.add(nextToken);

if (onTokenGenerated != null) {
onTokenGenerated.accept(nextToken);
}

if (stopTokens.contains(nextToken)) {
break;
}
}

currentToken = nextToken;
state.latestToken = currentToken;
pos++;
}

long endNanos = System.nanoTime();
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
int totalTokens = promptIndex + generatedTokens.size();

LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);

return generatedTokens;
}
}
Loading