Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions src/main/java/com/example/LlamaApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import com.example.inference.engine.impl.Options;
import com.example.loader.weights.ModelLoader;
import com.example.loader.weights.State;
import com.example.tokenizer.impl.Tokenizer;
import com.example.tornadovm.FloatArrayUtils;
import com.example.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand All @@ -29,7 +28,8 @@ public class LlamaApp {
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "false")); // Show performance metrics in interactive mode
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode

/**
* Creates and configures a sampler for token generation based on specified parameters.
*
Expand Down Expand Up @@ -115,7 +115,6 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
return sampler;
}


static void runInteractive(Llama model, Sampler sampler, Options options) {
State state = null;
List<Integer> conversationTokens = new ArrayList<>();
Expand Down Expand Up @@ -162,15 +161,12 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
// Choose between GPU and CPU path based on configuration
if (USE_TORNADOVM) {
// GPU path using TornadoVM
responseTokens = Llama.generateTokensGPU(model, state, startPosition,
conversationTokens.subList(startPosition, conversationTokens.size()),
stopTokens, options.maxTokens(), sampler, options.echo(),
options.stream() ? tokenConsumer : null, tornadoVMPlan);
responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
} else {
// CPU path
responseTokens = Llama.generateTokens(model, state, startPosition,
conversationTokens.subList(startPosition, conversationTokens.size()),
stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
options.echo(), tokenConsumer);
}

// Include stop token in the prompt history, but not in the response displayed to the user.
Expand Down Expand Up @@ -211,7 +207,7 @@ static void runInteractive(Llama model, Sampler sampler, Options options) {
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
State state = model.createNewState();
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
TornadoVMMasterPlan tornadoVMPlan =null;
TornadoVMMasterPlan tornadoVMPlan = null;

List<Integer> promptTokens = new ArrayList<>();
promptTokens.add(chatFormat.beginOfText);
Expand All @@ -233,10 +229,9 @@ static void runInstructOnce(Llama model, Sampler sampler, Options options) {

Set<Integer> stopTokens = chatFormat.getStopTokens();
if (USE_TORNADOVM) {
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
// Call generateTokensGPU without the token consumer parameter
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(),
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
} else {
// CPU path still uses the token consumer
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {

public static final int DEFAULT_MAX_TOKENS = 512;
public static final int DEFAULT_MAX_TOKENS = 1024;

public Options {
require(modelPath != null, "Missing argument: --model <path> is required");
Expand Down