diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index d3c19f06..1e9604f0 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -3624,17 +3624,26 @@ struct server_context { } json model_meta() const { + // Read optional string metadata from GGUF headers; empty string if absent. + auto read_meta_str = [&](const char * key) -> std::string { + char buf[512] = {}; + int32_t n = llama_model_meta_val_str(model, key, buf, sizeof(buf)); + return n >= 0 ? std::string(buf, n) : std::string(); + }; + return json{ - {"vocab_type", llama_vocab_type(vocab)}, - {"n_vocab", llama_vocab_n_tokens(vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd(model)}, - {"n_params", llama_model_n_params(model)}, - {"size", llama_model_size(model)}, - {"modalities", json{ + {"vocab_type", llama_vocab_type(vocab)}, + {"n_vocab", llama_vocab_n_tokens(vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd(model)}, + {"n_params", llama_model_n_params(model)}, + {"size", llama_model_size(model)}, + {"modalities", json{ {"vision", mctx ? mtmd_support_vision(mctx) : false}, {"audio", mctx ? mtmd_support_audio(mctx) : false}, }}, + {"architecture", read_meta_str("general.architecture")}, + {"name", read_meta_str("general.name")}, }; } }; diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index 9d97b91a..70e94401 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -368,16 +368,7 @@ public InferenceParameters setPenaltyPrompt(String penaltyPrompt) { */ public InferenceParameters setPenaltyPrompt(int[] tokens) { if (tokens.length > 0) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - for (int i = 0; i < tokens.length; i++) { - builder.append(tokens[i]); - if (i < tokens.length - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_PENALTY_PROMPT, builder.toString()); + parameters.put(PARAM_PENALTY_PROMPT, serializer.buildIntArray(tokens).toString()); } return this; } @@ -408,7 +399,7 @@ public InferenceParameters setIgnoreEos(boolean ignoreEos) { */ public InferenceParameters setTokenIdBias(Map logitBias) { if (!logitBias.isEmpty()) { - parameters.put(PARAM_LOGIT_BIAS, buildBiasPairArray(logitBias, String::valueOf)); + parameters.put(PARAM_LOGIT_BIAS, serializer.buildTokenIdBiasArray(logitBias).toString()); } return this; } @@ -428,7 +419,7 @@ public InferenceParameters setTokenIdBias(Map logitBias) { */ public InferenceParameters disableTokenIds(Collection tokenIds) { if (!tokenIds.isEmpty()) { - parameters.put(PARAM_LOGIT_BIAS, buildDisablePairArray(tokenIds, String::valueOf)); + parameters.put(PARAM_LOGIT_BIAS, serializer.buildDisableTokenIdArray(tokenIds).toString()); } return this; } @@ -448,7 +439,7 @@ public InferenceParameters disableTokenIds(Collection tokenIds) { */ public InferenceParameters setTokenBias(Map logitBias) { if (!logitBias.isEmpty()) { - parameters.put(PARAM_LOGIT_BIAS, buildBiasPairArray(logitBias, this::toJsonString)); + parameters.put(PARAM_LOGIT_BIAS, serializer.buildTokenStringBiasArray(logitBias).toString()); } return this; } @@ -468,7 +459,7 @@ public InferenceParameters setTokenBias(Map logitBias) { */ public InferenceParameters disableTokens(Collection tokens) { if (!tokens.isEmpty()) { - parameters.put(PARAM_LOGIT_BIAS, buildDisablePairArray(tokens, this::toJsonString)); + parameters.put(PARAM_LOGIT_BIAS, serializer.buildDisableTokenStringArray(tokens).toString()); } return this; } @@ -481,16 +472,7 @@ public InferenceParameters disableTokens(Collection tokens) { */ public InferenceParameters setStopStrings(String... stopStrings) { if (stopStrings.length > 0) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - for (int i = 0; i < stopStrings.length; i++) { - builder.append(toJsonString(stopStrings[i])); - if (i < stopStrings.length - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_STOP, builder.toString()); + parameters.put(PARAM_STOP, serializer.buildStopStrings(stopStrings).toString()); } return this; } @@ -503,29 +485,7 @@ public InferenceParameters setStopStrings(String... stopStrings) { */ public InferenceParameters setSamplers(Sampler... samplers) { if (samplers.length > 0) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - for (int i = 0; i < samplers.length; i++) { - switch (samplers[i]) { - case TOP_K: - builder.append("\"top_k\""); - break; - case TOP_P: - builder.append("\"top_p\""); - break; - case MIN_P: - builder.append("\"min_p\""); - break; - case TEMPERATURE: - builder.append("\"temperature\""); - break; - } - if (i < samplers.length - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_SAMPLERS, builder.toString()); + parameters.put(PARAM_SAMPLERS, serializer.buildSamplers(samplers).toString()); } return this; } @@ -567,7 +527,7 @@ public InferenceParameters setChatTemplate(String chatTemplate) { * @return this builder */ public InferenceParameters setChatTemplateKwargs(java.util.Map kwargs) { - parameters.put(PARAM_CHAT_TEMPLATE_KWARGS, mapToJsonObject(kwargs)); + parameters.put(PARAM_CHAT_TEMPLATE_KWARGS, serializer.buildRawValueObject(kwargs).toString()); return this; } @@ -581,44 +541,7 @@ public InferenceParameters setChatTemplateKwargs(java.util.Map k * @return this builder */ public InferenceParameters setMessages(String systemMessage, List> messages) { - StringBuilder messagesBuilder = new StringBuilder(); - messagesBuilder.append("["); - - // Add system message (if provided) - if (systemMessage != null && !systemMessage.isEmpty()) { - messagesBuilder.append("{\"role\": \"system\", \"content\": ") - .append(toJsonString(systemMessage)) - .append("}"); - if (!messages.isEmpty()) { - messagesBuilder.append(", "); - } - } - - // Add user/assistant messages - for (int i = 0; i < messages.size(); i++) { - Pair message = messages.get(i); - String role = message.getKey(); - String content = message.getValue(); - - if (!role.equals("user") && !role.equals("assistant")) { - throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'."); - } - - messagesBuilder.append("{\"role\":") - .append(toJsonString(role)) - .append(", \"content\": ") - .append(toJsonString(content)) - .append("}"); - - if (i < messages.size() - 1) { - messagesBuilder.append(", "); - } - } - - messagesBuilder.append("]"); - - // Convert ArrayNode to a JSON string and store it in parameters - parameters.put(PARAM_MESSAGES, messagesBuilder.toString()); + parameters.put(PARAM_MESSAGES, serializer.buildMessages(systemMessage, messages).toString()); return this; } @@ -627,38 +550,5 @@ InferenceParameters setStream(boolean stream) { return this; } - private static String buildBiasPairArray(Map map, - java.util.function.Function keySerializer) { - StringBuilder builder = new StringBuilder("["); - int i = 0; - for (Map.Entry entry : map.entrySet()) { - builder.append("[") - .append(keySerializer.apply(entry.getKey())) - .append(", ") - .append(entry.getValue()) - .append("]"); - if (i++ < map.size() - 1) { - builder.append(", "); - } - } - builder.append("]"); - return builder.toString(); - } - - private static String buildDisablePairArray(Collection items, - java.util.function.Function serializer) { - StringBuilder builder = new StringBuilder("["); - int i = 0; - for (T item : items) { - builder.append("[") - .append(serializer.apply(item)) - .append(", false]"); - if (i++ < items.size() - 1) { - builder.append(", "); - } - } - builder.append("]"); - return builder.toString(); - } - } + diff --git a/src/main/java/de/kherud/llama/JsonParameters.java b/src/main/java/de/kherud/llama/JsonParameters.java index fc3e9dd2..aaa87297 100644 --- a/src/main/java/de/kherud/llama/JsonParameters.java +++ b/src/main/java/de/kherud/llama/JsonParameters.java @@ -1,5 +1,7 @@ package de.kherud.llama; +import de.kherud.llama.json.ParameterJsonSerializer; + import java.util.HashMap; import java.util.Map; @@ -14,6 +16,8 @@ abstract class JsonParameters { // The JNI code for a proper Java-typed data object is comparatively too complex and hard to maintain. final Map parameters = new HashMap<>(); + protected final ParameterJsonSerializer serializer = new ParameterJsonSerializer(); + @Override public String toString() { StringBuilder builder = new StringBuilder(); @@ -35,73 +39,8 @@ public String toString() { return builder.toString(); } - static String mapToJsonObject(Map map) { - StringBuilder sb = new StringBuilder("{"); - boolean first = true; - for (Map.Entry entry : map.entrySet()) { - if (!first) sb.append(","); - sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue()); - first = false; - } - sb.append("}"); - return sb.toString(); - } - - // taken from org.json.JSONObject#quote(String, Writer) String toJsonString(String text) { if (text == null) return null; - StringBuilder builder = new StringBuilder((text.length()) + 2); - - char b; - char c = 0; - String hhhh; - int i; - int len = text.length(); - - builder.append('"'); - for (i = 0; i < len; i += 1) { - b = c; - c = text.charAt(i); - switch (c) { - case '\\': - case '"': - builder.append('\\'); - builder.append(c); - break; - case '/': - if (b == '<') { - builder.append('\\'); - } - builder.append(c); - break; - case '\b': - builder.append("\\b"); - break; - case '\t': - builder.append("\\t"); - break; - case '\n': - builder.append("\\n"); - break; - case '\f': - builder.append("\\f"); - break; - case '\r': - builder.append("\\r"); - break; - default: - if (c < ' ' || (c >= '\u0080' && c < '\u00a0') || (c >= '\u2000' && c < '\u2100')) { - builder.append("\\u"); - hhhh = Integer.toHexString(c); - builder.append("0000", 0, 4 - hhhh.length()); - builder.append(hhhh); - } - else { - builder.append(c); - } - } - } - builder.append('"'); - return builder.toString(); + return serializer.toJsonString(text); } } diff --git a/src/main/java/de/kherud/llama/LlamaIterable.java b/src/main/java/de/kherud/llama/LlamaIterable.java index 7e6dff89..24384db4 100644 --- a/src/main/java/de/kherud/llama/LlamaIterable.java +++ b/src/main/java/de/kherud/llama/LlamaIterable.java @@ -3,13 +3,44 @@ import org.jetbrains.annotations.NotNull; /** - * An iterable used by {@link LlamaModel#generate(InferenceParameters)} that specifically returns a {@link LlamaIterator}. + * An {@link Iterable} wrapper around {@link LlamaIterator} returned by + * {@link LlamaModel#generate(InferenceParameters)} and {@link LlamaModel#generateChat(InferenceParameters)}. + * + *

Implements {@link AutoCloseable} so that a try-with-resources block automatically cancels + * any in-progress generation when the loop exits early (e.g. via {@code break}), preventing the + * native task slot from leaking: + * + *

{@code
+ * try (LlamaIterable it = model.generate(params)) {
+ *     for (LlamaOutput o : it) {
+ *         if (done) break;   // close() cancels the native task automatically
+ *     }
+ * }
+ * }
+ * + *

A plain for-each loop without try-with-resources continues to work; the {@link #close()} + * method just will not be called on early exit in that case. */ -@FunctionalInterface -public interface LlamaIterable extends Iterable { +public final class LlamaIterable implements Iterable, AutoCloseable { + + private final LlamaIterator iterator; + + LlamaIterable(LlamaIterator iterator) { + this.iterator = iterator; + } @NotNull @Override - LlamaIterator iterator(); + public LlamaIterator iterator() { + return iterator; + } + /** + * Cancels any in-progress generation. Delegates to {@link LlamaIterator#close()}. + * Safe to call multiple times. + */ + @Override + public void close() { + iterator.close(); + } } diff --git a/src/main/java/de/kherud/llama/LlamaIterator.java b/src/main/java/de/kherud/llama/LlamaIterator.java index c238298e..b431a54b 100644 --- a/src/main/java/de/kherud/llama/LlamaIterator.java +++ b/src/main/java/de/kherud/llama/LlamaIterator.java @@ -1,5 +1,6 @@ package de.kherud.llama; +import de.kherud.llama.json.CompletionResponseParser; import java.util.Iterator; import java.util.NoSuchElementException; @@ -7,11 +8,16 @@ * This iterator is used by {@link LlamaModel#generate(InferenceParameters)} and * {@link LlamaModel#generateChat(InferenceParameters)}. In addition to implementing {@link Iterator}, * it allows to cancel ongoing inference (see {@link #cancel()}). + * + *

{@link LlamaIterator} implements {@link AutoCloseable}. When used via {@link LlamaIterable} + * inside a try-with-resources block, {@link #close()} is called automatically on early exit + * (e.g. {@code break}), preventing the native task slot from leaking. */ -public final class LlamaIterator implements Iterator { +public final class LlamaIterator implements Iterator, AutoCloseable { private final LlamaModel model; private final int taskId; + private final CompletionResponseParser completionParser = new CompletionResponseParser(); private boolean hasNext = true; @@ -38,7 +44,7 @@ public LlamaOutput next() { throw new NoSuchElementException(); } String json = model.receiveCompletionJson(taskId); - LlamaOutput output = LlamaOutput.fromJson(json); + LlamaOutput output = completionParser.parse(json); hasNext = !output.stop; if (output.stop) { model.releaseTask(taskId); @@ -53,4 +59,18 @@ public void cancel() { model.cancelCompletion(taskId); hasNext = false; } + + /** + * Cancels any in-progress generation if the iterator has not yet reached a stop token. + * Safe to call multiple times — subsequent calls are no-ops. + * + *

Prefer using the enclosing {@link LlamaIterable} in a try-with-resources block rather + * than calling this directly. + */ + @Override + public void close() { + if (hasNext) { + cancel(); + } + } } diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index c633aa48..31b8a763 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -1,6 +1,9 @@ package de.kherud.llama; import de.kherud.llama.args.LogFormat; +import de.kherud.llama.json.ChatResponseParser; +import de.kherud.llama.json.CompletionResponseParser; +import de.kherud.llama.json.RerankResponseParser; import java.lang.annotation.Native; import java.nio.charset.StandardCharsets; import java.util.HashMap; @@ -30,6 +33,10 @@ public class LlamaModel implements AutoCloseable { @Native private long ctx; + private final CompletionResponseParser completionParser = new CompletionResponseParser(); + private final ChatResponseParser chatParser = new ChatResponseParser(); + private final RerankResponseParser rerankParser = new RerankResponseParser(); + /** * Load with the given {@link ModelParameters}. Make sure to either set *

    @@ -56,7 +63,7 @@ public String complete(InferenceParameters parameters) { parameters.setStream(false); int taskId = requestCompletion(parameters.toString()); String json = receiveCompletionJson(taskId); - return LlamaOutput.getContentFromJson(json); + return completionParser.parse(json).text; } /** @@ -67,7 +74,7 @@ public String complete(InferenceParameters parameters) { * @return iterable LLM outputs */ public LlamaIterable generate(InferenceParameters parameters) { - return () -> new LlamaIterator(this, parameters); + return new LlamaIterable(new LlamaIterator(this, parameters)); } @@ -157,7 +164,7 @@ public static String jsonSchemaToGrammar(String schema) { */ public List> rerank(boolean reRank, String query, String... documents) { String json = handleRerank(query, documents); - List> rankedDocuments = LlamaOutput.parseRerankResults(json); + List> rankedDocuments = rerankParser.parse(json); if (reRank) { rankedDocuments.sort((a, b) -> Float.compare(b.getValue(), a.getValue())); } @@ -174,12 +181,12 @@ public List> rerank(boolean reRank, String query, String... */ public LlamaOutput rerank(String query, String... documents) { String json = handleRerank(query, documents); - List> results = LlamaOutput.parseRerankResults(json); + List> results = rerankParser.parse(json); Map probabilities = new HashMap<>(); for (Pair pair : results) { probabilities.put(pair.getKey(), pair.getValue()); } - return new LlamaOutput(query, probabilities, true); + return new LlamaOutput(query, probabilities, true, StopReason.EOS); } native String handleRerank(String query, String... documents) throws LlamaException; @@ -222,6 +229,20 @@ public String chatComplete(InferenceParameters parameters) { return handleChatCompletions(parameters.toString()); } + /** + * Run an OpenAI-compatible chat completion and return only the assistant's text content. + * This is the plain-string equivalent of {@link #chatComplete(InferenceParameters)}, which + * returns the raw OAI JSON. Use this when you want the generated text directly, the same + * way {@link #complete(InferenceParameters)} works for raw completions. + * + * @param parameters the inference parameters including messages + * @return the assistant's reply text (extracted from {@code choices[0].message.content}) + * @throws LlamaException if the model was loaded in embedding mode or if inference fails + */ + public String chatCompleteText(InferenceParameters parameters) { + return chatParser.extractChoiceContent(chatComplete(parameters)); + } + /** * Stream an OpenAI-compatible chat completion token by token. The parameters must contain a * "messages" array in the standard OpenAI chat format. The model's chat template is automatically applied. @@ -245,7 +266,7 @@ public String chatComplete(InferenceParameters parameters) { * @throws LlamaException if inference fails */ public LlamaIterable generateChat(InferenceParameters parameters) { - return () -> new LlamaIterator(this, parameters, true); + return new LlamaIterable(new LlamaIterator(this, parameters, true)); } /** diff --git a/src/main/java/de/kherud/llama/LlamaOutput.java b/src/main/java/de/kherud/llama/LlamaOutput.java index 00793235..328413f9 100644 --- a/src/main/java/de/kherud/llama/LlamaOutput.java +++ b/src/main/java/de/kherud/llama/LlamaOutput.java @@ -2,10 +2,6 @@ import org.jetbrains.annotations.NotNull; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; import java.util.Map; /** @@ -27,153 +23,25 @@ public final class LlamaOutput { @NotNull public final Map probabilities; - final boolean stop; + /** Whether this is the final token of the generation. */ + public final boolean stop; - LlamaOutput(@NotNull String text, @NotNull Map probabilities, boolean stop) { + /** + * The reason generation stopped. {@link StopReason#NONE} on intermediate streaming tokens. + * Only meaningful when {@link #stop} is {@code true}. + */ + @NotNull + public final StopReason stopReason; + + public LlamaOutput(@NotNull String text, @NotNull Map probabilities, boolean stop, @NotNull StopReason stopReason) { this.text = text; this.probabilities = probabilities; this.stop = stop; + this.stopReason = stopReason; } @Override public String toString() { return text; } - - /** - * Parse a LlamaOutput from a JSON string returned by the native receiveCompletionJson method. - * The JSON has the structure: {"content": "...", "stop": true/false, ...} - */ - static LlamaOutput fromJson(String json) { - String content = getContentFromJson(json); - boolean stop = json.contains("\"stop\":true"); - Map probabilities = parseProbabilities(json); - return new LlamaOutput(content, probabilities, stop); - } - - /** - * Extract the "content" field from a JSON response string. - */ - static String getContentFromJson(String json) { - // Find "content":"..." or "content": "..." - int keyIdx = json.indexOf("\"content\""); - if (keyIdx < 0) { - return ""; - } - int colonIdx = json.indexOf(':', keyIdx + 9); - if (colonIdx < 0) { - return ""; - } - int startQuote = json.indexOf('"', colonIdx + 1); - if (startQuote < 0) { - return ""; - } - StringBuilder sb = new StringBuilder(); - for (int i = startQuote + 1; i < json.length(); i++) { - char c = json.charAt(i); - if (c == '\\' && i + 1 < json.length()) { - char next = json.charAt(i + 1); - switch (next) { - case '"': sb.append('"'); i++; break; - case '\\': sb.append('\\'); i++; break; - case '/': sb.append('/'); i++; break; - case 'n': sb.append('\n'); i++; break; - case 'r': sb.append('\r'); i++; break; - case 't': sb.append('\t'); i++; break; - case 'b': sb.append('\b'); i++; break; - case 'f': sb.append('\f'); i++; break; - case 'u': - if (i + 5 < json.length()) { - String hex = json.substring(i + 2, i + 6); - sb.append((char) Integer.parseInt(hex, 16)); - i += 5; - } - break; - default: sb.append('\\').append(next); i++; break; - } - } else if (c == '"') { - break; - } else { - sb.append(c); - } - } - return sb.toString(); - } - - /** - * Parse token probabilities from a JSON response. Returns an empty map if no probabilities are present. - */ - private static Map parseProbabilities(String json) { - if (!json.contains("\"completion_probabilities\"")) { - return Collections.emptyMap(); - } - // For now, return empty map. Full probability parsing can be added later if needed. - // The probabilities data is available in the raw JSON for advanced users. - return Collections.emptyMap(); - } - - /** - * Parse rerank results from a JSON array string. - * Expected format: [{"document": "...", "index": 0, "score": 0.95}, ...] - */ - static List> parseRerankResults(String json) { - List> results = new ArrayList<>(); - // Simple parser for the known JSON array structure - int idx = 0; - while ((idx = json.indexOf("\"document\"", idx)) >= 0) { - // Extract document string - int colonIdx = json.indexOf(':', idx + 10); - int startQuote = json.indexOf('"', colonIdx + 1); - int endQuote = findEndQuote(json, startQuote + 1); - String document = unescapeJson(json.substring(startQuote + 1, endQuote)); - - // Extract score - int scoreIdx = json.indexOf("\"score\"", endQuote); - if (scoreIdx < 0) break; - int scoreColon = json.indexOf(':', scoreIdx + 7); - int scoreStart = scoreColon + 1; - // Skip whitespace - while (scoreStart < json.length() && json.charAt(scoreStart) == ' ') scoreStart++; - int scoreEnd = scoreStart; - while (scoreEnd < json.length() && (Character.isDigit(json.charAt(scoreEnd)) || json.charAt(scoreEnd) == '.' || json.charAt(scoreEnd) == '-' || json.charAt(scoreEnd) == 'e' || json.charAt(scoreEnd) == 'E' || json.charAt(scoreEnd) == '+')) scoreEnd++; - float score = Float.parseFloat(json.substring(scoreStart, scoreEnd)); - - results.add(new Pair<>(document, score)); - idx = scoreEnd; - } - return results; - } - - private static int findEndQuote(String s, int from) { - for (int i = from; i < s.length(); i++) { - if (s.charAt(i) == '\\') { - i++; // skip escaped char - } else if (s.charAt(i) == '"') { - return i; - } - } - return s.length(); - } - - private static String unescapeJson(String s) { - if (!s.contains("\\")) return s; - StringBuilder sb = new StringBuilder(s.length()); - for (int i = 0; i < s.length(); i++) { - char c = s.charAt(i); - if (c == '\\' && i + 1 < s.length()) { - char next = s.charAt(i + 1); - switch (next) { - case '"': sb.append('"'); i++; break; - case '\\': sb.append('\\'); i++; break; - case 'n': sb.append('\n'); i++; break; - case 'r': sb.append('\r'); i++; break; - case 't': sb.append('\t'); i++; break; - default: sb.append(c); break; - } - } else { - sb.append(c); - } - } - return sb.toString(); - } } diff --git a/src/main/java/de/kherud/llama/ModelMeta.java b/src/main/java/de/kherud/llama/ModelMeta.java index 0e31ae38..9981603d 100644 --- a/src/main/java/de/kherud/llama/ModelMeta.java +++ b/src/main/java/de/kherud/llama/ModelMeta.java @@ -60,6 +60,23 @@ public boolean supportsAudio() { return node.at("/modalities/audio").asBoolean(false); } + /** + * The model architecture string from GGUF {@code general.architecture} metadata + * (e.g. {@code "llama"}, {@code "gemma3"}, {@code "mistral"}). + * Returns an empty string if the field is absent in the GGUF file. + */ + public String getArchitecture() { + return node.path("architecture").asText(""); + } + + /** + * The human-readable model name from GGUF {@code general.name} metadata. + * Returns an empty string if the field is absent in the GGUF file. + */ + public String getModelName() { + return node.path("name").asText(""); + } + /** * Returns the underlying {@link JsonNode} for direct access to any field, * including fields added in future llama.cpp versions. diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index b1659d74..0aabec8c 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -1,6 +1,7 @@ package de.kherud.llama; import de.kherud.llama.args.*; +import de.kherud.llama.json.ParameterJsonSerializer; /*** * Parameters used for initializing a {@link LlamaModel}. @@ -8,6 +9,8 @@ @SuppressWarnings("unused") public final class ModelParameters extends CliParameters { + private final ParameterJsonSerializer serializer = new ParameterJsonSerializer(); + private static final String ARG_FIT = "--fit"; static final String ARG_POOLING = "--pooling"; public static final String FIT_ON = "on"; @@ -233,8 +236,7 @@ public ModelParameters setKeep(int keep) { * @return this builder */ public ModelParameters disableContextShift() { - parameters.put("--no-context-shift", null); - return this; + return setFlag(ModelFlag.NO_CONTEXT_SHIFT); } /** @@ -243,8 +245,7 @@ public ModelParameters disableContextShift() { * @return this builder */ public ModelParameters enableFlashAttn() { - parameters.put("--flash-attn", null); - return this; + return setFlag(ModelFlag.FLASH_ATTN); } /** @@ -253,8 +254,7 @@ public ModelParameters enableFlashAttn() { * @return this builder */ public ModelParameters disablePerf() { - parameters.put("--no-perf", null); - return this; + return setFlag(ModelFlag.NO_PERF); } /** @@ -263,8 +263,7 @@ public ModelParameters disablePerf() { * @return this builder */ public ModelParameters enableEscape() { - parameters.put("--escape", null); - return this; + return setFlag(ModelFlag.ESCAPE); } /** @@ -273,8 +272,7 @@ public ModelParameters enableEscape() { * @return this builder */ public ModelParameters disableEscape() { - parameters.put("--no-escape", null); - return this; + return setFlag(ModelFlag.NO_ESCAPE); } /** @@ -283,8 +281,7 @@ public ModelParameters disableEscape() { * @return this builder */ public ModelParameters enableSpecial() { - parameters.put("--special", null); - return this; + return setFlag(ModelFlag.SPECIAL); } /** @@ -293,8 +290,7 @@ public ModelParameters enableSpecial() { * @return this builder */ public ModelParameters skipWarmup() { - parameters.put("--no-warmup", null); - return this; + return setFlag(ModelFlag.NO_WARMUP); } /** @@ -304,8 +300,7 @@ public ModelParameters skipWarmup() { * @return this builder */ public ModelParameters setSpmInfill() { - parameters.put("--spm-infill", null); - return this; + return setFlag(ModelFlag.SPM_INFILL); } /** @@ -318,8 +313,7 @@ public ModelParameters setSamplers(Sampler... samplers) { if (samplers.length > 0) { StringBuilder builder = new StringBuilder(); for (int i = 0; i < samplers.length; i++) { - Sampler sampler = samplers[i]; - builder.append(sampler.name().toLowerCase()); + builder.append(samplers[i].getArgValue()); if (i < samplers.length - 1) { builder.append(";"); } @@ -346,8 +340,7 @@ public ModelParameters setSeed(long seed) { * @return this builder */ public ModelParameters ignoreEos() { - parameters.put("--ignore-eos", null); - return this; + return setFlag(ModelFlag.IGNORE_EOS); } /** @@ -561,7 +554,7 @@ public ModelParameters setDynatempExponent(float dynatempExponent) { * @return this builder */ public ModelParameters setMirostat(MiroStat mirostat) { - parameters.put("--mirostat", String.valueOf(mirostat.ordinal())); + parameters.put("--mirostat", mirostat.getArgValue()); return this; } @@ -774,8 +767,7 @@ public ModelParameters setGrpAttnW(int grpAttnW) { * @return this builder */ public ModelParameters enableDumpKvCache() { - parameters.put("--dump-kv-cache", null); - return this; + return setFlag(ModelFlag.DUMP_KV_CACHE); } /** @@ -784,8 +776,7 @@ public ModelParameters enableDumpKvCache() { * @return this builder */ public ModelParameters disableKvOffload() { - parameters.put("--no-kv-offload", null); - return this; + return setFlag(ModelFlag.NO_KV_OFFLOAD); } /** @@ -795,7 +786,7 @@ public ModelParameters disableKvOffload() { * @return this builder */ public ModelParameters setCacheTypeK(CacheType type) { - parameters.put("--cache-type-k", type.name().toLowerCase()); + parameters.put("--cache-type-k", type.getArgValue()); return this; } @@ -806,7 +797,7 @@ public ModelParameters setCacheTypeK(CacheType type) { * @return this builder */ public ModelParameters setCacheTypeV(CacheType type) { - parameters.put("--cache-type-v", type.name().toLowerCase()); + parameters.put("--cache-type-v", type.getArgValue()); return this; } @@ -838,8 +829,7 @@ public ModelParameters setParallel(int nParallel) { * @return this builder */ public ModelParameters enableContBatching() { - parameters.put("--cont-batching", null); - return this; + return setFlag(ModelFlag.CONT_BATCHING); } /** @@ -848,8 +838,7 @@ public ModelParameters enableContBatching() { * @return this builder */ public ModelParameters disableContBatching() { - parameters.put("--no-cont-batching", null); - return this; + return setFlag(ModelFlag.NO_CONT_BATCHING); } /** @@ -858,8 +847,7 @@ public ModelParameters disableContBatching() { * @return this builder */ public ModelParameters enableMlock() { - parameters.put("--mlock", null); - return this; + return setFlag(ModelFlag.MLOCK); } /** @@ -868,8 +856,7 @@ public ModelParameters enableMlock() { * @return this builder */ public ModelParameters disableMmap() { - parameters.put("--no-mmap", null); - return this; + return setFlag(ModelFlag.NO_MMAP); } /** @@ -879,7 +866,7 @@ public ModelParameters disableMmap() { * @return this builder */ public ModelParameters setNuma(NumaStrategy numaStrategy) { - parameters.put("--numa", numaStrategy.name().toLowerCase()); + parameters.put("--numa", numaStrategy.getArgValue()); return this; } @@ -912,7 +899,7 @@ public ModelParameters setGpuLayers(int gpuLayers) { * @return this builder */ public ModelParameters setSplitMode(GpuSplitMode splitMode) { - parameters.put("--split-mode", splitMode.name().toLowerCase()); + parameters.put("--split-mode", splitMode.getArgValue()); return this; } @@ -944,8 +931,7 @@ public ModelParameters setMainGpu(int mainGpu) { * @return this builder */ public ModelParameters enableCheckTensors() { - parameters.put("--check-tensors", null); - return this; + return setFlag(ModelFlag.CHECK_TENSORS); } /** @@ -1100,8 +1086,7 @@ public ModelParameters setHfToken(String hfToken) { * @return this builder */ public ModelParameters enableEmbedding() { - parameters.put("--embedding", null); - return this; + return setFlag(ModelFlag.EMBEDDING); } /** @@ -1110,8 +1095,7 @@ public ModelParameters enableEmbedding() { * @return this builder */ public ModelParameters enableReranking() { - parameters.put("--reranking", null); - return this; + return setFlag(ModelFlag.RERANKING); } /** @@ -1164,7 +1148,7 @@ public ModelParameters setChatTemplate(String chatTemplate) { * @return this builder */ public ModelParameters setChatTemplateKwargs(java.util.Map kwargs) { - parameters.put("--chat-template-kwargs", JsonParameters.mapToJsonObject(kwargs)); + parameters.put("--chat-template-kwargs", serializer.buildRawValueObject(kwargs).toString()); return this; } @@ -1185,8 +1169,7 @@ public ModelParameters setSlotPromptSimilarity(float similarity) { * @return this builder */ public ModelParameters setLoraInitWithoutApply() { - parameters.put("--lora-init-without-apply", null); - return this; + return setFlag(ModelFlag.LORA_INIT_WITHOUT_APPLY); } /** @@ -1195,8 +1178,7 @@ public ModelParameters setLoraInitWithoutApply() { * @return this builder */ public ModelParameters disableLog() { - parameters.put("--log-disable", null); - return this; + return setFlag(ModelFlag.LOG_DISABLE); } /** @@ -1216,8 +1198,7 @@ public ModelParameters setLogFile(String logFile) { * @return this builder */ public ModelParameters setVerbose() { - parameters.put("--verbose", null); - return this; + return setFlag(ModelFlag.VERBOSE); } /** @@ -1237,8 +1218,7 @@ public ModelParameters setLogVerbosity(int verbosity) { * @return this builder */ public ModelParameters enableLogPrefix() { - parameters.put("--log-prefix", null); - return this; + return setFlag(ModelFlag.LOG_PREFIX); } /** @@ -1247,8 +1227,7 @@ public ModelParameters enableLogPrefix() { * @return this builder */ public ModelParameters enableLogTimestamps() { - parameters.put("--log-timestamps", null); - return this; + return setFlag(ModelFlag.LOG_TIMESTAMPS); } /** @@ -1334,8 +1313,7 @@ public ModelParameters setModelDraft(String modelDraft) { * @return this builder */ public ModelParameters enableJinja() { - parameters.put("--jinja", null); - return this; + return setFlag(ModelFlag.JINJA); } /** @@ -1346,8 +1324,7 @@ public ModelParameters enableJinja() { * @return this builder */ public ModelParameters setVocabOnly() { - parameters.put("--vocab-only", null); - return this; + return setFlag(ModelFlag.VOCAB_ONLY); } /** @@ -1361,8 +1338,8 @@ public ModelParameters setVocabOnly() { * @return this builder */ public ModelParameters setKvUnified(boolean kvUnified) { - parameters.put(kvUnified ? "--kv-unified" : "--no-kv-unified", null); - parameters.remove(kvUnified ? "--no-kv-unified" : "--kv-unified"); + setFlag(kvUnified ? ModelFlag.KV_UNIFIED : ModelFlag.NO_KV_UNIFIED); + clearFlag(kvUnified ? ModelFlag.NO_KV_UNIFIED : ModelFlag.KV_UNIFIED); return this; } @@ -1399,8 +1376,32 @@ public ModelParameters setCacheRamMib(int cacheRamMib) { * @return this builder */ public ModelParameters setClearIdle(boolean clearIdle) { - parameters.put(clearIdle ? "--clear-idle" : "--no-clear-idle", null); - parameters.remove(clearIdle ? "--no-clear-idle" : "--clear-idle"); + setFlag(clearIdle ? ModelFlag.CLEAR_IDLE : ModelFlag.NO_CLEAR_IDLE); + clearFlag(clearIdle ? ModelFlag.NO_CLEAR_IDLE : ModelFlag.CLEAR_IDLE); + return this; + } + + /** + * Enable the given flag, adding it to the active parameter set. + * Equivalent to calling the specific named method (e.g. {@link #enableFlashAttn()} + * for {@link ModelFlag#FLASH_ATTN}). + * + * @param flag the flag to enable + * @return this builder + */ + public ModelParameters setFlag(ModelFlag flag) { + parameters.put(flag.getCliFlag(), null); + return this; + } + + /** + * Remove the given flag from the active parameter set. + * + * @param flag the flag to remove + * @return this builder + */ + public ModelParameters clearFlag(ModelFlag flag) { + parameters.remove(flag.getCliFlag()); return this; } diff --git a/src/main/java/de/kherud/llama/StopReason.java b/src/main/java/de/kherud/llama/StopReason.java new file mode 100644 index 00000000..9e595b0e --- /dev/null +++ b/src/main/java/de/kherud/llama/StopReason.java @@ -0,0 +1,62 @@ +package de.kherud.llama; + +/** + * The reason why token generation stopped for a {@link LlamaOutput}. + * + *
      + *
    • {@link #NONE} — generation has not stopped yet (intermediate streaming token); + * {@link #getStopType()} returns {@code null}.
    • + *
    • {@link #EOS} — the model produced the end-of-sequence token.
    • + *
    • {@link #STOP_STRING} — a caller-specified stop string was matched.
    • + *
    • {@link #MAX_TOKENS} — the token budget ({@code nPredict} or context limit) was exhausted; + * the response was truncated.
    • + *
    + */ +public enum StopReason { + + /** No stop yet; the {@code "stop_type"} field is absent for intermediate tokens. */ + NONE(null), + + /** End-of-sequence token produced. Server {@code "stop_type"} value: {@code "eos"}. */ + EOS("eos"), + + /** A caller-supplied stop string was matched. Server {@code "stop_type"} value: {@code "word"}. */ + STOP_STRING("word"), + + /** Token budget exhausted. Server {@code "stop_type"} value: {@code "limit"}. */ + MAX_TOKENS("limit"); + + private final String stopType; + + StopReason(String stopType) { + this.stopType = stopType; + } + + /** + * Returns the {@code "stop_type"} string used by the native server for this constant, + * or {@code null} for {@link #NONE} (intermediate tokens carry no stop-type field). + * + * @return the stop-type string, or {@code null} for {@link #NONE} + */ + public String getStopType() { + return stopType; + } + + /** + * Map a raw {@code "stop_type"} string from the native server to a {@link StopReason}. + * Pass the already-extracted field value, e.g. + * {@code node.path("stop_type").asText("")}. + * + * @param stopType the raw stop-type string, or {@code null} / empty for absent field + * @return the corresponding {@link StopReason}, or {@link #NONE} if unrecognised + */ + public static StopReason fromStopType(String stopType) { + if (stopType == null) return NONE; + switch (stopType) { + case "eos": return EOS; + case "word": return STOP_STRING; + case "limit": return MAX_TOKENS; + default: return NONE; + } + } +} diff --git a/src/main/java/de/kherud/llama/args/CacheType.java b/src/main/java/de/kherud/llama/args/CacheType.java index 8404ed75..4e18b4f7 100644 --- a/src/main/java/de/kherud/llama/args/CacheType.java +++ b/src/main/java/de/kherud/llama/args/CacheType.java @@ -1,15 +1,28 @@ package de.kherud.llama.args; -public enum CacheType { - - F32, - F16, - BF16, - Q8_0, - Q4_0, - Q4_1, - IQ4_NL, - Q5_0, - Q5_1 +/** + * KV cache quantization type for {@code --cache-type-k} and {@code --cache-type-v}. + */ +public enum CacheType implements CliArg { + F32("f32"), + F16("f16"), + BF16("bf16"), + Q8_0("q8_0"), + Q4_0("q4_0"), + Q4_1("q4_1"), + IQ4_NL("iq4_nl"), + Q5_0("q5_0"), + Q5_1("q5_1"); + + private final String argValue; + + CacheType(String argValue) { + this.argValue = argValue; + } + + @Override + public String getArgValue() { + return argValue; + } } diff --git a/src/main/java/de/kherud/llama/args/CliArg.java b/src/main/java/de/kherud/llama/args/CliArg.java new file mode 100644 index 00000000..285be24c --- /dev/null +++ b/src/main/java/de/kherud/llama/args/CliArg.java @@ -0,0 +1,19 @@ +package de.kherud.llama.args; + +/** + * Implemented by every enum in this package that maps to a CLI argument value. + * + *

    The contract: {@link #getArgValue()} returns the exact string accepted by the + * corresponding llama.cpp CLI argument (e.g. {@code "q8_0"} for {@code --cache-type-k q8_0}). + * Callers pass this string directly to {@code parameters.put("--flag", arg.getArgValue())} + * without any post-processing. + */ +public interface CliArg { + + /** + * Returns the CLI argument value string for this constant. + * + * @return the value string accepted by the corresponding llama.cpp CLI argument + */ + String getArgValue(); +} diff --git a/src/main/java/de/kherud/llama/args/GpuSplitMode.java b/src/main/java/de/kherud/llama/args/GpuSplitMode.java index 0c0cd934..8994c417 100644 --- a/src/main/java/de/kherud/llama/args/GpuSplitMode.java +++ b/src/main/java/de/kherud/llama/args/GpuSplitMode.java @@ -1,8 +1,22 @@ package de.kherud.llama.args; -public enum GpuSplitMode { +/** + * GPU tensor split mode for {@code --split-mode}. + */ +public enum GpuSplitMode implements CliArg { - NONE, - LAYER, - ROW + NONE("none"), + LAYER("layer"), + ROW("row"); + + private final String argValue; + + GpuSplitMode(String argValue) { + this.argValue = argValue; + } + + @Override + public String getArgValue() { + return argValue; + } } diff --git a/src/main/java/de/kherud/llama/args/MiroStat.java b/src/main/java/de/kherud/llama/args/MiroStat.java index 5268d9bc..262e0427 100644 --- a/src/main/java/de/kherud/llama/args/MiroStat.java +++ b/src/main/java/de/kherud/llama/args/MiroStat.java @@ -1,8 +1,25 @@ package de.kherud.llama.args; -public enum MiroStat { +/** + * Mirostat sampling mode for {@code --mirostat}. + * + *

    The arg values ({@code "0"}, {@code "1"}, {@code "2"}) are the integer strings + * accepted by the CLI flag, matching llama.cpp's {@code MIROSTAT_*} constants. + */ +public enum MiroStat implements CliArg { - DISABLED, - V1, - V2 + DISABLED("0"), + V1("1"), + V2("2"); + + private final String argValue; + + MiroStat(String argValue) { + this.argValue = argValue; + } + + @Override + public String getArgValue() { + return argValue; + } } diff --git a/src/main/java/de/kherud/llama/args/ModelFlag.java b/src/main/java/de/kherud/llama/args/ModelFlag.java new file mode 100644 index 00000000..056b9260 --- /dev/null +++ b/src/main/java/de/kherud/llama/args/ModelFlag.java @@ -0,0 +1,115 @@ +package de.kherud.llama.args; + +/** + * Boolean CLI flags for {@link de.kherud.llama.ModelParameters}. + * + *

    Each constant maps to a single CLI argument that takes no value — its presence + * alone enables the behaviour. Pass to + * {@link de.kherud.llama.ModelParameters#setFlag(ModelFlag)} / + * {@link de.kherud.llama.ModelParameters#clearFlag(ModelFlag)} for programmatic control, + * or use the named convenience methods (e.g. {@link de.kherud.llama.ModelParameters#enableFlashAttn()}). + */ +public enum ModelFlag { + + /** Disable context shift on infinite text generation. */ + NO_CONTEXT_SHIFT("--no-context-shift"), + + /** Enable Flash Attention. */ + FLASH_ATTN("--flash-attn"), + + /** Disable internal libllama performance timings. */ + NO_PERF("--no-perf"), + + /** Process escape sequences (e.g. {@code \\n}, {@code \\t}). */ + ESCAPE("--escape"), + + /** Do not process escape sequences. */ + NO_ESCAPE("--no-escape"), + + /** Enable special tokens in output. */ + SPECIAL("--special"), + + /** Skip warming up the model with an empty run. */ + NO_WARMUP("--no-warmup"), + + /** Use Suffix/Prefix/Middle infill pattern instead of Prefix/Suffix/Middle. */ + SPM_INFILL("--spm-infill"), + + /** Ignore end-of-stream token and continue generating. */ + IGNORE_EOS("--ignore-eos"), + + /** Enable verbose printing of the KV cache. */ + DUMP_KV_CACHE("--dump-kv-cache"), + + /** Disable KV offload. */ + NO_KV_OFFLOAD("--no-kv-offload"), + + /** Enable continuous (dynamic) batching. */ + CONT_BATCHING("--cont-batching"), + + /** Disable continuous batching. */ + NO_CONT_BATCHING("--no-cont-batching"), + + /** Force system to keep model in RAM rather than swapping or compressing. */ + MLOCK("--mlock"), + + /** Do not memory-map model (slower load but may reduce pageouts if not using mlock). */ + NO_MMAP("--no-mmap"), + + /** Enable checking model tensor data for invalid values. */ + CHECK_TENSORS("--check-tensors"), + + /** Enable embedding use case; use only with dedicated embedding models. */ + EMBEDDING("--embedding"), + + /** Enable reranking endpoint on server. */ + RERANKING("--reranking"), + + /** Load LoRA adapters without applying them (apply later via POST /lora-adapters). */ + LORA_INIT_WITHOUT_APPLY("--lora-init-without-apply"), + + /** Disable logging. */ + LOG_DISABLE("--log-disable"), + + /** Set verbosity level to infinity (log all messages). */ + VERBOSE("--verbose"), + + /** Enable prefix in log messages. */ + LOG_PREFIX("--log-prefix"), + + /** Enable timestamps in log messages. */ + LOG_TIMESTAMPS("--log-timestamps"), + + /** Enable Jinja templating for chat templates. */ + JINJA("--jinja"), + + /** Only load the vocabulary for tokenization; no weights are loaded. */ + VOCAB_ONLY("--vocab-only"), + + /** Enable a single unified KV buffer shared across all sequences. */ + KV_UNIFIED("--kv-unified"), + + /** Disable the unified KV buffer. */ + NO_KV_UNIFIED("--no-kv-unified"), + + /** Enable saving and clearing idle slots when a new task starts. */ + CLEAR_IDLE("--clear-idle"), + + /** Disable saving and clearing idle slots. */ + NO_CLEAR_IDLE("--no-clear-idle"); + + private final String cliFlag; + + ModelFlag(String cliFlag) { + this.cliFlag = cliFlag; + } + + /** + * Returns the CLI argument string for this flag (e.g. {@code "--flash-attn"}). + * + * @return the CLI flag string + */ + public String getCliFlag() { + return cliFlag; + } +} diff --git a/src/main/java/de/kherud/llama/args/NumaStrategy.java b/src/main/java/de/kherud/llama/args/NumaStrategy.java index fa7a61b0..c0bd9682 100644 --- a/src/main/java/de/kherud/llama/args/NumaStrategy.java +++ b/src/main/java/de/kherud/llama/args/NumaStrategy.java @@ -1,8 +1,22 @@ package de.kherud.llama.args; -public enum NumaStrategy { +/** + * NUMA optimization strategy for {@code --numa}. + */ +public enum NumaStrategy implements CliArg { - DISTRIBUTE, - ISOLATE, - NUMACTL + DISTRIBUTE("distribute"), + ISOLATE("isolate"), + NUMACTL("numactl"); + + private final String argValue; + + NumaStrategy(String argValue) { + this.argValue = argValue; + } + + @Override + public String getArgValue() { + return argValue; + } } diff --git a/src/main/java/de/kherud/llama/args/PoolingType.java b/src/main/java/de/kherud/llama/args/PoolingType.java index f582898f..e38d13f6 100644 --- a/src/main/java/de/kherud/llama/args/PoolingType.java +++ b/src/main/java/de/kherud/llama/args/PoolingType.java @@ -17,7 +17,7 @@ * @see * llama.cpp b8609 – include/llama.h: {@code llama_pooling_type} enum */ -public enum PoolingType { +public enum PoolingType implements CliArg { /** * Use the model's built-in default pooling type. diff --git a/src/main/java/de/kherud/llama/args/RopeScalingType.java b/src/main/java/de/kherud/llama/args/RopeScalingType.java index 138d05be..ae6bdc0a 100644 --- a/src/main/java/de/kherud/llama/args/RopeScalingType.java +++ b/src/main/java/de/kherud/llama/args/RopeScalingType.java @@ -1,6 +1,6 @@ package de.kherud.llama.args; -public enum RopeScalingType { +public enum RopeScalingType implements CliArg { UNSPECIFIED("unspecified"), NONE("none"), diff --git a/src/main/java/de/kherud/llama/args/Sampler.java b/src/main/java/de/kherud/llama/args/Sampler.java index 564a2e6f..21f9abcd 100644 --- a/src/main/java/de/kherud/llama/args/Sampler.java +++ b/src/main/java/de/kherud/llama/args/Sampler.java @@ -1,15 +1,28 @@ package de.kherud.llama.args; -public enum Sampler { - - DRY, - TOP_K, - TOP_P, - TYP_P, - MIN_P, - TEMPERATURE, - XTC, - INFILL, - PENALTIES +/** + * Sampling algorithm for {@code --samplers} (CLI) and the {@code "samplers"} JSON field. + */ +public enum Sampler implements CliArg { + DRY("dry"), + TOP_K("top_k"), + TOP_P("top_p"), + TYP_P("typ_p"), + MIN_P("min_p"), + TEMPERATURE("temperature"), + XTC("xtc"), + INFILL("infill"), + PENALTIES("penalties"); + + private final String argValue; + + Sampler(String argValue) { + this.argValue = argValue; + } + + @Override + public String getArgValue() { + return argValue; + } } diff --git a/src/main/java/de/kherud/llama/json/ChatResponseParser.java b/src/main/java/de/kherud/llama/json/ChatResponseParser.java new file mode 100644 index 00000000..ce7ce230 --- /dev/null +++ b/src/main/java/de/kherud/llama/json/ChatResponseParser.java @@ -0,0 +1,89 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; + +/** + * Pure JSON transforms for OAI-compatible chat completion responses. + * + *

    All methods are stateless and have zero dependency on JNI, native libraries, or llama + * model state — they can be tested with JSON string literals alone (see + * {@code ChatResponseParserTest}). + * + *

    The native server produces an OAI-compatible chat completion JSON: + *

    {@code
    + * {
    + *   "id": "chatcmpl-...",
    + *   "object": "chat.completion",
    + *   "choices": [
    + *     {
    + *       "index": 0,
    + *       "message": {"role": "assistant", "content": "Hello!"},
    + *       "finish_reason": "stop"
    + *     }
    + *   ],
    + *   "usage": {"prompt_tokens": 12, "completion_tokens": 5, "total_tokens": 17}
    + * }
    + * }
    + */ +public class ChatResponseParser { + + /** Shared Jackson mapper; thread-safe and reused across all instances. */ + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * Extract the assistant's reply text from an OAI chat completion JSON string. + * Navigates {@code choices[0].message.content} via Jackson. + * + *

    Returns an empty string when: the JSON is malformed, {@code choices} is absent + * or empty, or {@code content} is null/absent. + * + * @param json OAI-compatible chat completion JSON string + * @return the assistant content string, or {@code ""} on any failure + */ + public String extractChoiceContent(String json) { + try { + return extractChoiceContent(OBJECT_MAPPER.readTree(json)); + } catch (IOException e) { + return ""; + } + } + + /** + * Extract the assistant's reply text from a pre-parsed OAI chat completion node. + * Navigates {@code choices[0].message.content} via Jackson path API. + * + * @param node pre-parsed OAI chat completion response node + * @return the assistant content string, or {@code ""} if absent + */ + public String extractChoiceContent(JsonNode node) { + return node.path("choices").path(0).path("message").path("content").asText(""); + } + + /** + * Read a numeric usage field from the {@code "usage"} object in a chat completion node. + * Common field names: {@code "prompt_tokens"}, {@code "completion_tokens"}, + * {@code "total_tokens"}. + * + * @param node the parsed chat completion response + * @param field the field name within {@code "usage"} + * @return the integer value, or {@code 0} if the field or the {@code "usage"} object is absent + */ + public int extractUsageField(JsonNode node, String field) { + return node.path("usage").path(field).asInt(0); + } + + /** + * Count the number of choices returned in the response. + * Returns {@code 0} when the {@code "choices"} array is absent or not an array. + * + * @param node pre-parsed OAI chat completion response node + * @return the number of choices, or {@code 0} if absent + */ + public int countChoices(JsonNode node) { + JsonNode choices = node.path("choices"); + return choices.isArray() ? choices.size() : 0; + } +} diff --git a/src/main/java/de/kherud/llama/json/CompletionResponseParser.java b/src/main/java/de/kherud/llama/json/CompletionResponseParser.java new file mode 100644 index 00000000..61591b01 --- /dev/null +++ b/src/main/java/de/kherud/llama/json/CompletionResponseParser.java @@ -0,0 +1,117 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.kherud.llama.LlamaOutput; +import de.kherud.llama.StopReason; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Pure JSON transforms for native completion/streaming responses. + * + *

    All methods are stateless and have zero dependency on JNI, native libraries, or llama + * model state — they can be tested with JSON string literals alone (see + * {@code CompletionResponseParserTest}). + * + *

    The native server produces one JSON object per streamed token: + *

    {@code
    + * {
    + *   "content": "Hello",
    + *   "stop": false,
    + *   "stop_type": "none",
    + *   "completion_probabilities": [
    + *     {"token": "Hello", "bytes": [...], "id": 15043, "prob": 0.82,
    + *      "top_probs": [{"token": "Hi", "bytes": [...], "id": 9932, "prob": 0.1}]}
    + *   ]
    + * }
    + * }
    + * + *

    This is the Java analogue of {@code json_helpers.hpp} in the C++ layer. + */ +public class CompletionResponseParser { + + /** Shared Jackson mapper; thread-safe and reused across all instances. */ + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * Parse a {@link LlamaOutput} from a raw JSON string returned by the native + * {@code receiveCompletionJson} method. Delegates to {@link #parse(JsonNode)} after + * a single {@code readTree} call so the string is parsed only once. + * + * @param json raw JSON string from the native completion response + * @return parsed {@link LlamaOutput}; empty output on parse failure + */ + public LlamaOutput parse(String json) { + try { + return parse(OBJECT_MAPPER.readTree(json)); + } catch (IOException e) { + return new LlamaOutput("", Collections.emptyMap(), false, StopReason.NONE); + } + } + + /** + * Parse a {@link LlamaOutput} from a pre-parsed {@link JsonNode}. + * Callers that already hold a parsed node should prefer this overload to avoid re-parsing. + * + * @param node pre-parsed completion response node + * @return parsed {@link LlamaOutput} + */ + public LlamaOutput parse(JsonNode node) { + String content = extractContent(node); + boolean stop = node.path("stop").asBoolean(false); + Map probabilities = parseProbabilities(node); + StopReason stopReason = stop ? StopReason.fromStopType(node.path("stop_type").asText("")) : StopReason.NONE; + return new LlamaOutput(content, probabilities, stop, stopReason); + } + + /** + * Extract the {@code "content"} string from a completion response node. + * Returns an empty string if the field is absent. + * + * @param node completion response node + * @return the content string, or {@code ""} if absent + */ + public String extractContent(JsonNode node) { + return node.path("content").asText(""); + } + + /** + * Parse the {@code completion_probabilities} array into a {@code token → probability} map. + * + *

    Each array entry carries the generated token and either a {@code "prob"} value + * (post-sampling mode) or {@code "logprob"} (pre-sampling mode). The nested + * {@code top_probs}/{@code top_logprobs} arrays are invisible at the outer entry level + * and do not interfere with field lookup. + * + *

    Returns an empty map when the field is absent or the array is empty. + * Requires {@code InferenceParameters#setNProbs(int)} to be configured before inference. + * + * @param root the top-level completion response node + * @return map from token string to probability; empty when no probability data is present + */ + public Map parseProbabilities(JsonNode root) { + JsonNode array = root.path("completion_probabilities"); + if (!array.isArray() || array.size() == 0) { + return Collections.emptyMap(); + } + Map result = new HashMap(); + for (JsonNode entry : array) { + String token = entry.path("token").asText(""); + if (token.isEmpty()) continue; + + // "prob" (post-sampling) or "logprob" (pre-sampling) + JsonNode probNode = entry.path("prob"); + if (probNode.isMissingNode() || probNode.isNull()) { + probNode = entry.path("logprob"); + } + if (probNode.isMissingNode() || probNode.isNull()) continue; + + result.put(token, (float) probNode.asDouble(0.0)); + } + return result.isEmpty() ? Collections.emptyMap() : result; + } +} diff --git a/src/main/java/de/kherud/llama/json/ParameterJsonSerializer.java b/src/main/java/de/kherud/llama/json/ParameterJsonSerializer.java new file mode 100644 index 00000000..b09d8749 --- /dev/null +++ b/src/main/java/de/kherud/llama/json/ParameterJsonSerializer.java @@ -0,0 +1,244 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import de.kherud.llama.Pair; +import de.kherud.llama.args.Sampler; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** + * Pure JSON builders for inference request parameters. + * + *

    All methods are stateless and have zero dependency on JNI, native libraries, or llama + * model state — they can be tested with plain Java values alone (see + * {@code ParameterJsonSerializerTest}). + * + *

    Methods return Jackson {@link ArrayNode} or {@link ObjectNode}. Callers that need a JSON + * string (e.g. callers in {@code JsonParameters}) call {@code node.toString()}. + * + *

    This class replaces hand-rolled {@code StringBuilder} loops and the + * {@code org.json}-derived {@code toJsonString()} escaper previously embedded in + * {@code JsonParameters}. + */ +public class ParameterJsonSerializer { + + /** Shared Jackson mapper; thread-safe and reused across all instances. */ + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // ------------------------------------------------------------------ + // String escaping + // ------------------------------------------------------------------ + + /** + * Serialize a Java string to a quoted, properly escaped JSON string literal + * (e.g. {@code "hello\nworld"} → {@code "\"hello\\nworld\""}). + * Returns {@code "null"} for a {@code null} input. + * + *

    Replaces the hand-rolled {@code toJsonString()} method in {@code JsonParameters}. + * + * @param value the Java string to serialize, or {@code null} + * @return a JSON string literal, or {@code "null"} if the input is {@code null} + */ + public String toJsonString(String value) { + if (value == null) return "null"; + try { + return OBJECT_MAPPER.writeValueAsString(value); + } catch (JsonProcessingException e) { + return "null"; + } + } + + // ------------------------------------------------------------------ + // Message array + // ------------------------------------------------------------------ + + /** + * Build an OAI-compatible {@code messages} array node. + * + *

    An optional system message is prepended when non-null and non-empty. + * Each message in {@code messages} must have role {@code "user"} or {@code "assistant"}. + * + * @param systemMessage optional system prompt; skipped when {@code null} or empty + * @param messages list of user/assistant message pairs (role as key, content as value) + * @return a Jackson {@link ArrayNode} of {@code {"role", "content"}} message objects + * @throws IllegalArgumentException if any message has an invalid role + */ + public ArrayNode buildMessages(String systemMessage, List> messages) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + if (systemMessage != null && !systemMessage.isEmpty()) { + ObjectNode sys = OBJECT_MAPPER.createObjectNode(); + sys.put("role", "system"); + sys.put("content", systemMessage); + arr.add(sys); + } + for (Pair message : messages) { + String role = message.getKey(); + String content = message.getValue(); + if (!"user".equals(role) && !"assistant".equals(role)) { + throw new IllegalArgumentException( + "Invalid role: " + role + ". Role must be 'user' or 'assistant'."); + } + ObjectNode msg = OBJECT_MAPPER.createObjectNode(); + msg.put("role", role); + msg.put("content", content); + arr.add(msg); + } + return arr; + } + + // ------------------------------------------------------------------ + // Simple array builders + // ------------------------------------------------------------------ + + /** + * Build a JSON string array from the given stop strings + * (e.g. {@code ["<|endoftext|>", "\n"]}). + * + * @param stops one or more stop strings + * @return a Jackson {@link ArrayNode} of stop string values + */ + public ArrayNode buildStopStrings(String... stops) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (String stop : stops) arr.add(stop); + return arr; + } + + /** + * Build a JSON string array from the given sampler sequence + * (e.g. {@code ["top_k", "top_p", "temperature"]}). + * + * @param samplers one or more samplers in the desired order + * @return a Jackson {@link ArrayNode} of sampler name strings + */ + public ArrayNode buildSamplers(Sampler... samplers) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (Sampler sampler : samplers) { + arr.add(sampler.getArgValue()); + } + return arr; + } + + /** + * Build a JSON integer array from a primitive {@code int[]} + * (used for penalty-prompt token sequences). + * + * @param values the token IDs to include + * @return a Jackson {@link ArrayNode} of integer values + */ + public ArrayNode buildIntArray(int[] values) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (int v : values) arr.add(v); + return arr; + } + + // ------------------------------------------------------------------ + // Logit-bias pair arrays — [[key, value], ...] + // ------------------------------------------------------------------ + + /** + * Build a logit-bias array for integer token IDs: + * {@code [[15043, 1.0], [50256, -0.5]]}. + * + * @param biases map from token ID to logit bias value + * @return a Jackson {@link ArrayNode} of {@code [tokenId, biasValue]} pairs + */ + public ArrayNode buildTokenIdBiasArray(Map biases) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (Map.Entry entry : biases.entrySet()) { + ArrayNode pair = OBJECT_MAPPER.createArrayNode(); + pair.add(entry.getKey()); + pair.add(entry.getValue()); + arr.add(pair); + } + return arr; + } + + /** + * Build a logit-bias array for string tokens: + * {@code [["Hello", 1.0], [" world", -0.5]]}. + * + * @param biases map from token string to logit bias value + * @return a Jackson {@link ArrayNode} of {@code ["token", biasValue]} pairs + */ + public ArrayNode buildTokenStringBiasArray(Map biases) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (Map.Entry entry : biases.entrySet()) { + ArrayNode pair = OBJECT_MAPPER.createArrayNode(); + pair.add(entry.getKey()); + pair.add(entry.getValue()); + arr.add(pair); + } + return arr; + } + + /** + * Build a disable-token array for integer token IDs: + * {@code [[15043, false], [50256, false]]}. + * + * @param ids collection of integer token IDs to disable + * @return a Jackson {@link ArrayNode} of {@code [tokenId, false]} pairs + */ + public ArrayNode buildDisableTokenIdArray(Collection ids) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (Integer id : ids) { + ArrayNode pair = OBJECT_MAPPER.createArrayNode(); + pair.add(id); + pair.add(false); + arr.add(pair); + } + return arr; + } + + /** + * Build a disable-token array for string tokens: + * {@code [["Hello", false], [" world", false]]}. + * + * @param tokens collection of token strings to disable + * @return a Jackson {@link ArrayNode} of {@code ["token", false]} pairs + */ + public ArrayNode buildDisableTokenStringArray(Collection tokens) { + ArrayNode arr = OBJECT_MAPPER.createArrayNode(); + for (String token : tokens) { + ArrayNode pair = OBJECT_MAPPER.createArrayNode(); + pair.add(token); + pair.add(false); + arr.add(pair); + } + return arr; + } + + // ------------------------------------------------------------------ + // Object with pre-serialized JSON values + // ------------------------------------------------------------------ + + /** + * Build a JSON object where each map value is a pre-serialized JSON string + * (not a plain Java string). For example, a map entry {@code ("enable_thinking", "true")} + * produces {@code {"enable_thinking": true}}, not {@code {"enable_thinking": "true"}}. + * + *

    Used for {@code chat_template_kwargs} which stores raw JSON values. + * If a value cannot be parsed as JSON, it is stored as a JSON string literal. + * + * @param map map of key to pre-serialized JSON value strings + * @return a Jackson {@link ObjectNode} with each value embedded as a parsed JSON node + */ + public ObjectNode buildRawValueObject(Map map) { + ObjectNode node = OBJECT_MAPPER.createObjectNode(); + for (Map.Entry entry : map.entrySet()) { + try { + JsonNode val = OBJECT_MAPPER.readTree(entry.getValue()); + node.set(entry.getKey(), val); + } catch (IOException e) { + node.put(entry.getKey(), entry.getValue()); + } + } + return node; + } +} diff --git a/src/main/java/de/kherud/llama/json/RerankResponseParser.java b/src/main/java/de/kherud/llama/json/RerankResponseParser.java new file mode 100644 index 00000000..87fd2e13 --- /dev/null +++ b/src/main/java/de/kherud/llama/json/RerankResponseParser.java @@ -0,0 +1,67 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.kherud.llama.Pair; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Pure JSON transforms for native rerank responses. + * + *

    All methods are stateless and have zero dependency on JNI, native libraries, or llama + * model state — they can be tested with JSON string literals alone (see + * {@code RerankResponseParserTest}). + * + *

    The native server produces a JSON array of reranked results: + *

    {@code
    + * [
    + *   {"document": "The quick brown fox", "index": 0, "score": 0.92},
    + *   {"document": "Another document",    "index": 1, "score": 0.43}
    + * ]
    + * }
    + */ +public class RerankResponseParser { + + /** Shared Jackson mapper; thread-safe and reused across all instances. */ + public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + /** + * Parse rerank results from a raw JSON array string. Delegates to {@link #parse(JsonNode)} + * after a single {@code readTree} call. + * + * @param json raw JSON array string from the native rerank response + * @return list of document/score pairs; empty list on parse failure or empty array + */ + public List> parse(String json) { + try { + return parse(OBJECT_MAPPER.readTree(json)); + } catch (IOException e) { + return Collections.emptyList(); + } + } + + /** + * Parse rerank results from a pre-parsed {@link JsonNode} array. + * Each element must contain {@code "document"} (string) and {@code "score"} (number). + * Returns an empty list when the node is not an array or is empty. + * + * @param arr pre-parsed {@link JsonNode} array of rerank result objects + * @return list of document/score pairs; empty list if the node is not an array or is empty + */ + public List> parse(JsonNode arr) { + if (!arr.isArray() || arr.size() == 0) { + return Collections.emptyList(); + } + List> results = new ArrayList>(); + for (JsonNode entry : arr) { + String doc = entry.path("document").asText(""); + float score = (float) entry.path("score").asDouble(0.0); + results.add(new Pair(doc, score)); + } + return results; + } +} diff --git a/src/test/java/de/kherud/llama/ChatAdvancedTest.java b/src/test/java/de/kherud/llama/ChatAdvancedTest.java index 4b75e9f0..bffc019e 100644 --- a/src/test/java/de/kherud/llama/ChatAdvancedTest.java +++ b/src/test/java/de/kherud/llama/ChatAdvancedTest.java @@ -8,6 +8,7 @@ import de.kherud.llama.args.MiroStat; import de.kherud.llama.args.Sampler; +import de.kherud.llama.json.CompletionResponseParser; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Assume; @@ -42,6 +43,7 @@ public class ChatAdvancedTest { private static final int N_PREDICT = 10; + private final CompletionResponseParser completionParser = new CompletionResponseParser(); private static final String SIMPLE_PROMPT = "def hello():"; private static LlamaModel model; @@ -144,7 +146,7 @@ public void testSetNProbsStreamingJsonHasProbabilities() { while (!done) { String json = model.receiveCompletionJson(taskId); Assert.assertNotNull("receiveCompletionJson must not be null", json); - LlamaOutput output = LlamaOutput.fromJson(json); + LlamaOutput output = completionParser.parse(json); if (json.contains("\"completion_probabilities\"")) { foundProbabilities = true; } @@ -338,7 +340,7 @@ public void testRequestCompletionDirectStreaming() { while (!stopped) { String json = model.receiveCompletionJson(taskId); Assert.assertNotNull("receiveCompletionJson must not return null", json); - LlamaOutput output = LlamaOutput.fromJson(json); + LlamaOutput output = completionParser.parse(json); sb.append(output.text); tokens++; if (output.stop) { diff --git a/src/test/java/de/kherud/llama/ChatScenarioTest.java b/src/test/java/de/kherud/llama/ChatScenarioTest.java index 9e21128b..17d6decd 100644 --- a/src/test/java/de/kherud/llama/ChatScenarioTest.java +++ b/src/test/java/de/kherud/llama/ChatScenarioTest.java @@ -7,6 +7,8 @@ import java.util.List; import de.kherud.llama.args.PoolingType; +import de.kherud.llama.json.ChatResponseParser; +import de.kherud.llama.json.CompletionResponseParser; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Assume; @@ -41,6 +43,8 @@ public class ChatScenarioTest { private static final int N_PREDICT = 10; + private final CompletionResponseParser completionParser = new CompletionResponseParser(); + private final ChatResponseParser chatParser = new ChatResponseParser(); private static LlamaModel model; @@ -100,6 +104,50 @@ public void testChatCompleteResponseJsonStructure() { response.contains("\"assistant\"") || response.contains("assistant")); } + /** + * chatCompleteText() must return only the assistant's plain text, not the OAI JSON wrapper. + * The result should be non-empty and must NOT contain the JSON key "choices". + */ + @Test + public void testChatCompleteTextReturnsPlainString() { + List> messages = new ArrayList<>(); + messages.add(new Pair<>("user", "Say the word OK.")); + + InferenceParameters params = new InferenceParameters("") + .setMessages(null, messages) + .setNPredict(N_PREDICT) + .setSeed(42) + .setTemperature(0.0f); + + String text = model.chatCompleteText(params); + + Assert.assertNotNull(text); + Assert.assertFalse("chatCompleteText must not be empty", text.isEmpty()); + Assert.assertFalse("chatCompleteText must not contain OAI JSON wrapper", text.contains("\"choices\"")); + } + + /** + * chatCompleteText() must return the same content as extracting choices[0].message.content + * from the raw chatComplete() JSON. + */ + @Test + public void testChatCompleteTextMatchesChatCompleteContent() { + List> messages = new ArrayList<>(); + messages.add(new Pair<>("user", "What is 2 plus 2?")); + + InferenceParameters params = new InferenceParameters("") + .setMessages("You are a helpful assistant.", messages) + .setNPredict(N_PREDICT) + .setSeed(42) + .setTemperature(0.0f); + + String rawJson = model.chatComplete(params); + String text = model.chatCompleteText(params); + + String expected = chatParser.extractChoiceContent(rawJson); + Assert.assertEquals("chatCompleteText must match choices[0].message.content", expected, text); + } + /** * handleChatCompletions can be called directly with a raw JSON string. * Verify the response contains valid OAI chat completion fields. @@ -147,7 +195,7 @@ public void testRequestChatCompletionDirectStreaming() { while (!stopped) { String json = model.receiveCompletionJson(taskId); Assert.assertNotNull("receiveCompletionJson must not return null", json); - LlamaOutput output = LlamaOutput.fromJson(json); + LlamaOutput output = completionParser.parse(json); sb.append(output.text); tokens++; if (output.stop) { @@ -226,7 +274,7 @@ public void testChatCompleteWithStopString() { .setSeed(42) .setTemperature(0.0f); String unJson = model.chatComplete(unconstrained); - String unContent = extractChoiceContent(unJson); + String unContent = chatParser.extractChoiceContent(unJson); // Stopped at "3" InferenceParameters stopped = new InferenceParameters("") @@ -236,7 +284,7 @@ public void testChatCompleteWithStopString() { .setTemperature(0.0f) .setStopStrings("4"); String stJson = model.chatComplete(stopped); - String stContent = extractChoiceContent(stJson); + String stContent = chatParser.extractChoiceContent(stJson); Assert.assertNotNull("Stop-string response must not be null", stJson); // Content with stop should be shorter (or at most equal) @@ -312,7 +360,7 @@ public void testChatCompleteMultiTurnThreeTurns() { .setTemperature(0.0f); String json = model.chatComplete(params); - String content = extractChoiceContent(json); + String content = chatParser.extractChoiceContent(json); Assert.assertNotNull("Turn " + turn + ": response must not be null", json); Assert.assertFalse("Turn " + turn + ": content must not be empty", content.isEmpty()); @@ -421,8 +469,8 @@ public void testHandleInfillDirect() { String prefix = "def greet(name):\n \"\"\" "; String suffix = "\n return greeting\n"; - String json = "{\"input_prefix\": " + toJsonString(prefix) + - ", \"input_suffix\": " + toJsonString(suffix) + + String json = "{\"input_prefix\": " + jsonStr(prefix) + + ", \"input_suffix\": " + jsonStr(suffix) + ", \"n_predict\": " + N_PREDICT + ", \"seed\": 42, \"temperature\": 0.0}"; @@ -493,8 +541,8 @@ public void testHandleTokenizeWithSpecialTokens() { Assert.assertNotNull(withoutSpecial); Assert.assertTrue("Both responses must contain 'tokens'", withSpecial.contains("\"tokens\"")); - int countWith = countTokensInJson(withSpecial); - int countWithout = countTokensInJson(withoutSpecial); + int countWith = tokenCount(withSpecial); + int countWithout = tokenCount(withoutSpecial); Assert.assertTrue( "addSpecial=true should produce at least as many tokens as addSpecial=false " + @@ -522,7 +570,7 @@ public void testHandleDetokenizeRoundTrip() { Assert.assertTrue("handleDetokenize response must contain 'content'", response.contains("\"content\"")); // Extract the detokenized text (simple search for content field value) - String detokenized = LlamaOutput.getContentFromJson(response); + String detokenized = completionParser.parse(response).text; // The tokenizer typically prepends a space; check the meaningful content Assert.assertTrue( "Detokenized text should contain original content (got: '" + detokenized + "')", @@ -590,7 +638,7 @@ public void testChatCompleteNPredictOne() { String response = model.chatComplete(params); Assert.assertNotNull(response); Assert.assertFalse("nPredict=1 must still return a non-empty response", response.isEmpty()); - String content = extractChoiceContent(response); + String content = chatParser.extractChoiceContent(response); // Content should be at most one token long — just verify it doesn't crash Assert.assertNotNull("Content must not be null for nPredict=1", content); } @@ -640,52 +688,24 @@ public void testGenerateChatStopFlagOnFinalToken() { // Helpers // ------------------------------------------------------------------ - /** - * Extract the assistant's content string from an OAI-compatible chat - * completion JSON response. - * Expected structure: {"choices":[{"message":{"role":"assistant","content":"..."}}]} - */ - private static String extractChoiceContent(String json) { - // Find choices[0].message.content - int choicesIdx = json.indexOf("\"choices\""); - if (choicesIdx < 0) { - // Fallback: try plain "content" field (non-OAI response shape) - return LlamaOutput.getContentFromJson(json); - } - // Find "content" after "choices" - int contentIdx = json.indexOf("\"content\"", choicesIdx); - if (contentIdx < 0) { - return ""; + /** Serialize a string to a JSON string literal using Jackson. */ + private static String jsonStr(String s) { + try { + return CompletionResponseParser.OBJECT_MAPPER.writeValueAsString(s); + } catch (Exception e) { + return "null"; } - // Reuse LlamaOutput's JSON extractor on a substring starting at "content" - return LlamaOutput.getContentFromJson(json.substring(contentIdx)); } - /** - * Count the number of comma-separated elements in the JSON array value - * of the "tokens" field. This is a best-effort heuristic — it works for - * the simple integer-array format returned by handleTokenize. - */ - private static int countTokensInJson(String json) { - int tokensIdx = json.indexOf("\"tokens\""); - if (tokensIdx < 0) return 0; - int openBracket = json.indexOf('[', tokensIdx); - int closeBracket = json.indexOf(']', openBracket); - if (openBracket < 0 || closeBracket < 0) return 0; - String array = json.substring(openBracket + 1, closeBracket).trim(); - if (array.isEmpty()) return 0; - return array.split(",").length; - } - - /** Minimal JSON string escaping for test helper strings. */ - private static String toJsonString(String s) { - if (s == null) return "null"; - return "\"" + s - .replace("\\", "\\\\") - .replace("\"", "\\\"") - .replace("\n", "\\n") - .replace("\r", "\\r") - .replace("\t", "\\t") - + "\""; + /** Count elements in the {@code "tokens"} array of a tokenize response. */ + private static int tokenCount(String json) { + try { + com.fasterxml.jackson.databind.JsonNode node = + CompletionResponseParser.OBJECT_MAPPER.readTree(json); + com.fasterxml.jackson.databind.JsonNode arr = node.path("tokens"); + return arr.isArray() ? arr.size() : 0; + } catch (Exception e) { + return 0; + } } } diff --git a/src/test/java/de/kherud/llama/InferenceParametersTest.java b/src/test/java/de/kherud/llama/InferenceParametersTest.java index d63acef9..211706a2 100644 --- a/src/test/java/de/kherud/llama/InferenceParametersTest.java +++ b/src/test/java/de/kherud/llama/InferenceParametersTest.java @@ -276,7 +276,7 @@ public void testSetStopStringsSingle() { @Test public void testSetStopStringsMultiple() { InferenceParameters params = new InferenceParameters("").setStopStrings("stop1", "stop2"); - assertEquals("[\"stop1\", \"stop2\"]", params.parameters.get("stop")); + assertEquals("[\"stop1\",\"stop2\"]", params.parameters.get("stop")); } @Test @@ -299,7 +299,7 @@ public void testSetSamplersSingle() { @Test public void testSetSamplersMultiple() { InferenceParameters params = new InferenceParameters("").setSamplers(Sampler.TOP_K, Sampler.TOP_P, Sampler.TEMPERATURE); - assertEquals("[\"top_k\", \"top_p\", \"temperature\"]", params.parameters.get("samplers")); + assertEquals("[\"top_k\",\"top_p\",\"temperature\"]", params.parameters.get("samplers")); } @Test @@ -396,7 +396,7 @@ public void testDisableTokensEmpty() { @Test public void testSetPenaltyPromptTokenIds() { InferenceParameters params = new InferenceParameters("").setPenaltyPrompt(new int[]{1, 2, 3}); - assertEquals("[1, 2, 3]", params.parameters.get("penalty_prompt")); + assertEquals("[1,2,3]", params.parameters.get("penalty_prompt")); } @Test @@ -532,11 +532,12 @@ public void testToJsonStringNull() { } @Test - public void testToJsonStringEscapesSlashAfterLt() { - // '"); String value = params.parameters.get("prompt"); - assertTrue(value.contains("<\\/")); + assertTrue(value.contains("")); + assertFalse(value.contains("<\\/")); } // ------------------------------------------------------------------------- diff --git a/src/test/java/de/kherud/llama/LlamaModelTest.java b/src/test/java/de/kherud/llama/LlamaModelTest.java index 0bd34ccd..80adfa34 100644 --- a/src/test/java/de/kherud/llama/LlamaModelTest.java +++ b/src/test/java/de/kherud/llama/LlamaModelTest.java @@ -183,6 +183,30 @@ public void testCancelGenerating() { ); } + /** + * LlamaIterable implements AutoCloseable. Breaking out of a for-each loop early inside a + * try-with-resources block must not throw and must not leave the task slot hanging — the + * iterator's close() cancels the native task automatically. + */ + @Test + public void testGenerateAutoCloseOnEarlyBreak() throws Exception { + InferenceParameters params = new InferenceParameters(prefix).setNPredict(nPredict); + + int collected = 0; + try (LlamaIterable iterable = model.generate(params)) { + for (LlamaOutput ignored : iterable) { + collected++; + break; // exit before stop token + } + } // close() must cancel without throwing + + Assert.assertTrue("Should have collected at least one token before break", collected >= 1); + + // The model must still be usable after an early-exit close + String result = model.complete(new InferenceParameters(prefix).setNPredict(5)); + Assert.assertNotNull("Model must be functional after autoclosed iterator", result); + } + @Test public void testEmbedding() { float[] embedding = model.embed(prefix); @@ -946,6 +970,15 @@ public void testGetModelMeta() throws LlamaException { Assert.assertTrue("modalities field must be present", meta.asJson().has("modalities")); Assert.assertTrue("vocab_type field must be present", meta.asJson().has("vocab_type")); + // Architecture and name from GGUF general.* metadata + String architecture = meta.getArchitecture(); + Assert.assertNotNull("getArchitecture() must not return null", architecture); + Assert.assertFalse("CodeLlama GGUF must have general.architecture set", architecture.isEmpty()); + + // general.name may or may not be present in the GGUF; just verify the getter does not throw + String modelName = meta.getModelName(); + Assert.assertNotNull("getModelName() must not return null", modelName); + // Round-trip: toString() must produce valid compact JSON containing all top-level keys String json = meta.toString(); Assert.assertNotNull(json); @@ -958,10 +991,7 @@ public void testGetModelMeta() throws LlamaException { Assert.assertTrue(json.contains("\"modalities\"")); Assert.assertTrue(json.contains("\"vision\"")); Assert.assertTrue(json.contains("\"audio\"")); - - // Fill in the expected value from the failure message and re-run to pin exact output: - Assert.assertEquals("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," - + "\"n_embd\":4096,\"n_params\":6738546688,\"size\":2825274880," - + "\"modalities\":{\"vision\":false,\"audio\":false}}", json); + Assert.assertTrue(json.contains("\"architecture\"")); + Assert.assertTrue(json.contains("\"name\"")); } } diff --git a/src/test/java/de/kherud/llama/LlamaOutputTest.java b/src/test/java/de/kherud/llama/LlamaOutputTest.java index e5fa8529..4a4bf730 100644 --- a/src/test/java/de/kherud/llama/LlamaOutputTest.java +++ b/src/test/java/de/kherud/llama/LlamaOutputTest.java @@ -1,5 +1,6 @@ package de.kherud.llama; +import de.kherud.llama.json.CompletionResponseParser; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -9,27 +10,29 @@ import static org.junit.Assert.*; @ClaudeGenerated( - purpose = "Verify that LlamaOutput correctly stores text, the probability map and stop flag " + - "unchanged, and that toString() delegates to the text field." + purpose = "Verify that LlamaOutput correctly stores text, the probability map, stop flag, " + + "and stopReason, and that toString() delegates to the text field." ) public class LlamaOutputTest { + private final CompletionResponseParser parser = new CompletionResponseParser(); + @Test public void testTextFromString() { - LlamaOutput output = new LlamaOutput("hello", Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput("hello", Collections.emptyMap(), false, StopReason.NONE); assertEquals("hello", output.text); } @Test public void testEmptyText() { - LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false, StopReason.NONE); assertEquals("", output.text); } @Test public void testUtf8MultibyteText() { String original = "héllo wörld"; - LlamaOutput output = new LlamaOutput(original, Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput(original, Collections.emptyMap(), false, StopReason.NONE); assertEquals(original, output.text); } @@ -38,7 +41,7 @@ public void testProbabilitiesStored() { Map probs = new HashMap<>(); probs.put("hello", 0.9f); probs.put("world", 0.1f); - LlamaOutput output = new LlamaOutput("", probs, false); + LlamaOutput output = new LlamaOutput("", probs, false, StopReason.NONE); assertEquals(2, output.probabilities.size()); assertEquals(0.9f, output.probabilities.get("hello"), 0.0001f); assertEquals(0.1f, output.probabilities.get("world"), 0.0001f); @@ -46,38 +49,38 @@ public void testProbabilitiesStored() { @Test public void testEmptyProbabilities() { - LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false, StopReason.NONE); assertTrue(output.probabilities.isEmpty()); } @Test public void testStopFlagFalse() { - LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false, StopReason.NONE); assertFalse(output.stop); } @Test public void testStopFlagTrue() { - LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), true); + LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), true, StopReason.EOS); assertTrue(output.stop); } @Test public void testToStringReturnsText() { - LlamaOutput output = new LlamaOutput("generated text", Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput("generated text", Collections.emptyMap(), false, StopReason.NONE); assertEquals("generated text", output.toString()); } @Test public void testToStringEmptyText() { - LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false); + LlamaOutput output = new LlamaOutput("", Collections.emptyMap(), false, StopReason.NONE); assertEquals("", output.toString()); } @Test public void testFromJson() { String json = "{\"content\":\"hello world\",\"stop\":true}"; - LlamaOutput output = LlamaOutput.fromJson(json); + LlamaOutput output = parser.parse(json); assertEquals("hello world", output.text); assertTrue(output.stop); } @@ -85,14 +88,122 @@ public void testFromJson() { @Test public void testFromJsonWithEscapes() { String json = "{\"content\":\"line1\\nline2\\t\\\"quoted\\\"\",\"stop\":false}"; - LlamaOutput output = LlamaOutput.fromJson(json); + LlamaOutput output = parser.parse(json); assertEquals("line1\nline2\t\"quoted\"", output.text); assertFalse(output.stop); } + @Test + public void testFromJsonWithUnicodeEscape() { + String json = "{\"content\":\"caf\\u00e9\",\"stop\":false}"; + LlamaOutput output = parser.parse(json); + assertEquals("café", output.text); + assertFalse(output.stop); + } + + @Test + public void testFromJsonMalformedReturnsEmptyNonStop() { + LlamaOutput output = parser.parse("{not valid json"); + assertEquals("", output.text); + assertFalse(output.stop); + assertEquals(StopReason.NONE, output.stopReason); + assertTrue(output.probabilities.isEmpty()); + } + @Test public void testGetContentFromJsonEmpty() { String json = "{\"content\":\"\",\"stop\":true}"; - assertEquals("", LlamaOutput.getContentFromJson(json)); + assertEquals("", parser.parse(json).text); + } + + // --- parseProbabilities tests --- + + @Test + public void testProbabilitiesAbsentWhenNoProbsKey() { + String json = "{\"content\":\"hi\",\"stop\":true,\"stop_type\":\"eos\"}"; + LlamaOutput output = parser.parse(json); + assertTrue("No completion_probabilities key → empty map", output.probabilities.isEmpty()); + } + + @Test + public void testProbabilitiesParsedPostSampling() { + // post_sampling_probs=true → "prob" key + String json = "{\"content\":\"hi\",\"stop\":true,\"stop_type\":\"eos\"," + + "\"completion_probabilities\":[" + + "{\"token\":\"Hello\",\"bytes\":[72],\"id\":15043,\"prob\":0.82," + + "\"top_probs\":[{\"token\":\"Hi\",\"bytes\":[72],\"id\":9932,\"prob\":0.1}]}," + + "{\"token\":\" world\",\"bytes\":[32,119],\"id\":1917,\"prob\":0.65," + + "\"top_probs\":[{\"token\":\" World\",\"bytes\":[32,87],\"id\":2304,\"prob\":0.2}]}" + + "]}"; + LlamaOutput output = parser.parse(json); + assertEquals(2, output.probabilities.size()); + assertEquals(0.82f, output.probabilities.get("Hello"), 0.001f); + assertEquals(0.65f, output.probabilities.get(" world"), 0.001f); + } + + @Test + public void testProbabilitiesParsedPreSampling() { + // post_sampling_probs=false → "logprob" key + String json = "{\"content\":\"hi\",\"stop\":true,\"stop_type\":\"eos\"," + + "\"completion_probabilities\":[" + + "{\"token\":\"Hello\",\"bytes\":[72],\"id\":15043,\"logprob\":-0.2," + + "\"top_logprobs\":[{\"token\":\"Hi\",\"bytes\":[72],\"id\":9932,\"logprob\":-2.3}]}" + + "]}"; + LlamaOutput output = parser.parse(json); + assertEquals(1, output.probabilities.size()); + assertEquals(-0.2f, output.probabilities.get("Hello"), 0.001f); + } + + @Test + public void testProbabilitiesTokenWithEscapedChars() { + String json = "{\"content\":\"hi\",\"stop\":true,\"stop_type\":\"eos\"," + + "\"completion_probabilities\":[" + + "{\"token\":\"say \\\"yes\\\"\",\"bytes\":[],\"id\":1,\"prob\":0.5," + + "\"top_probs\":[]}" + + "]}"; + LlamaOutput output = parser.parse(json); + assertEquals(1, output.probabilities.size()); + assertEquals(0.5f, output.probabilities.get("say \"yes\""), 0.001f); + } + + // --- StopReason tests --- + + @Test + public void testStopReasonNoneOnIntermediateToken() { + LlamaOutput output = new LlamaOutput("token", Collections.emptyMap(), false, StopReason.NONE); + assertEquals(StopReason.NONE, output.stopReason); + } + + @Test + public void testStopReasonFromJsonEos() { + String json = "{\"content\":\"done\",\"stop\":true,\"stop_type\":\"eos\"}"; + LlamaOutput output = parser.parse(json); + assertTrue(output.stop); + assertEquals(StopReason.EOS, output.stopReason); + } + + @Test + public void testStopReasonFromJsonWord() { + String json = "{\"content\":\"done\",\"stop\":true,\"stop_type\":\"word\",\"stopping_word\":\"END\"}"; + LlamaOutput output = parser.parse(json); + assertTrue(output.stop); + assertEquals(StopReason.STOP_STRING, output.stopReason); + } + + @Test + public void testStopReasonFromJsonLimit() { + String json = "{\"content\":\"truncated\",\"stop\":true,\"stop_type\":\"limit\",\"truncated\":true}"; + LlamaOutput output = parser.parse(json); + assertTrue(output.stop); + assertEquals(StopReason.MAX_TOKENS, output.stopReason); + } + + @Test + public void testStopReasonNoneWhenStopFalse() { + String json = "{\"content\":\"partial\",\"stop\":false,\"stop_type\":\"eos\"}"; + LlamaOutput output = parser.parse(json); + assertFalse(output.stop); + // stopReason is NONE for non-final tokens regardless of stop_type + assertEquals(StopReason.NONE, output.stopReason); } } diff --git a/src/test/java/de/kherud/llama/ModelMetaTest.java b/src/test/java/de/kherud/llama/ModelMetaTest.java new file mode 100644 index 00000000..0b4b87ee --- /dev/null +++ b/src/test/java/de/kherud/llama/ModelMetaTest.java @@ -0,0 +1,118 @@ +package de.kherud.llama; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link ModelMeta} typed getters. + * Constructs {@code ModelMeta} directly from JSON strings — no native library or model file required. + */ +@ClaudeGenerated( + purpose = "Verify that ModelMeta typed getters map correctly from the underlying JsonNode, " + + "including the new architecture and name fields from GGUF general.* metadata." +) +public class ModelMetaTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private ModelMeta parse(String json) throws Exception { + return new ModelMeta(MAPPER.readTree(json)); + } + + @Test + public void testNumericGetters() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," + + "\"n_embd\":4096,\"n_params\":6738546688,\"size\":2825274880," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); + + assertEquals(1, meta.getVocabType()); + assertEquals(32016, meta.getNVocab()); + assertEquals(16384, meta.getNCtxTrain()); + assertEquals(4096, meta.getNEmbd()); + assertEquals(6738546688L, meta.getNParams()); + assertEquals(2825274880L, meta.getSize()); + } + + @Test + public void testModalityGetters() throws Exception { + ModelMeta textOnly = parse("{\"vocab_type\":1,\"n_vocab\":100,\"n_ctx_train\":4096," + + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"llama\",\"name\":\"\"}"); + assertFalse(textOnly.supportsVision()); + assertFalse(textOnly.supportsAudio()); + + ModelMeta multimodal = parse("{\"vocab_type\":1,\"n_vocab\":100,\"n_ctx_train\":4096," + + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + + "\"modalities\":{\"vision\":true,\"audio\":true}," + + "\"architecture\":\"gemma3\",\"name\":\"Gemma-3\"}"); + assertTrue(multimodal.supportsVision()); + assertTrue(multimodal.supportsAudio()); + } + + @Test + public void testGetArchitecture() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," + + "\"n_embd\":4096,\"n_params\":6738546688,\"size\":2825274880," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); + + assertEquals("llama", meta.getArchitecture()); + } + + @Test + public void testGetModelName() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," + + "\"n_embd\":4096,\"n_params\":6738546688,\"size\":2825274880," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"mistral\",\"name\":\"Mistral-7B-v0.1\"}"); + + assertEquals("Mistral-7B-v0.1", meta.getModelName()); + } + + @Test + public void testGetArchitectureEmptyWhenAbsent() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":100,\"n_ctx_train\":4096," + + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + + "\"modalities\":{\"vision\":false,\"audio\":false}}"); + + assertEquals("", meta.getArchitecture()); + } + + @Test + public void testGetModelNameEmptyWhenAbsent() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":100,\"n_ctx_train\":4096," + + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + + "\"modalities\":{\"vision\":false,\"audio\":false}}"); + + assertEquals("", meta.getModelName()); + } + + @Test + public void testGetArchitectureVariousModels() throws Exception { + for (String arch : new String[]{"llama", "gemma3", "mistral", "falcon", "phi3"}) { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":100,\"n_ctx_train\":4096," + + "\"n_embd\":512,\"n_params\":1000000,\"size\":500000," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"" + arch + "\",\"name\":\"\"}"); + assertEquals(arch, meta.getArchitecture()); + } + } + + @Test + public void testToStringContainsNewFields() throws Exception { + ModelMeta meta = parse("{\"vocab_type\":1,\"n_vocab\":32016,\"n_ctx_train\":16384," + + "\"n_embd\":4096,\"n_params\":6738546688,\"size\":2825274880," + + "\"modalities\":{\"vision\":false,\"audio\":false}," + + "\"architecture\":\"llama\",\"name\":\"CodeLlama-7B\"}"); + + String json = meta.toString(); + assertTrue(json.contains("\"architecture\"")); + assertTrue(json.contains("\"name\"")); + assertTrue(json.contains("\"llama\"")); + assertTrue(json.contains("\"CodeLlama-7B\"")); + } +} diff --git a/src/test/java/de/kherud/llama/StopReasonTest.java b/src/test/java/de/kherud/llama/StopReasonTest.java new file mode 100644 index 00000000..0f8af234 --- /dev/null +++ b/src/test/java/de/kherud/llama/StopReasonTest.java @@ -0,0 +1,66 @@ +package de.kherud.llama; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +import static org.junit.Assert.*; + +/** + * Round-trip tests for {@link StopReason}. + * + *

    The parameterised suite drives one test per enum constant: it obtains the + * constant's {@code "stop_type"} string via {@link StopReason#getStopType()} and + * verifies that feeding it back into {@link StopReason#fromStopType(String)} returns + * the original constant. The data provider is {@link StopReason#values()} so the + * suite automatically covers any future constant added to the enum. + * + *

    Edge cases (null, empty string, unknown value) are tested in separate + * {@code @Test} methods below the round-trip test. + */ +@RunWith(Parameterized.class) +public class StopReasonTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(StopReason.values()); + } + + private final StopReason reason; + + public StopReasonTest(StopReason reason) { + this.reason = reason; + } + + @Test + public void testRoundTrip() { + assertSame(reason, StopReason.fromStopType(reason.getStopType())); + } + + // ------------------------------------------------------------------ + // Edge cases — tested separately from the round-trip + // ------------------------------------------------------------------ + + @Test + public void testFromStopType_nullReturnsNone() { + assertSame(StopReason.NONE, StopReason.fromStopType(null)); + } + + @Test + public void testFromStopType_emptyStringReturnsNone() { + assertSame(StopReason.NONE, StopReason.fromStopType("")); + } + + @Test + public void testFromStopType_unknownReturnsNone() { + assertSame(StopReason.NONE, StopReason.fromStopType("something_else")); + } + + @Test + public void testEnumCount() { + assertEquals(4, StopReason.values().length); + } +} diff --git a/src/test/java/de/kherud/llama/args/CacheTypeTest.java b/src/test/java/de/kherud/llama/args/CacheTypeTest.java index 76bbee9d..1979db34 100644 --- a/src/test/java/de/kherud/llama/args/CacheTypeTest.java +++ b/src/test/java/de/kherud/llama/args/CacheTypeTest.java @@ -1,72 +1,62 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify CacheType enum values, count, and lowercase name convention used by ModelParameters.", - model = "claude-opus-4-6" -) +@RunWith(Parameterized.class) public class CacheTypeTest { - @Test - public void testEnumCount() { - assertEquals(9, CacheType.values().length); - } - - @Test - public void testF32() { - assertEquals("f32", CacheType.F32.name().toLowerCase()); - } - - @Test - public void testF16() { - assertEquals("f16", CacheType.F16.name().toLowerCase()); - } - - @Test - public void testBF16() { - assertEquals("bf16", CacheType.BF16.name().toLowerCase()); + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {CacheType.F32, "f32"}, + {CacheType.F16, "f16"}, + {CacheType.BF16, "bf16"}, + {CacheType.Q8_0, "q8_0"}, + {CacheType.Q4_0, "q4_0"}, + {CacheType.Q4_1, "q4_1"}, + {CacheType.IQ4_NL, "iq4_nl"}, + {CacheType.Q5_0, "q5_0"}, + {CacheType.Q5_1, "q5_1"}, + }); } - @Test - public void testQ8_0() { - assertEquals("q8_0", CacheType.Q8_0.name().toLowerCase()); - } + private final CacheType cacheType; + private final String expectedArgValue; - @Test - public void testQ4_0() { - assertEquals("q4_0", CacheType.Q4_0.name().toLowerCase()); + public CacheTypeTest(CacheType cacheType, String expectedArgValue) { + this.cacheType = cacheType; + this.expectedArgValue = expectedArgValue; } @Test - public void testQ4_1() { - assertEquals("q4_1", CacheType.Q4_1.name().toLowerCase()); + public void testGetArgValue() { + assertEquals(expectedArgValue, cacheType.getArgValue()); } - @Test - public void testIQ4_NL() { - assertEquals("iq4_nl", CacheType.IQ4_NL.name().toLowerCase()); - } + // ------------------------------------------------------------------ + // Structural invariants — tested separately from the per-value check + // ------------------------------------------------------------------ @Test - public void testQ5_0() { - assertEquals("q5_0", CacheType.Q5_0.name().toLowerCase()); + public void testEnumCount() { + assertEquals(9, CacheType.values().length); } @Test - public void testQ5_1() { - assertEquals("q5_1", CacheType.Q5_1.name().toLowerCase()); + public void testImplementsCliArg() { + assertTrue(cacheType instanceof CliArg); } @Test - public void testAllValuesHaveNonEmptyLowercaseName() { - for (CacheType ct : CacheType.values()) { - String lower = ct.name().toLowerCase(); - assertNotNull(lower); - assertFalse("CacheType " + ct + " has empty lowercase name", lower.isEmpty()); - } + public void testArgValueNonEmpty() { + assertNotNull(cacheType.getArgValue()); + assertFalse(cacheType.getArgValue().isEmpty()); } } diff --git a/src/test/java/de/kherud/llama/args/GpuSplitModeTest.java b/src/test/java/de/kherud/llama/args/GpuSplitModeTest.java index 429e88da..9e40363c 100644 --- a/src/test/java/de/kherud/llama/args/GpuSplitModeTest.java +++ b/src/test/java/de/kherud/llama/args/GpuSplitModeTest.java @@ -1,42 +1,56 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify GpuSplitMode enum values, count, and lowercase name convention used by ModelParameters.", - model = "claude-opus-4-6" -) +@RunWith(Parameterized.class) public class GpuSplitModeTest { - @Test - public void testEnumCount() { - assertEquals(3, GpuSplitMode.values().length); + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {GpuSplitMode.NONE, "none"}, + {GpuSplitMode.LAYER, "layer"}, + {GpuSplitMode.ROW, "row"}, + }); + } + + private final GpuSplitMode gpuSplitMode; + private final String expectedArgValue; + + public GpuSplitModeTest(GpuSplitMode gpuSplitMode, String expectedArgValue) { + this.gpuSplitMode = gpuSplitMode; + this.expectedArgValue = expectedArgValue; } @Test - public void testNone() { - assertEquals("none", GpuSplitMode.NONE.name().toLowerCase()); + public void testGetArgValue() { + assertEquals(expectedArgValue, gpuSplitMode.getArgValue()); } + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ + @Test - public void testLayer() { - assertEquals("layer", GpuSplitMode.LAYER.name().toLowerCase()); + public void testEnumCount() { + assertEquals(3, GpuSplitMode.values().length); } @Test - public void testRow() { - assertEquals("row", GpuSplitMode.ROW.name().toLowerCase()); + public void testImplementsCliArg() { + assertTrue(gpuSplitMode instanceof CliArg); } @Test - public void testAllValuesHaveNonEmptyLowercaseName() { - for (GpuSplitMode mode : GpuSplitMode.values()) { - String lower = mode.name().toLowerCase(); - assertNotNull(lower); - assertFalse("GpuSplitMode " + mode + " has empty lowercase name", lower.isEmpty()); - } + public void testArgValueNonEmpty() { + assertNotNull(gpuSplitMode.getArgValue()); + assertFalse(gpuSplitMode.getArgValue().isEmpty()); } } diff --git a/src/test/java/de/kherud/llama/args/MiroStatTest.java b/src/test/java/de/kherud/llama/args/MiroStatTest.java index 5610215e..49a91e89 100644 --- a/src/test/java/de/kherud/llama/args/MiroStatTest.java +++ b/src/test/java/de/kherud/llama/args/MiroStatTest.java @@ -1,40 +1,56 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify MiroStat enum values and count.", - model = "claude-opus-4-6" -) +@RunWith(Parameterized.class) public class MiroStatTest { - @Test - public void testEnumCount() { - assertEquals(3, MiroStat.values().length); + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {MiroStat.DISABLED, "0"}, + {MiroStat.V1, "1"}, + {MiroStat.V2, "2"}, + }); + } + + private final MiroStat miroStat; + private final String expectedArgValue; + + public MiroStatTest(MiroStat miroStat, String expectedArgValue) { + this.miroStat = miroStat; + this.expectedArgValue = expectedArgValue; } @Test - public void testDisabledOrdinal() { - assertEquals(0, MiroStat.DISABLED.ordinal()); + public void testGetArgValue() { + assertEquals(expectedArgValue, miroStat.getArgValue()); } + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ + @Test - public void testV1Ordinal() { - assertEquals(1, MiroStat.V1.ordinal()); + public void testEnumCount() { + assertEquals(3, MiroStat.values().length); } @Test - public void testV2Ordinal() { - assertEquals(2, MiroStat.V2.ordinal()); + public void testImplementsCliArg() { + assertTrue(miroStat instanceof CliArg); } @Test - public void testValueOf() { - assertEquals(MiroStat.DISABLED, MiroStat.valueOf("DISABLED")); - assertEquals(MiroStat.V1, MiroStat.valueOf("V1")); - assertEquals(MiroStat.V2, MiroStat.valueOf("V2")); + public void testArgValueNonEmpty() { + assertNotNull(miroStat.getArgValue()); + assertFalse(miroStat.getArgValue().isEmpty()); } } diff --git a/src/test/java/de/kherud/llama/args/ModelFlagTest.java b/src/test/java/de/kherud/llama/args/ModelFlagTest.java new file mode 100644 index 00000000..16ce3e44 --- /dev/null +++ b/src/test/java/de/kherud/llama/args/ModelFlagTest.java @@ -0,0 +1,81 @@ +package de.kherud.llama.args; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; + +import static org.junit.Assert.*; + +@RunWith(Parameterized.class) +public class ModelFlagTest { + + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {ModelFlag.NO_CONTEXT_SHIFT, "--no-context-shift"}, + {ModelFlag.FLASH_ATTN, "--flash-attn"}, + {ModelFlag.NO_PERF, "--no-perf"}, + {ModelFlag.ESCAPE, "--escape"}, + {ModelFlag.NO_ESCAPE, "--no-escape"}, + {ModelFlag.SPECIAL, "--special"}, + {ModelFlag.NO_WARMUP, "--no-warmup"}, + {ModelFlag.SPM_INFILL, "--spm-infill"}, + {ModelFlag.IGNORE_EOS, "--ignore-eos"}, + {ModelFlag.DUMP_KV_CACHE, "--dump-kv-cache"}, + {ModelFlag.NO_KV_OFFLOAD, "--no-kv-offload"}, + {ModelFlag.CONT_BATCHING, "--cont-batching"}, + {ModelFlag.NO_CONT_BATCHING, "--no-cont-batching"}, + {ModelFlag.MLOCK, "--mlock"}, + {ModelFlag.NO_MMAP, "--no-mmap"}, + {ModelFlag.CHECK_TENSORS, "--check-tensors"}, + {ModelFlag.EMBEDDING, "--embedding"}, + {ModelFlag.RERANKING, "--reranking"}, + {ModelFlag.LORA_INIT_WITHOUT_APPLY,"--lora-init-without-apply"}, + {ModelFlag.LOG_DISABLE, "--log-disable"}, + {ModelFlag.VERBOSE, "--verbose"}, + {ModelFlag.LOG_PREFIX, "--log-prefix"}, + {ModelFlag.LOG_TIMESTAMPS, "--log-timestamps"}, + {ModelFlag.JINJA, "--jinja"}, + {ModelFlag.VOCAB_ONLY, "--vocab-only"}, + {ModelFlag.KV_UNIFIED, "--kv-unified"}, + {ModelFlag.NO_KV_UNIFIED, "--no-kv-unified"}, + {ModelFlag.CLEAR_IDLE, "--clear-idle"}, + {ModelFlag.NO_CLEAR_IDLE, "--no-clear-idle"}, + }); + } + + private final ModelFlag flag; + private final String expectedCliFlag; + + public ModelFlagTest(ModelFlag flag, String expectedCliFlag) { + this.flag = flag; + this.expectedCliFlag = expectedCliFlag; + } + + @Test + public void testGetCliFlag() { + assertEquals(expectedCliFlag, flag.getCliFlag()); + } + + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ + + @Test + public void testEnumCount() { + assertEquals(29, ModelFlag.values().length); + } + + @Test + public void testCliFlagStartsWithDoubleDash() { + assertTrue("Flag " + flag + " must start with --", flag.getCliFlag().startsWith("--")); + } + + @Test + public void testCliFlagNonEmpty() { + assertFalse("Flag " + flag + " has empty CLI string", flag.getCliFlag().isEmpty()); + } +} diff --git a/src/test/java/de/kherud/llama/args/NumaStrategyTest.java b/src/test/java/de/kherud/llama/args/NumaStrategyTest.java index 4c3477f3..fd6fcc6f 100644 --- a/src/test/java/de/kherud/llama/args/NumaStrategyTest.java +++ b/src/test/java/de/kherud/llama/args/NumaStrategyTest.java @@ -1,42 +1,56 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify NumaStrategy enum values, count, and lowercase name convention used by ModelParameters.", - model = "claude-opus-4-6" -) +@RunWith(Parameterized.class) public class NumaStrategyTest { - @Test - public void testEnumCount() { - assertEquals(3, NumaStrategy.values().length); + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {NumaStrategy.DISTRIBUTE, "distribute"}, + {NumaStrategy.ISOLATE, "isolate"}, + {NumaStrategy.NUMACTL, "numactl"}, + }); + } + + private final NumaStrategy numaStrategy; + private final String expectedArgValue; + + public NumaStrategyTest(NumaStrategy numaStrategy, String expectedArgValue) { + this.numaStrategy = numaStrategy; + this.expectedArgValue = expectedArgValue; } @Test - public void testDistribute() { - assertEquals("distribute", NumaStrategy.DISTRIBUTE.name().toLowerCase()); + public void testGetArgValue() { + assertEquals(expectedArgValue, numaStrategy.getArgValue()); } + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ + @Test - public void testIsolate() { - assertEquals("isolate", NumaStrategy.ISOLATE.name().toLowerCase()); + public void testEnumCount() { + assertEquals(3, NumaStrategy.values().length); } @Test - public void testNumactl() { - assertEquals("numactl", NumaStrategy.NUMACTL.name().toLowerCase()); + public void testImplementsCliArg() { + assertTrue(numaStrategy instanceof CliArg); } @Test - public void testAllValuesHaveNonEmptyLowercaseName() { - for (NumaStrategy ns : NumaStrategy.values()) { - String lower = ns.name().toLowerCase(); - assertNotNull(lower); - assertFalse("NumaStrategy " + ns + " has empty lowercase name", lower.isEmpty()); - } + public void testArgValueNonEmpty() { + assertNotNull(numaStrategy.getArgValue()); + assertFalse(numaStrategy.getArgValue().isEmpty()); } } diff --git a/src/test/java/de/kherud/llama/args/PoolingTypeTest.java b/src/test/java/de/kherud/llama/args/PoolingTypeTest.java index 605402bd..1daeb71c 100644 --- a/src/test/java/de/kherud/llama/args/PoolingTypeTest.java +++ b/src/test/java/de/kherud/llama/args/PoolingTypeTest.java @@ -1,57 +1,59 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify that every PoolingType enum constant returns the exact CLI argument " + - "string expected by llama.cpp (e.g. MEAN -> \"mean\", RANK -> \"rank\") via " + - "getArgValue(), and that the enum has the expected number of constants." -) +@RunWith(Parameterized.class) public class PoolingTypeTest { - @Test - public void testUnspecifiedArgValue() { - assertEquals("unspecified", PoolingType.UNSPECIFIED.getArgValue()); - } - - @Test - public void testNoneArgValue() { - assertEquals("none", PoolingType.NONE.getArgValue()); - } - - @Test - public void testMeanArgValue() { - assertEquals("mean", PoolingType.MEAN.getArgValue()); - } - - @Test - public void testClsArgValue() { - assertEquals("cls", PoolingType.CLS.getArgValue()); - } - - @Test - public void testLastArgValue() { - assertEquals("last", PoolingType.LAST.getArgValue()); - } - - @Test - public void testRankArgValue() { - assertEquals("rank", PoolingType.RANK.getArgValue()); - } - - @Test - public void testAllValuesHaveArgValue() { - for (PoolingType type : PoolingType.values()) { - assertNotNull("getArgValue() should not be null for " + type, type.getArgValue()); - assertFalse("getArgValue() should not be empty for " + type, type.getArgValue().isEmpty()); - } - } - - @Test - public void testEnumCount() { - assertEquals(6, PoolingType.values().length); - } + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {PoolingType.UNSPECIFIED, "unspecified"}, + {PoolingType.NONE, "none"}, + {PoolingType.MEAN, "mean"}, + {PoolingType.CLS, "cls"}, + {PoolingType.LAST, "last"}, + {PoolingType.RANK, "rank"}, + }); + } + + private final PoolingType poolingType; + private final String expectedArgValue; + + public PoolingTypeTest(PoolingType poolingType, String expectedArgValue) { + this.poolingType = poolingType; + this.expectedArgValue = expectedArgValue; + } + + @Test + public void testGetArgValue() { + assertEquals(expectedArgValue, poolingType.getArgValue()); + } + + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ + + @Test + public void testEnumCount() { + assertEquals(6, PoolingType.values().length); + } + + @Test + public void testImplementsCliArg() { + assertTrue(poolingType instanceof CliArg); + } + + @Test + public void testArgValueNonEmpty() { + assertNotNull(poolingType.getArgValue()); + assertFalse(poolingType.getArgValue().isEmpty()); + } } diff --git a/src/test/java/de/kherud/llama/args/RopeScalingTypeTest.java b/src/test/java/de/kherud/llama/args/RopeScalingTypeTest.java index fce82846..e1635a8b 100644 --- a/src/test/java/de/kherud/llama/args/RopeScalingTypeTest.java +++ b/src/test/java/de/kherud/llama/args/RopeScalingTypeTest.java @@ -1,57 +1,59 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify that every RopeScalingType enum constant returns the exact CLI argument " + - "string expected by llama.cpp (e.g. YARN2 -> \"yarn\", LONGROPE -> \"longrope\") " + - "via getArgValue(), and that the enum has the expected number of constants." -) +@RunWith(Parameterized.class) public class RopeScalingTypeTest { - @Test - public void testUnspecifiedArgValue() { - assertEquals("unspecified", RopeScalingType.UNSPECIFIED.getArgValue()); - } - - @Test - public void testNoneArgValue() { - assertEquals("none", RopeScalingType.NONE.getArgValue()); - } - - @Test - public void testLinearArgValue() { - assertEquals("linear", RopeScalingType.LINEAR.getArgValue()); - } - - @Test - public void testYarn2ArgValue() { - assertEquals("yarn", RopeScalingType.YARN2.getArgValue()); - } - - @Test - public void testLongRopeArgValue() { - assertEquals("longrope", RopeScalingType.LONGROPE.getArgValue()); - } - - @Test - public void testMaxValueArgValue() { - assertEquals("maxvalue", RopeScalingType.MAX_VALUE.getArgValue()); - } - - @Test - public void testAllValuesHaveArgValue() { - for (RopeScalingType type : RopeScalingType.values()) { - assertNotNull("getArgValue() should not be null for " + type, type.getArgValue()); - assertFalse("getArgValue() should not be empty for " + type, type.getArgValue().isEmpty()); - } - } - - @Test - public void testEnumCount() { - assertEquals(6, RopeScalingType.values().length); - } + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {RopeScalingType.UNSPECIFIED, "unspecified"}, + {RopeScalingType.NONE, "none"}, + {RopeScalingType.LINEAR, "linear"}, + {RopeScalingType.YARN2, "yarn"}, + {RopeScalingType.LONGROPE, "longrope"}, + {RopeScalingType.MAX_VALUE, "maxvalue"}, + }); + } + + private final RopeScalingType ropeScalingType; + private final String expectedArgValue; + + public RopeScalingTypeTest(RopeScalingType ropeScalingType, String expectedArgValue) { + this.ropeScalingType = ropeScalingType; + this.expectedArgValue = expectedArgValue; + } + + @Test + public void testGetArgValue() { + assertEquals(expectedArgValue, ropeScalingType.getArgValue()); + } + + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ + + @Test + public void testEnumCount() { + assertEquals(6, RopeScalingType.values().length); + } + + @Test + public void testImplementsCliArg() { + assertTrue(ropeScalingType instanceof CliArg); + } + + @Test + public void testArgValueNonEmpty() { + assertNotNull(ropeScalingType.getArgValue()); + assertFalse(ropeScalingType.getArgValue().isEmpty()); + } } diff --git a/src/test/java/de/kherud/llama/args/SamplerTest.java b/src/test/java/de/kherud/llama/args/SamplerTest.java index 846c6667..b0518af8 100644 --- a/src/test/java/de/kherud/llama/args/SamplerTest.java +++ b/src/test/java/de/kherud/llama/args/SamplerTest.java @@ -1,73 +1,62 @@ package de.kherud.llama.args; -import de.kherud.llama.ClaudeGenerated; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; import static org.junit.Assert.*; -@ClaudeGenerated( - purpose = "Verify Sampler enum values, count, and lowercase name convention used by " + - "ModelParameters.setSamplers() semicolon-separated serialization.", - model = "claude-opus-4-6" -) +@RunWith(Parameterized.class) public class SamplerTest { - @Test - public void testEnumCount() { - assertEquals(9, Sampler.values().length); - } - - @Test - public void testDry() { - assertEquals("dry", Sampler.DRY.name().toLowerCase()); - } - - @Test - public void testTopK() { - assertEquals("top_k", Sampler.TOP_K.name().toLowerCase()); - } - - @Test - public void testTopP() { - assertEquals("top_p", Sampler.TOP_P.name().toLowerCase()); + @Parameterized.Parameters(name = "{0} -> {1}") + public static Collection data() { + return Arrays.asList(new Object[][]{ + {Sampler.DRY, "dry"}, + {Sampler.TOP_K, "top_k"}, + {Sampler.TOP_P, "top_p"}, + {Sampler.TYP_P, "typ_p"}, + {Sampler.MIN_P, "min_p"}, + {Sampler.TEMPERATURE, "temperature"}, + {Sampler.XTC, "xtc"}, + {Sampler.INFILL, "infill"}, + {Sampler.PENALTIES, "penalties"}, + }); } - @Test - public void testTypP() { - assertEquals("typ_p", Sampler.TYP_P.name().toLowerCase()); - } + private final Sampler sampler; + private final String expectedArgValue; - @Test - public void testMinP() { - assertEquals("min_p", Sampler.MIN_P.name().toLowerCase()); + public SamplerTest(Sampler sampler, String expectedArgValue) { + this.sampler = sampler; + this.expectedArgValue = expectedArgValue; } @Test - public void testTemperature() { - assertEquals("temperature", Sampler.TEMPERATURE.name().toLowerCase()); + public void testGetArgValue() { + assertEquals(expectedArgValue, sampler.getArgValue()); } - @Test - public void testXtc() { - assertEquals("xtc", Sampler.XTC.name().toLowerCase()); - } + // ------------------------------------------------------------------ + // Structural invariants + // ------------------------------------------------------------------ @Test - public void testInfill() { - assertEquals("infill", Sampler.INFILL.name().toLowerCase()); + public void testEnumCount() { + assertEquals(9, Sampler.values().length); } @Test - public void testPenalties() { - assertEquals("penalties", Sampler.PENALTIES.name().toLowerCase()); + public void testImplementsCliArg() { + assertTrue(sampler instanceof CliArg); } @Test - public void testAllValuesHaveNonEmptyLowercaseName() { - for (Sampler s : Sampler.values()) { - String lower = s.name().toLowerCase(); - assertNotNull(lower); - assertFalse("Sampler " + s + " has empty lowercase name", lower.isEmpty()); - } + public void testArgValueNonEmpty() { + assertNotNull(sampler.getArgValue()); + assertFalse(sampler.getArgValue().isEmpty()); } } diff --git a/src/test/java/de/kherud/llama/json/ChatResponseParserTest.java b/src/test/java/de/kherud/llama/json/ChatResponseParserTest.java new file mode 100644 index 00000000..69572862 --- /dev/null +++ b/src/test/java/de/kherud/llama/json/ChatResponseParserTest.java @@ -0,0 +1,163 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link ChatResponseParser}. + * No JVM native library or model file needed — JSON string literals only. + */ +public class ChatResponseParserTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final ChatResponseParser parser = new ChatResponseParser(); + + // ------------------------------------------------------------------ + // extractChoiceContent(String) + // ------------------------------------------------------------------ + + @Test + public void testExtractChoiceContent_typical() { + String json = "{\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"OK\"}," + + "\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":1}}"; + assertEquals("OK", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_emptyContent() { + String json = "{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"\"}}]}"; + assertEquals("", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_escapedContent() { + String json = "{\"choices\":[{\"message\":{\"role\":\"assistant\"," + + "\"content\":\"line1\\nline2\\t\\\"quoted\\\"\"}}]}"; + assertEquals("line1\nline2\t\"quoted\"", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_unicodeInContent() { + String json = "{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"caf\\u00e9\"}}]}"; + assertEquals("café", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_missingChoices() { + String json = "{\"id\":\"x\",\"object\":\"chat.completion\"}"; + assertEquals("", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_emptyChoicesArray() { + String json = "{\"choices\":[]}"; + assertEquals("", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_missingContent() { + String json = "{\"choices\":[{\"message\":{\"role\":\"assistant\"}}]}"; + assertEquals("", parser.extractChoiceContent(json)); + } + + @Test + public void testExtractChoiceContent_malformedJson() { + assertEquals("", parser.extractChoiceContent("{not json")); + } + + @Test + public void testExtractChoiceContent_multilineResponse() { + String content = "First line.\\nSecond line.\\nThird line."; + String json = "{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"" + content + "\"}}]}"; + assertEquals("First line.\nSecond line.\nThird line.", parser.extractChoiceContent(json)); + } + + // ------------------------------------------------------------------ + // extractChoiceContent(JsonNode) + // ------------------------------------------------------------------ + + @Test + public void testExtractChoiceContent_node() throws Exception { + JsonNode node = MAPPER.readTree( + "{\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Hello\"}}]}"); + assertEquals("Hello", parser.extractChoiceContent(node)); + } + + @Test + public void testExtractChoiceContent_nodeMultipleChoices_takesFirst() throws Exception { + JsonNode node = MAPPER.readTree( + "{\"choices\":[" + + "{\"message\":{\"content\":\"First\"}}," + + "{\"message\":{\"content\":\"Second\"}}" + + "]}"); + assertEquals("First", parser.extractChoiceContent(node)); + } + + // ------------------------------------------------------------------ + // extractUsageField + // ------------------------------------------------------------------ + + @Test + public void testExtractUsageField_promptTokens() throws Exception { + JsonNode node = MAPPER.readTree( + "{\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":5,\"total_tokens\":17}}"); + assertEquals(12, parser.extractUsageField(node, "prompt_tokens")); + } + + @Test + public void testExtractUsageField_completionTokens() throws Exception { + JsonNode node = MAPPER.readTree( + "{\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":5,\"total_tokens\":17}}"); + assertEquals(5, parser.extractUsageField(node, "completion_tokens")); + } + + @Test + public void testExtractUsageField_totalTokens() throws Exception { + JsonNode node = MAPPER.readTree( + "{\"usage\":{\"prompt_tokens\":12,\"completion_tokens\":5,\"total_tokens\":17}}"); + assertEquals(17, parser.extractUsageField(node, "total_tokens")); + } + + @Test + public void testExtractUsageField_missingUsage_returnsZero() throws Exception { + JsonNode node = MAPPER.readTree("{\"id\":\"x\"}"); + assertEquals(0, parser.extractUsageField(node, "prompt_tokens")); + } + + @Test + public void testExtractUsageField_missingField_returnsZero() throws Exception { + JsonNode node = MAPPER.readTree("{\"usage\":{}}"); + assertEquals(0, parser.extractUsageField(node, "prompt_tokens")); + } + + // ------------------------------------------------------------------ + // countChoices + // ------------------------------------------------------------------ + + @Test + public void testCountChoices_one() throws Exception { + JsonNode node = MAPPER.readTree("{\"choices\":[{\"message\":{\"content\":\"hi\"}}]}"); + assertEquals(1, parser.countChoices(node)); + } + + @Test + public void testCountChoices_multiple() throws Exception { + JsonNode node = MAPPER.readTree("{\"choices\":[{},{},{}]}"); + assertEquals(3, parser.countChoices(node)); + } + + @Test + public void testCountChoices_empty() throws Exception { + JsonNode node = MAPPER.readTree("{\"choices\":[]}"); + assertEquals(0, parser.countChoices(node)); + } + + @Test + public void testCountChoices_absent() throws Exception { + JsonNode node = MAPPER.readTree("{\"id\":\"x\"}"); + assertEquals(0, parser.countChoices(node)); + } +} diff --git a/src/test/java/de/kherud/llama/json/CompletionResponseParserTest.java b/src/test/java/de/kherud/llama/json/CompletionResponseParserTest.java new file mode 100644 index 00000000..812c381b --- /dev/null +++ b/src/test/java/de/kherud/llama/json/CompletionResponseParserTest.java @@ -0,0 +1,205 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.kherud.llama.LlamaOutput; +import de.kherud.llama.StopReason; +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link CompletionResponseParser}. + * + *

    All tests use JSON string literals — no JVM native library or model file is needed. + * This mirrors the pattern established by {@code test_json_helpers.cpp} on the C++ side. + */ +public class CompletionResponseParserTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final CompletionResponseParser parser = new CompletionResponseParser(); + + // ------------------------------------------------------------------ + // parse(String) + // ------------------------------------------------------------------ + + @Test + public void testParseString_text() throws Exception { + String json = "{\"content\":\"Hello world\",\"stop\":false}"; + LlamaOutput out = parser.parse(json); + assertEquals("Hello world", out.text); + } + + @Test + public void testParseString_stopFalse() { + String json = "{\"content\":\"partial\",\"stop\":false}"; + LlamaOutput out = parser.parse(json); + assertFalse(out.stop); + assertEquals(StopReason.NONE, out.stopReason); + } + + @Test + public void testParseString_stopTrueEos() { + String json = "{\"content\":\"done\",\"stop\":true,\"stop_type\":\"eos\"}"; + LlamaOutput out = parser.parse(json); + assertTrue(out.stop); + assertEquals(StopReason.EOS, out.stopReason); + } + + @Test + public void testParseString_stopTrueWord() { + String json = "{\"content\":\"end\",\"stop\":true,\"stop_type\":\"word\",\"stopping_word\":\"END\"}"; + LlamaOutput out = parser.parse(json); + assertTrue(out.stop); + assertEquals(StopReason.STOP_STRING, out.stopReason); + } + + @Test + public void testParseString_stopTrueLimit() { + String json = "{\"content\":\"truncated\",\"stop\":true,\"stop_type\":\"limit\",\"truncated\":true}"; + LlamaOutput out = parser.parse(json); + assertTrue(out.stop); + assertEquals(StopReason.MAX_TOKENS, out.stopReason); + } + + @Test + public void testParseString_malformedReturnsEmptyNonStop() { + LlamaOutput out = parser.parse("{not valid json"); + assertEquals("", out.text); + assertFalse(out.stop); + assertEquals(StopReason.NONE, out.stopReason); + assertTrue(out.probabilities.isEmpty()); + } + + @Test + public void testParseString_escapedContent() { + String json = "{\"content\":\"line1\\nline2\\t\\\"quoted\\\"\",\"stop\":false}"; + LlamaOutput out = parser.parse(json); + assertEquals("line1\nline2\t\"quoted\"", out.text); + } + + @Test + public void testParseString_unicodeEscape() { + String json = "{\"content\":\"caf\\u00e9\",\"stop\":false}"; + LlamaOutput out = parser.parse(json); + assertEquals("café", out.text); + } + + @Test + public void testParseString_emptyContent() { + String json = "{\"content\":\"\",\"stop\":true,\"stop_type\":\"eos\"}"; + LlamaOutput out = parser.parse(json); + assertEquals("", out.text); + assertTrue(out.stop); + } + + // ------------------------------------------------------------------ + // parse(JsonNode) + // ------------------------------------------------------------------ + + @Test + public void testParseNode_delegatesCorrectly() throws Exception { + JsonNode node = MAPPER.readTree("{\"content\":\"hi\",\"stop\":true,\"stop_type\":\"eos\"}"); + LlamaOutput out = parser.parse(node); + assertEquals("hi", out.text); + assertTrue(out.stop); + assertEquals(StopReason.EOS, out.stopReason); + } + + // ------------------------------------------------------------------ + // extractContent + // ------------------------------------------------------------------ + + @Test + public void testExtractContent_present() throws Exception { + JsonNode node = MAPPER.readTree("{\"content\":\"hello\",\"stop\":false}"); + assertEquals("hello", parser.extractContent(node)); + } + + @Test + public void testExtractContent_absent() throws Exception { + JsonNode node = MAPPER.readTree("{\"stop\":false}"); + assertEquals("", parser.extractContent(node)); + } + + @Test + public void testExtractContent_empty() throws Exception { + JsonNode node = MAPPER.readTree("{\"content\":\"\",\"stop\":true}"); + assertEquals("", parser.extractContent(node)); + } + + // ------------------------------------------------------------------ + // parseProbabilities + // ------------------------------------------------------------------ + + @Test + public void testParseProbabilities_absentKey() throws Exception { + JsonNode node = MAPPER.readTree("{\"content\":\"hi\",\"stop\":true}"); + assertTrue(parser.parseProbabilities(node).isEmpty()); + } + + @Test + public void testParseProbabilities_emptyArray() throws Exception { + JsonNode node = MAPPER.readTree("{\"content\":\"hi\",\"stop\":true,\"completion_probabilities\":[]}"); + assertTrue(parser.parseProbabilities(node).isEmpty()); + } + + @Test + public void testParseProbabilities_postSampling() throws Exception { + String json = "{\"content\":\"hi\",\"stop\":true," + + "\"completion_probabilities\":[" + + "{\"token\":\"Hello\",\"bytes\":[72],\"id\":15043,\"prob\":0.82," + + "\"top_probs\":[{\"token\":\"Hi\",\"bytes\":[72],\"id\":9932,\"prob\":0.1}]}," + + "{\"token\":\" world\",\"bytes\":[32,119],\"id\":1917,\"prob\":0.65," + + "\"top_probs\":[]}" + + "]}"; + JsonNode node = MAPPER.readTree(json); + Map probs = parser.parseProbabilities(node); + assertEquals(2, probs.size()); + assertEquals(0.82f, probs.get("Hello"), 0.001f); + assertEquals(0.65f, probs.get(" world"), 0.001f); + } + + @Test + public void testParseProbabilities_preSampling() throws Exception { + String json = "{\"content\":\"hi\",\"stop\":true," + + "\"completion_probabilities\":[" + + "{\"token\":\"Hello\",\"bytes\":[72],\"id\":15043,\"logprob\":-0.2," + + "\"top_logprobs\":[{\"token\":\"Hi\",\"bytes\":[72],\"id\":9932,\"logprob\":-2.3}]}" + + "]}"; + JsonNode node = MAPPER.readTree(json); + Map probs = parser.parseProbabilities(node); + assertEquals(1, probs.size()); + assertEquals(-0.2f, probs.get("Hello"), 0.001f); + } + + @Test + public void testParseProbabilities_escapedToken() throws Exception { + String json = "{\"content\":\"hi\",\"stop\":true," + + "\"completion_probabilities\":[" + + "{\"token\":\"say \\\"yes\\\"\",\"bytes\":[],\"id\":1,\"prob\":0.5," + + "\"top_probs\":[]}" + + "]}"; + JsonNode node = MAPPER.readTree(json); + Map probs = parser.parseProbabilities(node); + assertEquals(1, probs.size()); + assertEquals(0.5f, probs.get("say \"yes\""), 0.001f); + } + + @Test + public void testParseProbabilities_topProbs_notIncluded() throws Exception { + // top_probs entries must NOT appear in the outer map — only the outer token/prob + String json = "{\"content\":\"hi\",\"stop\":true," + + "\"completion_probabilities\":[" + + "{\"token\":\"A\",\"bytes\":[],\"id\":1,\"prob\":0.9," + + "\"top_probs\":[{\"token\":\"B\",\"bytes\":[],\"id\":2,\"prob\":0.05}]}" + + "]}"; + JsonNode node = MAPPER.readTree(json); + Map probs = parser.parseProbabilities(node); + assertEquals(1, probs.size()); + assertTrue("only outer token 'A' should be present", probs.containsKey("A")); + assertFalse("inner top_probs token 'B' must not appear", probs.containsKey("B")); + } +} diff --git a/src/test/java/de/kherud/llama/json/ParameterJsonSerializerTest.java b/src/test/java/de/kherud/llama/json/ParameterJsonSerializerTest.java new file mode 100644 index 00000000..e97e7225 --- /dev/null +++ b/src/test/java/de/kherud/llama/json/ParameterJsonSerializerTest.java @@ -0,0 +1,336 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import de.kherud.llama.Pair; +import de.kherud.llama.args.Sampler; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link ParameterJsonSerializer}. + * No JVM native library or model file needed — plain Java values only. + */ +public class ParameterJsonSerializerTest { + + private final ParameterJsonSerializer serializer = new ParameterJsonSerializer(); + + // ------------------------------------------------------------------ + // toJsonString + // ------------------------------------------------------------------ + + @Test + public void testToJsonString_simple() { + assertEquals("\"hello\"", serializer.toJsonString("hello")); + } + + @Test + public void testToJsonString_null() { + assertEquals("null", serializer.toJsonString(null)); + } + + @Test + public void testToJsonString_emptyString() { + assertEquals("\"\"", serializer.toJsonString("")); + } + + @Test + public void testToJsonString_newline() { + assertEquals("\"line1\\nline2\"", serializer.toJsonString("line1\nline2")); + } + + @Test + public void testToJsonString_tab() { + assertEquals("\"a\\tb\"", serializer.toJsonString("a\tb")); + } + + @Test + public void testToJsonString_quote() { + assertEquals("\"say \\\"hi\\\"\"", serializer.toJsonString("say \"hi\"")); + } + + @Test + public void testToJsonString_backslash() { + assertEquals("\"path\\\\file\"", serializer.toJsonString("path\\file")); + } + + @Test + public void testToJsonString_unicode() { + assertEquals("\"café\"", serializer.toJsonString("café")); + } + + // ------------------------------------------------------------------ + // buildMessages + // ------------------------------------------------------------------ + + @Test + public void testBuildMessages_withSystemMessage() { + List> msgs = Collections.singletonList(new Pair<>("user", "Hello")); + ArrayNode arr = serializer.buildMessages("You are helpful.", msgs); + assertEquals(2, arr.size()); + assertEquals("system", arr.get(0).path("role").asText()); + assertEquals("You are helpful.", arr.get(0).path("content").asText()); + assertEquals("user", arr.get(1).path("role").asText()); + assertEquals("Hello", arr.get(1).path("content").asText()); + } + + @Test + public void testBuildMessages_withoutSystemMessage() { + List> msgs = Arrays.asList( + new Pair<>("user", "Hi"), + new Pair<>("assistant", "Hello there") + ); + ArrayNode arr = serializer.buildMessages(null, msgs); + assertEquals(2, arr.size()); + assertEquals("user", arr.get(0).path("role").asText()); + assertEquals("assistant", arr.get(1).path("role").asText()); + } + + @Test + public void testBuildMessages_emptySystemMessage_skipped() { + List> msgs = Collections.singletonList(new Pair<>("user", "Hi")); + ArrayNode arr = serializer.buildMessages("", msgs); + assertEquals(1, arr.size()); + assertEquals("user", arr.get(0).path("role").asText()); + } + + @Test + public void testBuildMessages_specialCharsInContent() { + List> msgs = Collections.singletonList( + new Pair<>("user", "line1\nline2\t\"quoted\"")); + ArrayNode arr = serializer.buildMessages(null, msgs); + assertEquals("line1\nline2\t\"quoted\"", arr.get(0).path("content").asText()); + } + + @Test(expected = IllegalArgumentException.class) + public void testBuildMessages_invalidRole_throws() { + List> msgs = Collections.singletonList(new Pair<>("system", "oops")); + serializer.buildMessages(null, msgs); + } + + @Test + public void testBuildMessages_roundtripsAsJson() throws Exception { + List> msgs = Collections.singletonList(new Pair<>("user", "Hello")); + ArrayNode arr = serializer.buildMessages("Sys", msgs); + String json = arr.toString(); + JsonNode parsed = serializer.OBJECT_MAPPER.readTree(json); + assertEquals("system", parsed.get(0).path("role").asText()); + assertEquals("Sys", parsed.get(0).path("content").asText()); + assertEquals("user", parsed.get(1).path("role").asText()); + assertEquals("Hello", parsed.get(1).path("content").asText()); + } + + // ------------------------------------------------------------------ + // buildStopStrings + // ------------------------------------------------------------------ + + @Test + public void testBuildStopStrings_single() { + ArrayNode arr = serializer.buildStopStrings("<|endoftext|>"); + assertEquals(1, arr.size()); + assertEquals("<|endoftext|>", arr.get(0).asText()); + } + + @Test + public void testBuildStopStrings_multiple() { + ArrayNode arr = serializer.buildStopStrings("stop1", "stop2", "stop3"); + assertEquals(3, arr.size()); + assertEquals("stop1", arr.get(0).asText()); + assertEquals("stop3", arr.get(2).asText()); + } + + @Test + public void testBuildStopStrings_withSpecialChars() { + ArrayNode arr = serializer.buildStopStrings("line\nnewline", "tab\there"); + assertEquals("line\nnewline", arr.get(0).asText()); + assertEquals("tab\there", arr.get(1).asText()); + } + + @Test + public void testBuildStopStrings_roundtripsAsJson() throws Exception { + ArrayNode arr = serializer.buildStopStrings("a", "b"); + JsonNode parsed = serializer.OBJECT_MAPPER.readTree(arr.toString()); + assertTrue(parsed.isArray()); + assertEquals("a", parsed.get(0).asText()); + } + + // ------------------------------------------------------------------ + // buildSamplers + // ------------------------------------------------------------------ + + @Test + public void testBuildSamplers_allTypes() { + ArrayNode arr = serializer.buildSamplers( + Sampler.TOP_K, Sampler.TOP_P, Sampler.MIN_P, Sampler.TEMPERATURE); + assertEquals(4, arr.size()); + assertEquals("top_k", arr.get(0).asText()); + assertEquals("top_p", arr.get(1).asText()); + assertEquals("min_p", arr.get(2).asText()); + assertEquals("temperature", arr.get(3).asText()); + } + + @Test + public void testBuildSamplers_single() { + ArrayNode arr = serializer.buildSamplers(Sampler.TEMPERATURE); + assertEquals(1, arr.size()); + assertEquals("temperature", arr.get(0).asText()); + } + + // ------------------------------------------------------------------ + // buildIntArray + // ------------------------------------------------------------------ + + @Test + public void testBuildIntArray_values() { + ArrayNode arr = serializer.buildIntArray(new int[]{1, 2, 3}); + assertEquals(3, arr.size()); + assertEquals(1, arr.get(0).asInt()); + assertEquals(3, arr.get(2).asInt()); + } + + @Test + public void testBuildIntArray_empty() { + ArrayNode arr = serializer.buildIntArray(new int[]{}); + assertEquals(0, arr.size()); + } + + @Test + public void testBuildIntArray_roundtripsAsJson() throws Exception { + ArrayNode arr = serializer.buildIntArray(new int[]{10, 20}); + JsonNode parsed = serializer.OBJECT_MAPPER.readTree(arr.toString()); + assertTrue(parsed.isArray()); + assertEquals(10, parsed.get(0).asInt()); + } + + // ------------------------------------------------------------------ + // buildTokenIdBiasArray + // ------------------------------------------------------------------ + + @Test + public void testBuildTokenIdBiasArray_structure() { + Map biases = new LinkedHashMap<>(); + biases.put(15043, 1.0f); + biases.put(50256, -0.5f); + ArrayNode arr = serializer.buildTokenIdBiasArray(biases); + assertEquals(2, arr.size()); + assertEquals(15043, arr.get(0).get(0).asInt()); + assertEquals(1.0, arr.get(0).get(1).asDouble(), 0.001); + assertEquals(50256, arr.get(1).get(0).asInt()); + assertEquals(-0.5, arr.get(1).get(1).asDouble(), 0.001); + } + + @Test + public void testBuildTokenIdBiasArray_empty() { + ArrayNode arr = serializer.buildTokenIdBiasArray(Collections.emptyMap()); + assertEquals(0, arr.size()); + } + + // ------------------------------------------------------------------ + // buildTokenStringBiasArray + // ------------------------------------------------------------------ + + @Test + public void testBuildTokenStringBiasArray_structure() { + Map biases = new LinkedHashMap<>(); + biases.put("Hello", 1.0f); + biases.put(" world", -0.5f); + ArrayNode arr = serializer.buildTokenStringBiasArray(biases); + assertEquals(2, arr.size()); + assertEquals("Hello", arr.get(0).get(0).asText()); + assertEquals(1.0, arr.get(0).get(1).asDouble(), 0.001); + assertEquals(" world", arr.get(1).get(0).asText()); + } + + @Test + public void testBuildTokenStringBiasArray_specialCharsInKey() { + Map biases = new LinkedHashMap<>(); + biases.put("line\nnewline", 2.0f); + ArrayNode arr = serializer.buildTokenStringBiasArray(biases); + assertEquals("line\nnewline", arr.get(0).get(0).asText()); + } + + // ------------------------------------------------------------------ + // buildDisableTokenIdArray + // ------------------------------------------------------------------ + + @Test + public void testBuildDisableTokenIdArray_structure() { + ArrayNode arr = serializer.buildDisableTokenIdArray(Arrays.asList(100, 200, 300)); + assertEquals(3, arr.size()); + for (int i = 0; i < arr.size(); i++) { + assertFalse(arr.get(i).get(1).asBoolean()); + } + assertEquals(100, arr.get(0).get(0).asInt()); + } + + @Test + public void testBuildDisableTokenIdArray_empty() { + ArrayNode arr = serializer.buildDisableTokenIdArray(Collections.emptyList()); + assertEquals(0, arr.size()); + } + + // ------------------------------------------------------------------ + // buildDisableTokenStringArray + // ------------------------------------------------------------------ + + @Test + public void testBuildDisableTokenStringArray_structure() { + ArrayNode arr = serializer.buildDisableTokenStringArray(Arrays.asList("foo", "bar")); + assertEquals(2, arr.size()); + assertEquals("foo", arr.get(0).get(0).asText()); + assertFalse(arr.get(0).get(1).asBoolean()); + assertEquals("bar", arr.get(1).get(0).asText()); + } + + // ------------------------------------------------------------------ + // buildRawValueObject + // ------------------------------------------------------------------ + + @Test + public void testBuildRawValueObject_booleanValue() { + Map map = Collections.singletonMap("enable_thinking", "true"); + ObjectNode node = serializer.buildRawValueObject(map); + assertTrue(node.path("enable_thinking").isBoolean()); + assertTrue(node.path("enable_thinking").asBoolean()); + } + + @Test + public void testBuildRawValueObject_numberValue() { + Map map = Collections.singletonMap("temperature", "0.7"); + ObjectNode node = serializer.buildRawValueObject(map); + assertEquals(0.7, node.path("temperature").asDouble(), 0.001); + } + + @Test + public void testBuildRawValueObject_stringValue() { + Map map = Collections.singletonMap("mode", "\"fast\""); + ObjectNode node = serializer.buildRawValueObject(map); + assertEquals("fast", node.path("mode").asText()); + } + + @Test + public void testBuildRawValueObject_invalidJsonFallsBackToString() { + Map map = Collections.singletonMap("key", "not-valid-json{{{"); + ObjectNode node = serializer.buildRawValueObject(map); + assertEquals("not-valid-json{{{", node.path("key").asText()); + } + + @Test + public void testBuildRawValueObject_roundtripsAsJson() throws Exception { + Map map = new LinkedHashMap<>(); + map.put("flag", "true"); + map.put("count", "3"); + ObjectNode node = serializer.buildRawValueObject(map); + JsonNode parsed = serializer.OBJECT_MAPPER.readTree(node.toString()); + assertTrue(parsed.path("flag").asBoolean()); + assertEquals(3, parsed.path("count").asInt()); + } +} diff --git a/src/test/java/de/kherud/llama/json/RerankResponseParserTest.java b/src/test/java/de/kherud/llama/json/RerankResponseParserTest.java new file mode 100644 index 00000000..84cc285e --- /dev/null +++ b/src/test/java/de/kherud/llama/json/RerankResponseParserTest.java @@ -0,0 +1,123 @@ +package de.kherud.llama.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.kherud.llama.Pair; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.*; + +/** + * Unit tests for {@link RerankResponseParser}. + * No JVM native library or model file needed — JSON string literals only. + */ +public class RerankResponseParserTest { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + private final RerankResponseParser parser = new RerankResponseParser(); + + // ------------------------------------------------------------------ + // parse(String) + // ------------------------------------------------------------------ + + @Test + public void testParseString_singleEntry() { + String json = "[{\"document\":\"The quick brown fox\",\"index\":0,\"score\":0.92}]"; + List> result = parser.parse(json); + assertEquals(1, result.size()); + assertEquals("The quick brown fox", result.get(0).getKey()); + assertEquals(0.92f, result.get(0).getValue(), 0.001f); + } + + @Test + public void testParseString_multipleEntries() { + String json = "[" + + "{\"document\":\"First\",\"index\":0,\"score\":0.9}," + + "{\"document\":\"Second\",\"index\":1,\"score\":0.5}," + + "{\"document\":\"Third\",\"index\":2,\"score\":0.1}" + + "]"; + List> result = parser.parse(json); + assertEquals(3, result.size()); + assertEquals("First", result.get(0).getKey()); + assertEquals("Second", result.get(1).getKey()); + assertEquals("Third", result.get(2).getKey()); + assertEquals(0.9f, result.get(0).getValue(), 0.001f); + assertEquals(0.5f, result.get(1).getValue(), 0.001f); + assertEquals(0.1f, result.get(2).getValue(), 0.001f); + } + + @Test + public void testParseString_emptyArray() { + List> result = parser.parse("[]"); + assertTrue(result.isEmpty()); + } + + @Test + public void testParseString_malformed() { + List> result = parser.parse("{not json"); + assertTrue(result.isEmpty()); + } + + @Test + public void testParseString_notAnArray() { + List> result = parser.parse("{\"document\":\"x\",\"score\":0.5}"); + assertTrue(result.isEmpty()); + } + + @Test + public void testParseString_documentWithSpecialChars() { + String json = "[{\"document\":\"line1\\nline2\\t\\\"quoted\\\"\",\"index\":0,\"score\":0.75}]"; + List> result = parser.parse(json); + assertEquals(1, result.size()); + assertEquals("line1\nline2\t\"quoted\"", result.get(0).getKey()); + } + + @Test + public void testParseString_scoreZero() { + String json = "[{\"document\":\"irrelevant\",\"index\":0,\"score\":0.0}]"; + List> result = parser.parse(json); + assertEquals(1, result.size()); + assertEquals(0.0f, result.get(0).getValue(), 0.001f); + } + + // ------------------------------------------------------------------ + // parse(JsonNode) + // ------------------------------------------------------------------ + + @Test + public void testParseNode_preservesOrder() throws Exception { + String json = "[" + + "{\"document\":\"A\",\"index\":0,\"score\":0.8}," + + "{\"document\":\"B\",\"index\":1,\"score\":0.3}" + + "]"; + JsonNode arr = MAPPER.readTree(json); + List> result = parser.parse(arr); + assertEquals(2, result.size()); + assertEquals("A", result.get(0).getKey()); + assertEquals("B", result.get(1).getKey()); + } + + @Test + public void testParseNode_notArray() throws Exception { + JsonNode obj = MAPPER.readTree("{\"document\":\"x\",\"score\":0.5}"); + assertTrue(parser.parse(obj).isEmpty()); + } + + @Test + public void testParseNode_missingScore_defaultsToZero() throws Exception { + JsonNode arr = MAPPER.readTree("[{\"document\":\"doc\",\"index\":0}]"); + List> result = parser.parse(arr); + assertEquals(1, result.size()); + assertEquals(0.0f, result.get(0).getValue(), 0.001f); + } + + @Test + public void testParseNode_missingDocument_defaultsToEmpty() throws Exception { + JsonNode arr = MAPPER.readTree("[{\"index\":0,\"score\":0.5}]"); + List> result = parser.parse(arr); + assertEquals(1, result.size()); + assertEquals("", result.get(0).getKey()); + } +}