Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,19 @@ static json parse_json_params(JNIEnv *env, jstring jparams) {
return require_json_field_impl(env, data, field, c_llama_error);
}

/**
* Throws if the model was not loaded with embedding support. Returns false
* (after throwing) when embedding is unavailable, true otherwise.
*/
[[nodiscard]] static bool require_embedding_support(JNIEnv *env, server_context *ctx_server) {
if (!ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
"Model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))");
return false;
}
return true;
}

/**
* Validates `jfilename`, builds a SAVE or RESTORE slot task, dispatches it,
* and returns the result as a jstring. Shared by the SAVE (case 1) and
Expand Down Expand Up @@ -849,11 +862,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletionJson(
JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) {
REQUIRE_SERVER_CONTEXT(nullptr);

if (!ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
"model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))");
return nullptr;
}
if (!require_embedding_support(env, ctx_server)) return nullptr;

const std::string prompt = parse_jstring(env, jprompt);

Expand Down Expand Up @@ -1117,11 +1126,7 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn
jstring jparams, jboolean joaiCompat) {
REQUIRE_SERVER_CONTEXT(nullptr);

if (!ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
"Model was not loaded with embedding support (see ModelParameters#enableEmbedding())");
return nullptr;
}
if (!require_embedding_support(env, ctx_server)) return nullptr;

oaicompat_type oaicompat = joaiCompat ? OAICOMPAT_TYPE_EMBEDDING : OAICOMPAT_TYPE_NONE;

Expand Down
58 changes: 28 additions & 30 deletions src/main/cpp/jni_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,48 +353,46 @@ inline void append_task(server_context *ctx_server,
}

// ---------------------------------------------------------------------------
// embedding_to_jfloat_array_impl
// vec_to_jarray_impl
//
// Converts a float vector to a Java jfloatArray.
//
// On success: returns a new jfloatArray filled with the embedding values.
// Generic helper: converts a C++ vector to a JNI primitive array.
// Parameterized on JNI array/element types and the alloc/copy member fns.
// On allocation failure: throws via JNI with oom_class and returns nullptr.
// ---------------------------------------------------------------------------
[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl(
JNIEnv *env,
const std::vector<float> &values,
jclass oom_class) {
template <typename JArray, typename JElem, typename CppElem>
[[nodiscard]] inline JArray vec_to_jarray_impl(
JNIEnv *env,
const std::vector<CppElem> &values,
jclass oom_class,
const char *oom_msg,
JArray (JNIEnv_::*alloc)(jsize),
void (JNIEnv_::*copy)(JArray, jsize, jsize, const JElem *)) {
const jsize len = static_cast<jsize>(values.size());
jfloatArray arr = env->NewFloatArray(len);
JArray arr = (env->*alloc)(len);
if (arr == nullptr) {
env->ThrowNew(oom_class, "could not allocate embedding");
env->ThrowNew(oom_class, oom_msg);
return nullptr;
}
env->SetFloatArrayRegion(arr, 0, len,
reinterpret_cast<const jfloat *>(values.data()));
(env->*copy)(arr, 0, len, reinterpret_cast<const JElem *>(values.data()));
return arr;
}

// ---------------------------------------------------------------------------
// tokens_to_jint_array_impl
//
// Converts a token vector to a Java jintArray. Symmetric with
// embedding_to_jfloat_array_impl for the int (token) case.
//
// On success: returns a new jintArray filled with the token values.
// On allocation failure: throws via JNI with oom_class and returns nullptr.
// ---------------------------------------------------------------------------
// Converts a float vector to a Java jfloatArray.
[[nodiscard]] inline jfloatArray embedding_to_jfloat_array_impl(
JNIEnv *env,
const std::vector<float> &values,
jclass oom_class) {
return vec_to_jarray_impl<jfloatArray, jfloat>(
env, values, oom_class, "could not allocate embedding",
&JNIEnv_::NewFloatArray, &JNIEnv_::SetFloatArrayRegion);
}

// Converts a token vector to a Java jintArray.
[[nodiscard]] inline jintArray tokens_to_jint_array_impl(
JNIEnv *env,
const std::vector<int32_t> &tokens,
jclass oom_class) {
const jsize len = static_cast<jsize>(tokens.size());
jintArray arr = env->NewIntArray(len);
if (arr == nullptr) {
env->ThrowNew(oom_class, "could not allocate token memory");
return nullptr;
}
env->SetIntArrayRegion(arr, 0, len,
reinterpret_cast<const jint *>(tokens.data()));
return arr;
return vec_to_jarray_impl<jintArray, jint>(
env, tokens, oom_class, "could not allocate token memory",
&JNIEnv_::NewIntArray, &JNIEnv_::SetIntArrayRegion);
}
112 changes: 39 additions & 73 deletions src/main/java/de/kherud/llama/InferenceParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -408,23 +408,7 @@ public InferenceParameters setIgnoreEos(boolean ignoreEos) {
*/
public InferenceParameters setTokenIdBias(Map<Integer, Float> logitBias) {
if (!logitBias.isEmpty()) {
StringBuilder builder = new StringBuilder();
builder.append("[");
int i = 0;
for (Map.Entry<Integer, Float> entry : logitBias.entrySet()) {
Integer key = entry.getKey();
Float value = entry.getValue();
builder.append("[")
.append(key)
.append(", ")
.append(value)
.append("]");
if (i++ < logitBias.size() - 1) {
builder.append(", ");
}
}
builder.append("]");
parameters.put(PARAM_LOGIT_BIAS, builder.toString());
parameters.put(PARAM_LOGIT_BIAS, buildBiasPairArray(logitBias, String::valueOf));
}
return this;
}
Expand All @@ -444,21 +428,7 @@ public InferenceParameters setTokenIdBias(Map<Integer, Float> logitBias) {
*/
public InferenceParameters disableTokenIds(Collection<Integer> tokenIds) {
if (!tokenIds.isEmpty()) {
StringBuilder builder = new StringBuilder();
builder.append("[");
int i = 0;
for (Integer token : tokenIds) {
builder.append("[")
.append(token)
.append(", ")
.append(false)
.append("]");
if (i++ < tokenIds.size() - 1) {
builder.append(", ");
}
}
builder.append("]");
parameters.put(PARAM_LOGIT_BIAS, builder.toString());
parameters.put(PARAM_LOGIT_BIAS, buildDisablePairArray(tokenIds, String::valueOf));
}
return this;
}
Expand All @@ -478,23 +448,7 @@ public InferenceParameters disableTokenIds(Collection<Integer> tokenIds) {
*/
public InferenceParameters setTokenBias(Map<String, Float> logitBias) {
if (!logitBias.isEmpty()) {
StringBuilder builder = new StringBuilder();
builder.append("[");
int i = 0;
for (Map.Entry<String, Float> entry : logitBias.entrySet()) {
String key = entry.getKey();
Float value = entry.getValue();
builder.append("[")
.append(toJsonString(key))
.append(", ")
.append(value)
.append("]");
if (i++ < logitBias.size() - 1) {
builder.append(", ");
}
}
builder.append("]");
parameters.put(PARAM_LOGIT_BIAS, builder.toString());
parameters.put(PARAM_LOGIT_BIAS, buildBiasPairArray(logitBias, this::toJsonString));
}
return this;
}
Expand All @@ -514,21 +468,7 @@ public InferenceParameters setTokenBias(Map<String, Float> logitBias) {
*/
public InferenceParameters disableTokens(Collection<String> tokens) {
if (!tokens.isEmpty()) {
StringBuilder builder = new StringBuilder();
builder.append("[");
int i = 0;
for (String token : tokens) {
builder.append("[")
.append(toJsonString(token))
.append(", ")
.append(false)
.append("]");
if (i++ < tokens.size() - 1) {
builder.append(", ");
}
}
builder.append("]");
parameters.put(PARAM_LOGIT_BIAS, builder.toString());
parameters.put(PARAM_LOGIT_BIAS, buildDisablePairArray(tokens, this::toJsonString));
}
return this;
}
Expand Down Expand Up @@ -627,15 +567,7 @@ public InferenceParameters setChatTemplate(String chatTemplate) {
* @return this builder
*/
public InferenceParameters setChatTemplateKwargs(java.util.Map<String, String> kwargs) {
StringBuilder sb = new StringBuilder("{");
boolean first = true;
for (java.util.Map.Entry<String, String> entry : kwargs.entrySet()) {
if (!first) sb.append(",");
sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue());
first = false;
}
sb.append("}");
parameters.put(PARAM_CHAT_TEMPLATE_KWARGS, sb.toString());
parameters.put(PARAM_CHAT_TEMPLATE_KWARGS, mapToJsonObject(kwargs));
return this;
}

Expand Down Expand Up @@ -695,4 +627,38 @@ InferenceParameters setStream(boolean stream) {
return this;
}

private static <K, V> String buildBiasPairArray(Map<K, V> map,
java.util.function.Function<K, String> keySerializer) {
StringBuilder builder = new StringBuilder("[");
int i = 0;
for (Map.Entry<K, V> entry : map.entrySet()) {
builder.append("[")
.append(keySerializer.apply(entry.getKey()))
.append(", ")
.append(entry.getValue())
.append("]");
if (i++ < map.size() - 1) {
builder.append(", ");
}
}
builder.append("]");
return builder.toString();
}

private static <T> String buildDisablePairArray(Collection<T> items,
java.util.function.Function<T, String> serializer) {
StringBuilder builder = new StringBuilder("[");
int i = 0;
for (T item : items) {
builder.append("[")
.append(serializer.apply(item))
.append(", false]");
if (i++ < items.size() - 1) {
builder.append(", ");
}
}
builder.append("]");
return builder.toString();
}

}
12 changes: 12 additions & 0 deletions src/main/java/de/kherud/llama/JsonParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ public String toString() {
return builder.toString();
}

static String mapToJsonObject(Map<String, String> map) {
StringBuilder sb = new StringBuilder("{");
boolean first = true;
for (Map.Entry<String, String> entry : map.entrySet()) {
if (!first) sb.append(",");
sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue());
first = false;
}
sb.append("}");
return sb.toString();
}

// taken from org.json.JSONObject#quote(String, Writer)
String toJsonString(String text) {
if (text == null) return null;
Expand Down
10 changes: 1 addition & 9 deletions src/main/java/de/kherud/llama/ModelParameters.java
Original file line number Diff line number Diff line change
Expand Up @@ -1164,15 +1164,7 @@ public ModelParameters setChatTemplate(String chatTemplate) {
* @return this builder
*/
public ModelParameters setChatTemplateKwargs(java.util.Map<String, String> kwargs) {
StringBuilder sb = new StringBuilder("{");
boolean first = true;
for (java.util.Map.Entry<String, String> entry : kwargs.entrySet()) {
if (!first) sb.append(",");
sb.append("\"").append(entry.getKey()).append("\":").append(entry.getValue());
first = false;
}
sb.append("}");
parameters.put("--chat-template-kwargs", sb.toString());
parameters.put("--chat-template-kwargs", JsonParameters.mapToJsonObject(kwargs));
return this;
}

Expand Down
Loading