From 3e703de21ba21bd94aa8e9191a781b35bb3c154e Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 8 Apr 2026 07:56:39 +0000 Subject: [PATCH] Extract duplicate logic to reduce LOC across C++ and Java - jni_helpers.hpp: deduplicate embedding_to_jfloat_array_impl and tokens_to_jint_array_impl via a shared vec_to_jarray_impl template - jllama.cpp: extract require_embedding_support() guard used by embed() and handleEmbeddings() - JsonParameters.java: add mapToJsonObject() utility for map-to-JSON serialization, used by both setChatTemplateKwargs implementations - InferenceParameters.java: extract buildBiasPairArray() and buildDisablePairArray() helpers to deduplicate four near-identical token bias serialization methods https://claude.ai/code/session_01AGrb17h4jLeWr1de6RGubX --- src/main/cpp/jllama.cpp | 25 ++-- src/main/cpp/jni_helpers.hpp | 58 +++++---- .../de/kherud/llama/InferenceParameters.java | 112 ++++++------------ .../java/de/kherud/llama/JsonParameters.java | 12 ++ .../java/de/kherud/llama/ModelParameters.java | 10 +- 5 files changed, 95 insertions(+), 122 deletions(-) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index f8f1e4c3..c3e21f2e 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -350,6 +350,19 @@ static json parse_json_params(JNIEnv *env, jstring jparams) { return require_json_field_impl(env, data, field, c_llama_error); } +/** + * Throws if the model was not loaded with embedding support. Returns false + * (after throwing) when embedding is unavailable, true otherwise. + */ +[[nodiscard]] static bool require_embedding_support(JNIEnv *env, server_context *ctx_server) { + if (!ctx_server->params_base.embedding) { + env->ThrowNew(c_llama_error, + "Model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))"); + return false; + } + return true; +} + /** * Validates `jfilename`, builds a SAVE or RESTORE slot task, dispatches it, * and returns the result as a jstring. Shared by the SAVE (case 1) and @@ -849,11 +862,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletionJson( JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) { REQUIRE_SERVER_CONTEXT(nullptr); - if (!ctx_server->params_base.embedding) { - env->ThrowNew(c_llama_error, - "model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))"); - return nullptr; - } + if (!require_embedding_support(env, ctx_server)) return nullptr; const std::string prompt = parse_jstring(env, jprompt); @@ -1117,11 +1126,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn jstring jparams, jboolean joaiCompat) { REQUIRE_SERVER_CONTEXT(nullptr); - if (!ctx_server->params_base.embedding) { - env->ThrowNew(c_llama_error, - "Model was not loaded with embedding support (see ModelParameters#enableEmbedding())"); - return nullptr; - } + if (!require_embedding_support(env, ctx_server)) return nullptr; oaicompat_type oaicompat = joaiCompat ? OAICOMPAT_TYPE_EMBEDDING : OAICOMPAT_TYPE_NONE; diff --git a/src/main/cpp/jni_helpers.hpp b/src/main/cpp/jni_helpers.hpp index 001d7d0e..17614a07 100644 --- a/src/main/cpp/jni_helpers.hpp +++ b/src/main/cpp/jni_helpers.hpp @@ -353,48 +353,46 @@ inline void append_task(server_context *ctx_server, } // --------------------------------------------------------------------------- -// embedding_to_jfloat_array_impl +// vec_to_jarray_impl // -// Converts a float vector to a Java jfloatArray. -// -// On success: returns a new jfloatArray filled with the embedding values. +// Generic helper: converts a C++ vector to a JNI primitive array. +// Parameterized on JNI array/element types and the alloc/copy member fns. // On allocation failure: throws via JNI with oom_class and returns nullptr. // --------------------------------------------------------------------------- -[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl( - JNIEnv *env, - const std::vector &values, - jclass oom_class) { +template +[[nodiscard]] inline JArray vec_to_jarray_impl( + JNIEnv *env, + const std::vector &values, + jclass oom_class, + const char *oom_msg, + JArray (JNIEnv_::*alloc)(jsize), + void (JNIEnv_::*copy)(JArray, jsize, jsize, const JElem *)) { const jsize len = static_cast(values.size()); - jfloatArray arr = env->NewFloatArray(len); + JArray arr = (env->*alloc)(len); if (arr == nullptr) { - env->ThrowNew(oom_class, "could not allocate embedding"); + env->ThrowNew(oom_class, oom_msg); return nullptr; } - env->SetFloatArrayRegion(arr, 0, len, - reinterpret_cast(values.data())); + (env->*copy)(arr, 0, len, reinterpret_cast(values.data())); return arr; } -// --------------------------------------------------------------------------- -// tokens_to_jint_array_impl -// -// Converts a token vector to a Java jintArray. Symmetric with -// embedding_to_jfloat_array_impl for the int (token) case. -// -// On success: returns a new jintArray filled with the token values. -// On allocation failure: throws via JNI with oom_class and returns nullptr. -// --------------------------------------------------------------------------- +// Converts a float vector to a Java jfloatArray. +[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl( + JNIEnv *env, + const std::vector &values, + jclass oom_class) { + return vec_to_jarray_impl( + env, values, oom_class, "could not allocate embedding", + &JNIEnv_::NewFloatArray, &JNIEnv_::SetFloatArrayRegion); +} + +// Converts a token vector to a Java jintArray. [[nodiscard]] inline jintArray tokens_to_jint_array_impl( JNIEnv *env, const std::vector &tokens, jclass oom_class) { - const jsize len = static_cast(tokens.size()); - jintArray arr = env->NewIntArray(len); - if (arr == nullptr) { - env->ThrowNew(oom_class, "could not allocate token memory"); - return nullptr; - } - env->SetIntArrayRegion(arr, 0, len, - reinterpret_cast(tokens.data())); - return arr; + return vec_to_jarray_impl( + env, tokens, oom_class, "could not allocate token memory", + &JNIEnv_::NewIntArray, &JNIEnv_::SetIntArrayRegion); } diff --git a/src/main/java/de/kherud/llama/InferenceParameters.java b/src/main/java/de/kherud/llama/InferenceParameters.java index f0f6d24b..9d97b91a 100644 --- a/src/main/java/de/kherud/llama/InferenceParameters.java +++ b/src/main/java/de/kherud/llama/InferenceParameters.java @@ -408,23 +408,7 @@ public InferenceParameters setIgnoreEos(boolean ignoreEos) { */ public InferenceParameters setTokenIdBias(Map logitBias) { if (!logitBias.isEmpty()) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - int i = 0; - for (Map.Entry entry : logitBias.entrySet()) { - Integer key = entry.getKey(); - Float value = entry.getValue(); - builder.append("[") - .append(key) - .append(", ") - .append(value) - .append("]"); - if (i++ < logitBias.size() - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_LOGIT_BIAS, builder.toString()); + parameters.put(PARAM_LOGIT_BIAS, buildBiasPairArray(logitBias, String::valueOf)); } return this; } @@ -444,21 +428,7 @@ public InferenceParameters setTokenIdBias(Map logitBias) { */ public InferenceParameters disableTokenIds(Collection tokenIds) { if (!tokenIds.isEmpty()) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - int i = 0; - for (Integer token : tokenIds) { - builder.append("[") - .append(token) - .append(", ") - .append(false) - .append("]"); - if (i++ < tokenIds.size() - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_LOGIT_BIAS, builder.toString()); + parameters.put(PARAM_LOGIT_BIAS, buildDisablePairArray(tokenIds, String::valueOf)); } return this; } @@ -478,23 +448,7 @@ public InferenceParameters disableTokenIds(Collection tokenIds) { */ public InferenceParameters setTokenBias(Map logitBias) { if (!logitBias.isEmpty()) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - int i = 0; - for (Map.Entry entry : logitBias.entrySet()) { - String key = entry.getKey(); - Float value = entry.getValue(); - builder.append("[") - .append(toJsonString(key)) - .append(", ") - .append(value) - .append("]"); - if (i++ < logitBias.size() - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_LOGIT_BIAS, builder.toString()); + parameters.put(PARAM_LOGIT_BIAS, buildBiasPairArray(logitBias, this::toJsonString)); } return this; } @@ -514,21 +468,7 @@ public InferenceParameters setTokenBias(Map logitBias) { */ public InferenceParameters disableTokens(Collection tokens) { if (!tokens.isEmpty()) { - StringBuilder builder = new StringBuilder(); - builder.append("["); - int i = 0; - for (String token : tokens) { - builder.append("[") - .append(toJsonString(token)) - .append(", ") - .append(false) - .append("]"); - if (i++ < tokens.size() - 1) { - builder.append(", "); - } - } - builder.append("]"); - parameters.put(PARAM_LOGIT_BIAS, builder.toString()); + parameters.put(PARAM_LOGIT_BIAS, buildDisablePairArray(tokens, this::toJsonString)); } return this; } @@ -627,15 +567,7 @@ public InferenceParameters setChatTemplate(String chatTemplate) { * @return this builder */ public InferenceParameters setChatTemplateKwargs(java.util.Map kwargs) { - StringBuilder sb = new StringBuilder("{"); - boolean first = true; - for (java.util.Map.Entry entry : kwargs.entrySet()) { - if (!first) sb.append(","); - sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue()); - first = false; - } - sb.append("}"); - parameters.put(PARAM_CHAT_TEMPLATE_KWARGS, sb.toString()); + parameters.put(PARAM_CHAT_TEMPLATE_KWARGS, mapToJsonObject(kwargs)); return this; } @@ -695,4 +627,38 @@ InferenceParameters setStream(boolean stream) { return this; } + private static String buildBiasPairArray(Map map, + java.util.function.Function keySerializer) { + StringBuilder builder = new StringBuilder("["); + int i = 0; + for (Map.Entry entry : map.entrySet()) { + builder.append("[") + .append(keySerializer.apply(entry.getKey())) + .append(", ") + .append(entry.getValue()) + .append("]"); + if (i++ < map.size() - 1) { + builder.append(", "); + } + } + builder.append("]"); + return builder.toString(); + } + + private static String buildDisablePairArray(Collection items, + java.util.function.Function serializer) { + StringBuilder builder = new StringBuilder("["); + int i = 0; + for (T item : items) { + builder.append("[") + .append(serializer.apply(item)) + .append(", false]"); + if (i++ < items.size() - 1) { + builder.append(", "); + } + } + builder.append("]"); + return builder.toString(); + } + } diff --git a/src/main/java/de/kherud/llama/JsonParameters.java b/src/main/java/de/kherud/llama/JsonParameters.java index e9916976..fc3e9dd2 100644 --- a/src/main/java/de/kherud/llama/JsonParameters.java +++ b/src/main/java/de/kherud/llama/JsonParameters.java @@ -35,6 +35,18 @@ public String toString() { return builder.toString(); } + static String mapToJsonObject(Map map) { + StringBuilder sb = new StringBuilder("{"); + boolean first = true; + for (Map.Entry entry : map.entrySet()) { + if (!first) sb.append(","); + sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue()); + first = false; + } + sb.append("}"); + return sb.toString(); + } + // taken from org.json.JSONObject#quote(String, Writer) String toJsonString(String text) { if (text == null) return null; diff --git a/src/main/java/de/kherud/llama/ModelParameters.java b/src/main/java/de/kherud/llama/ModelParameters.java index d288127f..b1659d74 100644 --- a/src/main/java/de/kherud/llama/ModelParameters.java +++ b/src/main/java/de/kherud/llama/ModelParameters.java @@ -1164,15 +1164,7 @@ public ModelParameters setChatTemplate(String chatTemplate) { * @return this builder */ public ModelParameters setChatTemplateKwargs(java.util.Map kwargs) { - StringBuilder sb = new StringBuilder("{"); - boolean first = true; - for (java.util.Map.Entry entry : kwargs.entrySet()) { - if (!first) sb.append(","); - sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue()); - first = false; - } - sb.append("}"); - parameters.put("--chat-template-kwargs", sb.toString()); + parameters.put("--chat-template-kwargs", JsonParameters.mapToJsonObject(kwargs)); return this; }