diff --git a/common/arg.cpp b/common/arg.cpp index 099f0aeab24..2555fa680c6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2490,7 +2490,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to LoRA adapter (use comma-separated values to load multiple adapters)", [](common_params & params, const std::string & value) { for (const auto & item : parse_csv_row(value)) { - params.lora_adapters.push_back({ item, 1.0, "", "", nullptr }); + params.lora_adapters.push_back({ item, 1.0, "", "", nullptr, true, {} }); } } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg @@ -2505,11 +2505,60 @@ common_params_context common_params_parser_init(common_params & params, llama_ex if (parts.size() != 2) { throw std::invalid_argument("lora-scaled format: FNAME:SCALE"); } - params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr }); + params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr, true, {} }); } } // we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA})); + add_opt(common_arg( + {"--lora-modality"}, "INDEX:MODALITY[,MODALITY...]", + "Bind LoRA adapter to modality type(s). Adapter activates when any specified modality is present.\n" + "INDEX is the 0-based adapter index (order of --lora arguments).\n" + "MODALITY can be: image, audio (comma-separated for multiple).\n" + "Example: --lora-modality 0:image,audio", + [](common_params & params, const std::string & value) { + // Parse "INDEX:MODALITY[,MODALITY...]" + size_t colon_pos = value.find(':'); + if (colon_pos == std::string::npos) { + throw std::invalid_argument("Invalid format for --lora-modality. Expected INDEX:MODALITY"); + } + + // Parse the index value + std::string index_str = value.substr(0, colon_pos); + const auto start_index = index_str.find_first_not_of(" \t\r\n"); + if (start_index != std::string::npos) { + const auto end_index = index_str.find_last_not_of(" \t\r\n"); + index_str = index_str.substr(start_index, end_index - start_index + 1); + } + + int adapter_idx = std::stoi(index_str); + if (adapter_idx < 0 || adapter_idx >= (int)params.lora_adapters.size()) { + throw std::invalid_argument("Invalid adapter index: " + index_str); + } + + // Parse comma-separated modalities + std::string modalities_str = value.substr(colon_pos + 1); + std::vector modality_strs = string_split(modalities_str, ','); + + // Validate and store modality strings + auto & lora_info = params.lora_adapters[adapter_idx]; + lora_info.mmlora_modality_types.clear(); + + for (auto mod_str : modality_strs) { + // Strip whitespace + const auto start_index = mod_str.find_first_not_of(" \t\r\n"); + if (start_index != std::string::npos) { + auto end_index = mod_str.find_last_not_of(" \t\r\n"); + mod_str = mod_str.substr(start_index, end_index - start_index + 1); + } + // Validate the string (simple string validation, enum conversion happens later) + if (mod_str != "image" && mod_str != "audio") { + throw std::invalid_argument("Invalid modality type: " + mod_str + " (must be 'image' or 'audio')"); + } + lora_info.mmlora_modality_types.push_back(mod_str); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--control-vector"}, "FNAME", "add a control vector\nnote: use comma-separated values to add multiple control vectors", diff --git a/common/common.cpp b/common/common.cpp index 6cde71d819a..9fae53bd851 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1180,6 +1180,18 @@ common_init_result::common_init_result(common_params & params) : la.task_name = buf; llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); la.prompt_prefix = buf; + + // Validate MMLoRA configuration + if (!la.mmlora_modality_types.empty()) { + // Check for aLoRA conflict + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(la.ptr); + if (n_alora_tokens > 0) { + LOG_ERR("%s: adapter '%s' cannot be both MMLoRA and aLoRA\n", __func__, la.path.c_str()); + pimpl->model.reset(model); + return; + } + } + pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters } @@ -1439,7 +1451,8 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector mmlora_modality_types; + + // Helper to check if this is an MMLoRA adapter + bool is_mmlora() const { return !mmlora_modality_types.empty(); } }; using llama_tokens = std::vector; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b282c3239f0..8591dacf121 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -261,6 +261,9 @@ endif() set(LLAMA_TEST_NAME test-mtmd-c-api) llama_build_and_test(test-mtmd-c-api.c) target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd) +set(LLAMA_TEST_NAME test-mmlora) +llama_build_and_test(test-mmlora.cpp) +target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd) unset(LLAMA_TEST_NAME) # GGUF model data fetcher library for tests that need real model metadata diff --git a/tests/test-mmlora.cpp b/tests/test-mmlora.cpp new file mode 100644 index 00000000000..5e3f50e2d05 --- /dev/null +++ b/tests/test-mmlora.cpp @@ -0,0 +1,119 @@ +// Tests for Multi-Modal LoRA (MMLoRA) functionality +#include "mtmd.h" +#include "mtmd-helper.h" +#include "common.h" + +#include +#include +#include +#include +#include + +#undef NDEBUG +#include + +#define LOG_INF(...) printf(__VA_ARGS__) + +#define TEST(cond) do { \ + if (!(cond)) { \ + fprintf(stderr, "TEST FAILED: %s at %s:%d\n", #cond, __FILE__, __LINE__); \ + return 1; \ + } \ +} while(0) + +// Test string↔enum conversion functions +static int test_mmlora_modality_type_conversion() { +LOG_INF("Testing modality type string conversions...\n"); + + // Test enum to string + TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_IMAGE), "image") == 0); + TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_AUDIO), "audio") == 0); + TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_TEXT), "text") == 0); + TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_UNKNOWN), "unknown") == 0); + TEST(strcmp(mtmd_modality_type_to_str((enum mtmd_input_chunk_type)999), "unknown") == 0); + + // Test string to enum + TEST(mtmd_modality_type_from_str("image") == MTMD_INPUT_CHUNK_TYPE_IMAGE); + TEST(mtmd_modality_type_from_str("audio") == MTMD_INPUT_CHUNK_TYPE_AUDIO); + TEST(mtmd_modality_type_from_str("text") == MTMD_INPUT_CHUNK_TYPE_TEXT); + TEST(mtmd_modality_type_from_str("unknown") == MTMD_INPUT_CHUNK_TYPE_UNKNOWN); + TEST(mtmd_modality_type_from_str("invalid") == MTMD_INPUT_CHUNK_TYPE_UNKNOWN); + + // Test validation + TEST(mtmd_is_valid_modality_str("image") == true); + TEST(mtmd_is_valid_modality_str("audio") == true); + TEST(mtmd_is_valid_modality_str("text") == false); + TEST(mtmd_is_valid_modality_str("unknown") == false); + TEST(mtmd_is_valid_modality_str("invalid") == false); + TEST(mtmd_is_valid_modality_str(NULL) == false); + + LOG_INF(" Passed modality type conversions\n"); + return 0; +} + +// Test common_adapter_lora_info MMLoRA helpers +static int test_mmlora_info_helpers() { + LOG_INF("Testing MMLoRA info helpers...\n"); + + // Test is_mmlora() with empty vector (not MMLoRA) + common_adapter_lora_info regular_lora; + TEST(regular_lora.is_mmlora() == false); + TEST(regular_lora.mmlora_modality_types.empty() == true); + + // Test is_mmlora() with single modality + common_adapter_lora_info image_lora; + image_lora.mmlora_modality_types.push_back("image"); + TEST(image_lora.is_mmlora() == true); + TEST(image_lora.mmlora_modality_types.size() == 1); + + // Test is_mmlora() with multiple modalities + common_adapter_lora_info multi_lora; + multi_lora.mmlora_modality_types.push_back("image"); + multi_lora.mmlora_modality_types.push_back("audio"); + TEST(multi_lora.is_mmlora() == true); + TEST(multi_lora.mmlora_modality_types.size() == 2); + + LOG_INF(" Passed MMLoRA info helpers\n"); + return 0; +} + +// Test CLI argument parsing (simulated) +static int test_mmlora_arg_parsing() {LOG_INF("Testing MMLoRA argument parsing logic...\n"); + + // Simulate parsing "0:image,audio" + std::string value = "0:image,audio"; + size_t colon_pos = value.find(':'); + TEST(colon_pos != std::string::npos); + + std::string index_str = value.substr(0, colon_pos); + std::string modalities_str = value.substr(colon_pos + 1); + + TEST(index_str == "0"); + std::vector modality_strs = string_split(modalities_str, ','); + TEST(modality_strs.size() == 2); + TEST(modality_strs[0] == "image"); + TEST(modality_strs[1] == "audio"); + + // Validate modalities + for (const auto & mod : modality_strs) { + TEST(mod == "image" || mod == "audio"); + } + + LOG_INF(" Passed argument parsing logic\n"); + return 0; +} + +int main(int argc, char ** argv) { + (void)argc; + (void)argv; + + LOG_INF("MMLoRA unit tests starting...\n"); + + int result = 0; + result += test_mmlora_modality_type_conversion(); + result += test_mmlora_info_helpers(); + result += test_mmlora_arg_parsing(); + + LOG_INF("MMLoRA unit tests %s!\n", result == 0 ? "passed" : "failed"); + return result; +} diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 40940741637..dd8d8ff8d12 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -535,3 +535,28 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * return mtmd_helper_bitmap_init_from_buf(ctx, buf.data(), buf.size()); } + +// Modality type string↔enum conversion functions +const char * mtmd_modality_type_to_str(enum mtmd_input_chunk_type type) { + switch (type) { + case MTMD_INPUT_CHUNK_TYPE_IMAGE: return "image"; + case MTMD_INPUT_CHUNK_TYPE_AUDIO: return "audio"; + case MTMD_INPUT_CHUNK_TYPE_TEXT: return "text"; + default: return "unknown"; + } +} + +enum mtmd_input_chunk_type mtmd_modality_type_from_str(const char * str) { + if (strcmp(str, "image") == 0) { + return MTMD_INPUT_CHUNK_TYPE_IMAGE; + } else if (strcmp(str, "audio") == 0) { + return MTMD_INPUT_CHUNK_TYPE_AUDIO; + } else if (strcmp(str, "text") == 0) { + return MTMD_INPUT_CHUNK_TYPE_TEXT; + } + return MTMD_INPUT_CHUNK_TYPE_UNKNOWN; // default/invalid +} + +bool mtmd_is_valid_modality_str(const char * str) { + return str && (strcmp(str, "image") == 0 || strcmp(str, "audio") == 0); +} diff --git a/tools/mtmd/mtmd-helper.h b/tools/mtmd/mtmd-helper.h index 57da78a754f..89a45a83ff9 100644 --- a/tools/mtmd/mtmd-helper.h +++ b/tools/mtmd/mtmd-helper.h @@ -89,6 +89,16 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx, int32_t n_batch, llama_pos * new_n_past); +// Convert modality type enum to string +MTMD_API const char * mtmd_modality_type_to_str(enum mtmd_input_chunk_type type); + +// Convert string to modality type enum +// Returns MTMD_INPUT_CHUNK_TYPE_TEXT on invalid input +MTMD_API enum mtmd_input_chunk_type mtmd_modality_type_from_str(const char * str); + +// Validate if string is a valid modality type (image or audio) +MTMD_API bool mtmd_is_valid_modality_str(const char * str); + #ifdef __cplusplus } // extern "C" #endif diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 6e36cb8ec8c..c4f64a8a363 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -54,6 +54,7 @@ enum mtmd_input_chunk_type { MTMD_INPUT_CHUNK_TYPE_TEXT, MTMD_INPUT_CHUNK_TYPE_IMAGE, MTMD_INPUT_CHUNK_TYPE_AUDIO, + MTMD_INPUT_CHUNK_TYPE_UNKNOWN, }; // opaque types diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index cae64884b36..65a83fe25d7 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -344,6 +344,32 @@ const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const { throw std::runtime_error("Chunk not found"); } +bool server_tokens::has_modality_type(const std::string & modality_type) const { + if (!has_mtmd) { + return false; + } + + enum mtmd_input_chunk_type target_type = mtmd_modality_type_from_str(modality_type.c_str()); + + for (const auto & [idx, chunk_ptr] : map_idx_to_media) { + if (!chunk_ptr) continue; + enum mtmd_input_chunk_type chunk_type = mtmd_input_chunk_get_type(chunk_ptr.get()); + if (chunk_type == target_type) { + return true; + } + } + return false; +} + +bool server_tokens::has_any_modality_type(const std::vector & modality_types) const { + for (const auto & mod_type : modality_types) { + if (has_modality_type(mod_type)) { + return true; + } + } + return false; +} + void server_tokens::push_back(llama_token tok) { if (tok == LLAMA_TOKEN_NULL) { throw std::runtime_error("Invalid token"); diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 093a43453c2..472f3acd5e6 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -178,6 +178,10 @@ struct server_tokens { const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; + // Check if request contains any of the specified modality types + bool has_modality_type(const std::string & modality_type) const; + bool has_any_modality_type(const std::vector & modality_types) const; + void push_back(llama_token tok); // will create a copy of the chunk if it contains non-text data diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index a5372572f01..b1889852c79 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1279,6 +1279,24 @@ struct server_context_impl { } } + // Handle Multi-Modal LoRA (MMLoRA) activation + for (size_t i = 0; i < slot.lora.size(); ++i) { + auto & lora = slot.lora[i]; + + if (!lora.mmlora_modality_types.empty()) { + // Check if request has any of the required modalities (OR logic) + const bool has_modality = task.tokens.has_any_modality_type(lora.mmlora_modality_types); + + if (!has_modality) { + SLT_DBG(slot, "MMLoRA %zu requires modality but not found, deactivating\n", i); + lora.enabled = false; + } else { + SLT_DBG(slot, "MMLoRA %zu activated (modality present)\n", i); + lora.enabled = true; + } + } + } + if (!task.tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 2187b8d21b5..e0e9f87d418 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1951,6 +1951,9 @@ json server_task_result_get_lora::to_json() { entry["alora_invocation_string"] = lora.alora_invocation_string; entry["alora_invocation_tokens"] = lora.alora_invocation_tokens; } + if (!lora.info.mmlora_modality_types.empty()) { + entry["mmlora_modality_types"] = lora.info.mmlora_modality_types; + } result.push_back(std::move(entry)); } return result;