diff --git a/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/SQLiteIndex.kt b/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/SQLiteIndex.kt index 786dca5c03..4d8dae0627 100644 --- a/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/SQLiteIndex.kt +++ b/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/SQLiteIndex.kt @@ -3,15 +3,20 @@ package org.appdevforall.codeonthego.indexing import android.content.ContentValues import android.content.Context import android.database.sqlite.SQLiteDatabase +import android.os.Looper import androidx.sqlite.db.SupportSQLiteDatabase import androidx.sqlite.db.SupportSQLiteOpenHelper import androidx.sqlite.db.framework.FrameworkSQLiteOpenHelperFactory import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withContext +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import org.appdevforall.codeonthego.indexing.api.Index import org.appdevforall.codeonthego.indexing.api.IndexDescriptor import org.appdevforall.codeonthego.indexing.api.IndexQuery import org.appdevforall.codeonthego.indexing.api.Indexable +import org.slf4j.LoggerFactory import kotlin.collections.iterator /** @@ -58,6 +63,10 @@ class SQLiteIndex( override val name: String = "sqlite:${descriptor.name}", private val batchSize: Int = 500, ) : Index { + companion object { + private val log = LoggerFactory.getLogger(SQLiteIndex::class.java) + } + private val tableName = descriptor.name.replace(Regex("[^a-zA-Z0-9_]"), "_") @@ -71,6 +80,8 @@ class SQLiteIndex( .filter { it.prefixSearchable } .associate { it.name to "f_${it.name}_lower" } + private val mutex = Mutex() + @Volatile private var closed = false private val db: SupportSQLiteDatabase init { @@ -102,51 +113,59 @@ class SQLiteIndex( createTable(db) } - override fun query(query: IndexQuery): Sequence { - val (sql, args) = buildSelectQuery(query) - val cursor = db.query(sql, args.toTypedArray()) - return cursor.use { - val payloadIdx = it.getColumnIndexOrThrow("_payload") - buildList { - while (it.moveToNext()) { - add(descriptor.deserialize(it.getBlob(payloadIdx))) + override fun query(query: IndexQuery): Sequence = runBlocking { + ifOpen(emptySequence()) { + val (sql, args) = buildSelectQuery(query) + val cursor = db.query(sql, args.toTypedArray()) + cursor.use { + val payloadIdx = it.getColumnIndexOrThrow("_payload") + buildList { + while (it.moveToNext()) { + add(descriptor.deserialize(it.getBlob(payloadIdx))) + } } - } - }.asSequence() + }.asSequence() + } } override suspend fun get(key: String): T? = withContext(Dispatchers.IO) { - val cursor = db.query( - "SELECT _payload FROM $tableName WHERE _key = ? LIMIT 1", - arrayOf(key), - ) - cursor.use { - if (it.moveToFirst()) { - descriptor.deserialize(it.getBlob(0)) - } else null + ifOpen(null) { + val cursor = db.query( + "SELECT _payload FROM $tableName WHERE _key = ? LIMIT 1", + arrayOf(key), + ) + cursor.use { + if (it.moveToFirst()) { + descriptor.deserialize(it.getBlob(0)) + } else null + } } } override suspend fun containsSource(sourceId: String): Boolean = withContext(Dispatchers.IO) { - val cursor = db.query( - "SELECT 1 FROM $tableName WHERE _source_id = ? LIMIT 1", - arrayOf(sourceId), - ) - cursor.use { it.moveToFirst() } + ifOpen(false) { + val cursor = db.query( + "SELECT 1 FROM $tableName WHERE _source_id = ? LIMIT 1", + arrayOf(sourceId), + ) + cursor.use { it.moveToFirst() } + } } - override fun distinctValues(fieldName: String): Sequence { - val col = fieldColumns[fieldName] - ?: throw IllegalArgumentException("Unknown field: $fieldName") - val cursor = db.query("SELECT DISTINCT $col FROM $tableName WHERE $col IS NOT NULL") - return cursor.use { - buildList { - while (it.moveToNext()) { - add(it.getString(0)) + override fun distinctValues(fieldName: String): Sequence = runBlocking { + ifOpen(emptySequence()) { + val col = fieldColumns[fieldName] + ?: throw IllegalArgumentException("Unknown field: $fieldName") + val cursor = db.query("SELECT DISTINCT $col FROM $tableName WHERE $col IS NOT NULL") + cursor.use { + buildList { + while (it.moveToNext()) { + add(it.getString(0)) + } } - } - }.asSequence() + }.asSequence() + } } override suspend fun insertAll(entries: Sequence) = withContext(Dispatchers.IO) { @@ -154,34 +173,53 @@ class SQLiteIndex( for (entry in entries) { batch.add(entry) if (batch.size >= batchSize) { - insertBatch(batch) + ifOpen { insertBatchLocked(batch) } batch.clear() } } if (batch.isNotEmpty()) { - insertBatch(batch) + ifOpen { insertBatchLocked(batch) } } } override suspend fun insert(entry: T) = withContext(Dispatchers.IO) { - insertBatch(listOf(entry)) + ifOpen { insertBatchLocked(listOf(entry)) } } override suspend fun removeBySource(sourceId: String) = withContext(Dispatchers.IO) { - db.execSQL("DELETE FROM $tableName WHERE _source_id = ?", arrayOf(sourceId)) + ifOpen { db.execSQL("DELETE FROM $tableName WHERE _source_id = ?", arrayOf(sourceId)) } } override suspend fun clear() = withContext(Dispatchers.IO) { - db.execSQL("DELETE FROM $tableName") + ifOpen { db.execSQL("DELETE FROM $tableName") } } override fun close() { - db.close() + if (Looper.getMainLooper() == Looper.myLooper()) { + log.warn( + "SQLiteIndex.close() called on the main thread; waiting on mutex and closing db may block and cause ANR" + ) + } + runBlocking { + mutex.withLock { + if (closed) return@withLock + closed = true + db.close() + } + } } + private suspend inline fun ifOpen(default: R, crossinline block: () -> R): R = + mutex.withLock { if (closed) default else block() } + + private suspend inline fun ifOpen(crossinline block: () -> Unit) = + mutex.withLock { if (!closed) block() } + suspend fun size(): Int = withContext(Dispatchers.IO) { - val cursor = db.query("SELECT COUNT(*) FROM $tableName") - cursor.use { if (it.moveToFirst()) it.getInt(0) else 0 } + ifOpen(0) { + val cursor = db.query("SELECT COUNT(*) FROM $tableName") + cursor.use { if (it.moveToFirst()) it.getInt(0) else 0 } + } } private fun createTable(db: SupportSQLiteDatabase) { @@ -224,7 +262,7 @@ class SQLiteIndex( } } - private fun insertBatch(entries: List) { + private fun insertBatchLocked(entries: List) { db.beginTransaction() try { for (entry in entries) { diff --git a/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/util/BackgroundIndexer.kt b/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/util/BackgroundIndexer.kt index 3d2e01a6ae..a61a86f84b 100644 --- a/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/util/BackgroundIndexer.kt +++ b/lsp/indexing/src/main/kotlin/org/appdevforall/codeonthego/indexing/util/BackgroundIndexer.kt @@ -9,6 +9,7 @@ import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.isActive import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking import org.appdevforall.codeonthego.indexing.api.Index import org.appdevforall.codeonthego.indexing.api.Indexable import org.slf4j.LoggerFactory @@ -43,11 +44,12 @@ sealed class IndexingEvent { */ class BackgroundIndexer( private val index: Index, - private val scope: CoroutineScope = CoroutineScope( - SupervisorJob() + Dispatchers.Default - ), + parentScope: CoroutineScope = CoroutineScope(Dispatchers.Default), ) : Closeable { + private val job = SupervisorJob(parentScope.coroutineContext[Job]) + private val scope = CoroutineScope(parentScope.coroutineContext + job) + companion object { private val log = LoggerFactory.getLogger(BackgroundIndexer::class.java) } @@ -163,7 +165,16 @@ class BackgroundIndexer( val activeJobCount: Int get() = activeJobs.size override fun close() { - activeJobs.values.forEach { it.cancel() } + val activeCount = activeJobCount + if (activeCount > 0) { + log.warn( + "Closing indexer with {} active job(s); cancellation is cooperative and close will wait for completion", + activeCount, + ) + } + runBlocking { + job.cancelAndJoin() + } activeJobs.clear() } } diff --git a/lsp/jvm-symbol-index/src/main/kotlin/org/appdevforall/codeonthego/indexing/jvm/JvmSymbolIndex.kt b/lsp/jvm-symbol-index/src/main/kotlin/org/appdevforall/codeonthego/indexing/jvm/JvmSymbolIndex.kt index 75b9019ba6..53dd8463ac 100644 --- a/lsp/jvm-symbol-index/src/main/kotlin/org/appdevforall/codeonthego/indexing/jvm/JvmSymbolIndex.kt +++ b/lsp/jvm-symbol-index/src/main/kotlin/org/appdevforall/codeonthego/indexing/jvm/JvmSymbolIndex.kt @@ -139,11 +139,7 @@ class JvmSymbolIndex( suspend fun awaitIndexing() = indexer.awaitAll() override fun close() { - super.close() - if (backing is AutoCloseable) { - backing.close() - } - indexer.close() + super.close() } }