Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import de.unistuttgart.iste.meitrex.common.service.JsonSchemaGeneratorService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import javax.annotation.Nullable;
Expand All @@ -17,7 +16,6 @@
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Matcher;
Expand Down Expand Up @@ -60,8 +58,7 @@ public String getTemplate(final String templateFileName) {
}
} catch (IOException e) {
log.error("Failed to read template file: {}", templateFileName, e);
final StringBuilder error = new StringBuilder("Failed to read template file: ").append(templateFileName);
throw new RuntimeException(error.toString(), e);
throw new RuntimeException("Failed to read template file: " + templateFileName, e);
}
}

Expand Down Expand Up @@ -107,25 +104,28 @@ public String fillTemplate(final String promptTemplate, final Map<String, String
}

/**
* Starts a query to the LLM by filling a prompt template, sending it to Ollama,
* and parsing the response into the given type.
* Executes a full LLM query cycle: loads a template from resources, fills it with arguments,
* requests a structured response from Ollama, and parses the result.
*
* @param responseType the target class to parse the response into
* @param prompt the template prompt text
* @param argMap A map of placeholder keys and their replacement values.
* @param error the fallback value if parsing or the request fails
* @param modelOverride the specific LLM model name to use for this request (e.g., "llama3:70b").
* If null or blank, the default model from the configuration is used.
* @return the parsed response or the fallback error value
* @param <ResponseType> the target type for the structured JSON response
* @param responseType the class of the target type to parse the response into
* @param templateFileName the name of the file in the prompt folder (e.g., "analysis_prompt.md")
* @param argMap a map of placeholder keys (without braces) and their replacement values
* @param error the fallback value to return if the query, parsing, or template loading fails
* @param modelOverride optional model name (e.g., "llama3.3:70b") to bypass the default config.
* If null/blank, the default from {@link OllamaConfig} is used.
* @return the parsed response object of type ResponseType, or the provided error fallback
*/
public <ResponseType> ResponseType startQuery(
final Class<ResponseType> responseType,
final String prompt,
final String templateFileName,
final Map<String, String> argMap,
final ResponseType error,
@Nullable final String modelOverride) {
try {
final String filledPrompt = fillTemplate(prompt, argMap);
final String promptTemplate = getTemplate(templateFileName);

final String filledPrompt = fillTemplate(promptTemplate, argMap);

final TypeReference<HashMap<String, Object>> typeRef = new TypeReference<>() {};
final String jsonSchema = jsonSchemaService.getJsonSchema(responseType);
Expand Down Expand Up @@ -159,6 +159,24 @@ public <ResponseType> ResponseType startQuery(
}
}

/**
* Overloaded helper method for backward compatibility.
* Starts a query using the default model defined in the configuration.
*
* @param responseType the target class to parse the response into
* @param templateFileName the name of the file in the prompt folder (e.g., "analysis_prompt.md")
* @param argMap A map of placeholder keys and their replacement values.
* @param error the fallback value if parsing or the request fails
* @return the parsed response or the fallback error value
*/
public <ResponseType> ResponseType startQuery(
final Class<ResponseType> responseType,
final String templateFileName,
final Map<String, String> argMap,
final ResponseType error) {
return startQuery(responseType, templateFileName, argMap, error, null);
}

/**
* Sends the given request to the Ollama LLM endpoint and returns the raw response.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ class OllamaClientTest {

@BeforeEach
void setUp() {
ollamaClient = new OllamaClient(config, jsonSchemaService, jsonMapper, httpClient);
OllamaClient realClient = new OllamaClient(config, jsonSchemaService, jsonMapper, httpClient);
ollamaClient = spy(realClient);
}

// --- Tests for fillTemplate ---

@Test
void testFillTemplateSuccess() {
String template = "Hello {{name}}, welcome to {{place}}!";
Expand All @@ -71,11 +70,14 @@ void testFillTemplateThrowsOnUnknownArgument() {

@Test
void testStartQuerySuccess() throws IOException, InterruptedException {
String prompt = "Calculate 1+1";
String templateFileName = "test_prompt.md";
String templateContent = "Calculate 1+1";
Map<String, String> args = Map.of();
String mockSchema = "{\"type\":\"object\"}";
String ollamaJsonResponse = "{\"response\": \"{\\\"result\\\": 2}\", \"done\": true}";

doReturn(templateContent).when(ollamaClient).getTemplate(templateFileName);

when(config.getModel()).thenReturn("llama3");
when(config.getUrl()).thenReturn("http://localhost:11434");
when(config.getEndpoint()).thenReturn("api/generate");
Expand All @@ -92,7 +94,7 @@ void testStartQuerySuccess() throws IOException, InterruptedException {

TestResponseDto result = ollamaClient.startQuery(
TestResponseDto.class,
prompt,
templateFileName,
args,
new TestResponseDto(0)
);
Expand All @@ -109,6 +111,9 @@ void testStartQuerySuccess() throws IOException, InterruptedException {

@Test
void testStartQueryHandlesNetworkError() throws IOException, InterruptedException {
String templateFileName = "error_test.md";
doReturn("some prompt").when(ollamaClient).getTemplate(templateFileName);

when(config.getModel()).thenReturn("llama3");
when(config.getUrl()).thenReturn("http://localhost:11434");
when(config.getEndpoint()).thenReturn("api/generate");
Expand All @@ -123,14 +128,17 @@ void testStartQueryHandlesNetworkError() throws IOException, InterruptedExceptio
TestResponseDto fallback = new TestResponseDto(-1);

TestResponseDto result = ollamaClient.startQuery(
TestResponseDto.class, "prompt", Map.of(), fallback
TestResponseDto.class, templateFileName, Map.of(), fallback
);

assertThat(result, sameInstance(fallback));
}

@Test
void testStartQueryHandlesOllamaError() throws IOException, InterruptedException {
String templateFileName = "ollama_error.md";
doReturn("some prompt").when(ollamaClient).getTemplate(templateFileName);

when(config.getModel()).thenReturn("llama3");
when(config.getUrl()).thenReturn("http://localhost:11434");
when(config.getEndpoint()).thenReturn("api/generate");
Expand All @@ -149,7 +157,7 @@ void testStartQueryHandlesOllamaError() throws IOException, InterruptedException
TestResponseDto fallback = new TestResponseDto(-1);

TestResponseDto result = ollamaClient.startQuery(
TestResponseDto.class, "prompt", Map.of(), fallback
TestResponseDto.class, templateFileName, Map.of(), fallback
);

assertThat(result, sameInstance(fallback));
Expand Down
Loading