Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2490,7 +2490,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"path to LoRA adapter (use comma-separated values to load multiple adapters)",
[](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) {
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr, true, {} });
}
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
Expand All @@ -2505,11 +2505,60 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
if (parts.size() != 2) {
throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
}
params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr });
params.lora_adapters.push_back({ parts[0], std::stof(parts[1]), "", "", nullptr, true, {} });
}
}
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
add_opt(common_arg(
{"--lora-modality"}, "INDEX:MODALITY[,MODALITY...]",
"Bind LoRA adapter to modality type(s). Adapter activates when any specified modality is present.\n"
"INDEX is the 0-based adapter index (order of --lora arguments).\n"
"MODALITY can be: image, audio (comma-separated for multiple).\n"
"Example: --lora-modality 0:image,audio",
[](common_params & params, const std::string & value) {
// Parse "INDEX:MODALITY[,MODALITY...]"
size_t colon_pos = value.find(':');
if (colon_pos == std::string::npos) {
throw std::invalid_argument("Invalid format for --lora-modality. Expected INDEX:MODALITY");
}

// Parse the index value
std::string index_str = value.substr(0, colon_pos);
const auto start_index = index_str.find_first_not_of(" \t\r\n");
if (start_index != std::string::npos) {
const auto end_index = index_str.find_last_not_of(" \t\r\n");
index_str = index_str.substr(start_index, end_index - start_index + 1);
}

int adapter_idx = std::stoi(index_str);
if (adapter_idx < 0 || adapter_idx >= (int)params.lora_adapters.size()) {
throw std::invalid_argument("Invalid adapter index: " + index_str);
}

// Parse comma-separated modalities
std::string modalities_str = value.substr(colon_pos + 1);
std::vector<std::string> modality_strs = string_split<std::string>(modalities_str, ',');

// Validate and store modality strings
auto & lora_info = params.lora_adapters[adapter_idx];
lora_info.mmlora_modality_types.clear();

for (auto mod_str : modality_strs) {
// Strip whitespace
const auto start_index = mod_str.find_first_not_of(" \t\r\n");
if (start_index != std::string::npos) {
auto end_index = mod_str.find_last_not_of(" \t\r\n");
mod_str = mod_str.substr(start_index, end_index - start_index + 1);
}
// Validate the string (simple string validation, enum conversion happens later)
if (mod_str != "image" && mod_str != "audio") {
throw std::invalid_argument("Invalid modality type: " + mod_str + " (must be 'image' or 'audio')");
}
lora_info.mmlora_modality_types.push_back(mod_str);
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--control-vector"}, "FNAME",
"add a control vector\nnote: use comma-separated values to add multiple control vectors",
Expand Down
15 changes: 14 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,18 @@ common_init_result::common_init_result(common_params & params) :
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;

// Validate MMLoRA configuration
if (!la.mmlora_modality_types.empty()) {
// Check for aLoRA conflict
const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(la.ptr);
if (n_alora_tokens > 0) {
LOG_ERR("%s: adapter '%s' cannot be both MMLoRA and aLoRA\n", __func__, la.path.c_str());
pimpl->model.reset(model);
return;
}
}

pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

Expand Down Expand Up @@ -1439,7 +1451,8 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap

for (auto & la: lora) {
loras.push_back(la.ptr);
scales.push_back(la.scale);
// set scale to 0 if disabled
scales.push_back(la.enabled ? la.scale : 0.0);
}

llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data());
Expand Down
12 changes: 12 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ struct common_adapter_lora_info {
std::string prompt_prefix;

struct llama_adapter_lora * ptr;

// used to toggle on/off programmatically without changing scale
bool enabled;

// Multi-Modal LoRA activation (MMLoRA)
// Empty vector = not an MMLoRA adapter (always active)
// Non-empty = MMLoRA adapter (activates if ANY specified modality present - OR logic)
// Modality types stored as strings: "image", "audio"
std::vector<std::string> mmlora_modality_types;

// Helper to check if this is an MMLoRA adapter
bool is_mmlora() const { return !mmlora_modality_types.empty(); }
};

using llama_tokens = std::vector<llama_token>;
Expand Down
3 changes: 3 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ endif()
set(LLAMA_TEST_NAME test-mtmd-c-api)
llama_build_and_test(test-mtmd-c-api.c)
target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
set(LLAMA_TEST_NAME test-mmlora)
llama_build_and_test(test-mmlora.cpp)
target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
unset(LLAMA_TEST_NAME)

# GGUF model data fetcher library for tests that need real model metadata
Expand Down
119 changes: 119 additions & 0 deletions tests/test-mmlora.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Tests for Multi-Modal LoRA (MMLoRA) functionality
#include "mtmd.h"
#include "mtmd-helper.h"
#include "common.h"

#include <cstdio>
#include <cstdlib>
#include <vector>
#include <string>
#include <cstring>

#undef NDEBUG
#include <cassert>

#define LOG_INF(...) printf(__VA_ARGS__)

#define TEST(cond) do { \
if (!(cond)) { \
fprintf(stderr, "TEST FAILED: %s at %s:%d\n", #cond, __FILE__, __LINE__); \
return 1; \
} \
} while(0)

// Test string↔enum conversion functions
static int test_mmlora_modality_type_conversion() {
LOG_INF("Testing modality type string conversions...\n");

// Test enum to string
TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_IMAGE), "image") == 0);
TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_AUDIO), "audio") == 0);
TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_TEXT), "text") == 0);
TEST(strcmp(mtmd_modality_type_to_str(MTMD_INPUT_CHUNK_TYPE_UNKNOWN), "unknown") == 0);
TEST(strcmp(mtmd_modality_type_to_str((enum mtmd_input_chunk_type)999), "unknown") == 0);

// Test string to enum
TEST(mtmd_modality_type_from_str("image") == MTMD_INPUT_CHUNK_TYPE_IMAGE);
TEST(mtmd_modality_type_from_str("audio") == MTMD_INPUT_CHUNK_TYPE_AUDIO);
TEST(mtmd_modality_type_from_str("text") == MTMD_INPUT_CHUNK_TYPE_TEXT);
TEST(mtmd_modality_type_from_str("unknown") == MTMD_INPUT_CHUNK_TYPE_UNKNOWN);
TEST(mtmd_modality_type_from_str("invalid") == MTMD_INPUT_CHUNK_TYPE_UNKNOWN);

// Test validation
TEST(mtmd_is_valid_modality_str("image") == true);
TEST(mtmd_is_valid_modality_str("audio") == true);
TEST(mtmd_is_valid_modality_str("text") == false);
TEST(mtmd_is_valid_modality_str("unknown") == false);
TEST(mtmd_is_valid_modality_str("invalid") == false);
TEST(mtmd_is_valid_modality_str(NULL) == false);

LOG_INF(" Passed modality type conversions\n");
return 0;
}

// Test common_adapter_lora_info MMLoRA helpers
static int test_mmlora_info_helpers() {
LOG_INF("Testing MMLoRA info helpers...\n");

// Test is_mmlora() with empty vector (not MMLoRA)
common_adapter_lora_info regular_lora;
TEST(regular_lora.is_mmlora() == false);
TEST(regular_lora.mmlora_modality_types.empty() == true);

// Test is_mmlora() with single modality
common_adapter_lora_info image_lora;
image_lora.mmlora_modality_types.push_back("image");
TEST(image_lora.is_mmlora() == true);
TEST(image_lora.mmlora_modality_types.size() == 1);

// Test is_mmlora() with multiple modalities
common_adapter_lora_info multi_lora;
multi_lora.mmlora_modality_types.push_back("image");
multi_lora.mmlora_modality_types.push_back("audio");
TEST(multi_lora.is_mmlora() == true);
TEST(multi_lora.mmlora_modality_types.size() == 2);

LOG_INF(" Passed MMLoRA info helpers\n");
return 0;
}

// Test CLI argument parsing (simulated)
static int test_mmlora_arg_parsing() {LOG_INF("Testing MMLoRA argument parsing logic...\n");

// Simulate parsing "0:image,audio"
std::string value = "0:image,audio";
size_t colon_pos = value.find(':');
TEST(colon_pos != std::string::npos);

std::string index_str = value.substr(0, colon_pos);
std::string modalities_str = value.substr(colon_pos + 1);

TEST(index_str == "0");
std::vector<std::string> modality_strs = string_split<std::string>(modalities_str, ',');
TEST(modality_strs.size() == 2);
TEST(modality_strs[0] == "image");
TEST(modality_strs[1] == "audio");

// Validate modalities
for (const auto & mod : modality_strs) {
TEST(mod == "image" || mod == "audio");
}

LOG_INF(" Passed argument parsing logic\n");
return 0;
}

int main(int argc, char ** argv) {
(void)argc;
(void)argv;

LOG_INF("MMLoRA unit tests starting...\n");

int result = 0;
result += test_mmlora_modality_type_conversion();
result += test_mmlora_info_helpers();
result += test_mmlora_arg_parsing();

LOG_INF("MMLoRA unit tests %s!\n", result == 0 ? "passed" : "failed");
return result;
}
25 changes: 25 additions & 0 deletions tools/mtmd/mtmd-helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,28 @@ mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char *

return mtmd_helper_bitmap_init_from_buf(ctx, buf.data(), buf.size());
}

// Modality type string↔enum conversion functions
const char * mtmd_modality_type_to_str(enum mtmd_input_chunk_type type) {
switch (type) {
case MTMD_INPUT_CHUNK_TYPE_IMAGE: return "image";
case MTMD_INPUT_CHUNK_TYPE_AUDIO: return "audio";
case MTMD_INPUT_CHUNK_TYPE_TEXT: return "text";
default: return "unknown";
}
}

enum mtmd_input_chunk_type mtmd_modality_type_from_str(const char * str) {
if (strcmp(str, "image") == 0) {
return MTMD_INPUT_CHUNK_TYPE_IMAGE;
} else if (strcmp(str, "audio") == 0) {
return MTMD_INPUT_CHUNK_TYPE_AUDIO;
} else if (strcmp(str, "text") == 0) {
return MTMD_INPUT_CHUNK_TYPE_TEXT;
}
return MTMD_INPUT_CHUNK_TYPE_UNKNOWN; // default/invalid
}

bool mtmd_is_valid_modality_str(const char * str) {
return str && (strcmp(str, "image") == 0 || strcmp(str, "audio") == 0);
}
10 changes: 10 additions & 0 deletions tools/mtmd/mtmd-helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
int32_t n_batch,
llama_pos * new_n_past);

// Convert modality type enum to string
MTMD_API const char * mtmd_modality_type_to_str(enum mtmd_input_chunk_type type);

// Convert string to modality type enum
// Returns MTMD_INPUT_CHUNK_TYPE_TEXT on invalid input
MTMD_API enum mtmd_input_chunk_type mtmd_modality_type_from_str(const char * str);

// Validate if string is a valid modality type (image or audio)
MTMD_API bool mtmd_is_valid_modality_str(const char * str);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/mtmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ enum mtmd_input_chunk_type {
MTMD_INPUT_CHUNK_TYPE_TEXT,
MTMD_INPUT_CHUNK_TYPE_IMAGE,
MTMD_INPUT_CHUNK_TYPE_AUDIO,
MTMD_INPUT_CHUNK_TYPE_UNKNOWN,
};

// opaque types
Expand Down
26 changes: 26 additions & 0 deletions tools/server/server-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,32 @@ const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const {
throw std::runtime_error("Chunk not found");
}

bool server_tokens::has_modality_type(const std::string & modality_type) const {
if (!has_mtmd) {
return false;
}

enum mtmd_input_chunk_type target_type = mtmd_modality_type_from_str(modality_type.c_str());

for (const auto & [idx, chunk_ptr] : map_idx_to_media) {
if (!chunk_ptr) continue;
enum mtmd_input_chunk_type chunk_type = mtmd_input_chunk_get_type(chunk_ptr.get());
if (chunk_type == target_type) {
return true;
}
}
return false;
}

bool server_tokens::has_any_modality_type(const std::vector<std::string> & modality_types) const {
for (const auto & mod_type : modality_types) {
if (has_modality_type(mod_type)) {
return true;
}
}
return false;
}

void server_tokens::push_back(llama_token tok) {
if (tok == LLAMA_TOKEN_NULL) {
throw std::runtime_error("Invalid token");
Expand Down
4 changes: 4 additions & 0 deletions tools/server/server-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ struct server_tokens {

const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;

// Check if request contains any of the specified modality types
bool has_modality_type(const std::string & modality_type) const;
bool has_any_modality_type(const std::vector<std::string> & modality_types) const;

void push_back(llama_token tok);

// will create a copy of the chunk if it contains non-text data
Expand Down
18 changes: 18 additions & 0 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,24 @@ struct server_context_impl {
}
}

// Handle Multi-Modal LoRA (MMLoRA) activation
for (size_t i = 0; i < slot.lora.size(); ++i) {
auto & lora = slot.lora[i];

if (!lora.mmlora_modality_types.empty()) {
// Check if request has any of the required modalities (OR logic)
const bool has_modality = task.tokens.has_any_modality_type(lora.mmlora_modality_types);

if (!has_modality) {
SLT_DBG(slot, "MMLoRA %zu requires modality but not found, deactivating\n", i);
lora.enabled = false;
} else {
SLT_DBG(slot, "MMLoRA %zu activated (modality present)\n", i);
lora.enabled = true;
}
}
}

if (!task.tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
Expand Down
3 changes: 3 additions & 0 deletions tools/server/server-task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1951,6 +1951,9 @@ json server_task_result_get_lora::to_json() {
entry["alora_invocation_string"] = lora.alora_invocation_string;
entry["alora_invocation_tokens"] = lora.alora_invocation_tokens;
}
if (!lora.info.mmlora_modality_types.empty()) {
entry["mmlora_modality_types"] = lora.info.mmlora_modality_types;
}
result.push_back(std::move(entry));
}
return result;
Expand Down
Loading