Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,41 @@ set(LLAMA_CURL ON)
if(WIN32)
set(LLAMA_BUILD_BORINGSSL ON CACHE BOOL "" FORCE)
endif()

# Disable AVX-512 for all builds.
#
# AVX-512 causes problems in several well-known scenarios, so we turn it off
# unconditionally to keep the distributed library safe and broadly compatible.
#
# 1. Many CPUs simply do not support AVX-512 at all. Loading a binary that
# contains AVX-512 instructions on such a machine crashes the JVM immediately
# with SIGILL (illegal instruction). A real example from CI: the AMD EPYC
# 7763 used by GitHub Actions runners has no AVX-512, while the Intel Xeon
# Platinum 8370C used in the manylinux build container does. A library built
# on the Xeon would crash on the EPYC.
#
# 2. Intel disabled AVX-512 on 12th-gen (Alder Lake) and later desktop CPUs.
# Alder Lake mixes big P-cores (which have AVX-512 hardware) with small
# E-cores (which do not). Because the OS scheduler can move a thread between
# core types at any time, running AVX-512 instructions is unsafe. Intel
# first locked it out via BIOS/microcode updates, then began physically fusing
# it off in silicon on new chips. All current shipping Alder Lake, Raptor
# Lake, and newer Intel desktop CPUs have AVX-512 permanently disabled.
#
# 3. Even CPUs that do support AVX-512 often throttle their clock speed while
# executing it. Older Intel server chips (Skylake-X era) can drop hundreds
# of MHz across all cores — even cores not running AVX-512. AMD Ryzen 9000
# (Zen 5) drops roughly 10 % in frequency when AVX-512 is active and takes
# more than 100 ms to ramp back up after the workload ends.
#
# 4. AVX-512's high instantaneous power draw can destabilize systems that are
# running near their voltage limits, leading to random crashes.
#
# Keeping AVX-512 off produces a binary that runs correctly on any x86-64
# machine with at least AVX2 support, which covers virtually all hardware
# built since 2013. FORCE prevents llama.cpp's own cmake defaults from
# re-enabling it regardless of the build flags passed on the command line.
set(GGML_AVX512 OFF CACHE BOOL "" FORCE)
FetchContent_Declare(
llama.cpp
GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git
Expand Down Expand Up @@ -206,12 +241,15 @@ if(BUILD_TESTING)
add_executable(jllama_test
src/test/cpp/test_utils.cpp
src/test/cpp/test_server.cpp
src/test/cpp/test_jni_helpers.cpp
)

target_include_directories(jllama_test PRIVATE
src/main/cpp
# mtmd.h is not always propagated transitively — add it explicitly
${llama.cpp_SOURCE_DIR}/tools/mtmd
# jni.h / jni_md.h needed by jni_helpers.hpp (mock JNI tests, no JVM required)
${JNI_INCLUDE_DIRS}
)
target_link_libraries(jllama_test PRIVATE common mtmd llama nlohmann_json GTest::gtest_main)
target_compile_features(jllama_test PRIVATE cxx_std_17)
Expand Down
142 changes: 50 additions & 92 deletions src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "arg.h"
#include "json-schema-to-grammar.h"
#include "jni_helpers.hpp"
#include "llama.h"
#include "log.h"
#include "nlohmann/json.hpp"
Expand Down Expand Up @@ -32,20 +33,7 @@ static constexpr int N_PARALLEL_AUTO = -1;
// appropriate default and preserves pre-b7433 behaviour.
static constexpr int N_PARALLEL_DEFAULT = 1;

/**
* Wrapper that owns a server_context and the background worker thread.
* Stored as the Java-side `ctx` (jlong) pointer. Using a wrapper allows
* us to join the thread on close() instead of detaching it, which
* eliminates the race between thread teardown and JVM shutdown.
*/
struct jllama_context {
server_context *server = nullptr;
std::thread worker;
bool vocab_only = false;
// Signals that the worker thread has entered start_loop() and is ready.
// Without this, terminate() can race with start_loop() setting running=true.
std::atomic<bool> worker_ready{false};
};
// jllama_context is defined in jni_helpers.hpp.

JavaVM *g_vm = nullptr;

Expand Down Expand Up @@ -104,6 +92,16 @@ jobject o_log_format_json = nullptr;
jobject o_log_format_text = nullptr;
jobject o_log_callback = nullptr;

/**
* Convenience wrapper: extracts and validates the server_context from the
* Java-side model object using the module-level field-ID and error-class
* globals. Returns nullptr (with a JNI exception pending) when the model
* is not loaded.
*/
static server_context *get_server_context(JNIEnv *env, jobject obj) {
return get_server_context_impl(env, obj, f_model_pointer, c_llama_error);
}

/**
* Convert a Java string to a std::string
*/
Expand Down Expand Up @@ -539,8 +537,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
}

JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv *env, jobject obj, jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return 0;

std::string c_params = parse_jstring(env, jparams);
json data = json::parse(c_params);
Expand Down Expand Up @@ -597,15 +595,15 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
}

JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_releaseTask(JNIEnv *env, jobject obj, jint id_task) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return;
ctx_server->queue_results.remove_waiting_task_id(id_task);
}

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletionJson(JNIEnv *env, jobject obj,
jint id_task) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

server_task_result_ptr result = ctx_server->queue_results.recv(id_task);

Expand All @@ -628,8 +626,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletionJson(
}

JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, jobject obj, jstring jprompt) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

if (!ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
Expand Down Expand Up @@ -715,8 +713,8 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleRerank(JNIEnv *env, jobject obj, jstring jprompt,
jobjectArray documents) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

if (!ctx_server->params_base.embedding || ctx_server->params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
env->ThrowNew(c_llama_error,
Expand Down Expand Up @@ -779,8 +777,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleRerank(JNIEnv *e
}

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
const auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

std::string c_params = parse_jstring(env, jparams);
json data = json::parse(c_params);
Expand All @@ -796,12 +794,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleChatCompletions(JNIEnv *env, jobject obj,
jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

std::string c_params = parse_jstring(env, jparams);
json body = json::parse(c_params);
Expand Down Expand Up @@ -887,12 +881,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleChatCompletions(

JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestChatCompletion(JNIEnv *env, jobject obj,
jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return 0;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return 0;

std::string c_params = parse_jstring(env, jparams);
json body = json::parse(c_params);
Expand Down Expand Up @@ -953,8 +943,8 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestChatCompletion(JNI
}

JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

const std::string c_prompt = parse_jstring(env, jprompt);

Expand All @@ -974,8 +964,8 @@ JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env,

JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_decodeBytes(JNIEnv *env, jobject obj,
jintArray java_tokens) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

jsize length = env->GetArrayLength(java_tokens);
jint *elements = env->GetIntArrayElements(java_tokens, nullptr);
Expand Down Expand Up @@ -1025,8 +1015,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_delete(JNIEnv *env, jobje
}

JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_cancelCompletion(JNIEnv *env, jobject obj, jint id_task) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return;
std::unordered_set<int> id_tasks = {id_task};
ctx_server->cancel_tasks(id_tasks);
ctx_server->queue_results.remove_waiting_task_id(id_task);
Expand Down Expand Up @@ -1067,12 +1057,8 @@ JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammar

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletions(JNIEnv *env, jobject obj,
jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

std::string c_params = parse_jstring(env, jparams);
json data = json::parse(c_params);
Expand Down Expand Up @@ -1146,12 +1132,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletions(JNIE

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletionsOai(JNIEnv *env, jobject obj,
jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

std::string c_params = parse_jstring(env, jparams);
json body = json::parse(c_params);
Expand Down Expand Up @@ -1226,12 +1208,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleCompletionsOai(J
}

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleInfill(JNIEnv *env, jobject obj, jstring jparams) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

// Check model compatibility for infill
std::string err;
Expand Down Expand Up @@ -1341,12 +1319,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleInfill(JNIEnv *e

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEnv *env, jobject obj,
jstring jparams, jboolean joaiCompat) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

if (!ctx_server->params_base.embedding) {
env->ThrowNew(c_llama_error,
Expand Down Expand Up @@ -1442,12 +1416,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleEmbeddings(JNIEn
JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleTokenize(JNIEnv *env, jobject obj, jstring jcontent,
jboolean jaddSpecial,
jboolean jwithPieces) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

const std::string content = parse_jstring(env, jcontent);
const bool add_special = jaddSpecial;
Expand Down Expand Up @@ -1485,12 +1455,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleTokenize(JNIEnv

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleDetokenize(JNIEnv *env, jobject obj,
jintArray jtokens) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

jsize length = env->GetArrayLength(jtokens);
jint *elements = env->GetIntArrayElements(jtokens, nullptr);
Expand All @@ -1512,12 +1478,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleDetokenize(JNIEn

JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleSlotAction(JNIEnv *env, jobject obj, jint action,
jint slotId, jstring jfilename) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return nullptr;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return nullptr;

switch (action) {
case 0: { // LIST — get slot info via metrics
Expand Down Expand Up @@ -1626,12 +1588,8 @@ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_handleSlotAction(JNIEn

JNIEXPORT jboolean JNICALL Java_de_kherud_llama_LlamaModel_configureParallelInference(JNIEnv *env, jobject obj,
jstring jconfig) {
jlong server_handle = env->GetLongField(obj, f_model_pointer);
if (server_handle == 0) {
env->ThrowNew(c_llama_error, "Model is not loaded");
return JNI_FALSE;
}
auto *ctx_server = reinterpret_cast<jllama_context *>(server_handle)->server; // NOLINT(*-no-int-to-ptr)
auto *ctx_server = get_server_context(env, obj);
if (!ctx_server) return JNI_FALSE;

std::string config_str = parse_jstring(env, jconfig);
json config = json::parse(config_str);
Expand Down
Loading
Loading