From 1f13c87614b5f5f124b9c7080adcc9be4e2177b2 Mon Sep 17 00:00:00 2001 From: roopamtvam-gpu Date: Thu, 4 Sep 2025 10:49:27 +0530 Subject: [PATCH] Updated llama.cpp and working android app with thinking model --- androidApp/src/main/AndroidManifest.xml | 3 + .../dilivva/inferkt/android/ChatMessage.kt | 3 +- .../com/dilivva/inferkt/android/ChatScreen.kt | 148 ++++- .../dilivva/inferkt/android/ChatViewModel.kt | 199 ++++++- .../src/main/res/drawable/chat_app_logo.png | Bin 0 -> 7289 bytes .../res/drawable/ic_launcher_foreground.xml | 9 + .../res/mipmap-anydpi-v26/ic_launcher.xml | 6 + .../mipmap-anydpi-v26/ic_launcher_round.xml | 6 + .../res/values/ic_launcher_background.xml | 4 + androidApp/src/main/res/values/strings.xml | 5 + gradle.properties | 4 +- gradle/libs.versions.toml | 2 +- library/build.gradle.kts | 13 +- library/src/androidMain/cpp/CMakeLists.txt | 28 +- library/src/androidMain/cpp/infer-android.cpp | 514 ++++++++++++------ .../kotlin/com/dilivva/inferkt/InferNative.kt | 5 +- library/src/native/llama.cpp | 2 +- library/src/native/src/Inference.cpp | 104 +++- 18 files changed, 852 insertions(+), 203 deletions(-) create mode 100644 androidApp/src/main/res/drawable/chat_app_logo.png create mode 100644 androidApp/src/main/res/drawable/ic_launcher_foreground.xml create mode 100644 androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml create mode 100644 androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml create mode 100644 androidApp/src/main/res/values/ic_launcher_background.xml create mode 100644 androidApp/src/main/res/values/strings.xml diff --git a/androidApp/src/main/AndroidManifest.xml b/androidApp/src/main/AndroidManifest.xml index 22d1fac..791dd0e 100644 --- a/androidApp/src/main/AndroidManifest.xml +++ b/androidApp/src/main/AndroidManifest.xml @@ -4,6 +4,9 @@ , val type: Type, - val id: String + val id: String, + val isThinking: Boolean = false ){ enum class Type{ User, Bot diff --git a/androidApp/src/main/java/com/dilivva/inferkt/android/ChatScreen.kt b/androidApp/src/main/java/com/dilivva/inferkt/android/ChatScreen.kt index 49bbad9..c573fc0 100644 --- a/androidApp/src/main/java/com/dilivva/inferkt/android/ChatScreen.kt +++ b/androidApp/src/main/java/com/dilivva/inferkt/android/ChatScreen.kt @@ -26,31 +26,54 @@ import android.net.Uri import androidx.activity.compose.ManagedActivityResultLauncher import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.result.contract.ActivityResultContracts +import androidx.compose.animation.core.LinearEasing +import androidx.compose.animation.core.RepeatMode +import androidx.compose.animation.core.StartOffset +import androidx.compose.animation.core.animateFloat +import androidx.compose.animation.core.infiniteRepeatable +import androidx.compose.animation.core.rememberInfiniteTransition +import androidx.compose.animation.core.tween import androidx.compose.foundation.background import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.WindowInsets +import androidx.compose.foundation.layout.systemBars import androidx.compose.foundation.layout.fillMaxSize import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.layout.imePadding import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.windowInsetsPadding import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.items import androidx.compose.foundation.lazy.rememberLazyListState +import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.material3.Button import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.Switch +import androidx.compose.material3.TopAppBar import androidx.compose.material3.Text import androidx.compose.material3.TextField +import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.foundation.clickable +import androidx.compose.foundation.text.BasicText import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier +import androidx.compose.ui.draw.clip import androidx.compose.ui.graphics.Color import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.unit.dp import androidx.lifecycle.viewmodel.compose.viewModel +import com.dilivva.inferkt.setGlobalThinkingMode +import com.dilivva.inferkt.android.R @Composable fun ChatScreen() { @@ -60,7 +83,13 @@ fun ChatScreen() { viewModel.setModelPath(context, it) } - Column(modifier = Modifier.fillMaxSize().background(Color.White).padding(16.dp)) { + Column(modifier = Modifier + .fillMaxSize() + .windowInsetsPadding(WindowInsets.systemBars) + .imePadding() + .background(Color.White) + .padding(16.dp)) { + TopBar(viewModel) Box(modifier = Modifier.weight(1f)) { LazyColumn( state = rememberLazyListState(), @@ -69,8 +98,8 @@ fun ChatScreen() { contentPadding = PaddingValues(vertical = 10.dp), verticalArrangement = Arrangement.spacedBy(5.dp) ) { - items(viewModel.messages, key = { it.id }) { - ChatItem(it) + items(viewModel.messages, key = { it.id }) { msg -> + ChatItem(msg) } } @@ -93,6 +122,26 @@ fun ChatScreen() { } +@OptIn(ExperimentalMaterial3Api::class) +@Composable +private fun TopBar(viewModel: ChatViewModel) { + TopAppBar( + title = { Text(text = LocalContext.current.getString(R.string.app_name)) }, + actions = { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp) + ) { + Switch(checked = viewModel.isThinkingEnabled, onCheckedChange = { + viewModel.isThinkingEnabled = it + try { setGlobalThinkingMode(it) } catch (_: Throwable) {} + }) + Text("Thinking") + } + } + ) +} + @Composable private fun ChatItem(chatMessage: ChatMessage){ Column(modifier = Modifier.fillMaxWidth()) { @@ -108,18 +157,94 @@ private fun ChatItem(chatMessage: ChatMessage){ } ChatMessage.Type.Bot -> { - Text( - text = chatMessage.message.value, - color = Color.Black, + Column( modifier = Modifier.align(Alignment.Start) - .background(Color.LightGray, RoundedCornerShape(8.dp)) - .padding(8.dp) - ) + .background(Color(0xFFF4F4F5), RoundedCornerShape(12.dp)) + .padding(10.dp) + ) { + // Collapsible thinking bubble or final answer bubble + if (chatMessage.isThinking) { + CollapsibleThinking(text = chatMessage.message.value) + } else if (chatMessage.message.value.isNotBlank()) { + val isStreaming = chatMessage.message.value.length < 100000 + if (isStreaming) BasicText(text = chatMessage.message.value) else Text(text = chatMessage.message.value, color = Color.Black) + } + } } } } } +@Composable +private fun CollapsibleThinking(text: String) { + val expanded = remember { mutableStateOf(false) } + Column( + modifier = Modifier + .background(Color(0xFFEFEFF1), RoundedCornerShape(10.dp)) + .padding(8.dp) + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + modifier = Modifier + .fillMaxWidth() + .clickable { expanded.value = !expanded.value } + ) { + Text("Thinking", color = Color.Gray) + Text( + text = if (expanded.value) "Hide" else "Show", + color = Color.Gray, + modifier = Modifier + .clip(RoundedCornerShape(6.dp)) + .background(Color(0xFFDADAE0)) + .padding(horizontal = 6.dp, vertical = 2.dp) + .padding(2.dp) + .clickable { expanded.value = !expanded.value } + ) + } + if (expanded.value && text.isNotBlank()) { + BasicText(text = text) + } + } +} + +@Composable +private fun ThinkingIndicator(label: String){ + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(6.dp) + ) { + Dot() + Dot(delayMs = 150) + Dot(delayMs = 300) + Text( + text = label, + color = Color.Gray, + modifier = Modifier.padding(start = 2.dp) + ) + } +} + +@Composable +private fun Dot(delayMs: Int = 0){ + val transition = rememberInfiniteTransition() + val alpha = transition.animateFloat( + initialValue = 0.3f, + targetValue = 1f, + animationSpec = infiniteRepeatable( + animation = tween(durationMillis = 900, easing = LinearEasing), + repeatMode = RepeatMode.Reverse, + initialStartOffset = StartOffset(delayMs) + ) + ) + Box( + modifier = Modifier + .size(6.dp) + .clip(CircleShape) + .background(Color.Gray.copy(alpha = alpha.value)) + ) +} + @Composable private fun ModelControl(onClick: () -> Unit){ Button( @@ -132,7 +257,10 @@ private fun ModelControl(onClick: () -> Unit){ @Composable private fun ChatControl(viewModel: ChatViewModel){ - Row { + Row( + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.spacedBy(8.dp) + ) { TextField( viewModel.userMessage, onValueChange = { viewModel.userMessage = it }, diff --git a/androidApp/src/main/java/com/dilivva/inferkt/android/ChatViewModel.kt b/androidApp/src/main/java/com/dilivva/inferkt/android/ChatViewModel.kt index cea85f6..7f1cd54 100644 --- a/androidApp/src/main/java/com/dilivva/inferkt/android/ChatViewModel.kt +++ b/androidApp/src/main/java/com/dilivva/inferkt/android/ChatViewModel.kt @@ -24,6 +24,8 @@ package com.dilivva.inferkt.android import android.content.Context import android.net.Uri +import android.os.Handler +import android.os.Looper import androidx.compose.runtime.getValue import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.setValue @@ -31,6 +33,7 @@ import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope import com.dilivva.inferkt.GenerationEvent import com.dilivva.inferkt.ModelSettings +import com.dilivva.inferkt.setGlobalThinkingMode import com.dilivva.inferkt.SamplingSettings import com.dilivva.inferkt.createInference import kotlinx.coroutines.CoroutineDispatcher @@ -50,6 +53,7 @@ class ChatViewModel: ViewModel() { private val inference = createInference() private var modelPath = "" var userMessage by mutableStateOf("") + var isThinkingEnabled by mutableStateOf(true) var isLoading by mutableStateOf(false) private set var isLoaded by mutableStateOf(false) @@ -85,6 +89,10 @@ class ChatViewModel: ViewModel() { numberOfGpuLayers = 99 ) val isLoaded = inference.preloadModel(modelSettings){ true } + // Set backend thinking mode to mirror UI default + try { + setGlobalThinkingMode(isThinkingEnabled) + } catch (_: Throwable) {} inference.setSamplingParams(SamplingSettings()) withContext(Dispatchers.Main){ this@ChatViewModel.isLoaded = isLoaded @@ -95,7 +103,10 @@ class ChatViewModel: ViewModel() { @OptIn(ExperimentalUuidApi::class) fun sendMessage() = viewModelScope.launch(Dispatchers.IO){ if (!isLoaded) return@launch - val prompt = userMessage + val prompt = sanitizeUserInput(userMessage) + val effectivePrompt = if (isThinkingEnabled) { + injectThinkingInstruction(prompt) + } else prompt userMessage = "" val userMessage = ChatMessage( message = mutableStateOf(prompt), @@ -103,23 +114,152 @@ class ChatViewModel: ViewModel() { id = Uuid.random().toString() ) messages += userMessage - val botMessage = ChatMessage( - message = mutableStateOf("Loading...\n\n"), + var answerIndex: Int + var thinkingIndex = -1 + val answerMessage = ChatMessage( + message = mutableStateOf(""), type = ChatMessage.Type.Bot, - id = Uuid.random().toString() + id = Uuid.random().toString(), + isThinking = false ) - messages += botMessage - val botMessageIndex = messages.size - 1 + messages += answerMessage + answerIndex = messages.size - 1 + // Stream: detect initial and optionally route until + var startedWithThink: Boolean? = null + var inThinking = false + val pending = StringBuilder() + val answerBuf = StringBuilder() + val thinkingBuf = StringBuilder() + var startNs = 0L + val mainHandler = Handler(Looper.getMainLooper()) + var uiScheduled = false + val FLUSH_DELAY_MS = 16L + val opener = "" + val closer = "" + val openerLen = opener.length + val closerLen = closer.length + fun ensureThinkingBubble() { + if (thinkingIndex != -1) return + val thinkingMessage = ChatMessage( + message = mutableStateOf(""), + type = ChatMessage.Type.Bot, + id = Uuid.random().toString(), + isThinking = true + ) + val list = messages.toMutableList() + list.add(answerIndex, thinkingMessage) + messages = list + thinkingIndex = answerIndex + answerIndex += 1 + } + fun postToUi() { + if (answerBuf.isEmpty() && thinkingBuf.isEmpty()) return + if (uiScheduled) return + uiScheduled = true + mainHandler.postDelayed({ + val a = if (answerBuf.isNotEmpty()) answerBuf.toString() else "" + val t = if (thinkingBuf.isNotEmpty()) thinkingBuf.toString() else "" + answerBuf.clear() + thinkingBuf.clear() + if (a.isNotEmpty()) messages[answerIndex].message.value += a + if (thinkingIndex != -1 && t.isNotEmpty()) messages[thinkingIndex].message.value += t + uiScheduled = false + }, FLUSH_DELAY_MS) + } + fun flush(force: Boolean = false) { + // Coalesce to at most one pending main-thread update + if (!force) { + postToUi() + return + } + // Force: push immediately on main + mainHandler.post { + val a = if (answerBuf.isNotEmpty()) answerBuf.toString() else "" + val t = if (thinkingBuf.isNotEmpty()) thinkingBuf.toString() else "" + answerBuf.clear() + thinkingBuf.clear() + if (a.isNotEmpty()) messages[answerIndex].message.value += a + if (thinkingIndex != -1 && t.isNotEmpty()) messages[thinkingIndex].message.value += t + uiScheduled = false + } + } inference.chat( - prompt = prompt, + prompt = effectivePrompt, maxTokens = 1024, onGenerate = { when(it){ is GenerationEvent.Error -> println("Error: ${it.error}") - GenerationEvent.Generated -> { isGenerating = false } + GenerationEvent.Generated -> { + isGenerating = false + // final flush + flush(force = true) + // Strip placeholders and trailing stop/control artifacts, trim whitespace for the answer bubble + messages[answerIndex].message.value = sanitizeModelOutput(messages[answerIndex].message.value) + } is GenerationEvent.Generating -> { - //isGenerating = true - messages[botMessageIndex].message.value += it.text + if (startNs == 0L) startNs = System.nanoTime() + pending.append(it.text) + parse@ while (pending.isNotEmpty()) { + if (startedWithThink == null) { + // Need at least openerLen chars to decide + // Skip leading whitespace but preserve it if not + var i = 0 + while (i < pending.length && pending[i].isWhitespace()) i++ + if (pending.length - i < openerLen) break@parse + if (pending.regionMatches(i, opener, 0, openerLen)) { + // Starts with + // Emit any whitespace before opener to thinking (for completeness) + if (i > 0) { + ensureThinkingBubble() + if (isThinkingEnabled) thinkingBuf.append(pending.substring(0, i)) + } + pending.delete(0, i + openerLen) + startedWithThink = true + inThinking = true + ensureThinkingBubble() + continue@parse + } else { + // Not a start → everything is answer + startedWithThink = false + if (i > 0) { + // Move whitespace into answer + answerBuf.append(pending.substring(0, i)) + pending.delete(0, i) + } + } + } + if (inThinking) { + val idx = pending.indexOf(closer) + if (idx >= 0) { + // Emit thought up to closer + if (idx > 0 && isThinkingEnabled) thinkingBuf.append(pending.substring(0, idx)) + pending.delete(0, idx + closerLen) + inThinking = false + // Remainder goes to answer + if (pending.isNotEmpty()) { + answerBuf.append(pending) + pending.setLength(0) + } + break@parse + } else { + // Emit safe prefix leaving tail to catch split closer + val safe = pending.length - (closerLen - 1) + if (safe > 0) { + val part = pending.substring(0, safe) + if (isThinkingEnabled) thinkingBuf.append(part) + pending.delete(0, safe) + } + break@parse + } + } else { + // Answer mode: append and clear + answerBuf.append(pending) + pending.setLength(0) + break@parse + } + } + // periodic flush + flush() } GenerationEvent.Loading -> { isGenerating = true @@ -158,4 +298,43 @@ class ChatViewModel: ViewModel() { super.onCleared() } + private fun sanitizeUserInput(input: String): String { + // Remove fallbacks/placeholders and trim + val stripped = input + .replace(Regex("<\\|im_start\\|>\\w+", RegexOption.IGNORE_CASE), "") + .replace(Regex("<\\|im_end\\|>", RegexOption.IGNORE_CASE), "") + .replace(Regex("\\w+", RegexOption.IGNORE_CASE), "") + return stripped.trim() + } + + private fun sanitizeModelOutput(output: String): String { + var text = output + val stops = listOf( + "<|im_end|>", + "<|eot_id|>", + "<|eom_id|>", + "", + "<|endoftext|>", + "<|end|>" + ) + for (s in stops) { + val idx = text.indexOf(s) + if (idx >= 0) { + text = text.substring(0, idx) + } + } + // If thinking is disabled, strip any residual ... the model might emit + if (!isThinkingEnabled) { + text = text.replace(Regex("[\\s\\S]*?", RegexOption.IGNORE_CASE), "") + } + // Remove leading role/sentinel artifacts that might slip in + text = text.replace(Regex("^(\\s*)(assistant:|<\\|assistant\\|>|assistant)\\s*", RegexOption.IGNORE_CASE), "") + return text.trim() + } + + private fun injectThinkingInstruction(user: String): String { + val instruction = "Wrap internal reasoning in and . After , provide the final answer clearly without additional preamble." + return "$instruction\n\n$user" + } + } \ No newline at end of file diff --git a/androidApp/src/main/res/drawable/chat_app_logo.png b/androidApp/src/main/res/drawable/chat_app_logo.png new file mode 100644 index 0000000000000000000000000000000000000000..5f1c6d35f4cd404e3ef15dbbae5eec3190789b62 GIT binary patch literal 7289 zcmY*;by!rHv^Al(eoF@%5&DJ@;X&aoe;Vn04{q4533 z4%7CP#-p)1f#5F$>ruyZRWb5FLnH3|*PgtciN-@idlG4{5B3CWsDq#`PW%vS7b_UQ zuaoPeHyWCxjIS#M>Hza(wu0H(J483j2=wvs;r9{dcX77`3W|%10|kVDLPC6x9(*2t&YlooK4%Zs{{Se!JfQCOuAcTT z&dmP+AyzJ4o>BlWFMDf{4a7#o%Gz3l&k7=J%_k^m4dH`W3k&jD3kzF|i3q+FvJtcb z{3pDpz0LoR@9gm(8XgG&{u==b@(cV+`JW8Za<_*)X7Mkiw4mgFkN+Q=B=BD%|0^-+ z|7Ure>!b8P%>1M2f7pdNKT_)R$Q^HT8UY#_ZPaUpSNgu@$jrM+*!G)Mnf>|OOpham z1S#P^y})YdJ6_i5`=^OpVwhOmDr%hhV{@Sj;{fen*t8{}aZdYhFMo%rzDfxxd7jKZ z-tPisHNtyMpu%uG&-5VLlXJ)Ac*tRsF}G!|Ix8=RLwLiu@6U|2u6k-@a9|ggK&vx>xu}bqp=&RjuDp>^&$TrAY<;wlc!d<&D zTmj7Jgt(>-_yXt&ov!_Rt%rsWn7C+T*%Gl7JtgsT1;bPuE&CN5xOj{aMy_i0-z~79{SmjQ>TF#HCFLG{wvE`(aUZhK1PM{9}FFfeB%_%$Tp(*d~|m{ z9~UHxfJj4Ggncp73>^8Fdv9C@@f6oL{Bk#c$bTp}PGZeR?=&r0kh+s}JNzhVx8$)Y zDZ*`|g!hFeh9Wii+++?}JC3Ms6PsK;BVJ#4le#MTYPz#TPwLHqP(%SNU4k~T_(Ld- z0Xomqt<%LiQ5L<_ju!*gzXRHKeKzMP8}mpFCQ}NbPi-dk&W@23Wz-B2SoFcG()i(F zI=me8!1kfzJ+7fHIM?SiZDrg(pt;^Ce~*!@#}5^r&1)MWwZ?u7x-(%C5xqe_iIJ)B zsVT377cKU@`7@Ka4$?phV`95VVfv+0f09X2>i*)6;L7Bt+S+HeLH%Pn0Z%gcI1P3l z_IyY;IlTbvd;|eu~6s1YjTrg8<^0Jsxnn!;Rhu0{m zn+st1tQT1Mq9fvVwdpm@pQ4!;dVFRSa+cvZexB>0S|35C9VVxJK=Sa05uZ2ifR(s}k!)`FZ*J&_t)y-6b=5O$!D*C~ zjNk~YJcs4XA_`r%XzOCa**nSFt2#R_?Uk0vI9|i3Xb!tV5wG1euDgb_jol5NN_(n| zkYx`k%5dFKcZ5V$Q7`I^mFI}^i1J=j59%vbG^hfzgjL6L}j8RMP&WTOBu6Q`N^EoPWuE(D6ZiOwUR!qR4qx8$@hgPeAaQCuRw&u8wgHM zz<2FzV0~l(_a=+8_6W$;=vV1ZxlZ(WHQDYu|4pofGt2lHj8U(N+{X{A1@2cd(wY>ohKPNzCg-qu|nKYSKZ-1yD8BV_r7Xk;5vcg@wB;l%GJ*DgF>)pe&p1X}SRK(1%UF zL|Ryb{<{0Ecz~amu3Zm?F}o-$zQV-wQY*C%f&$!${g4D!LRL2V;XT2K&wJ|+I<+>r z>)CmUW-M;egbC;58RXgcc9cVDutSGWF{h>nZ$FL~g_^z!PlS!@8OknTm-E0ivIT7p zeO4rZKCS%1{tG4!Vt{^pX!CEf2=!TLEoThAeeq2ZYP^w3#-Caxy&?pa!`{bc8S*_c2 z@=VJrPi|zS>8fCLec|Ei+fvn&2C-k;ZG_z)C*H5bnFL?U$WcRO1x^R0zC=3i^PPMn(mU$QiXZvR_v4v>QRu<38 zZK`~(1_&)*+|FMn$l&N)&-@7br0Y+ip<6*&RTCZ0VY+g9%(NSz?~Q2dowvBNZ?T$A z&evI$a%qd*58CqS408!Q7VYLS*bL?(0q2KqWX{lTIR;iG|Q1V^Fz+wmPHeb#S;CJTz=4zq?Ex8!sFG_bsB=mCz z@gA|*!)b6oZiN}V_uAsGMDYQ;cdF5WjNVP?_+bzinD> z8>{ymaldwp9A#VudKAbPc;VX78g(ioypv14V}V;;b5eoosQLUdOOdE{IqJS7IGqAR z^YxSkqzJ<{c(-jXE$h2NVCyAKgii{XIR!1RoXG|OXgnEz~6{(x-UNB-kynIESZ=rnL# zB03mx3E=D(=9q2!fEI&8F2er=xf1d|=N>?Md?PPYR9@eYuNR_HsO+7@$Bn~u&)w#g zdEW~^p%$cXb#P9Ru8LWGJQBJKo`WG+-OXncwbbELOiNnjZ*TorWjpg&sqlW&9nB9r z*5**xtt>1gTC;t@%~ha${VHq0d<(rTJtgmBi%v2ztM46jRo=eqDnEA$9KwteiKIjbm5$PmbWrgc8PIh z)gX@~bPbJUlgHUT-nU9Sj?GFW1B=s=lVGj1BKX6Hw#SFI)2A`2p5K-Y)cW1VUAzzc z^FGv;;WIgr@ov3zDa`iVZVUPl&=gF;W*TzGwR?Ehlh=)9Tdn znV!>4llWe|zwIS6j(I=$%%sev!-Yvy3qIdjg z4FBvXAY`I+Nvsx^Wj_`WyAS8`Il2v3o@@4uNA6d`qF{~B(JRML(c5#&()b@Ok@1N_ zYGaGt-|{c$Q$=2$T>;&9iqEXBivw9nR~7Qj@2fHV)IAzbx|=DY-s=Q_w*ZU6m>^P0 z@Eg_WASFZpk?D8LwA4cy6M5Fq2nr~4_0?G5UYqJ?gD3!ELF%5jCl-N3^HY-FO{rJ! z)AQc?)#9NxmFdr_sp-t~77T#ucpqpxoG2R`NHd#xWisV9BHOmWf z78GmgHu8|dKaquk7BRyF^wcW4tM2G~{?v-sQP`yOvD_6qU>16?Z9?)i29| zQp16FE~x`KuqA2kD%OF}hJ!`_GQZ?K(b@4h z+jL{+%B^rn|0%NL_v8Lz#f><$HDbJ_?J0%=c==h3LF~Dn;qjlZ%W`qi zCWRV|&i|4-hZM$c!Cry2D^0W5HAL>&2&hVINxoe?JF6+i{@DZbT{H4F_l?imK-9qR zcqJB<=f&>{B?jVa&boC{fdNFTZP;y?3s?J>+l-~1+DVP3nW+LB8B9SqK9cSRVU=>; z##iKDTK#kmi(1(R(&`Sd8)g11cvimdq-?CArC##d3}7_os$0{!Piv>-3BAzVj`;XW z>>K!}6Q3NG^>K>sz9c=VtJQYLiU@Pug2RNBX8`$ zLSk~bOxT`-*vnNV<=+BH0zMoM73VY&@z|NP`N{W?{Z2&$y(jkFJ=P$n+IPa|bT1pFLd^VbP~xKZs~8qLJLZTD-O^bv*4t%dZ=l$?A`9IO)M<}S_!&&IlBY-kz9EzOZJ#Q+0(yHN*zD`4l#90nX$2i0TC}56 zbD$P;9A=e{6c$5w*dp_ER|FhyLoxCqqVPHR%B3vweY3c?0d)Mv`yIa2m?PYfnJ)Gkax1Qa4pzajD;H&j%s1Y^Y{=_2O zo_p7@)7T&+61xFuHZgODOlSm>$we%tyy{BvXd!=0btN9U)Sjz1cNlX6mFbui1dJpoE_D&64Jqc`eJEdu0 zdO1#+s5a8ly`tWPzC%g(`c(ZkF^k&Ous4d=3{mzy**u1e!Taz1U1&inT^9@VUE)k4 zjb?uaivA8elJS}n$hgAuSne3hO0dYs&V1PMale0{hj}2-ah&>jaZrIb>E7q%lO`OB z>&a*qDEGhZ??*13$Ja{C;WUlGf3UV@4>F;mPrI0_?z!z{ok{t! zHZJ}6q5d_NrF-r{?;#W8=pH9C0KwiUbBArV`mFQMTHxtRvbVf>k}sU)oY~7oWu}z% z+L4$o<>@6(V*<`M*TINvpLXSh4KIeauHhpI*~3iG+Pz&a-3Mj{mnQ*OpZ(w6ZCTQ?R}#Q>~uMJC-O)kF{? z<;ZfBMm`<$J-t(cLF{pvJ&~>V=7= z==-mj36qdonu$wpa)90DGg<0vtq(g&&F6pWqJs)%f3E9(qGf#n%;v62Hgt*RiTLrS zu9=4_hkP~{@UY8LzJ=wqKuGrtcXD6vAtreoVJT|ZSb9`_O}{1VSlwwu2k!XkIIaCs zPmQN=^FB#DPlU_x(qjMAnv^u1#6d)19bS#z&8xLA14bKsE7Fd)D5QETRn2u$W+Y!ziq0Mht>!Ac6*AWIN_@HOG1@VEke&)d zM*J)xtH+~iE=^$p*#j#_`c>_Pkc52iK5fJD>)~J(-^BC$YQR|#7ej&PO>BCQ(#W>9 ze@iTFaLZ37UM0#=8Ee zeolY)+=kf6y{y(L!m{sR9oxrO@BBGifTKm+OsvNycJSuh;|}7`Xt}YtGKk_Ch-x2q zF-Oo@+=~poj6b+u3(k+%6;()QF11n^0*e(I$;0EW8$9cy*$==>)SGGEMSdzXLYw;K z-xUQ5DkUM{6|`Syir3PrK+~OV@GjS#DdiEBYJ!>8E>+cB{y|%4O|@cU^_Q%Ni^#Y) zafzp5GU~Qj+@-W!AMQ;Na^q5Nq2!GeraY&~GtMSRyM*dKzeV}+(7!zzNie=Tt!BAT z@tl6a(_g#Q!<0#2*vsrQd^;^475);%@a@<Giy1&*Dede(EfCyr`Tyk(EFbnRkqmeRe3Dv(JSB4j6 zywE_E_g1v=(Nkt8!7;m=%LspqPQy%%C_p=p@2 z^1f|co8{l|!YPBEfA&286?55k7uj@Mtafhp1n4i<*SkbpkkVC9AaF=fhsrRqF6ABY zys6D>Kvw2|V)(5eB{pg>Rm2cZt!71nt%lFp7gK{)o`%EhJW_>p@BO~KKY^v`VKj^<5v&WwJ%FCGgkPs0+IjZo{ zHbd2KW{R(2nlF760)1cOCRMPS%iA#Lci(Z^->4}h7CFzj^5-4U?au+tAJqJCii08Y hWa8(;!$9Q&LA(KrRNC^&*1w;DuN5^FD&#D~{tJtuA|(I- literal 0 HcmV?d00001 diff --git a/androidApp/src/main/res/drawable/ic_launcher_foreground.xml b/androidApp/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 0000000..1da2f17 --- /dev/null +++ b/androidApp/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,9 @@ + + + + + + + diff --git a/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml b/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml new file mode 100644 index 0000000..d570ea9 --- /dev/null +++ b/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml b/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml new file mode 100644 index 0000000..d570ea9 --- /dev/null +++ b/androidApp/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml @@ -0,0 +1,6 @@ + + + + + + diff --git a/androidApp/src/main/res/values/ic_launcher_background.xml b/androidApp/src/main/res/values/ic_launcher_background.xml new file mode 100644 index 0000000..728b1b9 --- /dev/null +++ b/androidApp/src/main/res/values/ic_launcher_background.xml @@ -0,0 +1,4 @@ + + #FFFFFF + + diff --git a/androidApp/src/main/res/values/strings.xml b/androidApp/src/main/res/values/strings.xml new file mode 100644 index 0000000..9e58faa --- /dev/null +++ b/androidApp/src/main/res/values/strings.xml @@ -0,0 +1,5 @@ + + ChatApp + Dr. Roopam K Gupta + + diff --git a/gradle.properties b/gradle.properties index 9382da3..f845d31 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ #Gradle -org.gradle.jvmargs=-Xmx2048M -Dfile.encoding=UTF-8 -Dkotlin.daemon.jvm.options\="-Xmx2048M" +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 -Dkotlin.daemon.jvm.options="-Xmx2048m" #org.gradle.caching=true #org.gradle.configuration-cache=true @@ -8,6 +8,8 @@ kotlin.code.style=official kotlin.mpp.enableCInteropCommonization=true #kotlin.native.cacheKind=none +kotlin.native.ignoreDisabledTargets=true + #Android android.useAndroidX=true android.nonTransitiveRClass=true diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 9e4a1bb..7631f95 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,5 +1,5 @@ [versions] -agp = "8.10.1" +agp = "8.7.3" kotlin = "2.1.21" compose = "1.5.4" compose-material3 = "1.1.2" diff --git a/library/build.gradle.kts b/library/build.gradle.kts index 8525a70..847074e 100644 --- a/library/build.gradle.kts +++ b/library/build.gradle.kts @@ -8,9 +8,10 @@ plugins { kotlin { androidTarget { - publishAllLibraryVariants() - compilerOptions { - jvmTarget.set(JvmTarget.JVM_11) + compilations.all { + kotlinOptions { + jvmTarget = "11" + } } } @@ -72,4 +73,10 @@ android { sourceCompatibility = JavaVersion.VERSION_11 targetCompatibility = JavaVersion.VERSION_11 } + publishing { + singleVariant("release") { + withSourcesJar() + withJavadocJar() + } + } } diff --git a/library/src/androidMain/cpp/CMakeLists.txt b/library/src/androidMain/cpp/CMakeLists.txt index e6ad2e7..01d7fdb 100644 --- a/library/src/androidMain/cpp/CMakeLists.txt +++ b/library/src/androidMain/cpp/CMakeLists.txt @@ -1,16 +1,34 @@ cmake_minimum_required(VERSION 3.22.1) - project(inferkt-android) -add_subdirectory(../../native ../../native/build) +# Configure the llama.cpp submodule build +set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libraries" FORCE) +set(LLAMA_BUILD_TESTS OFF CACHE BOOL "" FORCE) +set(LLAMA_BUILD_TOOLS OFF CACHE BOOL "" FORCE) +set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) +set(LLAMA_BUILD_SERVER OFF CACHE BOOL "" FORCE) +set(LLAMA_FATAL_WARNINGS OFF CACHE BOOL "" FORCE) +set(LLAMA_CURL OFF CACHE BOOL "" FORCE) + +# Add the llama.cpp submodule +add_subdirectory(../../native/llama.cpp ../../native/llama.cpp/build) +# Build the Inferkt JNI wrapper library add_library( inferkt-android SHARED infer-android.cpp ) -target_link_libraries(inferkt-android inferkt log android) -target_compile_options(inferkt PRIVATE -O3 -DNDEBUG) +# Tell our wrapper where to find the llama.cpp headers +target_include_directories(inferkt-android + PUBLIC + ../../native/llama.cpp/include +) -target_compile_options(inferkt PRIVATE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64 -pthread) \ No newline at end of file +# Link everything together +target_link_libraries(inferkt-android + llama + log + android + ) \ No newline at end of file diff --git a/library/src/androidMain/cpp/infer-android.cpp b/library/src/androidMain/cpp/infer-android.cpp index 10f0799..03173e9 100644 --- a/library/src/androidMain/cpp/infer-android.cpp +++ b/library/src/androidMain/cpp/infer-android.cpp @@ -1,205 +1,393 @@ - - -#include #include -#include -#include <__algorithm/min.h> -#include "inferkt.h" -#include +#include +#include #include +#include +#include +#include "llama.h" -struct InferContext { - JNIEnv* env; - jobject callback; - jmethodID methodID; -}; +// Global state to hold the model and its context +static llama_model *model = nullptr; +static llama_context *ctx = nullptr; +static bool is_running_inference = false; +// Sampling and thinking globals (must be declared before use) +static float g_temp = 0.8f; +static float g_top_p = 0.95f; +static float g_min_p = 0.05f; +static int g_top_k = 40; +static bool g_thinking_enabled = true; // controls instruction injection and think-tag passthrough +// No rolling chat history or transcript maintained here to keep native layer simple -extern "C" -JNIEXPORT jlong JNICALL -Java_com_dilivva_inferkt_InferNativeKt_init(JNIEnv *env, jclass clazz) { - __android_log_print(ANDROID_LOG_DEBUG, "INFERKT", "System properties: %s", llama_print_system_info()); - return init(); -} +// ===== Compatibility layer for existing Kotlin externals in InferNative.kt ===== -bool model_load_progress(float progress, void* user_data) { - auto* ctx = static_cast(user_data); - ctx->env->CallVoidMethod(ctx->callback, ctx->methodID, progress); - if (progress >= 1.0f) { - ctx->env->DeleteLocalRef(ctx->callback); - delete ctx; +static void send_generation_with_event(JNIEnv *env, jobject callback_obj, const char *text, jint event) { + if (!callback_obj) return; + jclass callback_class = env->GetObjectClass(callback_obj); + if (!callback_class) return; + jmethodID on_generation = env->GetMethodID(callback_class, "onGeneration", "(Ljava/lang/String;I)V"); + if (on_generation) { + jstring jtext = env->NewStringUTF(text ? text : ""); + env->CallVoidMethod(callback_obj, on_generation, jtext, event); + env->DeleteLocalRef(jtext); } - return true; + env->DeleteLocalRef(callback_class); } -extern "C" -JNIEXPORT jboolean JNICALL +struct ProgressCtx { JNIEnv *env; jobject callback; jmethodID on_progress; }; + +static bool progress_cb(float progress, void *user_data) { + auto *pc = reinterpret_cast(user_data); + if (!pc || !pc->callback) return true; + jfloat p = progress; + pc->env->CallVoidMethod(pc->callback, pc->on_progress, p); + return true; // continue loading +} + +extern "C" JNIEXPORT jlong JNICALL +Java_com_dilivva_inferkt_InferNativeKt_init(JNIEnv *env, jclass) { + // No opaque pointer is needed anymore; return a non-zero handle + return 1; +} + +extern "C" JNIEXPORT jboolean JNICALL Java_com_dilivva_inferkt_InferNativeKt_loadModel(JNIEnv *env, - jclass clazz, - jlong inference_ptr, - jstring path, - jint num_of_gpu, - jboolean use_mmap, - jboolean use_mlock, - jobject callback - ) { - auto path_to_model = env->GetStringUTFChars(path, 0); - jclass callbackClass = env->GetObjectClass(callback); - jmethodID onProgressCallback = env->GetMethodID(callbackClass, "onProgressCallback", "(F)V"); - if (onProgressCallback == nullptr) { - __android_log_print(ANDROID_LOG_ERROR, "INFERKT", "Method onProgressCallback not found"); + jclass, + jlong /*inference_ptr*/, + jstring path, + jint numberOfGpu, + jboolean useMmap, + jboolean useMlock, + jobject callback) { + const char *model_path = env->GetStringUTFChars(path, nullptr); + + if (ctx) { llama_free(ctx); ctx = nullptr; } + if (model) { llama_model_free(model); model = nullptr; } + + llama_backend_init(); + + ProgressCtx pc{}; + pc.env = env; + pc.callback = callback; + jclass cb_cls = env->GetObjectClass(callback); + pc.on_progress = cb_cls ? env->GetMethodID(cb_cls, "onProgressCallback", "(F)V") : nullptr; + if (cb_cls) env->DeleteLocalRef(cb_cls); + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = numberOfGpu; + model_params.use_mmap = useMmap; + model_params.use_mlock = useMlock; + model_params.progress_callback = progress_cb; + model_params.progress_callback_user_data = &pc; + + model = llama_model_load_from_file(model_path, model_params); + env->ReleaseStringUTFChars(path, model_path); + if (model == nullptr) { return false; } - auto* ctx = new InferContext{env, callback, onProgressCallback }; - auto is_model_loaded = load_model( - inference_ptr, - path_to_model, - num_of_gpu, - use_mmap, - use_mlock, - model_load_progress, - ctx - ); - env->ReleaseStringUTFChars(path, path_to_model); - - return is_model_loaded; -} -extern "C" -JNIEXPORT void JNICALL -Java_com_dilivva_inferkt_InferNativeKt_setSamplingParams(JNIEnv *env, - jclass clazz, - jlong inference_ptr, - jfloat temp, - jfloat top_p, - jfloat min_p, - jint top_k) { - set_sampling_params( - inference_ptr, - temp, - top_p, - min_p, - top_k - ); -} -extern "C" -JNIEXPORT jboolean JNICALL -Java_com_dilivva_inferkt_InferNativeKt_setContextParams(JNIEnv *env, - jclass clazz, - jlong inference_ptr, - jint context_window, - jint batch, - jint num_of_threads) { - if (num_of_threads == 0){ - num_of_threads = std::max(1, std::min(8, (int) sysconf(_SC_NPROCESSORS_ONLN) - 2)); + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = 2048; + ctx_params.n_threads = (int) std::thread::hardware_concurrency(); + ctx_params.n_threads_batch = (int) std::thread::hardware_concurrency(); + ctx = llama_init_from_model(model, ctx_params); + if (ctx == nullptr) { + llama_model_free(model); + model = nullptr; + return false; } - return set_context_params( - inference_ptr, - context_window, - batch, - num_of_threads - ); + return true; } +extern "C" JNIEXPORT void JNICALL +Java_com_dilivva_inferkt_InferNativeKt_setSamplingParams(JNIEnv *, jclass, jlong /*inference_ptr*/, jfloat temp, jfloat topP, jfloat minP, jint topK) { + g_temp = temp; g_top_p = topP; g_min_p = minP; g_top_k = topK; +} -void generation_callback(const char *message, enum generation_event event, void *user_data) { - auto* ctx = static_cast(user_data); - jstring messageStr = ctx->env->NewStringUTF(message); - jint j_event = static_cast(event); - ctx->env->CallVoidMethod(ctx->callback, ctx->methodID, messageStr, j_event); - ctx->env->DeleteLocalRef(messageStr); - if (event != LOADING && event != GENERATING){ - ctx->env->DeleteLocalRef(ctx->callback); - delete ctx; - } +extern "C" JNIEXPORT void JNICALL +Java_com_dilivva_inferkt_InferNativeKt_setGlobalThinkingMode(JNIEnv *, jclass, jboolean enabled) { + g_thinking_enabled = enabled; } +extern "C" JNIEXPORT jboolean JNICALL +Java_com_dilivva_inferkt_InferNativeKt_setContextParams(JNIEnv *, jclass, jlong /*inference_ptr*/, jint /*context_window*/, jint /*batch*/, jint numberOfThreads) { + if (!ctx) return false; + int nt = numberOfThreads > 0 ? numberOfThreads : (int) std::thread::hardware_concurrency(); + llama_set_n_threads(ctx, nt, nt); + return true; +} -extern "C" -JNIEXPORT void JNICALL -Java_com_dilivva_inferkt_InferNativeKt_completion(JNIEnv *env, - jclass clazz, - jlong inference_ptr, - jstring prompt, - jint max_generation_count, - jobject callback) { +static void run_generation_with_callback_ex(JNIEnv *env, const char *prompt, int max_tokens, jobject callback, bool skip_until_assistant) { + send_generation_with_event(env, callback, "", 0); // Loading - jclass callbackClass = env->GetObjectClass(callback); - jmethodID onGenerationMethod = env->GetMethodID(callbackClass, "onGeneration", "(Ljava/lang/String;I)V"); - if (onGenerationMethod == nullptr) { - __android_log_print(ANDROID_LOG_ERROR, "INFERKT", "Method onGeneration not found"); + if (ctx == nullptr || model == nullptr) { + send_generation_with_event(env, callback, "", 4); // DECODE_ERROR as model not ready + send_generation_with_event(env, callback, "", 5); return; } - auto prompt_char = env->GetStringUTFChars(prompt, 0); - __android_log_print(ANDROID_LOG_ERROR, "INFERKT", "Prompt: %s", prompt_char); - __android_log_print(ANDROID_LOG_DEBUG, "INFERKT", "Infering..."); - auto * ctx = new InferContext{env, callback, onGenerationMethod }; - complete(inference_ptr, prompt_char, max_generation_count, generation_callback, ctx); -} + const struct llama_vocab * vocab = llama_model_get_vocab(model); + // Start a fresh sequence using the public memory API + llama_memory_clear(llama_get_memory(ctx), false); + + std::vector tokens_list; + tokens_list.resize(2048); + // Parse special tokens so chat template markers become special IDs (e.g., <|im_start|>, <|im_end|>) + int n_tokens = llama_tokenize(vocab, prompt, (int) strlen(prompt), tokens_list.data(), (int) tokens_list.size(), true, true); + if (n_tokens < 0) { + send_generation_with_event(env, callback, "", 2); // TOKENIZE_ERROR + send_generation_with_event(env, callback, "", 5); + return; + } + tokens_list.resize(n_tokens); -extern "C" -JNIEXPORT void JNICALL -Java_com_dilivva_inferkt_InferNativeKt_chat(JNIEnv *env, - jclass clazz, - jlong inference_ptr, - jstring prompt, - jint max_generation_count, - jobject callback) { - - jclass callbackClass = env->GetObjectClass(callback); - jmethodID onGenerationMethod = env->GetMethodID(callbackClass, "onGeneration", "(Ljava/lang/String;I)V"); - if (onGenerationMethod == nullptr) { - __android_log_print(ANDROID_LOG_ERROR, "INFERKT", "Method onGeneration not found"); + if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), n_tokens)) != 0) { + send_generation_with_event(env, callback, "", 4); // DECODE_ERROR + send_generation_with_event(env, callback, "", 5); return; } - auto prompt_char = env->GetStringUTFChars(prompt, 0); - __android_log_print(ANDROID_LOG_ERROR, "INFERKT", "Prompt: %s", prompt_char); - auto * ctx = new InferContext{env, callback, onGenerationMethod }; - chat(inference_ptr, prompt_char, max_generation_count, generation_callback, ctx); + // Build a simple sampler chain honoring temperature/top-p/top-k when possible + struct llama_sampler_chain_params sp = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(sp); + if (g_top_k > 0) llama_sampler_chain_add(chain, llama_sampler_init_top_k(g_top_k)); + if (g_top_p > 0.0f && g_top_p < 1.0f) llama_sampler_chain_add(chain, llama_sampler_init_top_p(g_top_p, 1)); + if (g_min_p > 0.0f && g_min_p < 1.0f) llama_sampler_chain_add(chain, llama_sampler_init_min_p(g_min_p, 1)); + if (g_temp > 0.0f) llama_sampler_chain_add(chain, llama_sampler_init_temp(g_temp)); + // end with greedy/dist to pick a token + llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + + int produced = 0; + std::string pending; + pending.reserve(1024); + + // Stop sequences aligned with ChatML and common chat templates + const char *stops[] = { + "<|im_end|>", + "<|eot_id|>", + "<|eom_id|>", + "<|endoftext|>", + "<|end|>", + "<|im_start|>user" + }; + const size_t n_stops = sizeof(stops) / sizeof(stops[0]); + size_t max_stop_len = 0; + for (size_t i = 0; i < n_stops; ++i) { + size_t len = strlen(stops[i]); + if (len > max_stop_len) max_stop_len = len; + } + std::string stream_buf; + stream_buf.reserve(2048); + + // Sentinels that commonly denote the start of assistant output across templates + // We will drop any text until we encounter one of these when skip_until_assistant == true + const char *sentinels[] = { + "\nassistant\n", // smollm-style header + "assistant\n", // at BOM + "<|assistant|>", // chatml / phi3-like + "Assistant:", // vicuna/openchat-like + "[|assistant|]", // exaone + "model", // gemma uses 'model' for assistant + "assistant" + }; + const size_t sentinel_count = sizeof(sentinels) / sizeof(sentinels[0]); + while (produced < max_tokens && is_running_inference) { + const llama_token new_token_id = llama_sampler_sample(chain, ctx, -1); + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + char buf[512]; + // Do not render special/control tokens in user-visible output + const int32_t n = llama_token_to_piece(vocab, new_token_id, buf, (int32_t) sizeof(buf) - 1, 0, false); + if (n > 0) { + buf[n] = '\0'; + if (skip_until_assistant) { + pending.append(buf, (size_t) n); + // look for any sentinel + size_t cut = std::string::npos; + for (size_t i = 0; i < sentinel_count; ++i) { + size_t pos = pending.find(sentinels[i]); + if (pos != std::string::npos) { + size_t end = pos + strlen(sentinels[i]); + if (cut == std::string::npos || end > cut) cut = end; + } + } + if (cut != std::string::npos) { + // emit the remainder after the sentinel and stop skipping + std::string_view remainder(pending.data() + cut, pending.size() - cut); + if (!remainder.empty()) { + // After we've found assistant, start routing through the stop-filtering buffer + stream_buf.append(remainder.data(), remainder.size()); + // try stop detection immediately in case remainder already contains a stop + size_t stop_pos = std::string::npos; + for (size_t i = 0; i < n_stops; ++i) { + size_t p = stream_buf.find(stops[i]); + if (p != std::string::npos && (stop_pos == std::string::npos || p < stop_pos)) stop_pos = p; + } + if (stop_pos != std::string::npos) { + if (stop_pos > 0) { + std::string to_emit = stream_buf.substr(0, stop_pos); + if (!to_emit.empty()) { + send_generation_with_event(env, callback, to_emit.c_str(), 1); + } + } + // early stop + llama_sampler_free(chain); + send_generation_with_event(env, callback, "", 5); + return; + } + // no stop yet, flush safe part leaving tail for partial matches + if (stream_buf.size() > max_stop_len) { + size_t safe = stream_buf.size() - max_stop_len; + if (safe > 0) { + std::string to_emit = stream_buf.substr(0, safe); + send_generation_with_event(env, callback, to_emit.c_str(), 1); + stream_buf.erase(0, safe); + } + } + } + pending.clear(); + skip_until_assistant = false; + } else if (pending.size() > 4096) { + // avoid unbounded growth; keep the last 1KB + pending.erase(0, pending.size() - 1024); + } + } else { + // Append to buffer and check for stop sequences + stream_buf.append(buf, (size_t) n); + size_t stop_pos = std::string::npos; + for (size_t i = 0; i < n_stops; ++i) { + size_t p = stream_buf.find(stops[i]); + if (p != std::string::npos && (stop_pos == std::string::npos || p < stop_pos)) stop_pos = p; + } + if (stop_pos != std::string::npos) { + if (stop_pos > 0) { + std::string to_emit = stream_buf.substr(0, stop_pos); + if (!to_emit.empty()) { + send_generation_with_event(env, callback, to_emit.c_str(), 1); + } + } + // clear buffer to avoid double-flush and end generation cleanly + stream_buf.clear(); + break; + } + // flush safe portion leaving tail to catch partial stop tokens across boundaries + if (stream_buf.size() > max_stop_len) { + size_t safe = stream_buf.size() - max_stop_len; + if (safe > 0) { + std::string to_emit = stream_buf.substr(0, safe); + send_generation_with_event(env, callback, to_emit.c_str(), 1); + stream_buf.erase(0, safe); + } + } + } + } + tokens_list.push_back(new_token_id); + if (llama_decode(ctx, llama_batch_get_one(&tokens_list.back(), 1)) != 0) { + send_generation_with_event(env, callback, "", 4); // DECODE_ERROR + break; + } + llama_sampler_accept(chain, new_token_id); + produced++; + } + + llama_sampler_free(chain); + // Flush any remaining buffered text (that does not contain stops) + if (!stream_buf.empty()) { + send_generation_with_event(env, callback, stream_buf.c_str(), 1); + stream_buf.clear(); + } + send_generation_with_event(env, callback, "", 5); // GENERATED } -extern "C" -JNIEXPORT jobject JNICALL -Java_com_dilivva_inferkt_InferNativeKt_getModelDetails(JNIEnv *env, - jclass clazz, - jstring path) { +static void run_generation_with_callback(JNIEnv *env, const char *prompt, int max_tokens, jobject callback) { + run_generation_with_callback_ex(env, prompt, max_tokens, callback, /*skip_until_assistant=*/false); +} - auto path_to_model = env->GetStringUTFChars(path, 0); - auto model_details = get_model_details(path_to_model); +// Use model-provided chat template to construct a proper conversation prompt (single-turn) +static void run_chat_with_callback(JNIEnv *env, const char *user_prompt, int max_tokens, jobject callback) { + send_generation_with_event(env, callback, "", 0); // Loading - jclass modelDetails = env->FindClass("com/dilivva/inferkt/ModelDetails"); - if (!modelDetails) return nullptr; + if (ctx == nullptr || model == nullptr) { + send_generation_with_event(env, callback, "", 4); // DECODE_ERROR as model not ready + send_generation_with_event(env, callback, "", 5); + return; + } - jmethodID ctor = env->GetMethodID(modelDetails, "", "(ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;)V"); - if (!ctor) { - env->DeleteLocalRef(modelDetails); - return nullptr; + const char *tmpl = llama_model_chat_template(model, nullptr); + if (tmpl == nullptr) { + // Fallback to raw generation if no chat template is available + run_generation_with_callback_ex(env, user_prompt, max_tokens, callback, /*skip_until_assistant=*/true); + return; + } + + // Build a single-turn chat with add_assistant token to start assistant generation + llama_chat_message msgs[1] = { { "user", user_prompt } }; + std::vector formatted((size_t) llama_n_ctx(ctx)); + int new_len = llama_chat_apply_template(tmpl, msgs, 1, /*add_ass=*/true, formatted.data(), (int) formatted.size()); + if (new_len < 0) { + // If templating fails, fallback to raw prompt + run_generation_with_callback_ex(env, user_prompt, max_tokens, callback, /*skip_until_assistant=*/true); + return; + } + if ((size_t) new_len > formatted.size()) { + formatted.resize((size_t) new_len); + new_len = llama_chat_apply_template(tmpl, msgs, 1, /*add_ass=*/true, formatted.data(), (int) formatted.size()); + if (new_len < 0) { + run_generation_with_callback_ex(env, user_prompt, max_tokens, callback, /*skip_until_assistant=*/true); + return; + } } - jstring name = env->NewStringUTF(model_details.name); - jstring arch = env->NewStringUTF(model_details.architecture); - jstring context_length = env->NewStringUTF(model_details.context_length); - jint version = model_details.version; + std::string prompt(formatted.data(), (size_t) new_len); + // We used the chat template; generation should already be positioned at assistant. + run_generation_with_callback_ex(env, prompt.c_str(), max_tokens, callback, /*skip_until_assistant=*/false); +} - jobject dataObj = env->NewObject(modelDetails, ctor, version, arch, name, context_length); +extern "C" JNIEXPORT void JNICALL +Java_com_dilivva_inferkt_InferNativeKt_completion(JNIEnv *env, jclass, jlong /*inference_ptr*/, jstring prompt, jint max_generation_count, jobject callback) { + is_running_inference = true; + const char *c_prompt = env->GetStringUTFChars(prompt, nullptr); + run_generation_with_callback(env, c_prompt, max_generation_count, callback); + env->ReleaseStringUTFChars(prompt, c_prompt); + is_running_inference = false; +} - env->DeleteLocalRef(name); - env->DeleteLocalRef(arch); - env->DeleteLocalRef(context_length); - env->DeleteLocalRef(modelDetails); +extern "C" JNIEXPORT void JNICALL +Java_com_dilivva_inferkt_InferNativeKt_chat(JNIEnv *env, jclass, jlong /*inference_ptr*/, jstring prompt, jint max_generation_count, jobject callback) { + is_running_inference = true; + const char *c_prompt = env->GetStringUTFChars(prompt, nullptr); + run_chat_with_callback(env, c_prompt, max_generation_count, callback); + env->ReleaseStringUTFChars(prompt, c_prompt); + is_running_inference = false; +} - return dataObj; +extern "C" JNIEXPORT void JNICALL +Java_com_dilivva_inferkt_InferNativeKt_cancelGeneration(JNIEnv *, jclass, jlong /*inference_ptr*/) { + is_running_inference = false; } -extern "C" -JNIEXPORT void JNICALL -Java_com_dilivva_inferkt_InferNativeKt_cancelGeneration(JNIEnv *env,jclass clazz,jlong inference_ptr) { - cancel_inference(inference_ptr); +extern "C" JNIEXPORT void JNICALL +Java_com_dilivva_inferkt_InferNativeKt_cleanUp(JNIEnv *, jclass, jlong /*inference_ptr*/) { + if (ctx) { llama_free(ctx); ctx = nullptr; } + if (model) { llama_model_free(model); model = nullptr; } + llama_backend_free(); } -extern "C" -JNIEXPORT void JNICALL -Java_com_dilivva_inferkt_InferNativeKt_cleanUp(JNIEnv *env, jclass clazz, jlong inference_ptr) { - clean_up(inference_ptr); + +extern "C" JNIEXPORT jobject JNICALL +Java_com_dilivva_inferkt_InferNativeKt_getModelDetails(JNIEnv *env, jclass, jstring /*path*/) { + jclass modelDetails = env->FindClass("com/dilivva/inferkt/ModelDetails"); + if (!modelDetails) return nullptr; + jmethodID ctor = env->GetMethodID(modelDetails, "", "(ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;)V"); + if (!ctor) { env->DeleteLocalRef(modelDetails); return nullptr; } + jint version = 0; + jstring arch = env->NewStringUTF("unknown"); + jstring name = env->NewStringUTF("unknown"); + jstring ctx_len = env->NewStringUTF("0"); + jobject obj = env->NewObject(modelDetails, ctor, version, arch, name, ctx_len); + env->DeleteLocalRef(arch); + env->DeleteLocalRef(name); + env->DeleteLocalRef(ctx_len); + env->DeleteLocalRef(modelDetails); + return obj; } \ No newline at end of file diff --git a/library/src/androidMain/kotlin/com/dilivva/inferkt/InferNative.kt b/library/src/androidMain/kotlin/com/dilivva/inferkt/InferNative.kt index 25c540f..f20c16a 100644 --- a/library/src/androidMain/kotlin/com/dilivva/inferkt/InferNative.kt +++ b/library/src/androidMain/kotlin/com/dilivva/inferkt/InferNative.kt @@ -25,4 +25,7 @@ external fun getModelDetails(path: String): ModelDetails external fun cancelGeneration(inferencePtr: Long) -external fun cleanUp(inferencePtr: Long) \ No newline at end of file +external fun cleanUp(inferencePtr: Long) + +// Global runtime options +external fun setGlobalThinkingMode(enabled: Boolean) \ No newline at end of file diff --git a/library/src/native/llama.cpp b/library/src/native/llama.cpp index e6a39a8..9961d24 160000 --- a/library/src/native/llama.cpp +++ b/library/src/native/llama.cpp @@ -1 +1 @@ -Subproject commit e6a39a8d6e9eef13b5970877ba4dd9c54e2f6151 +Subproject commit 9961d244f2df6baf40af2f1ddc0927f8d91578c8 diff --git a/library/src/native/src/Inference.cpp b/library/src/native/src/Inference.cpp index b2fa061..8a0e327 100644 --- a/library/src/native/src/Inference.cpp +++ b/library/src/native/src/Inference.cpp @@ -162,6 +162,8 @@ void Inference::completion( generation_cancelled = false; common_batch_clear(batch); + // Ensure we don't carry over previous-turn KV cache into this completion + llama_kv_self_clear(ctx); // evaluate the initial prompt for (auto i = 0; i < tokens.size(); i++) { @@ -176,6 +178,23 @@ void Inference::completion( callback("", DECODE_ERROR, user_data); } + // Stop sequences common across chat templates + const char *stops[] = { + "<|im_end|>", + "<|eot_id|>", + "<|eom_id|>", + "", + "<|endoftext|>", + "<|end|>", + "user", + "<|im_start|>user" + }; + const size_t n_stops = sizeof(stops) / sizeof(stops[0]); + size_t max_stop_len = 0; + for (size_t i = 0; i < n_stops; ++i) { size_t l = strlen(stops[i]); if (l > max_stop_len) max_stop_len = l; } + std::string stream_buf; + stream_buf.reserve(2048); + // Generate response tokens for (size_t i = batch.n_tokens; i <= max_tokens; ++i) { if (generation_cancelled){ @@ -189,9 +208,32 @@ void Inference::completion( callback("", END_OF_GENERATION, user_data); break; } - // Decode the token - auto generated = common_token_to_piece(ctx, new_token_id); - callback(generated.c_str(), GENERATING, user_data); + // Decode the token and check stop sequences (do not render specials) + auto generated = common_token_to_piece(ctx, new_token_id, false); + if (!generated.empty()) { + stream_buf.append(generated); + size_t stop_pos = std::string::npos; + for (size_t j = 0; j < n_stops; ++j) { + size_t p = stream_buf.find(stops[j]); + if (p != std::string::npos && (stop_pos == std::string::npos || p < stop_pos)) stop_pos = p; + } + if (stop_pos != std::string::npos) { + if (stop_pos > 0) { + std::string to_emit = stream_buf.substr(0, stop_pos); + if (!to_emit.empty()) callback(to_emit.c_str(), GENERATING, user_data); + } + callback("", END_OF_GENERATION, user_data); + break; + } + if (stream_buf.size() > max_stop_len) { + size_t safe = stream_buf.size() - max_stop_len; + if (safe > 0) { + std::string to_emit = stream_buf.substr(0, safe); + callback(to_emit.c_str(), GENERATING, user_data); + stream_buf.erase(0, safe); + } + } + } // Prepare batch for next token common_batch_clear(batch); @@ -213,6 +255,7 @@ void Inference::chat( callback("", LOADING, user_data); generation_cancelled = false; const char * tmpl = llama_model_chat_template(model, nullptr); + // Use the model's chat template and do not force-add a custom system message messages.push_back({"user", strdup(prompt.c_str())}); int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); @@ -231,6 +274,22 @@ void Inference::chat( auto generate = [&](const std::string & prompt) { std::string response; + // Stop sequences common across chat templates + const char *stops[] = { + "<|im_end|>", + "<|eot_id|>", + "<|eom_id|>", + "", + "<|endoftext|>", + "<|end|>", + "user", + "<|im_start|>user" + }; + const size_t n_stops = sizeof(stops) / sizeof(stops[0]); + size_t max_stop_len = 0; + for (size_t i = 0; i < n_stops; ++i) { size_t l = strlen(stops[i]); if (l > max_stop_len) max_stop_len = l; } + std::string stream_buf; + stream_buf.reserve(2048); const bool is_first = llama_kv_self_used_cells(ctx) == 0; @@ -276,17 +335,43 @@ void Inference::chat( break; } - // convert the token to a string, print it and add it to the response + // convert the token to a string, print it and add it to the response; do not render specials char buf[256]; - int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); + int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, false); if (n < 0) { callback("", DECODE_ERROR, user_data); break; } std::string piece(buf, n); - callback(piece.c_str(), GENERATING, user_data); + // append piece to stream buffer and check for stop sequences + stream_buf.append(piece); + size_t stop_pos = std::string::npos; + for (size_t j = 0; j < n_stops; ++j) { + size_t p = stream_buf.find(stops[j]); + if (p != std::string::npos && (stop_pos == std::string::npos || p < stop_pos)) stop_pos = p; + } + if (stop_pos != std::string::npos) { + if (stop_pos > 0) { + std::string to_emit = stream_buf.substr(0, stop_pos); + if (!to_emit.empty()) { + callback(to_emit.c_str(), GENERATING, user_data); + response += to_emit; + } + } + callback("", END_OF_GENERATION, user_data); + break; + } + // flush safe portion + if (stream_buf.size() > max_stop_len) { + size_t safe = stream_buf.size() - max_stop_len; + if (safe > 0) { + std::string to_emit = stream_buf.substr(0, safe); + callback(to_emit.c_str(), GENERATING, user_data); + response += to_emit; + stream_buf.erase(0, safe); + } + } count++; - response += piece; if (count >= max_tokens && max_tokens > 0) { callback("", END_OF_GENERATION, user_data); @@ -297,6 +382,11 @@ void Inference::chat( batch = llama_batch_get_one(&new_token_id, 1); } + // flush remaining buffer (no stops detected) + if (!stream_buf.empty()) { + callback(stream_buf.c_str(), GENERATING, user_data); + response += stream_buf; + } return response; };