diff --git a/eslint.config.js b/eslint.config.js index e1b4014..67c2d62 100644 --- a/eslint.config.js +++ b/eslint.config.js @@ -93,6 +93,7 @@ export default tseslint.config( "build/**", "*.config.js", "*.config.ts", + "src/**/*.test.{js,jsx,ts,tsx}", "vite-env.d.ts", ], }, diff --git a/package-lock.json b/package-lock.json index 6e5e7d7..6fe8bed 100644 --- a/package-lock.json +++ b/package-lock.json @@ -14,6 +14,7 @@ "@ai-sdk/openai": "^1.3.22", "@langchain/community": "^0.3.53", "@langchain/core": "^0.3.72", + "@mediapipe/tasks-genai": "^0.10.14", "ai": "^4.3.19", "dedent": "^1.7.0", "react-basic-contenteditable": "^1.0.6", @@ -2871,6 +2872,12 @@ "react": ">=16" } }, + "node_modules/@mediapipe/tasks-genai": { + "version": "0.10.25", + "resolved": "https://registry.npmjs.org/@mediapipe/tasks-genai/-/tasks-genai-0.10.25.tgz", + "integrity": "sha512-/ce/qNTC/CzzPJ5QRBz4MIoaYD+L65wkeE38ysb7saAZd38RZgzzVp1ZoLVFe/52LC3Ly/h2qcZWmkQf5PSJag==", + "license": "Apache-2.0" + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", diff --git a/package.json b/package.json index 02966a9..1c4cca1 100644 --- a/package.json +++ b/package.json @@ -67,6 +67,7 @@ "@ai-sdk/openai": "^1.3.22", "@langchain/community": "^0.3.53", "@langchain/core": "^0.3.72", + "@mediapipe/tasks-genai": "^0.10.14", "ai": "^4.3.19", "dedent": "^1.7.0", "react-basic-contenteditable": "^1.0.6", diff --git a/src/providers/index.ts b/src/providers/index.ts index 3d96a74..cab2539 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1 +1,2 @@ +export { MediaPipeProvider } from "./mediaPipeProvider"; export { UserTokenProvider } from "./userTokenProvider"; diff --git a/src/providers/mediaPipeProvider/cache/cache_constants.ts b/src/providers/mediaPipeProvider/cache/cache_constants.ts new file mode 100644 index 0000000..e708064 --- /dev/null +++ b/src/providers/mediaPipeProvider/cache/cache_constants.ts @@ -0,0 +1,4 @@ +// Shared constants for MediaPipe model caching +// Bump MODEL_CACHE_KEY_VERSION to invalidate existing logical entries without manually clearing DB. +export const MODEL_CACHE_DB_NAME = "clover-ai-mediapipe-provider-models"; +export const MODEL_CACHE_KEY_VERSION = "v1"; // increment when manifest/chunk schema or logic changes diff --git a/src/providers/mediaPipeProvider/cache/model_cache.ts b/src/providers/mediaPipeProvider/cache/model_cache.ts new file mode 100644 index 0000000..1b4ff63 --- /dev/null +++ b/src/providers/mediaPipeProvider/cache/model_cache.ts @@ -0,0 +1,131 @@ +/** + * Helper class to manage model caching with Web Worker + */ + +import { get_logger } from "../logger"; +const logger = get_logger("ModelCache"); + +interface ProgressCallback { + (loaded: number, total: number, percent: number, source?: "cache" | "network"): void; +} + +export class ModelCache { + #chunk_worker: Worker | null = null; + #worker: Worker | null = null; + + #cleanup(): void { + if (this.#worker) { + this.#worker.terminate(); + this.#worker = null; + } + } + + #cleanup_chunk() { + if (this.#chunk_worker) { + this.#chunk_worker.terminate(); + this.#chunk_worker = null; + } + } + + /** + * Cancel ongoing download and cleanup + */ + cancel(): void { + this.#cleanup(); + } + + /** + * Load model from cache or network in chunks using a Web Worker and return a ReadableStreamDefaultReader + */ + async load_model( + url: string, + modelKey: string, + onProgress?: ProgressCallback, + chunkSize: number = 8 * 1024 * 1024, + ): Promise> { + // Create a readable stream that we'll feed with worker chunk messages + let controller: ReadableStreamDefaultController; + const stream = new ReadableStream({ + start(c) { + controller = c; + }, + cancel: () => { + this.#cleanup_chunk(); + }, + }); + + return new Promise((resolve, reject) => { + this.#chunk_worker = new Worker(new URL("./model_cache_worker.ts", import.meta.url), { + type: "module", + }); + + let source: "cache" | "network" = "network"; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + this.#chunk_worker.onmessage = (event: MessageEvent) => { + const msg = event.data; + if (msg && msg.type) { + switch (msg.type) { + case "cache": + logger.debug("cache-hit manifest received", msg); + source = "cache"; + break; + case "progress": + if (msg.percent % 10 === 0) logger.debug("progress", msg.percent, "%"); + break; + case "chunk": + if (msg.index % 100 === 0 || msg.final) + logger.debug("chunk", msg.index, "final=", msg.final); + break; + case "complete": + logger.info( + "complete", + msg.parts, + "parts totalMB=", + (msg.totalSize / 1024 / 1024).toFixed(1), + ); + break; + case "error": + logger.error("error message", msg.error); + break; + } + } + + switch (msg.type) { + case "progress": + onProgress?.(msg.loaded, msg.total, msg.percent, source); + break; + case "cache": + onProgress?.(0, msg.totalSize, 0, source); + break; + case "chunk": + if (msg.arrayBuffer && msg.arrayBuffer.byteLength > 0) { + controller.enqueue(new Uint8Array(msg.arrayBuffer)); + } + if (msg.final) { + break; + } + break; + case "complete": + controller.close(); + resolve(stream.getReader()); + this.#cleanup_chunk(); + break; + case "error": + controller.error(new Error(msg.error)); + this.#cleanup_chunk(); + reject(new Error(msg.error)); + break; + } + }; + + this.#chunk_worker.onerror = (e) => { + controller.error(e); + this.#cleanup_chunk(); + reject(new Error(`Chunk worker error: ${e.message}`)); + }; + + this.#chunk_worker.postMessage({ type: "download", url, modelKey, chunkSize }); + }); + } +} diff --git a/src/providers/mediaPipeProvider/cache/model_cache_worker.ts b/src/providers/mediaPipeProvider/cache/model_cache_worker.ts new file mode 100644 index 0000000..bdfe890 --- /dev/null +++ b/src/providers/mediaPipeProvider/cache/model_cache_worker.ts @@ -0,0 +1,391 @@ +/** + * Chunked model downloader & cacher worker. + * Streams large model without ever materializing full ArrayBuffer. + * Messages: + * from main: + * { type: 'download', url, modelKey, chunkSize } + * to main: + * { type: 'progress', loaded, total, percent } + * { type: 'chunk', index, arrayBuffer, final } + * { type: 'complete', parts, totalSize, chunkSize } + * { type: 'error', error } + */ + +import { scoped } from "../logger"; +import { MODEL_CACHE_DB_NAME } from "./cache_constants"; +const logger = scoped("ChunkWorker"); + +interface DownloadMsg { + type: "download"; + url: string; + modelKey: string; + chunkSize: number; +} +interface ProgressMsg { + type: "progress"; + loaded: number; + total: number; + percent: number; +} +interface ChunkMsg { + type: "chunk"; + index: number; + arrayBuffer: ArrayBuffer; + final: boolean; +} +interface CompleteMsg { + type: "complete"; + parts: number; + totalSize: number; + chunkSize: number; +} +interface CacheMsg { + type: "cache"; + parts: number; + totalSize: number; + chunkSize: number; +} +interface ErrorMsg { + type: "error"; + error: string; +} + +// Use distinct constants to avoid name collisions if bundled together (shared) +const DB_NAME = MODEL_CACHE_DB_NAME; +const DB_VERSION = 1; // bump for chunk store +const CHUNK_STORE = "model_chunks"; +const MANIFEST_STORE = "model_manifests"; + +interface ManifestRecord { + key: string; // modelKey + totalSize: number; + parts: number; + chunkSize: number; + timestamp: number; + complete: boolean; +} + +class ChunkDB { + db: IDBDatabase | null = null; + + async delete_chunks(modelKey: string): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(CHUNK_STORE, "readwrite"); + const store = tx.objectStore(CHUNK_STORE); + const prefix = `${modelKey}:part:`; + const req = store.openCursor(); + req.onerror = () => rej(req.error); + req.onsuccess = () => { + const cursor = req.result; + if (!cursor) { + res(); + return; + } + if (typeof cursor.key === "string" && cursor.key.startsWith(prefix)) { + cursor.delete(); + } + cursor.continue(); + }; + }); + } + + async delete_manifest(key: string): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(MANIFEST_STORE, "readwrite"); + const store = tx.objectStore(MANIFEST_STORE); + const r = store.delete(key); + r.onerror = () => rej(r.error); + r.onsuccess = () => res(); + }); + } + + async get_chunk(modelKey: string, index: number): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(CHUNK_STORE, "readonly"); + const store = tx.objectStore(CHUNK_STORE); + const key = `${modelKey}:part:${index}`; + const r = store.get(key); + r.onerror = () => rej(r.error); + r.onsuccess = () => res(r.result ? r.result.data : null); + }); + } + + async get_manifest(modelKey: string): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(MANIFEST_STORE, "readonly"); + const store = tx.objectStore(MANIFEST_STORE); + const r = store.get(modelKey); + r.onerror = () => rej(r.error); + r.onsuccess = () => res(r.result || null); + }); + } + + async has_chunk(modelKey: string, index: number): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(CHUNK_STORE, "readonly"); + const store = tx.objectStore(CHUNK_STORE); + const key = `${modelKey}:part:${index}`; + const r = store.count(key); + r.onerror = () => rej(r.error); + r.onsuccess = () => res(r.result > 0); + }); + } + + init(): Promise { + return new Promise((resolve, reject) => { + const req = indexedDB.open(DB_NAME, DB_VERSION); + req.onerror = () => reject(req.error); + req.onupgradeneeded = () => { + const db = req.result; + if (!db.objectStoreNames.contains(CHUNK_STORE)) { + db.createObjectStore(CHUNK_STORE, { keyPath: "key" }); + } + if (!db.objectStoreNames.contains(MANIFEST_STORE)) { + db.createObjectStore(MANIFEST_STORE, { keyPath: "key" }); + } + }; + req.onsuccess = () => { + this.db = req.result; + resolve(); + }; + }); + } + + async put_chunk(modelKey: string, index: number, ab: ArrayBuffer): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(CHUNK_STORE, "readwrite"); + const store = tx.objectStore(CHUNK_STORE); + const r = store.put({ + key: `${modelKey}:part:${index}`, + data: ab, + size: ab.byteLength, + idx: index, + }); + r.onerror = () => rej(r.error); + r.onsuccess = () => res(); + }); + } + + async put_manifest(m: ManifestRecord): Promise { + if (!this.db) throw new Error("db not init"); + return new Promise((res, rej) => { + const tx = this.db!.transaction(MANIFEST_STORE, "readwrite"); + const store = tx.objectStore(MANIFEST_STORE); + const r = store.put(m); + r.onerror = () => rej(r.error); + r.onsuccess = () => res(); + }); + } +} + +const db = new ChunkDB(); + +async function stream_download(msg: DownloadMsg) { + const { url, modelKey, chunkSize } = msg; + logger.info(`init modelKey=${modelKey} chunkSizeMB=${(chunkSize / 1024 / 1024).toFixed(2)}`); + await db.init(); + + const manifest = await db.get_manifest(modelKey); + if (manifest) { + logger.info( + `cache manifest parts=${manifest.parts} totalMB=${(manifest.totalSize / 1024 / 1024).toFixed(1)} complete=${manifest.complete}`, + ); + if (!manifest.complete) { + logger.warn("manifest incomplete -> purge"); + await db.delete_manifest(modelKey); + await db.delete_chunks(modelKey); + } else { + let loaded = 0; + self.postMessage({ + type: "cache", + parts: manifest.parts, + totalSize: manifest.totalSize, + chunkSize: manifest.chunkSize, + } as CacheMsg); + let allGood = true; + for (let i = 0; i < manifest.parts; i++) { + const ab = await db.get_chunk(modelKey, i); + if (!ab) { + logger.warn(`missing cached chunk ${i}; will redownload`); + allGood = false; + break; + } + loaded += ab.byteLength; + if (i % 50 === 0 || i === manifest.parts - 1) { + logger.debug( + `cached chunk ${i + 1}/${manifest.parts} loadedMB=${(loaded / 1024 / 1024).toFixed(1)}`, + ); + } + const percent = manifest.totalSize ? Math.round((loaded / manifest.totalSize) * 100) : 0; + self.postMessage({ + type: "progress", + loaded, + total: manifest.totalSize, + percent, + } as ProgressMsg); + self.postMessage( + { type: "chunk", index: i, arrayBuffer: ab, final: i === manifest.parts - 1 } as ChunkMsg, + { transfer: [ab] }, + ); + } + if (allGood && loaded === manifest.totalSize) { + logger.info(`cache stream complete`); + self.postMessage({ + type: "complete", + parts: manifest.parts, + totalSize: manifest.totalSize, + chunkSize: manifest.chunkSize, + } as CompleteMsg); + return; + } + logger.warn("cache corrupted -> purge"); + await db.delete_manifest(modelKey); + await db.delete_chunks(modelKey); + } + } + + // Network path + const resp = await fetch(url); + if (!resp.ok || !resp.body) { + self.postMessage({ type: "error", error: `http ${resp.status}` } as ErrorMsg); + return; + } + + const totalHeader = resp.headers.get("content-length"); + const total = totalHeader ? parseInt(totalHeader, 10) : 0; + logger.info(`network start sizeMB=${total ? (total / 1024 / 1024).toFixed(1) : "unknown"}`); + const reader = resp.body.getReader(); + let loadedNet = 0; + let partIndex = 0; + let current = new Uint8Array(chunkSize); + let offset = 0; + let parts = 0; + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + if (!value) continue; + const src = value; + let srcOff = 0; + while (srcOff < src.length) { + const toCopy = Math.min(chunkSize - offset, src.length - srcOff); + current.set(src.subarray(srcOff, srcOff + toCopy), offset); + offset += toCopy; + srcOff += toCopy; + if (offset === chunkSize) { + const full = current.buffer.slice(0); + await db.put_chunk(modelKey, partIndex, full); + self.postMessage( + { type: "chunk", index: partIndex, arrayBuffer: full, final: false } as ChunkMsg, + { transfer: [full] }, + ); + if (partIndex % 50 === 0) { + logger.debug( + `stored chunk ${partIndex} loadedMB=${(loadedNet / 1024 / 1024).toFixed(1)}`, + ); + } + partIndex++; + parts++; + current = new Uint8Array(chunkSize); + offset = 0; + } + } + loadedNet += value.length; + const percent = total ? Math.round((loadedNet / total) * 100) : 0; + self.postMessage({ + type: "progress", + loaded: loadedNet, + total, + percent, + } as ProgressMsg); + if (percent % 10 === 0) + logger.debug(`progress ${percent}% loadedMB=${(loadedNet / 1024 / 1024).toFixed(1)}`); + } + if (offset > 0) { + const finalBuf = current.slice(0, offset).buffer; + await db.put_chunk(modelKey, partIndex, finalBuf); + self.postMessage( + { type: "chunk", index: partIndex, arrayBuffer: finalBuf, final: true } as ChunkMsg, + { transfer: [finalBuf] }, + ); + parts++; + } else if (parts > 0) { + self.postMessage({ + type: "chunk", + index: partIndex - 1, + arrayBuffer: new ArrayBuffer(0), + final: true, + } as ChunkMsg); + } + // Integrity verification (robust): attempt to confirm contiguous chunks exist. + try { + let verifiedParts = 0; + let verifiedTotal = 0; + for (let i = 0; i < parts; i++) { + const ab = await db.get_chunk(modelKey, i); + if (!ab) { + logger.warn(`integrity: missing chunk ${i} (expected parts=${parts}); truncating cache`); + break; + } + verifiedParts++; + verifiedTotal += ab.byteLength; + } + if (verifiedParts !== parts) { + logger.warn( + `integrity: adjusted parts from ${parts} -> ${verifiedParts} totalMB=${(verifiedTotal / 1024 / 1024).toFixed(1)}`, + ); + parts = verifiedParts; + loadedNet = verifiedTotal; // reflect actual stored size + } else { + logger.debug( + `integrity: verified ${verifiedParts} parts totalMB=${(verifiedTotal / 1024 / 1024).toFixed(1)}`, + ); + } + } catch (verErr) { + logger.warn("integrity verification failed; proceeding without adjustment", verErr); + } + try { + await db.put_manifest({ + key: modelKey, + totalSize: loadedNet, + parts, + chunkSize, + timestamp: Date.now(), + complete: true, + }); + } catch (mErr) { + logger.error("putManifest failed", mErr); + // Continue; main thread will treat as network path next time. + } + logger.info(`network complete parts=${parts} totalMB=${(loadedNet / 1024 / 1024).toFixed(1)}`); + self.postMessage({ + type: "complete", + parts, + totalSize: loadedNet, + chunkSize, + } as CompleteMsg); + } catch (e) { + logger.error("error", e); + self.postMessage({ + type: "error", + error: e instanceof Error ? e.message : "unknown", + } as ErrorMsg); + } finally { + reader.releaseLock(); + logger.debug("reader released"); + } +} + +self.onmessage = (ev: MessageEvent) => { + if (ev.data.type === "download") { + stream_download(ev.data); + } else { + self.postMessage({ type: "error", error: "unknown message" } as ErrorMsg); + } +}; diff --git a/src/providers/mediaPipeProvider/index.tsx b/src/providers/mediaPipeProvider/index.tsx new file mode 100644 index 0000000..89cf91f --- /dev/null +++ b/src/providers/mediaPipeProvider/index.tsx @@ -0,0 +1,395 @@ +import { Button, Heading } from "@components"; +import { FilesetResolver, LlmInference } from "@mediapipe/tasks-genai"; +import type { AssistantMessage, Message } from "@types"; +import { BaseProvider } from "../../plugin/base_provider"; +import { MODEL_CACHE_DB_NAME, MODEL_CACHE_KEY_VERSION } from "./cache/cache_constants"; +import { ModelCache } from "./cache/model_cache"; +import { get_logger, set_log_level, type LogLevel } from "./logger"; +import styles from "./style.module.css"; + +type MediaPipeOptions = { + max_num_images?: number; + max_tokens?: number; + model_asset_path?: string; + random_seed?: number; + temperature?: number; + top_k?: number; + wasm_base_path?: string; +}; + +type MediaPipeProviderOptions = { + log_level?: LogLevel; + media_pipe_options?: MediaPipeOptions; +}; + +/** + * A provider that runs a local (on-device) LLM using MediaPipe @mediapipe/tasks-genai. + */ +export class MediaPipeProvider extends BaseProvider { + #cached_model_reader: ReadableStreamDefaultReader | null = null; + #llm: LlmInference | null = null; + #logger = get_logger("MediaPipeProvider"); + #media_pipe_options: Required; + #model_cache: ModelCache; + #progress_message: string = ""; + #setup_state: "consenting" | "loading" | "error" = "consenting"; + + constructor(options: MediaPipeProviderOptions = { log_level: "info", media_pipe_options: {} }) { + super(); + if (options.log_level) { + set_log_level(options.log_level); + } + this.#model_cache = new ModelCache(); + this.#media_pipe_options = { + max_num_images: options.media_pipe_options?.max_num_images ?? 5, + max_tokens: options.media_pipe_options?.max_tokens ?? 1000, + model_asset_path: + options.media_pipe_options?.model_asset_path ?? + "https://huggingface.co/charlesLoder/gemma-3n-E4B-it-litert-lm/resolve/main/gemma-3n-E4B-it-int4-Web.litertlm", + random_seed: options.media_pipe_options?.random_seed ?? 101, + temperature: options.media_pipe_options?.temperature ?? 0.8, + top_k: options.media_pipe_options?.top_k ?? 40, + wasm_base_path: + options.media_pipe_options?.wasm_base_path ?? + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai@latest/wasm", + }; + } + + /** + * Build prompt sequence following MediaPipe's multimodal format + */ + #build_prompt( + messages: Message[], + conversationHistory: Message[], + ): Array { + const sequence: Array = []; + + // Add system prompt if present + const systemPrompt = this.plugin_state.systemPrompt; + if (systemPrompt) { + sequence.push(systemPrompt + "\n"); + } + + const allMessages = [...conversationHistory, ...messages]; + + for (const message of allMessages) { + if (message.role === "system") continue; + + if (message.role === "user") { + sequence.push("user\n"); + + for (const content of message.content) { + if (content.type === "text") { + sequence.push(content.content); + } else if (content.type === "media" && content.content.src) { + sequence.push({ imageSource: content.content.src }); + } + } + + sequence.push("\n"); + } else if (message.role === "assistant") { + if (message.type === "response") { + sequence.push(`model\n${message.content.content}\n`); + } + } + } + + // Start model turn for response + sequence.push("model\n"); + return sequence; + } + + /** + * Clear cached model data (IndexedDB database) then invoke a callback. + */ + #clear_cache(on_cleared: () => void) { + const MAX_RETRIES = 5; + const RETRY_DELAY_MS = 500; + let attempt = 0; + const tryDelete = () => { + attempt++; + try { + this.#logger.info( + `clearing model cache (IndexedDB ${MODEL_CACHE_DB_NAME}) attempt=${attempt}`, + ); + const req = indexedDB.deleteDatabase(MODEL_CACHE_DB_NAME); + req.onsuccess = () => { + this.#logger.info("cache cleared"); + on_cleared(); + }; + req.onerror = () => { + this.#logger.warn("cache clear error", req.error); + on_cleared(); // proceed anyway + }; + req.onblocked = () => { + this.#logger.warn("cache clear blocked (open connections)"); + if (attempt < MAX_RETRIES) { + setTimeout(tryDelete, RETRY_DELAY_MS); + } else { + this.#logger.warn("cache clear giving up after retries"); + on_cleared(); + } + }; + } catch (e) { + this.#logger.warn("cache clear threw", e); + on_cleared(); + } + }; + tryDelete(); + } + + async #load_model() { + try { + this.#logger.debug("#load_model start"); + this.#progress_message = "Initializing WASM runtime..."; + this.update_plugin_provider(); + const genai = await FilesetResolver.forGenAiTasks(this.#media_pipe_options.wasm_base_path); + + if (!this.#cached_model_reader) { + this.#logger.info("starting streaming download"); + this.#progress_message = "Preparing chunked download..."; + this.update_plugin_provider(); + + const model_key = `gemma-3n-E4B-it-${MODEL_CACHE_KEY_VERSION}-${this.#media_pipe_options.model_asset_path.split("/").pop()}`; + // Throttled progress state (closure locals) + let last_percent = -1; + let last_time = 0; + let last_loaded_mb_shown = -1; + const PERCENT_STEP_NETWORK = 2; // only update UI every 2% + const PERCENT_STEP_CACHE = 10; // cache loads are fast; every 10% + const MIN_INTERVAL_MS = 900; // or at least every 0.9s + const MIN_MB_STEP = 64; // if total unknown, every 64MB + + this.#cached_model_reader = await this.#model_cache.load_model( + this.#media_pipe_options.model_asset_path, + model_key, + (loaded, total, percent, source) => { + const now = performance.now(); + const from_cache = source === "cache"; + const loaded_mb = loaded / 1024 / 1024; + const percent_step = from_cache ? PERCENT_STEP_CACHE : PERCENT_STEP_NETWORK; + + function calc_update(): boolean { + // always show completion + if (percent === 100) { + return true; + } + + // first update + if (last_percent < 0) { + return true; + } + + if (percent - last_percent >= percent_step) { + return true; + } + + // Unknown total size: use MB step + if ( + !total && + (last_loaded_mb_shown < 0 || loaded_mb - last_loaded_mb_shown >= MIN_MB_STEP) + ) { + return true; + } + + return false; + } + + let should_update = calc_update(); + + // Time-based backstop + if (!should_update && now - last_time >= MIN_INTERVAL_MS) { + should_update = true; + } + + // skip noisy update + if (!should_update) { + return; + } + + last_percent = percent; + last_time = now; + last_loaded_mb_shown = loaded_mb; + + const loaded_mb_disp = Math.round(loaded_mb); + const total_mb_disp = total ? Math.round(total / 1024 / 1024) : undefined; + const label = from_cache ? "Load from cache" : "Downloading (stream)"; + + this.#progress_message = + total_mb_disp !== undefined && total_mb_disp > 0 + ? `${label}: ${loaded_mb_disp}MB / ${total_mb_disp}MB (${percent}%)` + : `${label}: ${loaded_mb_disp}MB received`; + + this.#logger.debug("progress(throttled)", { + percent, + loaded_mb_disp, + total_mb_disp, + source, + }); + + // Update UI + this.update_plugin_provider(); + }, + ); + + this.#logger.info("streaming reader ready"); + } + + this.#progress_message = "Initializing model (this may take a moment)..."; + this.update_plugin_provider(); + + if (!this.#cached_model_reader) { + throw new Error("Model reader missing after download step"); + } + + this.#logger.info("creating LlmInference", { streaming: !!this.#cached_model_reader }); + this.#llm = await LlmInference.createFromOptions(genai, { + baseOptions: { + modelAssetBuffer: this.#cached_model_reader, + }, + maxTokens: this.#media_pipe_options.max_tokens, + topK: this.#media_pipe_options.top_k, + temperature: this.#media_pipe_options.temperature, + randomSeed: this.#media_pipe_options.random_seed, + maxNumImages: this.#media_pipe_options.max_num_images, + }); + this.#logger.info("model createFromOptions complete"); + + this.#progress_message = "Model loaded successfully!"; + this.update_plugin_provider(); + } catch (err) { + this.#logger.error("MediaPipe model load failed", err); + this.#progress_message = + err instanceof Error ? `${err.message}` : "An unknown error occurred."; + this.#setup_state = "error"; + this.update_plugin_provider(); + } + } + + get status() { + return this.#llm ? "ready" : "initializing"; + } + + async generate_response(messages: Message[], conversationHistory: Message[]): Promise { + try { + if (!this.#llm) { + throw new Error("Model not loaded"); + } + + const promptSequence = this.#build_prompt(messages, conversationHistory); + this.set_conversation_state("assistant_responding"); + + const assistantMessage: AssistantMessage = { + role: "assistant", + type: "response", + content: { type: "text", content: "" }, + }; + this.add_messages([assistantMessage as Message]); + await this.#llm.generateResponse(promptSequence, (partial: string, done: boolean) => { + assistantMessage.content.content += partial; + this.update_last_message(assistantMessage); + if (done) { + this.set_conversation_state("idle"); + } + }); + } catch (err) { + this.#logger.error("Inference failed", err); + this.set_conversation_state("error"); + } + } + + SetupComponent() { + const handleConsent = () => { + this.#setup_state = "loading"; + this.update_plugin_provider(); + this.#load_model(); + }; + + const handleForceReload = () => { + // Reset any existing in-memory state + this.#cached_model_reader?.cancel?.().catch(() => {}); + this.#cached_model_reader = null; + this.#progress_message = "Clearing cached model..."; + this.update_plugin_provider(); + this.#clear_cache(() => { + this.#setup_state = "loading"; + this.#progress_message = "Preparing chunked download..."; + this.update_plugin_provider(); + this.#load_model(); + }); + }; + + const handleRetry = () => { + this.#setup_state = "consenting"; + this.#progress_message = ""; + this.update_plugin_provider(); + }; + + // Consenting screen + if (this.#setup_state === "consenting") { + return ( +
+ MediaPipe Local LLM +

+ This will download and run a Gemma 3N model locally in your browser. The model will be + cached for future use. +

+
+ Requirements: +
    +
  • WebGPU-compatible browser
  • +
  • ~1.7GB download (model will be cached locally)
  • +
  • Sufficient device memory for on-device inference
  • +
+
+

+ By proceeding, you agree to the{" "} + + Gemma Terms of Use + + . +

+ + +
+ ); + } + + if (this.#setup_state === "error") { + return ( +
+ Error Loading Model +

The following error occurred while loading the model:

+

{this.#progress_message}

+

+ Check the browser console for detailed error information. +

+ +
+ ); + } + + return ( +
+ Loading Model +

{this.#progress_message || "Starting download..."}

+
+ ); + } +} diff --git a/src/providers/mediaPipeProvider/logger.ts b/src/providers/mediaPipeProvider/logger.ts new file mode 100644 index 0000000..5cb1a29 --- /dev/null +++ b/src/providers/mediaPipeProvider/logger.ts @@ -0,0 +1,65 @@ +/** + * Simple scoped logger with global log level control. + * Levels: silent < error < warn < info < debug + */ +export type LogLevel = "silent" | "error" | "warn" | "info" | "debug"; + +const order: Record = { + silent: 0, + error: 1, + warn: 2, + info: 3, + debug: 4, +}; + +let current: LogLevel = "info"; + +export function set_log_level(level: LogLevel) { + if (order[level] == null) return; + current = level; +} + +export function get_log_level(): LogLevel { + return current; +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type Fn = (...args: any[]) => void; +export interface Logger { + debug: Fn; + info: Fn; + warn: Fn; + error: Fn; + scope: string; +} + +function make(method: "debug" | "info" | "warn" | "error", scope: string): Fn { + const needed: Record = { + debug: "debug", + info: "info", + warn: "warn", + error: "error", + } as const; + const min = order[needed[method]]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (...args: any[]) => { + if (order[current] < min) { + return; + } + + // eslint-disable-next-line no-console + console[method](`[${scope}]`, ...args); + }; +} + +export function get_logger(scope: string): Logger { + return { + scope, + debug: make("debug", scope), + info: make("info", scope), + warn: make("warn", scope), + error: make("error", scope), + }; +} + +export const scoped = get_logger; diff --git a/src/providers/mediaPipeProvider/style.module.css b/src/providers/mediaPipeProvider/style.module.css new file mode 100644 index 0000000..a37115f --- /dev/null +++ b/src/providers/mediaPipeProvider/style.module.css @@ -0,0 +1,56 @@ +.requirementsBox { + padding: var(--clover-ai-space-4); + background-color: var(--clover-ai-colors-secondaryMuted); + border-radius: var(--clover-ai-radii-3); + font-size: var(--clover-ai-fontSizes-2); + + ul { + color: var(--clover-ai-colors-primary); + list-style-position: inside; + padding: 0; + margin: 0; + } +} + +.termsText { + font-size: var(--clover-ai-fontSizes-1); + color: color-mix(in srgb, var(--clover-ai-colors-primary) 80%, transparent); + + a { + color: var(--clover-ai-colors-accent); + } + + a:hover { + color: var(--clover-ai-colors-accentAlt); + } +} + +.errorBox { + padding: var(--clover-ai-space-4); + background-color: var(--clover-ai-colors-errorMuted); + border-radius: var(--clover-ai-radii-3); + font-size: var(--clover-ai-fontSizes-2); + color: var(--clover-ai-colors-error); +} + +.errorNote { + font-size: var(--clover-ai-fontSizes-1); + opacity: 0.8; + color: var(--clover-ai-colors-primary); +} + +.modelInfo { + padding: calc(var(--clover-ai-space-3) * 0.75); + background-color: var(--clover-ai-colors-secondaryMuted); + border-radius: var(--clover-ai-radii-2); + font-size: var(--clover-ai-fontSizes-1); + opacity: 0.7; + color: var(--clover-ai-colors-primary); +} + +.modelCode { + font-family: monospace; + background-color: var(--clover-ai-colors-secondaryAlt); + padding: calc(var(--clover-ai-space-1) * 0.5) var(--clover-ai-space-1); + border-radius: var(--clover-ai-radii-1); +}