From 19b046fd14e595b041a5546089bad14cb57df807 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Fri, 17 Apr 2026 17:08:38 +0800 Subject: [PATCH 01/52] refactor: replace markdown_widget with mixin_markdown_widget and update related code --- lib/widgets/markdown.dart | 445 ++++++++++--------------- macos/Runner.xcodeproj/project.pbxproj | 2 +- pubspec.lock | 73 ++-- pubspec.yaml | 3 +- 4 files changed, 223 insertions(+), 300 deletions(-) diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index 1e2ff2dc2d..01936fc36b 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -1,49 +1,43 @@ +import 'dart:io'; + import 'package:flutter/material.dart'; -import 'package:html/dom.dart' as h; -import 'package:html/dom_parsing.dart'; -import 'package:html/parser.dart'; -import 'package:markdown/markdown.dart' as m; -import 'package:markdown_widget/markdown_widget.dart'; -import 'package:mixin_logger/mixin_logger.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; +import 'package:mixin_markdown_widget/mixin_markdown_widget.dart'; +import '../ui/provider/setting_provider.dart'; import '../utils/extension/extension.dart'; import '../utils/uri_utils.dart'; -import 'high_light_text.dart'; import 'mixin_image.dart'; -class MarkdownColumn extends StatelessWidget { +class MarkdownColumn extends ConsumerWidget { const MarkdownColumn({required this.data, super.key}); final String data; @override - Widget build(BuildContext context) { - final widgets = - MarkdownGenerator( - textGenerator: (node, config, visitor) => - CustomTextNode(node.textContent, config, visitor), - generators: _kMixinGenerators, - richTextBuilder: CustomText.rich, - ).buildWidgets( - data, - config: _createMarkdownConfig( - context: context, - darkMode: context.brightness == Brightness.dark, - ), - ); + Widget build(BuildContext context, WidgetRef ref) { + final chatFontSizeDelta = ref.watch( + settingProvider.select((value) => value.chatFontSizeDelta), + ); + return ClipRect( - child: DefaultTextStyle.merge( - style: TextStyle(color: context.theme.text), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: widgets, - ), + child: MarkdownWidget( + data: data, + useColumn: true, + selectable: false, + padding: EdgeInsets.zero, + theme: _createMarkdownTheme(context, chatFontSizeDelta), + imageBuilder: _buildMarkdownImage, + onTapLink: (destination, title, label) { + if (destination.isEmpty) return; + openUri(context, destination); + }, ), ); } } -class Markdown extends StatelessWidget { +class Markdown extends ConsumerWidget { const Markdown({ required this.data, super.key, @@ -56,266 +50,187 @@ class Markdown extends StatelessWidget { final ScrollPhysics? physics; @override - Widget build(BuildContext context) => DefaultTextStyle.merge( - style: TextStyle(color: context.theme.text), - child: MarkdownWidget( + Widget build(BuildContext context, WidgetRef ref) { + final chatFontSizeDelta = ref.watch( + settingProvider.select((value) => value.chatFontSizeDelta), + ); + + return MarkdownWidget( data: data, padding: padding, physics: physics, - config: _createMarkdownConfig( - context: context, - darkMode: context.brightness == Brightness.dark, - ), - markdownGenerator: MarkdownGenerator( - textGenerator: (node, config, visitor) => - CustomTextNode(node.textContent, config, visitor), - generators: _kMixinGenerators, - richTextBuilder: CustomText.rich, - ), - ), - ); -} - -MarkdownConfig _createMarkdownConfig({ - required BuildContext context, - required bool darkMode, -}) => MarkdownConfig( - configs: [ - if (darkMode) ...[ - HrConfig.darkConfig, - H2Config.darkConfig, - H3Config.darkConfig, - H4Config.darkConfig, - H5Config.darkConfig, - H6Config.darkConfig, - PreConfig.darkConfig, - PConfig.darkConfig, - CodeConfig.darkConfig, - ], - _MixinH1Config(darkMode), - ImgConfig( - builder: (url, attributes) { - double? width; - double? height; - if (attributes['width'] != null) { - width = double.parse(attributes['width']!); - } - if (attributes['height'] != null) { - height = double.parse(attributes['height']!); - } - final imageUrl = url; - return ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 400), - child: MixinImage.network(imageUrl, width: width, height: height), - ); + theme: _createMarkdownTheme(context, chatFontSizeDelta), + imageBuilder: _buildMarkdownImage, + onTapLink: (destination, title, label) { + if (destination.isEmpty) return; + openUri(context, destination); }, - ), - LinkConfig( - style: TextStyle(color: context.theme.accent), - onTap: (href) { - if (href.isEmpty) return; - openUri(context, href); - }, - ), - ListConfig( - marker: (isOrdered, depth, index) { - final style = DefaultTextStyle.of(context).style; - final height = (style.fontSize ?? 16) * (style.height ?? 1.25); - return getDefaultMarker( - isOrdered, - depth, - context.theme.text, - index, - height / 2 + 1, - MarkdownConfig(), - ); - }, - ), - ], -); - -class _MixinH1Config extends HeadingConfig { - _MixinH1Config(this.dark); + ); + } +} - final bool dark; +Widget _buildMarkdownImage( + BuildContext context, + ImageBlock block, + MarkdownThemeData theme, +) { + final uri = Uri.tryParse(block.url); + final width = _tryParseImageDimension(uri, 'w', 'width'); + final height = _tryParseImageDimension(uri, 'h', 'height'); + + Widget errorBuilder(BuildContext context, Object error, StackTrace? stack) { + final iconColor = theme.bodyStyle.color?.withValues(alpha: 0.72); + if (width != null && height != null) { + return Container( + width: width, + height: height, + color: theme.imagePlaceholderBackgroundColor, + alignment: Alignment.center, + child: Icon(Icons.broken_image_outlined, color: theme.dividerColor), + ); + } - @override - HeadingDivider? get divider => null; + return Container( + padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 6), + decoration: BoxDecoration( + color: theme.imagePlaceholderBackgroundColor, + borderRadius: theme.imageBorderRadius, + ), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Icon(Icons.broken_image_outlined, size: 18, color: iconColor), + const SizedBox(width: 8), + Flexible( + child: Text( + block.alt?.isNotEmpty == true ? block.alt! : 'Image', + style: theme.bodyStyle.copyWith(color: iconColor), + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + ], + ), + ); + } - @override - TextStyle get style => TextStyle( - fontSize: 32, - height: 40 / 32, - color: dark ? Colors.white : null, - fontWeight: FontWeight.bold, + final image = _buildMixinImageForUrl( + block.url, + width: width, + height: height, + errorBuilder: errorBuilder, ); - @override - String get tag => MarkdownTag.h1.name; + return ClipRRect( + borderRadius: theme.imageBorderRadius, + child: image, + ); } -final RegExp htmlRep = RegExp('<[^>]*>', multiLine: true); +double? _tryParseImageDimension(Uri? uri, String shortKey, String fullKey) { + if (uri == null) return null; + final value = uri.queryParameters[shortKey] ?? uri.queryParameters[fullKey]; + return value == null ? null : double.tryParse(value); +} -/// parse [m.Node] to [h.Node] -/// https://github.com/asjqkkkk/markdown_widget/blob/1d549fd5c2d6b0172281d8bb66e367654b9d60f0/example/lib/markdown_custom/html_support.dart -List _parseHtml( - m.Text node, { - ValueCallback? onError, - WidgetVisitor? visitor, - TextStyle? parentStyle, +Widget _buildMixinImageForUrl( + String url, { + double? width, + double? height, + ImageErrorWidgetBuilder? errorBuilder, }) { - try { - final text = node.textContent.replaceAll( - RegExp(r'(\r?\n)|(\r?\t)|(\r)'), - '', + final uri = Uri.tryParse(url); + if (uri != null && (uri.scheme == 'http' || uri.scheme == 'https')) { + return MixinImage.network( + url, + width: width, + height: height, + errorBuilder: errorBuilder, ); - if (!text.contains(htmlRep)) return [TextNode(text: node.text)]; - final document = parseFragment(text); - return HtmlToSpanVisitor( - visitor: visitor, - parentStyle: parentStyle, - ).toVisit(document.nodes.toList()); - } catch (e) { - onError?.call(e); - return [TextNode(text: node.text)]; } -} -class HtmlElement extends m.Element { - HtmlElement(super.tag, super.children, this.textContent); - - @override - final String textContent; -} - -class HtmlToSpanVisitor extends TreeVisitor { - HtmlToSpanVisitor({WidgetVisitor? visitor, TextStyle? parentStyle}) - : visitor = visitor ?? WidgetVisitor(), - parentStyle = parentStyle ?? const TextStyle(); - final List _spans = []; - final List _spansStack = []; - final WidgetVisitor visitor; - final TextStyle parentStyle; - - List toVisit(List nodes) { - _spans.clear(); - for (final node in nodes) { - final emptyNode = ConcreteElementNode(style: parentStyle); - _spans.add(emptyNode); - _spansStack.add(emptyNode); - visit(node); - _spansStack.removeLast(); - } - final result = List.of(_spans); - _spans.clear(); - _spansStack.clear(); - return result; + if (uri != null && uri.scheme == 'file') { + return MixinImage.file( + File.fromUri(uri), + width: width, + height: height, + errorBuilder: errorBuilder, + ); } - @override - void visitText(h.Text node) { - final last = _spansStack.last; - if (last is ElementNode) { - final textNode = TextNode(text: node.text); - last.accept(textNode); - } + final file = File(url); + if (file.isAbsolute) { + return MixinImage.file( + file, + width: width, + height: height, + errorBuilder: errorBuilder, + ); } - @override - void visitElement(h.Element node) { - final localName = node.localName ?? ''; - final mdElement = m.Element(localName, []); - mdElement.attributes.addAll(node.attributes.cast()); - var spanNode = visitor.getNodeByElement(mdElement, visitor.config); - if (spanNode is! ElementNode) { - final n = ConcreteElementNode(tag: localName)..accept(spanNode); - spanNode = n; - } - final last = _spansStack.last; - if (last is ElementNode) { - last.accept(spanNode); - } - _spansStack.add(spanNode); - node.nodes.toList(growable: false).forEach(visit); - _spansStack.removeLast(); - } + return MixinImage.asset( + url, + width: width, + height: height, + errorBuilder: errorBuilder, + ); } -class CustomTextNode extends ElementNode { - CustomTextNode(this.text, this.config, this.visitor); - - final String text; - final MarkdownConfig config; - final WidgetVisitor visitor; - - @override - void onAccepted(SpanNode parent) { - final textStyle = config.p.textStyle.merge(parentStyle); - children.clear(); - if (!text.contains(htmlRep)) { - accept(TextNode(text: text, style: textStyle)); - return; - } - _parseHtml( - m.Text(text), - visitor: WidgetVisitor( - config: visitor.config, - generators: visitor.generators, - ), - parentStyle: parentStyle, - ).forEach(accept); +MarkdownThemeData _createMarkdownTheme( + BuildContext context, + double chatFontSizeDelta, +) { + final base = MarkdownThemeData.fallback(context); + final textColor = context.theme.text; + final accentColor = context.theme.accent; + final codeBlockBackgroundColor = context.theme.chatBackground; + + TextStyle applyTextColor(TextStyle style) => style.copyWith(color: textColor); + TextStyle applyFontSizeDelta(TextStyle style) { + final fontSize = style.fontSize; + if (fontSize == null) return style; + return style.copyWith(fontSize: fontSize + chatFontSizeDelta); } -} - -final _kMixinGenerators = [ - SpanNodeGeneratorWithTag( - tag: MarkdownTag.pre.name, - generator: (e, config, visitor) => - _MixinCodeBlockNode(e, config.pre, visitor), - ), -]; -class _MixinCodeBlockNode extends CodeBlockNode { - _MixinCodeBlockNode(super.content, super.preConfig, super.visitor); + TextStyle applyTextStyle(TextStyle style) => + applyTextColor(applyFontSizeDelta(style)); - @override - InlineSpan build() { - var language = preConfig.language; - try { - final languageValue = - (element.children!.first as m.Element).attributes['class']!; - language = languageValue.split('-').last; - } catch (e) { - i('get language error:$e'); - } - final splitContents = content.trim().split(RegExp(r'(\r?\n)|(\r?\t)|(\r)')); - if (splitContents.last.isEmpty) splitContents.removeLast(); - final widget = Container( - decoration: preConfig.decoration, - margin: preConfig.margin, - padding: preConfig.padding, - width: double.infinity, - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: List.generate(splitContents.length, (index) { - final currentContent = splitContents[index]; - return ProxyRichText( - TextSpan( - children: highLightSpans( - currentContent, - language: preConfig.language, - theme: preConfig.theme, - textStyle: style, - styleNotMatched: preConfig.styleNotMatched, - ), - ), - richTextBuilder: visitor.richTextBuilder, - ); - }), + return base.copyWith( + bodyStyle: applyTextStyle(base.bodyStyle), + quoteStyle: applyFontSizeDelta( + base.quoteStyle.copyWith( + color: textColor.withValues(alpha: 0.82), ), - ); - return WidgetSpan( - child: preConfig.wrapper?.call(widget, content, language) ?? widget, - ); - } + ), + linkStyle: base.linkStyle.copyWith( + color: accentColor, + decorationColor: accentColor, + fontSize: + (base.linkStyle.fontSize ?? base.bodyStyle.fontSize ?? 16) + + chatFontSizeDelta, + ), + inlineCodeStyle: applyTextStyle(base.inlineCodeStyle), + codeBlockStyle: applyTextStyle(base.codeBlockStyle), + codeBlockBackgroundColor: codeBlockBackgroundColor, + inlineCodeBackgroundColor: codeBlockBackgroundColor, + quoteBackgroundColor: codeBlockBackgroundColor, + tableHeaderStyle: applyTextStyle(base.tableHeaderStyle), + heading1Style: applyTextStyle( + applyFontSizeDelta( + base.heading1Style.copyWith( + fontSize: 32, + height: 40 / 32, + fontWeight: FontWeight.bold, + ), + ), + ), + heading2Style: applyTextStyle(base.heading2Style), + heading3Style: applyTextStyle(base.heading3Style), + heading4Style: applyTextStyle(base.heading4Style), + heading5Style: applyTextStyle(base.heading5Style), + heading6Style: applyTextStyle(base.heading6Style), + quoteBorderColor: accentColor.withValues(alpha: 0.4), + selectionColor: accentColor.withValues(alpha: 0.24), + showHeading1Divider: false, + ); } diff --git a/macos/Runner.xcodeproj/project.pbxproj b/macos/Runner.xcodeproj/project.pbxproj index ef954225eb..078af80835 100644 --- a/macos/Runner.xcodeproj/project.pbxproj +++ b/macos/Runner.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 60; + objectVersion = 54; objects = { /* Begin PBXAggregateTarget section */ diff --git a/pubspec.lock b/pubspec.lock index f422cb1752..dd5b1cd2ef 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -667,14 +667,6 @@ packages: url: "https://pub.dev" source: hosted version: "9.1.1" - flutter_highlight: - dependency: transitive - description: - name: flutter_highlight - sha256: "7b96333867aa07e122e245c033b8ad622e4e3a42a1a2372cbb098a2541d8782c" - url: "https://pub.dev" - source: hosted - version: "0.7.0" flutter_hooks: dependency: "direct main" description: @@ -728,6 +720,14 @@ packages: description: flutter source: sdk version: "0.0.0" + flutter_math_fork: + dependency: transitive + description: + name: flutter_math_fork + sha256: "6d5f2f1aa57ae539ffb0a04bb39d2da67af74601d685a161aff7ce5bda5fa407" + url: "https://pub.dev" + source: hosted + version: "0.7.4" flutter_plugin_android_lifecycle: dependency: transitive description: @@ -843,14 +843,6 @@ packages: url: "https://pub.dev" source: hosted version: "0.2.0" - highlight: - dependency: transitive - description: - name: highlight - sha256: "5353a83ffe3e3eca7df0abfb72dcf3fa66cc56b953728e7113ad4ad88497cf21" - url: "https://pub.dev" - source: hosted - version: "0.7.0" hive: dependency: "direct main" description: @@ -1227,14 +1219,6 @@ packages: url: "https://pub.dev" source: hosted version: "7.3.0" - markdown_widget: - dependency: "direct main" - description: - name: markdown_widget - sha256: b52c13d3ee4d0e60c812e15b0593f142a3b8a2003cde1babb271d001a1dbdc1c - url: "https://pub.dev" - source: hosted - version: "2.3.2+8" matcher: dependency: transitive description: @@ -1283,6 +1267,13 @@ packages: url: "https://pub.dev" source: hosted version: "0.1.3" + mixin_markdown_widget: + dependency: "direct main" + description: + path: "/Users/yangbin/workspace/mixin/flutter-plugins/packages/mixin_markdown_widget" + relative: false + source: path + version: "0.1.0" msix: dependency: "direct dev" description: @@ -1596,6 +1587,14 @@ packages: url: "https://pub.dev" source: hosted version: "6.0.2" + pretext: + dependency: transitive + description: + name: pretext + sha256: "414ef08acce07d877ec62348a56c998d8d45eab4cc6bfab35a7287520c6e4c7c" + url: "https://pub.dev" + source: hosted + version: "0.1.0" pretty_qr_code: dependency: "direct main" description: @@ -1700,6 +1699,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.2.3" + re_highlight: + dependency: transitive + description: + name: re_highlight + sha256: "6c4ac3f76f939fb7ca9df013df98526634e17d8f7460e028bd23a035870024f2" + url: "https://pub.dev" + source: hosted + version: "0.0.3" recase: dependency: "direct main" description: @@ -1772,14 +1779,6 @@ packages: url: "https://pub.dev" source: hosted version: "0.2.0" - scroll_to_index: - dependency: transitive - description: - name: scroll_to_index - sha256: b707546e7500d9f070d63e5acf74fd437ec7eeeb68d3412ef7b0afada0b4f176 - url: "https://pub.dev" - source: hosted - version: "3.0.1" scrollable_positioned_list: dependency: "direct main" description: @@ -2009,6 +2008,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.11.0" + tuple: + dependency: transitive + description: + name: tuple + sha256: a97ce2013f240b2f3807bcbaf218765b6f301c3eff91092bcfa23a039e7dd151 + url: "https://pub.dev" + source: hosted + version: "2.0.2" typed_data: dependency: transitive description: @@ -2358,4 +2365,4 @@ packages: version: "3.1.3" sdks: dart: ">=3.11.0 <4.0.0" - flutter: ">=3.38.1" + flutter: ">=3.38.1 <4.0.0" diff --git a/pubspec.yaml b/pubspec.yaml index d3b63c7f00..6456500b46 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -96,7 +96,8 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - markdown_widget: ^2.3.2+2 + mixin_markdown_widget: + path: /Users/yangbin/workspace/mixin/flutter-plugins/packages/mixin_markdown_widget mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 From e940d9bf1ec6f20ec3a61195ab90af68dd2bec67 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 20 Apr 2026 10:42:41 +0800 Subject: [PATCH 02/52] update --- lib/ai/ai_chat_controller.dart | 470 +++++++ lib/ai/model/ai_mode_state.dart | 23 + lib/ai/model/ai_prompt_message.dart | 6 + lib/ai/model/ai_provider_config.dart | 60 + lib/ai/model/ai_provider_type.dart | 15 + lib/db/dao/ai_chat_message_dao.dart | 66 + lib/db/dao/ai_chat_message_dao.g.dart | 13 + lib/db/dao/asset_dao.g.dart | 3 + lib/db/dao/chain_dao.g.dart | 3 + lib/db/dao/circle_conversation_dao.g.dart | 3 + lib/db/dao/circle_dao.g.dart | 3 + lib/db/dao/conversation_dao.g.dart | 3 + lib/db/dao/expired_message_dao.g.dart | 3 + lib/db/dao/favorite_app_dao.g.dart | 3 + lib/db/dao/flood_message_dao.g.dart | 3 + lib/db/dao/inscription_collection_dao.g.dart | 3 + lib/db/dao/inscription_item_dao.g.dart | 3 + lib/db/dao/message_dao.g.dart | 3 + lib/db/dao/participant_dao.g.dart | 3 + lib/db/dao/participant_session_dao.g.dart | 3 + lib/db/dao/pin_message_dao.g.dart | 3 + lib/db/dao/property_dao.g.dart | 3 + lib/db/dao/safe_snapshot_dao.g.dart | 3 + lib/db/dao/snapshot_dao.g.dart | 3 + lib/db/dao/sticker_album_dao.g.dart | 3 + lib/db/dao/sticker_dao.g.dart | 3 + lib/db/dao/sticker_relationship_dao.g.dart | 3 + lib/db/dao/token_dao.g.dart | 3 + lib/db/dao/transcript_message_dao.g.dart | 3 + lib/db/dao/user_dao.g.dart | 3 + lib/db/database.dart | 3 + lib/db/mixin_database.dart | 20 +- lib/db/mixin_database.g.dart | 1086 +++++++++++++++++ lib/db/moor/mixin.drift | 19 +- lib/ui/home/bloc/message_bloc.dart | 157 +++ lib/ui/home/chat/chat_page.dart | 117 +- lib/ui/home/chat/input_container.dart | 135 +- lib/ui/provider/ai_input_mode_provider.dart | 20 + .../responsive_navigator_provider.dart | 9 + lib/ui/setting/ai_provider_edit_page.dart | 173 +++ lib/ui/setting/ai_settings_page.dart | 113 ++ lib/ui/setting/setting_page.dart | 7 + lib/utils/property/setting_property.dart | 64 + lib/widgets/ai/ai_message_card.dart | 75 ++ lib/widgets/markdown.dart | 9 +- lib/widgets/message/message_day_time.dart | 21 +- 46 files changed, 2693 insertions(+), 57 deletions(-) create mode 100644 lib/ai/ai_chat_controller.dart create mode 100644 lib/ai/model/ai_mode_state.dart create mode 100644 lib/ai/model/ai_prompt_message.dart create mode 100644 lib/ai/model/ai_provider_config.dart create mode 100644 lib/ai/model/ai_provider_type.dart create mode 100644 lib/db/dao/ai_chat_message_dao.dart create mode 100644 lib/db/dao/ai_chat_message_dao.g.dart create mode 100644 lib/ui/provider/ai_input_mode_provider.dart create mode 100644 lib/ui/setting/ai_provider_edit_page.dart create mode 100644 lib/ui/setting/ai_settings_page.dart create mode 100644 lib/widgets/ai/ai_message_card.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart new file mode 100644 index 0000000000..6c970a9842 --- /dev/null +++ b/lib/ai/ai_chat_controller.dart @@ -0,0 +1,470 @@ +import 'dart:async'; +import 'dart:convert'; + +import 'package:dio/dio.dart'; +import 'package:drift/drift.dart'; +import 'package:mixin_logger/mixin_logger.dart'; +import 'package:uuid/uuid.dart'; + +import '../db/dao/ai_chat_message_dao.dart'; +import '../db/database.dart'; +import '../db/mixin_database.dart'; +import '../utils/proxy.dart'; +import 'model/ai_prompt_message.dart'; +import 'model/ai_provider_config.dart'; +import 'model/ai_provider_type.dart'; + +const _kAiRoleUser = 'user'; +const _kAiRoleAssistant = 'assistant'; +const _kAiStatusPending = 'pending'; +const _kAiStatusDone = 'done'; +const _kAiStatusError = 'error'; +const _kAiContextMessageLimit = 30; +const _kAiHistoryLimit = 12; +const _kAiStreamFlushChars = 32; +const _kAiStreamFlushInterval = Duration(milliseconds: 80); + +class AiChatController { + AiChatController(this.database); + + final Database database; + final _uuid = const Uuid(); + static const _openAiStrategy = _OpenAiCompatibleStrategy(); + static const _anthropicStrategy = _AnthropicStrategy(); + + Future send({ + required String conversationId, + required String input, + AiProviderConfig? provider, + }) async { + final config = provider ?? database.settingProperties.selectedAiProvider; + if (config == null) { + throw Exception('No AI provider configured'); + } + + final now = DateTime.now(); + final userMessageId = _uuid.v4(); + final assistantMessageId = _uuid.v4(); + final anchorMessage = await database.messageDao + .messagesByConversationId(conversationId, 1) + .getSingleOrNull(); + + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: userMessageId, + conversationId: conversationId, + role: _kAiRoleUser, + providerId: config.id, + anchorMessageId: Value(anchorMessage?.messageId), + anchorCreatedAt: Value(anchorMessage?.createdAt), + content: input, + status: _kAiStatusDone, + model: Value(config.model), + createdAt: now, + updatedAt: now, + ), + ); + + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: assistantMessageId, + conversationId: conversationId, + role: _kAiRoleAssistant, + providerId: config.id, + anchorMessageId: Value(anchorMessage?.messageId), + anchorCreatedAt: Value(anchorMessage?.createdAt), + content: '', + status: _kAiStatusPending, + model: Value(config.model), + createdAt: now, + updatedAt: now, + ), + ); + + try { + final messages = await _buildPromptMessages(conversationId, input); + final updater = _StreamingMessageUpdater( + dao: database.aiChatMessageDao, + messageId: assistantMessageId, + ); + final result = await _streamRequest( + config, + messages, + onContent: updater.append, + ); + await updater.flush(contentOverride: result, force: true); + await database.aiChatMessageDao.updateMessageStatus( + assistantMessageId, + _kAiStatusDone, + updatedAt: DateTime.now(), + ); + } catch (error, stacktrace) { + e('AI chat error: $error, $stacktrace'); + await database.aiChatMessageDao.updateMessageStatus( + assistantMessageId, + _kAiStatusError, + updatedAt: DateTime.now(), + errorText: error.toString(), + ); + rethrow; + } + } + + Future> _buildPromptMessages( + String conversationId, + String input, + ) async { + final recentMessages = await database.messageDao + .messagesByConversationId(conversationId, _kAiContextMessageLimit) + .get(); + final aiMessages = await database.aiChatMessageDao.conversationMessages( + conversationId, + ); + + final promptMessages = [ + AiPromptMessage( + role: 'system', + content: + 'You are a local AI assistant inside a chat application. ' + 'Only use the provided current conversation context. ' + 'Help summarize, answer questions about the conversation, and draft replies. ' + 'Be concise and practical.', + ), + ]; + + if (recentMessages.isNotEmpty) { + final lines = recentMessages.reversed + .map((message) { + final sender = message.userFullName ?? message.userId; + final content = _messagePlainText(message); + return '[${message.createdAt.toIso8601String()}] $sender: $content'; + }) + .join('\n'); + promptMessages.add( + AiPromptMessage( + role: 'system', + content: 'Current conversation recent messages:\n$lines', + ), + ); + } + + final history = aiMessages + .where((element) => element.status != _kAiStatusPending) + .takeLast(_kAiHistoryLimit); + for (final item in history) { + promptMessages.add( + AiPromptMessage(role: item.role, content: item.content), + ); + } + + promptMessages.add(AiPromptMessage(role: _kAiRoleUser, content: input)); + return promptMessages; + } + + String _messagePlainText(MessageItem message) { + if (message.content?.trim().isNotEmpty == true) { + return message.content!.trim(); + } + if (message.mediaName?.isNotEmpty == true) { + return '[${message.type}] ${message.mediaName}'; + } + return '[${message.type}]'; + } + + Future _streamRequest( + AiProviderConfig config, + List messages, { + required Future Function(String chunk) onContent, + }) async { + final dio = Dio( + BaseOptions( + baseUrl: config.baseUrl, + connectTimeout: const Duration(seconds: 20), + receiveTimeout: const Duration(minutes: 5), + sendTimeout: const Duration(seconds: 20), + headers: _strategyFor(config.type).headers(config), + ), + )..applyProxy(database.settingProperties.activatedProxy); + + return _strategyFor(config.type).streamResponse( + dio: dio, + config: config, + messages: messages, + onContent: onContent, + ); + } + + _AiProviderStrategy _strategyFor(AiProviderType type) => switch (type) { + AiProviderType.openaiCompatible => _openAiStrategy, + AiProviderType.anthropic => _anthropicStrategy, + }; +} + +extension on Iterable { + Iterable takeLast(int count) { + if (count <= 0) return const []; + final list = toList(); + if (list.length <= count) { + return list; + } + return list.sublist(list.length - count); + } +} + +abstract interface class _AiProviderStrategy { + const _AiProviderStrategy(); + + Map headers(AiProviderConfig config); + + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required Future Function(String chunk) onContent, + }); +} + +class _OpenAiCompatibleStrategy implements _AiProviderStrategy { + const _OpenAiCompatibleStrategy(); + + @override + Map headers(AiProviderConfig config) => { + 'Authorization': 'Bearer ${config.apiKey}', + 'Content-Type': 'application/json', + }; + + @override + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required Future Function(String chunk) onContent, + }) async { + final response = await dio.post( + '/chat/completions', + data: { + 'model': config.model, + 'stream': true, + 'messages': messages + .map( + (message) => { + 'role': message.role, + 'content': message.content, + }, + ) + .toList(), + }, + options: Options(responseType: ResponseType.stream), + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final buffer = StringBuffer(); + await for (final data in _decodeSse(body.stream)) { + if (data == '[DONE]') { + continue; + } + + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final choices = json['choices'] as List?; + if (choices == null || choices.isEmpty) { + continue; + } + + final first = choices.first; + if (first is! Map) { + continue; + } + + final delta = first['delta']; + if (delta is! Map) { + continue; + } + + final content = delta['content']; + if (content is String && content.isNotEmpty) { + buffer.write(content); + await onContent(content); + } + } + + final text = buffer.toString().trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + return text; + } +} + +class _AnthropicStrategy implements _AiProviderStrategy { + const _AnthropicStrategy(); + + @override + Map headers(AiProviderConfig config) => { + 'x-api-key': config.apiKey, + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', + }; + + @override + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required Future Function(String chunk) onContent, + }) async { + final response = await dio.post( + '/messages', + data: { + 'model': config.model, + 'max_tokens': 1024, + 'stream': true, + 'messages': messages + .where((message) => message.role != 'system') + .map( + (message) => { + 'role': message.role, + 'content': message.content, + }, + ) + .toList(), + 'system': messages + .where((message) => message.role == 'system') + .map((message) => message.content) + .join('\n\n'), + }, + options: Options(responseType: ResponseType.stream), + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final buffer = StringBuffer(); + await for (final data in _decodeSse(body.stream)) { + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final type = json['type'] as String?; + if (type == 'error') { + final error = json['error']; + if (error is Map) { + throw Exception(error['message'] ?? 'Anthropic request failed'); + } + throw Exception('Anthropic request failed'); + } + + if (type != 'content_block_delta') { + continue; + } + + final delta = json['delta']; + if (delta is! Map) { + continue; + } + + if (delta['type'] != 'text_delta') { + continue; + } + + final text = delta['text']; + if (text is String && text.isNotEmpty) { + buffer.write(text); + await onContent(text); + } + } + + final text = buffer.toString().trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + return text; + } +} + +class _StreamingMessageUpdater { + _StreamingMessageUpdater({ + required this.dao, + required this.messageId, + }); + + final AiChatMessageDao dao; + final String messageId; + final _buffer = StringBuffer(); + + String _persistedContent = ''; + DateTime _lastFlushedAt = DateTime.fromMillisecondsSinceEpoch(0); + + Future append(String chunk) async { + if (chunk.isEmpty) return; + + _buffer.write(chunk); + final now = DateTime.now(); + final pendingChars = _buffer.length - _persistedContent.length; + if (pendingChars < _kAiStreamFlushChars && + now.difference(_lastFlushedAt) < _kAiStreamFlushInterval) { + return; + } + + await flush(); + } + + Future flush({ + String? contentOverride, + bool force = false, + }) async { + final content = contentOverride ?? _buffer.toString(); + if (!force && content == _persistedContent) { + return; + } + + _persistedContent = content; + _lastFlushedAt = DateTime.now(); + await dao.updateMessageContent( + messageId, + content, + updatedAt: _lastFlushedAt, + ); + } +} + +Stream _decodeSse(Stream> stream) async* { + final buffer = StringBuffer(); + await for (final bytes in stream) { + final chunk = utf8.decode(bytes); + buffer.write(chunk.replaceAll('\r\n', '\n').replaceAll('\r', '\n')); + while (true) { + final current = buffer.toString(); + final separatorIndex = current.indexOf('\n\n'); + if (separatorIndex < 0) { + break; + } + + final rawEvent = current.substring(0, separatorIndex); + final remaining = current.substring(separatorIndex + 2); + buffer + ..clear() + ..write(remaining); + + final payload = rawEvent + .split('\n') + .where((line) => line.startsWith('data:')) + .map((line) => line.substring(5).trimLeft()) + .join('\n') + .trim(); + if (payload.isNotEmpty) { + yield payload; + } + } + } +} diff --git a/lib/ai/model/ai_mode_state.dart b/lib/ai/model/ai_mode_state.dart new file mode 100644 index 0000000000..b938cba063 --- /dev/null +++ b/lib/ai/model/ai_mode_state.dart @@ -0,0 +1,23 @@ +import 'package:equatable/equatable.dart'; + +class AiModeState extends Equatable { + const AiModeState({ + this.enabled = false, + this.providerId, + }); + + final bool enabled; + final String? providerId; + + @override + List get props => [enabled, providerId]; + + AiModeState copyWith({ + bool? enabled, + String? providerId, + bool clearProviderId = false, + }) => AiModeState( + enabled: enabled ?? this.enabled, + providerId: clearProviderId ? null : (providerId ?? this.providerId), + ); +} diff --git a/lib/ai/model/ai_prompt_message.dart b/lib/ai/model/ai_prompt_message.dart new file mode 100644 index 0000000000..c7ab2f94a4 --- /dev/null +++ b/lib/ai/model/ai_prompt_message.dart @@ -0,0 +1,6 @@ +class AiPromptMessage { + AiPromptMessage({required this.role, required this.content}); + + final String role; + final String content; +} diff --git a/lib/ai/model/ai_provider_config.dart b/lib/ai/model/ai_provider_config.dart new file mode 100644 index 0000000000..c7a7159b40 --- /dev/null +++ b/lib/ai/model/ai_provider_config.dart @@ -0,0 +1,60 @@ +import 'ai_provider_type.dart'; + +class AiProviderConfig { + AiProviderConfig({ + required this.id, + required this.name, + required this.type, + required this.baseUrl, + required this.apiKey, + required this.model, + this.enabled = true, + }); + + factory AiProviderConfig.fromJson(Map json) => + AiProviderConfig( + id: json['id'] as String, + name: json['name'] as String, + type: AiProviderType.fromValue(json['type'] as String? ?? ''), + baseUrl: json['baseUrl'] as String? ?? '', + apiKey: json['apiKey'] as String? ?? '', + model: json['model'] as String? ?? '', + enabled: json['enabled'] as bool? ?? true, + ); + + final String id; + final String name; + final AiProviderType type; + final String baseUrl; + final String apiKey; + final String model; + final bool enabled; + + Map toJson() => { + 'id': id, + 'name': name, + 'type': type.value, + 'baseUrl': baseUrl, + 'apiKey': apiKey, + 'model': model, + 'enabled': enabled, + }; + + AiProviderConfig copyWith({ + String? id, + String? name, + AiProviderType? type, + String? baseUrl, + String? apiKey, + String? model, + bool? enabled, + }) => AiProviderConfig( + id: id ?? this.id, + name: name ?? this.name, + type: type ?? this.type, + baseUrl: baseUrl ?? this.baseUrl, + apiKey: apiKey ?? this.apiKey, + model: model ?? this.model, + enabled: enabled ?? this.enabled, + ); +} diff --git a/lib/ai/model/ai_provider_type.dart b/lib/ai/model/ai_provider_type.dart new file mode 100644 index 0000000000..1b28278824 --- /dev/null +++ b/lib/ai/model/ai_provider_type.dart @@ -0,0 +1,15 @@ +enum AiProviderType { + openaiCompatible('openai_compatible'), + anthropic('anthropic') + ; + + const AiProviderType(this.value); + + final String value; + + static AiProviderType fromValue(String value) => + AiProviderType.values.firstWhere( + (element) => element.value == value, + orElse: () => AiProviderType.openaiCompatible, + ); +} diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart new file mode 100644 index 0000000000..41ac9d5d64 --- /dev/null +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -0,0 +1,66 @@ +import 'package:drift/drift.dart'; + +import '../mixin_database.dart'; + +part 'ai_chat_message_dao.g.dart'; + +@DriftAccessor() +class AiChatMessageDao extends DatabaseAccessor + with _$AiChatMessageDaoMixin { + AiChatMessageDao(super.db); + + Stream> watchConversationMessages( + String conversationId, + ) => + (select( + db.aiChatMessages, + ) + ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..orderBy([ + (tbl) => OrderingTerm.asc(tbl.createdAt), + (tbl) => OrderingTerm.asc(tbl.id), + ])) + .watch(); + + Future> conversationMessages(String conversationId) => + (select( + db.aiChatMessages, + ) + ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..orderBy([ + (tbl) => OrderingTerm.asc(tbl.createdAt), + (tbl) => OrderingTerm.asc(tbl.id), + ])) + .get(); + + Future insertMessage(AiChatMessagesCompanion row) => + into(db.aiChatMessages).insertOnConflictUpdate(row); + + Future updateMessageContent( + String id, + String content, { + required DateTime updatedAt, + }) => (update(db.aiChatMessages)..where((tbl) => tbl.id.equals(id))).write( + AiChatMessagesCompanion( + content: Value(content), + updatedAt: Value(updatedAt), + ), + ); + + Future updateMessageStatus( + String id, + String status, { + required DateTime updatedAt, + String? errorText, + }) => (update(db.aiChatMessages)..where((tbl) => tbl.id.equals(id))).write( + AiChatMessagesCompanion( + status: Value(status), + errorText: Value(errorText), + updatedAt: Value(updatedAt), + ), + ); + + Future deleteConversationMessages(String conversationId) => (delete( + db.aiChatMessages, + )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); +} diff --git a/lib/db/dao/ai_chat_message_dao.g.dart b/lib/db/dao/ai_chat_message_dao.g.dart new file mode 100644 index 0000000000..7ffc7133ff --- /dev/null +++ b/lib/db/dao/ai_chat_message_dao.g.dart @@ -0,0 +1,13 @@ +// GENERATED CODE - DO NOT MODIFY BY HAND + +part of 'ai_chat_message_dao.dart'; + +// ignore_for_file: type=lint +mixin _$AiChatMessageDaoMixin on DatabaseAccessor { + AiChatMessageDaoManager get managers => AiChatMessageDaoManager(this); +} + +class AiChatMessageDaoManager { + final _$AiChatMessageDaoMixin _db; + AiChatMessageDaoManager(this._db); +} diff --git a/lib/db/dao/asset_dao.g.dart b/lib/db/dao/asset_dao.g.dart index 12e7d29377..cc367fcc26 100644 --- a/lib/db/dao/asset_dao.g.dart +++ b/lib/db/dao/asset_dao.g.dart @@ -39,6 +39,7 @@ mixin _$AssetDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -164,6 +165,8 @@ class AssetDaoManager { $ExpiredMessagesTableManager(_db.attachedDatabase, _db.expiredMessages); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/chain_dao.g.dart b/lib/db/dao/chain_dao.g.dart index 0c5edb0e6d..6d337c8b17 100644 --- a/lib/db/dao/chain_dao.g.dart +++ b/lib/db/dao/chain_dao.g.dart @@ -39,6 +39,7 @@ mixin _$ChainDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -126,6 +127,8 @@ class ChainDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/circle_conversation_dao.g.dart b/lib/db/dao/circle_conversation_dao.g.dart index 2f4472a88c..ebe5a98ba9 100644 --- a/lib/db/dao/circle_conversation_dao.g.dart +++ b/lib/db/dao/circle_conversation_dao.g.dart @@ -39,6 +39,7 @@ mixin _$CircleConversationDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -145,6 +146,8 @@ class CircleConversationDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/circle_dao.g.dart b/lib/db/dao/circle_dao.g.dart index f2d332e379..ff29b0ce0b 100644 --- a/lib/db/dao/circle_dao.g.dart +++ b/lib/db/dao/circle_dao.g.dart @@ -39,6 +39,7 @@ mixin _$CircleDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -183,6 +184,8 @@ class CircleDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/conversation_dao.g.dart b/lib/db/dao/conversation_dao.g.dart index eab0696108..595adcc3ec 100644 --- a/lib/db/dao/conversation_dao.g.dart +++ b/lib/db/dao/conversation_dao.g.dart @@ -39,6 +39,7 @@ mixin _$ConversationDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -800,6 +801,8 @@ class ConversationDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/expired_message_dao.g.dart b/lib/db/dao/expired_message_dao.g.dart index a8c2b5524d..d1cb868dc4 100644 --- a/lib/db/dao/expired_message_dao.g.dart +++ b/lib/db/dao/expired_message_dao.g.dart @@ -39,6 +39,7 @@ mixin _$ExpiredMessageDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -196,6 +197,8 @@ class ExpiredMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/favorite_app_dao.g.dart b/lib/db/dao/favorite_app_dao.g.dart index 69968be3c0..4642f1d59d 100644 --- a/lib/db/dao/favorite_app_dao.g.dart +++ b/lib/db/dao/favorite_app_dao.g.dart @@ -39,6 +39,7 @@ mixin _$FavoriteAppDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -143,6 +144,8 @@ class FavoriteAppDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/flood_message_dao.g.dart b/lib/db/dao/flood_message_dao.g.dart index 9d78277a5b..52042fb936 100644 --- a/lib/db/dao/flood_message_dao.g.dart +++ b/lib/db/dao/flood_message_dao.g.dart @@ -39,6 +39,7 @@ mixin _$FloodMessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -138,6 +139,8 @@ class FloodMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/inscription_collection_dao.g.dart b/lib/db/dao/inscription_collection_dao.g.dart index 08ed37c7a4..e9d4cb8d3e 100644 --- a/lib/db/dao/inscription_collection_dao.g.dart +++ b/lib/db/dao/inscription_collection_dao.g.dart @@ -39,6 +39,7 @@ mixin _$InscriptionCollectionDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -127,6 +128,8 @@ class InscriptionCollectionDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/inscription_item_dao.g.dart b/lib/db/dao/inscription_item_dao.g.dart index 5c45471e24..6c995896e0 100644 --- a/lib/db/dao/inscription_item_dao.g.dart +++ b/lib/db/dao/inscription_item_dao.g.dart @@ -42,6 +42,7 @@ mixin _$InscriptionItemDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; Selectable inscriptionByHash(String hash) { @@ -151,6 +152,8 @@ class InscriptionItemDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/message_dao.g.dart b/lib/db/dao/message_dao.g.dart index e69968a36f..b9c45209c8 100644 --- a/lib/db/dao/message_dao.g.dart +++ b/lib/db/dao/message_dao.g.dart @@ -39,6 +39,7 @@ mixin _$MessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -492,6 +493,8 @@ class MessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/participant_dao.g.dart b/lib/db/dao/participant_dao.g.dart index b44cf301e7..52e0232ad6 100644 --- a/lib/db/dao/participant_dao.g.dart +++ b/lib/db/dao/participant_dao.g.dart @@ -39,6 +39,7 @@ mixin _$ParticipantDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -216,6 +217,8 @@ class ParticipantDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/participant_session_dao.g.dart b/lib/db/dao/participant_session_dao.g.dart index fcf4c41ee9..8f2e9cf93e 100644 --- a/lib/db/dao/participant_session_dao.g.dart +++ b/lib/db/dao/participant_session_dao.g.dart @@ -39,6 +39,7 @@ mixin _$ParticipantSessionDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -182,6 +183,8 @@ class ParticipantSessionDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/pin_message_dao.g.dart b/lib/db/dao/pin_message_dao.g.dart index a3d2dd2725..3375cd094b 100644 --- a/lib/db/dao/pin_message_dao.g.dart +++ b/lib/db/dao/pin_message_dao.g.dart @@ -39,6 +39,7 @@ mixin _$PinMessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -161,6 +162,8 @@ class PinMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/property_dao.g.dart b/lib/db/dao/property_dao.g.dart index f611989531..d83c18cb18 100644 --- a/lib/db/dao/property_dao.g.dart +++ b/lib/db/dao/property_dao.g.dart @@ -39,6 +39,7 @@ mixin _$PropertyDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -126,6 +127,8 @@ class PropertyDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/safe_snapshot_dao.g.dart b/lib/db/dao/safe_snapshot_dao.g.dart index 85d4854f19..40d9713c2b 100644 --- a/lib/db/dao/safe_snapshot_dao.g.dart +++ b/lib/db/dao/safe_snapshot_dao.g.dart @@ -41,6 +41,7 @@ mixin _$SafeSnapshotDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; InscriptionCollections get inscriptionCollections => attachedDatabase.inscriptionCollections; InscriptionItems get inscriptionItems => attachedDatabase.inscriptionItems; @@ -221,6 +222,8 @@ class SafeSnapshotDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $InscriptionCollectionsTableManager get inscriptionCollections => $InscriptionCollectionsTableManager( _db.attachedDatabase, diff --git a/lib/db/dao/snapshot_dao.g.dart b/lib/db/dao/snapshot_dao.g.dart index aba62b6afe..fb32f4242b 100644 --- a/lib/db/dao/snapshot_dao.g.dart +++ b/lib/db/dao/snapshot_dao.g.dart @@ -39,6 +39,7 @@ mixin _$SnapshotDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -230,6 +231,8 @@ class SnapshotDaoManager { $ExpiredMessagesTableManager(_db.attachedDatabase, _db.expiredMessages); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/sticker_album_dao.g.dart b/lib/db/dao/sticker_album_dao.g.dart index ae0ab42d06..acee16eacb 100644 --- a/lib/db/dao/sticker_album_dao.g.dart +++ b/lib/db/dao/sticker_album_dao.g.dart @@ -39,6 +39,7 @@ mixin _$StickerAlbumDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -126,6 +127,8 @@ class StickerAlbumDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/sticker_dao.g.dart b/lib/db/dao/sticker_dao.g.dart index f8b6a8a11d..f2b1d48f31 100644 --- a/lib/db/dao/sticker_dao.g.dart +++ b/lib/db/dao/sticker_dao.g.dart @@ -39,6 +39,7 @@ mixin _$StickerDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -150,6 +151,8 @@ class StickerDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/sticker_relationship_dao.g.dart b/lib/db/dao/sticker_relationship_dao.g.dart index efbeb178c6..078bba0360 100644 --- a/lib/db/dao/sticker_relationship_dao.g.dart +++ b/lib/db/dao/sticker_relationship_dao.g.dart @@ -39,6 +39,7 @@ mixin _$StickerRelationshipDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -143,6 +144,8 @@ class StickerRelationshipDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/token_dao.g.dart b/lib/db/dao/token_dao.g.dart index 641feb2cdf..82b5064992 100644 --- a/lib/db/dao/token_dao.g.dart +++ b/lib/db/dao/token_dao.g.dart @@ -40,6 +40,7 @@ mixin _$TokenDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; InscriptionCollections get inscriptionCollections => attachedDatabase.inscriptionCollections; @@ -136,6 +137,8 @@ class TokenDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $InscriptionCollectionsTableManager get inscriptionCollections => diff --git a/lib/db/dao/transcript_message_dao.g.dart b/lib/db/dao/transcript_message_dao.g.dart index 1a6c4cbfb7..ec5faee347 100644 --- a/lib/db/dao/transcript_message_dao.g.dart +++ b/lib/db/dao/transcript_message_dao.g.dart @@ -39,6 +39,7 @@ mixin _$TranscriptMessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -230,6 +231,8 @@ class TranscriptMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/user_dao.g.dart b/lib/db/dao/user_dao.g.dart index 5cf1907be6..1939b66a2f 100644 --- a/lib/db/dao/user_dao.g.dart +++ b/lib/db/dao/user_dao.g.dart @@ -39,6 +39,7 @@ mixin _$UserDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; + AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -310,6 +311,8 @@ class UserDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/database.dart b/lib/db/database.dart index 3c06ecaa5c..1e76227eed 100644 --- a/lib/db/database.dart +++ b/lib/db/database.dart @@ -4,6 +4,7 @@ import '../ui/provider/slide_category_provider.dart'; import '../utils/extension/extension.dart'; import '../utils/logger.dart'; import '../utils/property/setting_property.dart'; +import 'dao/ai_chat_message_dao.dart'; import 'dao/app_dao.dart'; import 'dao/asset_dao.dart'; import 'dao/chain_dao.dart'; @@ -47,6 +48,8 @@ class Database { AppDao get appDao => mixinDatabase.appDao; + AiChatMessageDao get aiChatMessageDao => mixinDatabase.aiChatMessageDao; + AssetDao get assetDao => mixinDatabase.assetDao; ChainDao get chainDao => mixinDatabase.chainDao; diff --git a/lib/db/mixin_database.dart b/lib/db/mixin_database.dart index c1ce880652..d4f4bd9568 100644 --- a/lib/db/mixin_database.dart +++ b/lib/db/mixin_database.dart @@ -18,6 +18,7 @@ import 'converter/safe_deposit_type_converter.dart'; import 'converter/safe_withdrawal_type_converter.dart'; import 'converter/user_relationship_converter.dart'; import 'dao/address_dao.dart'; +import 'dao/ai_chat_message_dao.dart'; import 'dao/app_dao.dart'; import 'dao/asset_dao.dart'; import 'dao/chain_dao.dart'; @@ -61,6 +62,7 @@ part 'mixin_database.g.dart'; include: {'moor/mixin.drift', 'moor/dao/common.drift'}, daos: [ AddressDao, + AiChatMessageDao, AppDao, AssetDao, CircleConversationDao, @@ -99,7 +101,7 @@ class MixinDatabase extends _$MixinDatabase { MixinDatabase(super.e); @override - int get schemaVersion => 28; + int get schemaVersion => 30; final eventBus = DataBaseEventBus.instance; @@ -278,6 +280,22 @@ class MixinDatabase extends _$MixinDatabase { if (from <= 27) { await _addColumnIfNotExists(m, tokens, tokens.precision); } + if (from <= 28) { + await m.createTable(aiChatMessages); + await m.createIndex(indexAiChatMessagesConversationIdCreatedAt); + } + if (from <= 29) { + await _addColumnIfNotExists( + m, + aiChatMessages, + aiChatMessages.anchorMessageId, + ); + await _addColumnIfNotExists( + m, + aiChatMessages, + aiChatMessages.anchorCreatedAt, + ); + } }, beforeOpen: (details) async { if (details.hadUpgrade && details.versionBefore! <= 20) { diff --git a/lib/db/mixin_database.g.dart b/lib/db/mixin_database.g.dart index 9a4c3f0f7e..0aeb8bfca1 100644 --- a/lib/db/mixin_database.g.dart +++ b/lib/db/mixin_database.g.dart @@ -17665,6 +17665,735 @@ class PropertiesCompanion extends UpdateCompanion { } } +class AiChatMessages extends Table + with TableInfo { + @override + final GeneratedDatabase attachedDatabase; + final String? _alias; + AiChatMessages(this.attachedDatabase, [this._alias]); + static const VerificationMeta _idMeta = const VerificationMeta('id'); + late final GeneratedColumn id = GeneratedColumn( + 'id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _conversationIdMeta = const VerificationMeta( + 'conversationId', + ); + late final GeneratedColumn conversationId = GeneratedColumn( + 'conversation_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _roleMeta = const VerificationMeta('role'); + late final GeneratedColumn role = GeneratedColumn( + 'role', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _providerIdMeta = const VerificationMeta( + 'providerId', + ); + late final GeneratedColumn providerId = GeneratedColumn( + 'provider_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _anchorMessageIdMeta = const VerificationMeta( + 'anchorMessageId', + ); + late final GeneratedColumn anchorMessageId = GeneratedColumn( + 'anchor_message_id', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + late final GeneratedColumnWithTypeConverter anchorCreatedAt = + GeneratedColumn( + 'anchor_created_at', + aliasedName, + true, + type: DriftSqlType.int, + requiredDuringInsert: false, + $customConstraints: '', + ).withConverter(AiChatMessages.$converteranchorCreatedAtn); + static const VerificationMeta _contentMeta = const VerificationMeta( + 'content', + ); + late final GeneratedColumn content = GeneratedColumn( + 'content', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _statusMeta = const VerificationMeta('status'); + late final GeneratedColumn status = GeneratedColumn( + 'status', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _modelMeta = const VerificationMeta('model'); + late final GeneratedColumn model = GeneratedColumn( + 'model', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + static const VerificationMeta _errorTextMeta = const VerificationMeta( + 'errorText', + ); + late final GeneratedColumn errorText = GeneratedColumn( + 'error_text', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + late final GeneratedColumnWithTypeConverter createdAt = + GeneratedColumn( + 'created_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(AiChatMessages.$convertercreatedAt); + late final GeneratedColumnWithTypeConverter updatedAt = + GeneratedColumn( + 'updated_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(AiChatMessages.$converterupdatedAt); + @override + List get $columns => [ + id, + conversationId, + role, + providerId, + anchorMessageId, + anchorCreatedAt, + content, + status, + model, + errorText, + createdAt, + updatedAt, + ]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'ai_chat_messages'; + @override + VerificationContext validateIntegrity( + Insertable instance, { + bool isInserting = false, + }) { + final context = VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } else if (isInserting) { + context.missing(_idMeta); + } + if (data.containsKey('conversation_id')) { + context.handle( + _conversationIdMeta, + conversationId.isAcceptableOrUnknown( + data['conversation_id']!, + _conversationIdMeta, + ), + ); + } else if (isInserting) { + context.missing(_conversationIdMeta); + } + if (data.containsKey('role')) { + context.handle( + _roleMeta, + role.isAcceptableOrUnknown(data['role']!, _roleMeta), + ); + } else if (isInserting) { + context.missing(_roleMeta); + } + if (data.containsKey('provider_id')) { + context.handle( + _providerIdMeta, + providerId.isAcceptableOrUnknown(data['provider_id']!, _providerIdMeta), + ); + } else if (isInserting) { + context.missing(_providerIdMeta); + } + if (data.containsKey('anchor_message_id')) { + context.handle( + _anchorMessageIdMeta, + anchorMessageId.isAcceptableOrUnknown( + data['anchor_message_id']!, + _anchorMessageIdMeta, + ), + ); + } + if (data.containsKey('content')) { + context.handle( + _contentMeta, + content.isAcceptableOrUnknown(data['content']!, _contentMeta), + ); + } else if (isInserting) { + context.missing(_contentMeta); + } + if (data.containsKey('status')) { + context.handle( + _statusMeta, + status.isAcceptableOrUnknown(data['status']!, _statusMeta), + ); + } else if (isInserting) { + context.missing(_statusMeta); + } + if (data.containsKey('model')) { + context.handle( + _modelMeta, + model.isAcceptableOrUnknown(data['model']!, _modelMeta), + ); + } + if (data.containsKey('error_text')) { + context.handle( + _errorTextMeta, + errorText.isAcceptableOrUnknown(data['error_text']!, _errorTextMeta), + ); + } + return context; + } + + @override + Set get $primaryKey => {id}; + @override + AiChatMessage map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return AiChatMessage( + id: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}id'], + )!, + conversationId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}conversation_id'], + )!, + role: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}role'], + )!, + providerId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}provider_id'], + )!, + anchorMessageId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}anchor_message_id'], + ), + anchorCreatedAt: AiChatMessages.$converteranchorCreatedAtn.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}anchor_created_at'], + ), + ), + content: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}content'], + )!, + status: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}status'], + )!, + model: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}model'], + ), + errorText: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}error_text'], + ), + createdAt: AiChatMessages.$convertercreatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}created_at'], + )!, + ), + updatedAt: AiChatMessages.$converterupdatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}updated_at'], + )!, + ), + ); + } + + @override + AiChatMessages createAlias(String alias) { + return AiChatMessages(attachedDatabase, alias); + } + + static TypeConverter $converteranchorCreatedAt = + const MillisDateConverter(); + static TypeConverter $converteranchorCreatedAtn = + NullAwareTypeConverter.wrap($converteranchorCreatedAt); + static TypeConverter $convertercreatedAt = + const MillisDateConverter(); + static TypeConverter $converterupdatedAt = + const MillisDateConverter(); + @override + List get customConstraints => const ['PRIMARY KEY(id)']; + @override + bool get dontWriteConstraints => true; +} + +class AiChatMessage extends DataClass implements Insertable { + final String id; + final String conversationId; + final String role; + final String providerId; + final String? anchorMessageId; + final DateTime? anchorCreatedAt; + final String content; + final String status; + final String? model; + final String? errorText; + final DateTime createdAt; + final DateTime updatedAt; + const AiChatMessage({ + required this.id, + required this.conversationId, + required this.role, + required this.providerId, + this.anchorMessageId, + this.anchorCreatedAt, + required this.content, + required this.status, + this.model, + this.errorText, + required this.createdAt, + required this.updatedAt, + }); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = Variable(id); + map['conversation_id'] = Variable(conversationId); + map['role'] = Variable(role); + map['provider_id'] = Variable(providerId); + if (!nullToAbsent || anchorMessageId != null) { + map['anchor_message_id'] = Variable(anchorMessageId); + } + if (!nullToAbsent || anchorCreatedAt != null) { + map['anchor_created_at'] = Variable( + AiChatMessages.$converteranchorCreatedAtn.toSql(anchorCreatedAt), + ); + } + map['content'] = Variable(content); + map['status'] = Variable(status); + if (!nullToAbsent || model != null) { + map['model'] = Variable(model); + } + if (!nullToAbsent || errorText != null) { + map['error_text'] = Variable(errorText); + } + { + map['created_at'] = Variable( + AiChatMessages.$convertercreatedAt.toSql(createdAt), + ); + } + { + map['updated_at'] = Variable( + AiChatMessages.$converterupdatedAt.toSql(updatedAt), + ); + } + return map; + } + + AiChatMessagesCompanion toCompanion(bool nullToAbsent) { + return AiChatMessagesCompanion( + id: Value(id), + conversationId: Value(conversationId), + role: Value(role), + providerId: Value(providerId), + anchorMessageId: anchorMessageId == null && nullToAbsent + ? const Value.absent() + : Value(anchorMessageId), + anchorCreatedAt: anchorCreatedAt == null && nullToAbsent + ? const Value.absent() + : Value(anchorCreatedAt), + content: Value(content), + status: Value(status), + model: model == null && nullToAbsent + ? const Value.absent() + : Value(model), + errorText: errorText == null && nullToAbsent + ? const Value.absent() + : Value(errorText), + createdAt: Value(createdAt), + updatedAt: Value(updatedAt), + ); + } + + factory AiChatMessage.fromJson( + Map json, { + ValueSerializer? serializer, + }) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return AiChatMessage( + id: serializer.fromJson(json['id']), + conversationId: serializer.fromJson(json['conversation_id']), + role: serializer.fromJson(json['role']), + providerId: serializer.fromJson(json['provider_id']), + anchorMessageId: serializer.fromJson(json['anchor_message_id']), + anchorCreatedAt: serializer.fromJson( + json['anchor_created_at'], + ), + content: serializer.fromJson(json['content']), + status: serializer.fromJson(json['status']), + model: serializer.fromJson(json['model']), + errorText: serializer.fromJson(json['error_text']), + createdAt: serializer.fromJson(json['created_at']), + updatedAt: serializer.fromJson(json['updated_at']), + ); + } + @override + Map toJson({ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + 'conversation_id': serializer.toJson(conversationId), + 'role': serializer.toJson(role), + 'provider_id': serializer.toJson(providerId), + 'anchor_message_id': serializer.toJson(anchorMessageId), + 'anchor_created_at': serializer.toJson(anchorCreatedAt), + 'content': serializer.toJson(content), + 'status': serializer.toJson(status), + 'model': serializer.toJson(model), + 'error_text': serializer.toJson(errorText), + 'created_at': serializer.toJson(createdAt), + 'updated_at': serializer.toJson(updatedAt), + }; + } + + AiChatMessage copyWith({ + String? id, + String? conversationId, + String? role, + String? providerId, + Value anchorMessageId = const Value.absent(), + Value anchorCreatedAt = const Value.absent(), + String? content, + String? status, + Value model = const Value.absent(), + Value errorText = const Value.absent(), + DateTime? createdAt, + DateTime? updatedAt, + }) => AiChatMessage( + id: id ?? this.id, + conversationId: conversationId ?? this.conversationId, + role: role ?? this.role, + providerId: providerId ?? this.providerId, + anchorMessageId: anchorMessageId.present + ? anchorMessageId.value + : this.anchorMessageId, + anchorCreatedAt: anchorCreatedAt.present + ? anchorCreatedAt.value + : this.anchorCreatedAt, + content: content ?? this.content, + status: status ?? this.status, + model: model.present ? model.value : this.model, + errorText: errorText.present ? errorText.value : this.errorText, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + ); + AiChatMessage copyWithCompanion(AiChatMessagesCompanion data) { + return AiChatMessage( + id: data.id.present ? data.id.value : this.id, + conversationId: data.conversationId.present + ? data.conversationId.value + : this.conversationId, + role: data.role.present ? data.role.value : this.role, + providerId: data.providerId.present + ? data.providerId.value + : this.providerId, + anchorMessageId: data.anchorMessageId.present + ? data.anchorMessageId.value + : this.anchorMessageId, + anchorCreatedAt: data.anchorCreatedAt.present + ? data.anchorCreatedAt.value + : this.anchorCreatedAt, + content: data.content.present ? data.content.value : this.content, + status: data.status.present ? data.status.value : this.status, + model: data.model.present ? data.model.value : this.model, + errorText: data.errorText.present ? data.errorText.value : this.errorText, + createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, + updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, + ); + } + + @override + String toString() { + return (StringBuffer('AiChatMessage(') + ..write('id: $id, ') + ..write('conversationId: $conversationId, ') + ..write('role: $role, ') + ..write('providerId: $providerId, ') + ..write('anchorMessageId: $anchorMessageId, ') + ..write('anchorCreatedAt: $anchorCreatedAt, ') + ..write('content: $content, ') + ..write('status: $status, ') + ..write('model: $model, ') + ..write('errorText: $errorText, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt') + ..write(')')) + .toString(); + } + + @override + int get hashCode => Object.hash( + id, + conversationId, + role, + providerId, + anchorMessageId, + anchorCreatedAt, + content, + status, + model, + errorText, + createdAt, + updatedAt, + ); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is AiChatMessage && + other.id == this.id && + other.conversationId == this.conversationId && + other.role == this.role && + other.providerId == this.providerId && + other.anchorMessageId == this.anchorMessageId && + other.anchorCreatedAt == this.anchorCreatedAt && + other.content == this.content && + other.status == this.status && + other.model == this.model && + other.errorText == this.errorText && + other.createdAt == this.createdAt && + other.updatedAt == this.updatedAt); +} + +class AiChatMessagesCompanion extends UpdateCompanion { + final Value id; + final Value conversationId; + final Value role; + final Value providerId; + final Value anchorMessageId; + final Value anchorCreatedAt; + final Value content; + final Value status; + final Value model; + final Value errorText; + final Value createdAt; + final Value updatedAt; + final Value rowid; + const AiChatMessagesCompanion({ + this.id = const Value.absent(), + this.conversationId = const Value.absent(), + this.role = const Value.absent(), + this.providerId = const Value.absent(), + this.anchorMessageId = const Value.absent(), + this.anchorCreatedAt = const Value.absent(), + this.content = const Value.absent(), + this.status = const Value.absent(), + this.model = const Value.absent(), + this.errorText = const Value.absent(), + this.createdAt = const Value.absent(), + this.updatedAt = const Value.absent(), + this.rowid = const Value.absent(), + }); + AiChatMessagesCompanion.insert({ + required String id, + required String conversationId, + required String role, + required String providerId, + this.anchorMessageId = const Value.absent(), + this.anchorCreatedAt = const Value.absent(), + required String content, + required String status, + this.model = const Value.absent(), + this.errorText = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + this.rowid = const Value.absent(), + }) : id = Value(id), + conversationId = Value(conversationId), + role = Value(role), + providerId = Value(providerId), + content = Value(content), + status = Value(status), + createdAt = Value(createdAt), + updatedAt = Value(updatedAt); + static Insertable custom({ + Expression? id, + Expression? conversationId, + Expression? role, + Expression? providerId, + Expression? anchorMessageId, + Expression? anchorCreatedAt, + Expression? content, + Expression? status, + Expression? model, + Expression? errorText, + Expression? createdAt, + Expression? updatedAt, + Expression? rowid, + }) { + return RawValuesInsertable({ + if (id != null) 'id': id, + if (conversationId != null) 'conversation_id': conversationId, + if (role != null) 'role': role, + if (providerId != null) 'provider_id': providerId, + if (anchorMessageId != null) 'anchor_message_id': anchorMessageId, + if (anchorCreatedAt != null) 'anchor_created_at': anchorCreatedAt, + if (content != null) 'content': content, + if (status != null) 'status': status, + if (model != null) 'model': model, + if (errorText != null) 'error_text': errorText, + if (createdAt != null) 'created_at': createdAt, + if (updatedAt != null) 'updated_at': updatedAt, + if (rowid != null) 'rowid': rowid, + }); + } + + AiChatMessagesCompanion copyWith({ + Value? id, + Value? conversationId, + Value? role, + Value? providerId, + Value? anchorMessageId, + Value? anchorCreatedAt, + Value? content, + Value? status, + Value? model, + Value? errorText, + Value? createdAt, + Value? updatedAt, + Value? rowid, + }) { + return AiChatMessagesCompanion( + id: id ?? this.id, + conversationId: conversationId ?? this.conversationId, + role: role ?? this.role, + providerId: providerId ?? this.providerId, + anchorMessageId: anchorMessageId ?? this.anchorMessageId, + anchorCreatedAt: anchorCreatedAt ?? this.anchorCreatedAt, + content: content ?? this.content, + status: status ?? this.status, + model: model ?? this.model, + errorText: errorText ?? this.errorText, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + rowid: rowid ?? this.rowid, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = Variable(id.value); + } + if (conversationId.present) { + map['conversation_id'] = Variable(conversationId.value); + } + if (role.present) { + map['role'] = Variable(role.value); + } + if (providerId.present) { + map['provider_id'] = Variable(providerId.value); + } + if (anchorMessageId.present) { + map['anchor_message_id'] = Variable(anchorMessageId.value); + } + if (anchorCreatedAt.present) { + map['anchor_created_at'] = Variable( + AiChatMessages.$converteranchorCreatedAtn.toSql(anchorCreatedAt.value), + ); + } + if (content.present) { + map['content'] = Variable(content.value); + } + if (status.present) { + map['status'] = Variable(status.value); + } + if (model.present) { + map['model'] = Variable(model.value); + } + if (errorText.present) { + map['error_text'] = Variable(errorText.value); + } + if (createdAt.present) { + map['created_at'] = Variable( + AiChatMessages.$convertercreatedAt.toSql(createdAt.value), + ); + } + if (updatedAt.present) { + map['updated_at'] = Variable( + AiChatMessages.$converterupdatedAt.toSql(updatedAt.value), + ); + } + if (rowid.present) { + map['rowid'] = Variable(rowid.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('AiChatMessagesCompanion(') + ..write('id: $id, ') + ..write('conversationId: $conversationId, ') + ..write('role: $role, ') + ..write('providerId: $providerId, ') + ..write('anchorMessageId: $anchorMessageId, ') + ..write('anchorCreatedAt: $anchorCreatedAt, ') + ..write('content: $content, ') + ..write('status: $status, ') + ..write('model: $model, ') + ..write('errorText: $errorText, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt, ') + ..write('rowid: $rowid') + ..write(')')) + .toString(); + } +} + class InscriptionCollections extends Table with TableInfo { @override @@ -18821,6 +19550,7 @@ abstract class _$MixinDatabase extends GeneratedDatabase { late final Fiats fiats = Fiats(this); late final FavoriteApps favoriteApps = FavoriteApps(this); late final Properties properties = Properties(this); + late final AiChatMessages aiChatMessages = AiChatMessages(this); late final InscriptionCollections inscriptionCollections = InscriptionCollections(this); late final InscriptionItems inscriptionItems = InscriptionItems(this); @@ -18884,7 +19614,14 @@ abstract class _$MixinDatabase extends GeneratedDatabase { 'index_tokens_collection_hash', 'CREATE INDEX IF NOT EXISTS index_tokens_collection_hash ON tokens (collection_hash)', ); + late final Index indexAiChatMessagesConversationIdCreatedAt = Index( + 'index_ai_chat_messages_conversation_id_created_at', + 'CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages (conversation_id, created_at DESC)', + ); late final AddressDao addressDao = AddressDao(this as MixinDatabase); + late final AiChatMessageDao aiChatMessageDao = AiChatMessageDao( + this as MixinDatabase, + ); late final AppDao appDao = AppDao(this as MixinDatabase); late final AssetDao assetDao = AssetDao(this as MixinDatabase); late final CircleConversationDao circleConversationDao = @@ -19415,6 +20152,7 @@ abstract class _$MixinDatabase extends GeneratedDatabase { fiats, favoriteApps, properties, + aiChatMessages, inscriptionCollections, inscriptionItems, indexConversationsCategoryStatus, @@ -19432,6 +20170,7 @@ abstract class _$MixinDatabase extends GeneratedDatabase { indexMessagesConversationIdQuoteMessageId, indexTokensKernelAssetId, indexTokensCollectionHash, + indexAiChatMessagesConversationIdCreatedAt, ]; @override StreamQueryUpdateRules get streamUpdateRules => const StreamQueryUpdateRules([ @@ -28013,6 +28752,351 @@ typedef $PropertiesProcessedTableManager = Property, PrefetchHooks Function() >; +typedef $AiChatMessagesCreateCompanionBuilder = + AiChatMessagesCompanion Function({ + required String id, + required String conversationId, + required String role, + required String providerId, + Value anchorMessageId, + Value anchorCreatedAt, + required String content, + required String status, + Value model, + Value errorText, + required DateTime createdAt, + required DateTime updatedAt, + Value rowid, + }); +typedef $AiChatMessagesUpdateCompanionBuilder = + AiChatMessagesCompanion Function({ + Value id, + Value conversationId, + Value role, + Value providerId, + Value anchorMessageId, + Value anchorCreatedAt, + Value content, + Value status, + Value model, + Value errorText, + Value createdAt, + Value updatedAt, + Value rowid, + }); + +class $AiChatMessagesFilterComposer + extends Composer<_$MixinDatabase, AiChatMessages> { + $AiChatMessagesFilterComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnFilters get id => $composableBuilder( + column: $table.id, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get role => $composableBuilder( + column: $table.role, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get providerId => $composableBuilder( + column: $table.providerId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get anchorMessageId => $composableBuilder( + column: $table.anchorMessageId, + builder: (column) => ColumnFilters(column), + ); + + ColumnWithTypeConverterFilters + get anchorCreatedAt => $composableBuilder( + column: $table.anchorCreatedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnFilters get content => $composableBuilder( + column: $table.content, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get model => $composableBuilder( + column: $table.model, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get errorText => $composableBuilder( + column: $table.errorText, + builder: (column) => ColumnFilters(column), + ); + + ColumnWithTypeConverterFilters get createdAt => + $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnWithTypeConverterFilters get updatedAt => + $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); +} + +class $AiChatMessagesOrderingComposer + extends Composer<_$MixinDatabase, AiChatMessages> { + $AiChatMessagesOrderingComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnOrderings get id => $composableBuilder( + column: $table.id, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get role => $composableBuilder( + column: $table.role, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get providerId => $composableBuilder( + column: $table.providerId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get anchorMessageId => $composableBuilder( + column: $table.anchorMessageId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get anchorCreatedAt => $composableBuilder( + column: $table.anchorCreatedAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get content => $composableBuilder( + column: $table.content, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get model => $composableBuilder( + column: $table.model, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get errorText => $composableBuilder( + column: $table.errorText, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get createdAt => $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get updatedAt => $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnOrderings(column), + ); +} + +class $AiChatMessagesAnnotationComposer + extends Composer<_$MixinDatabase, AiChatMessages> { + $AiChatMessagesAnnotationComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + GeneratedColumn get id => + $composableBuilder(column: $table.id, builder: (column) => column); + + GeneratedColumn get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => column, + ); + + GeneratedColumn get role => + $composableBuilder(column: $table.role, builder: (column) => column); + + GeneratedColumn get providerId => $composableBuilder( + column: $table.providerId, + builder: (column) => column, + ); + + GeneratedColumn get anchorMessageId => $composableBuilder( + column: $table.anchorMessageId, + builder: (column) => column, + ); + + GeneratedColumnWithTypeConverter get anchorCreatedAt => + $composableBuilder( + column: $table.anchorCreatedAt, + builder: (column) => column, + ); + + GeneratedColumn get content => + $composableBuilder(column: $table.content, builder: (column) => column); + + GeneratedColumn get status => + $composableBuilder(column: $table.status, builder: (column) => column); + + GeneratedColumn get model => + $composableBuilder(column: $table.model, builder: (column) => column); + + GeneratedColumn get errorText => + $composableBuilder(column: $table.errorText, builder: (column) => column); + + GeneratedColumnWithTypeConverter get createdAt => + $composableBuilder(column: $table.createdAt, builder: (column) => column); + + GeneratedColumnWithTypeConverter get updatedAt => + $composableBuilder(column: $table.updatedAt, builder: (column) => column); +} + +class $AiChatMessagesTableManager + extends + RootTableManager< + _$MixinDatabase, + AiChatMessages, + AiChatMessage, + $AiChatMessagesFilterComposer, + $AiChatMessagesOrderingComposer, + $AiChatMessagesAnnotationComposer, + $AiChatMessagesCreateCompanionBuilder, + $AiChatMessagesUpdateCompanionBuilder, + ( + AiChatMessage, + BaseReferences<_$MixinDatabase, AiChatMessages, AiChatMessage>, + ), + AiChatMessage, + PrefetchHooks Function() + > { + $AiChatMessagesTableManager(_$MixinDatabase db, AiChatMessages table) + : super( + TableManagerState( + db: db, + table: table, + createFilteringComposer: () => + $AiChatMessagesFilterComposer($db: db, $table: table), + createOrderingComposer: () => + $AiChatMessagesOrderingComposer($db: db, $table: table), + createComputedFieldComposer: () => + $AiChatMessagesAnnotationComposer($db: db, $table: table), + updateCompanionCallback: + ({ + Value id = const Value.absent(), + Value conversationId = const Value.absent(), + Value role = const Value.absent(), + Value providerId = const Value.absent(), + Value anchorMessageId = const Value.absent(), + Value anchorCreatedAt = const Value.absent(), + Value content = const Value.absent(), + Value status = const Value.absent(), + Value model = const Value.absent(), + Value errorText = const Value.absent(), + Value createdAt = const Value.absent(), + Value updatedAt = const Value.absent(), + Value rowid = const Value.absent(), + }) => AiChatMessagesCompanion( + id: id, + conversationId: conversationId, + role: role, + providerId: providerId, + anchorMessageId: anchorMessageId, + anchorCreatedAt: anchorCreatedAt, + content: content, + status: status, + model: model, + errorText: errorText, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + createCompanionCallback: + ({ + required String id, + required String conversationId, + required String role, + required String providerId, + Value anchorMessageId = const Value.absent(), + Value anchorCreatedAt = const Value.absent(), + required String content, + required String status, + Value model = const Value.absent(), + Value errorText = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + Value rowid = const Value.absent(), + }) => AiChatMessagesCompanion.insert( + id: id, + conversationId: conversationId, + role: role, + providerId: providerId, + anchorMessageId: anchorMessageId, + anchorCreatedAt: anchorCreatedAt, + content: content, + status: status, + model: model, + errorText: errorText, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + withReferenceMapper: (p0) => p0 + .map((e) => (e.readTable(table), BaseReferences(db, table, e))) + .toList(), + prefetchHooksCallback: null, + ), + ); +} + +typedef $AiChatMessagesProcessedTableManager = + ProcessedTableManager< + _$MixinDatabase, + AiChatMessages, + AiChatMessage, + $AiChatMessagesFilterComposer, + $AiChatMessagesOrderingComposer, + $AiChatMessagesAnnotationComposer, + $AiChatMessagesCreateCompanionBuilder, + $AiChatMessagesUpdateCompanionBuilder, + ( + AiChatMessage, + BaseReferences<_$MixinDatabase, AiChatMessages, AiChatMessage>, + ), + AiChatMessage, + PrefetchHooks Function() + >; typedef $InscriptionCollectionsCreateCompanionBuilder = InscriptionCollectionsCompanion Function({ required String collectionHash, @@ -28631,6 +29715,8 @@ class $MixinDatabaseManager { $FavoriteAppsTableManager(_db, _db.favoriteApps); $PropertiesTableManager get properties => $PropertiesTableManager(_db, _db.properties); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db, _db.aiChatMessages); $InscriptionCollectionsTableManager get inscriptionCollections => $InscriptionCollectionsTableManager(_db, _db.inscriptionCollections); $InscriptionItemsTableManager get inscriptionItems => diff --git a/lib/db/moor/mixin.drift b/lib/db/moor/mixin.drift index 52108ef196..805839e0b3 100644 --- a/lib/db/moor/mixin.drift +++ b/lib/db/moor/mixin.drift @@ -73,6 +73,22 @@ CREATE TABLE chains (chain_id TEXT NOT NULL, name TEXT NOT NULL, symbol TEXT NOT CREATE TABLE properties ("key" TEXT NOT NULL, "group" TEXT NOT NULL MAPPED BY `const PropertyGroupConverter()`, "value" TEXT NOT NULL, PRIMARY KEY("key", "group")); +CREATE TABLE ai_chat_messages ( + id TEXT NOT NULL, + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + provider_id TEXT NOT NULL, + anchor_message_id TEXT, + anchor_created_at INTEGER MAPPED BY `const MillisDateConverter()`, + content TEXT NOT NULL, + status TEXT NOT NULL, + model TEXT, + error_text TEXT, + created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + PRIMARY KEY(id) +); + CREATE TABLE safe_snapshots ( snapshot_id TEXT NOT NULL, type TEXT NOT NULL, @@ -152,4 +168,5 @@ CREATE INDEX IF NOT EXISTS index_messages_conversation_id_category_created_at ON CREATE INDEX IF NOT EXISTS index_message_conversation_id_status_user_id ON messages(conversation_id, status, user_id); CREATE INDEX IF NOT EXISTS index_messages_conversation_id_quote_message_id ON messages(conversation_id, quote_message_id); CREATE INDEX IF NOT EXISTS index_tokens_kernel_asset_id ON tokens(kernel_asset_id); -CREATE INDEX IF NOT EXISTS index_tokens_collection_hash ON tokens(collection_hash); \ No newline at end of file +CREATE INDEX IF NOT EXISTS index_tokens_collection_hash ON tokens(collection_hash); +CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages(conversation_id, created_at DESC); diff --git a/lib/ui/home/bloc/message_bloc.dart b/lib/ui/home/bloc/message_bloc.dart index e606123881..961884e437 100644 --- a/lib/ui/home/bloc/message_bloc.dart +++ b/lib/ui/home/bloc/message_bloc.dart @@ -25,6 +25,25 @@ abstract class _MessageEvent extends Equatable { List get props => []; } +class ChatTimelineItem extends Equatable { + const ChatTimelineItem.message(this.message) : aiMessage = null; + + const ChatTimelineItem.ai(this.aiMessage) : message = null; + + final MessageItem? message; + final AiChatMessage? aiMessage; + + bool get isMessage => message != null; + bool get isAiMessage => aiMessage != null; + + String get id => message?.messageId ?? aiMessage!.id; + + DateTime get createdAt => message?.createdAt ?? aiMessage!.createdAt; + + @override + List get props => [message, aiMessage]; +} + class _MessageJumpCurrentEvent extends _MessageEvent {} class _MessageInitEvent extends _MessageEvent { @@ -73,11 +92,21 @@ class _MessageDeleteEvent extends _MessageEvent { List get props => [messageId]; } +class _AiMessagesChangedEvent extends _MessageEvent { + _AiMessagesChangedEvent(this.data); + + final List data; + + @override + List get props => [data]; +} + class MessageState extends Equatable { MessageState({ this.top = const [], this.center, this.bottom = const [], + this.aiMessages = const [], this.conversationId, this.isLatest = false, this.isOldest = false, @@ -101,6 +130,7 @@ class MessageState extends Equatable { final List top; final MessageItem? center; final List bottom; + final List aiMessages; final bool isLatest; final bool isOldest; final String? lastReadMessageId; @@ -112,6 +142,7 @@ class MessageState extends Equatable { top, center, bottom, + aiMessages, isLatest, isOldest, lastReadMessageId, @@ -132,11 +163,86 @@ class MessageState extends Equatable { ...bottom, ]; + List get visibleAiMessages { + if (aiMessages.isEmpty) return const []; + + final messages = list; + if (messages.isEmpty) { + return aiMessages.toList()..sort(_compareAiMessages); + } + + final messageIds = messages.map((message) => message.messageId).toSet(); + final start = topMessage?.createdAt; + final end = bottomMessage?.createdAt; + + bool inLoadedRange(DateTime? value) { + if (value == null || start == null || end == null) return false; + return !value.isBefore(start) && !value.isAfter(end); + } + + final visible = aiMessages.where((message) { + final anchorMessageId = message.anchorMessageId; + if (anchorMessageId != null && messageIds.contains(anchorMessageId)) { + return true; + } + return inLoadedRange(message.anchorCreatedAt) || + inLoadedRange(message.createdAt); + }).toList()..sort(_compareAiMessages); + + return visible; + } + + List get timeline { + final messages = list; + final visibleAi = visibleAiMessages; + + if (messages.isEmpty) { + return visibleAi.map(ChatTimelineItem.ai).toList(); + } + + final anchored = >{}; + final floating = []; + + for (final aiMessage in visibleAi) { + final anchorMessageId = aiMessage.anchorMessageId; + if (anchorMessageId != null) { + anchored.putIfAbsent(anchorMessageId, () => []).add(aiMessage); + } else { + floating.add(aiMessage); + } + } + + final result = []; + var floatingIndex = 0; + + for (final message in messages) { + while (floatingIndex < floating.length && + !floating[floatingIndex].createdAt.isAfter(message.createdAt)) { + result.add(ChatTimelineItem.ai(floating[floatingIndex])); + floatingIndex++; + } + + result.add(ChatTimelineItem.message(message)); + final anchoredMessages = anchored[message.messageId]; + if (anchoredMessages != null) { + result.addAll(anchoredMessages.map(ChatTimelineItem.ai)); + } + } + + while (floatingIndex < floating.length) { + result.add(ChatTimelineItem.ai(floating[floatingIndex])); + floatingIndex++; + } + + return result; + } + MessageState copyWith({ String? conversationId, List? top, MessageItem? center, List? bottom, + List? aiMessages, bool? isLatest, bool? isOldest, String? lastReadMessageId, @@ -146,6 +252,7 @@ class MessageState extends Equatable { top: top ?? this.top, center: center ?? this.center, bottom: bottom ?? this.bottom, + aiMessages: aiMessages ?? this.aiMessages, isLatest: isLatest ?? this.isLatest, isOldest: isOldest ?? this.isOldest, lastReadMessageId: lastReadMessageId ?? this.lastReadMessageId, @@ -154,6 +261,7 @@ class MessageState extends Equatable { MessageState _copyWithJumpCurrentState() => MessageState( top: list.toList(), + aiMessages: aiMessages, refreshKey: Object(), conversationId: conversationId, isLatest: isLatest, @@ -167,6 +275,7 @@ class MessageState extends Equatable { conversationId: conversationId, top: top, bottom: bottom, + aiMessages: aiMessages, isLatest: isLatest, isOldest: isOldest, lastReadMessageId: lastReadMessageId, @@ -189,6 +298,26 @@ class MessageState extends Equatable { } } +int _compareAiMessages(AiChatMessage a, AiChatMessage b) { + final anchorCompare = _compareNullableDateTime( + a.anchorCreatedAt, + b.anchorCreatedAt, + ); + if (anchorCompare != 0) return anchorCompare; + + final createdAtCompare = a.createdAt.compareTo(b.createdAt); + if (createdAtCompare != 0) return createdAtCompare; + + return a.id.compareTo(b.id); +} + +int _compareNullableDateTime(DateTime? a, DateTime? b) { + if (a == null && b == null) return 0; + if (a == null) return 1; + if (b == null) return -1; + return a.compareTo(b); +} + class MessageBloc extends Bloc<_MessageEvent, MessageState> with SubscribeMixin { MessageBloc({ @@ -218,6 +347,10 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> _onEvent, transformer: sequential(), ); + on<_AiMessagesChangedEvent>( + _onEvent, + transformer: restartable(), + ); on<_MessageScrollEvent>( _onEvent, transformer: restartable(), @@ -268,6 +401,22 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> .listen((state) => add(_MessageInsertOrReplaceEvent(state))), ); + addSubscription( + conversationNotifier.stream + .startWith(conversationNotifier.state) + .map((event) => event?.conversationId) + .distinct() + .switchMap((conversationId) { + if (conversationId == null) { + return Stream.value(const []); + } + return database.aiChatMessageDao.watchConversationMessages( + conversationId, + ); + }) + .listen((state) => add(_AiMessagesChangedEvent(state))), + ); + addSubscription( DataBaseEventBus.instance.deleteMessageIdStream.listen((messageIds) { messageIds.forEach((messageId) { @@ -325,6 +474,8 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> } else if (event is _MessageDeleteEvent) { final messageState = state.removeMessage(event.messageId); emit(_pretreatment(messageState)); + } else if (event is _AiMessagesChangedEvent) { + emit(_pretreatment(state.copyWith(aiMessages: event.data))); } else { if (event is _MessageLoadMoreEvent) { if (event is _MessageLoadAfterEvent) { @@ -408,12 +559,16 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> limit, centerMessageId: _centerMessageId, ); + final aiMessages = await database.aiChatMessageDao.conversationMessages( + conversationId, + ); return state.copyWith( conversationId: conversationId, center: state.center, bottom: state.bottom, top: state.top, + aiMessages: aiMessages, ); } @@ -429,6 +584,7 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> return MessageState( top: list.reversed.toList(), + aiMessages: state.aiMessages, isLatest: true, isOldest: list.length < limit, ); @@ -468,6 +624,7 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> top: topList, center: center, bottom: bottomList, + aiMessages: state.aiMessages, isLatest: isLatest, isOldest: isOldest, ); diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index 8d453bd22e..2b441c6377 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -21,6 +21,7 @@ import '../../../utils/extension/extension.dart'; import '../../../utils/hook.dart'; import '../../../widgets/action_button.dart'; import '../../../widgets/actions/actions.dart'; +import '../../../widgets/ai/ai_message_card.dart'; import '../../../widgets/animated_visibility.dart'; import '../../../widgets/clamping_custom_scroll_view/clamping_custom_scroll_view.dart'; import '../../../widgets/conversation/mute_dialog.dart'; @@ -570,20 +571,33 @@ class _List extends HookConsumerWidget { final state = useBlocState( when: (state) => state.conversationId != null, ); - final key = ValueKey((state.conversationId, state.refreshKey)); - final top = state.top; final center = state.center; - final bottom = state.bottom; + final timeline = state.timeline; - final ref = useRef>({}); + final centerTimelineIndex = center == null + ? null + : timeline.indexWhere( + (item) => item.message?.messageId == center.messageId, + ); + final topTimeline = centerTimelineIndex == null + ? timeline + : timeline.take(centerTimelineIndex).toList(); + final centerTimeline = centerTimelineIndex == null + ? null + : timeline[centerTimelineIndex]; + final bottomTimeline = centerTimelineIndex == null + ? const [] + : timeline.skip(centerTimelineIndex + 1).toList(); + + final keyRef = useRef>({}); final ids = state.list.map((e) => e.messageId); useMemoized(() { - ref.value.removeWhere((key, value) => !ids.contains(key)); + keyRef.value.removeWhere((key, value) => !ids.contains(key)); ids.forEach((id) { - ref.value[id] = ref.value[id] ?? GlobalKey(debugLabel: id); + keyRef.value[id] = keyRef.value[id] ?? GlobalKey(debugLabel: id); }); }, [ids]); @@ -596,6 +610,49 @@ class _List extends HookConsumerWidget { context, ).scrollController; + MessageItem? prevMessageOf( + ChatTimelineItem item, + List items, + ) { + final index = items.indexOf(item); + if (index <= 0) return null; + for (var i = index - 1; i >= 0; i--) { + final message = items[i].message; + if (message != null) return message; + } + return null; + } + + MessageItem? nextMessageOf( + ChatTimelineItem item, + List items, + ) { + final index = items.indexOf(item); + if (index == -1 || index >= items.length - 1) return null; + for (var i = index + 1; i < items.length; i++) { + final message = items[i].message; + if (message != null) return message; + } + return null; + } + + Widget buildTimelineChild(ChatTimelineItem item) { + if (item.isAiMessage) { + return AiMessageCard( + key: ValueKey('ai-${item.id}'), + message: item.aiMessage!, + ); + } + final message = item.message!; + return MessageItemWidget( + key: keyRef.value[message.messageId], + prev: prevMessageOf(item, timeline), + message: message, + next: nextMessageOf(item, timeline), + lastReadMessageId: state.lastReadMessageId, + ); + } + return MessageDayTimeViewportWidget.chatPage( key: key, bottomKey: bottomKey, @@ -604,7 +661,7 @@ class _List extends HookConsumerWidget { scrollController: scrollController, centerKey: center == null ? null - : ref.value[center.messageId] as GlobalKey?, + : keyRef.value[center.messageId] as GlobalKey?, child: ClampingCustomScrollView( key: key, center: key, @@ -618,50 +675,28 @@ class _List extends HookConsumerWidget { context, index, ) { - final actualIndex = top.length - index - 1; - final messageItem = top[actualIndex]; - return MessageItemWidget( - key: ref.value[messageItem.messageId], - prev: top.getOrNull(actualIndex - 1), - message: messageItem, - next: - top.getOrNull(actualIndex + 1) ?? - center ?? - bottom.lastOrNull, - lastReadMessageId: state.lastReadMessageId, - ); - }, childCount: top.length), + final actualIndex = topTimeline.length - index - 1; + return buildTimelineChild(topTimeline[actualIndex]); + }, childCount: topTimeline.length), ), SliverToBoxAdapter( key: key, child: Builder( builder: (context) { - if (center == null) return const SizedBox(); - return MessageItemWidget( - key: ref.value[center.messageId], - prev: top.lastOrNull, - message: center, - next: bottom.firstOrNull, - lastReadMessageId: state.lastReadMessageId, - ); + if (centerTimeline == null) return const SizedBox(); + return buildTimelineChild(centerTimeline); }, ), ), SliverList( key: bottomKey, - delegate: SliverChildBuilderDelegate(( - context, - index, - ) { - final messageItem = bottom[index]; - return MessageItemWidget( - key: ref.value[messageItem.messageId], - prev: bottom.getOrNull(index - 1) ?? center ?? top.lastOrNull, - message: messageItem, - next: bottom.getOrNull(index + 1), - lastReadMessageId: state.lastReadMessageId, - ); - }, childCount: bottom.length), + delegate: SliverChildBuilderDelegate( + ( + context, + index, + ) => buildTimelineChild(bottomTimeline[index]), + childCount: bottomTimeline.length, + ), ), const SliverToBoxAdapter(child: SizedBox(height: 10)), ], diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 4dde511bc4..ae5b73f1a9 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -18,6 +18,8 @@ import 'package:rxdart/rxdart.dart'; import 'package:simple_animations/simple_animations.dart'; import 'package:super_context_menu/super_context_menu.dart'; +import '../../../ai/ai_chat_controller.dart'; +import '../../../ai/model/ai_mode_state.dart'; import '../../../constants/constants.dart'; import '../../../constants/icon_fonts.dart'; import '../../../constants/resources.dart'; @@ -38,11 +40,13 @@ import '../../../widgets/hover_overlay.dart'; import '../../../widgets/mention_panel.dart'; import '../../../widgets/menu.dart'; import '../../../widgets/message/item/quote_message.dart'; +import '../../../widgets/message/message_bubble.dart'; import '../../../widgets/sticker_page/bloc/cubit/sticker_albums_cubit.dart'; import '../../../widgets/sticker_page/sticker_page.dart'; import '../../../widgets/toast.dart'; import '../../../widgets/user_selector/conversation_selector.dart'; import '../../provider/abstract_responsive_navigator.dart'; +import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; import '../../provider/mention_cache_provider.dart'; import '../../provider/mention_provider.dart'; @@ -100,6 +104,9 @@ class _InputContainer extends HookConsumerWidget { (value) => (value?.conversationId, value?.conversation?.draft), ), ); + final aiModeState = ref.watch(aiInputModeProvider(conversationId ?? '')); + final selectedAiProvider = + context.database.settingProperties.selectedAiProvider; final quoteMessageId = ref.watch(quoteMessageIdProvider); @@ -197,6 +204,12 @@ class _InputContainer extends HookConsumerWidget { mainAxisAlignment: MainAxisAlignment.end, children: [ const _QuoteMessage(), + if (conversationId != null) + _AiModeBar( + conversationId: conversationId, + aiModeState: aiModeState, + providerName: selectedAiProvider?.name, + ), ConstrainedBox( constraints: const BoxConstraints(minHeight: 56), child: Container( @@ -223,6 +236,7 @@ class _InputContainer extends HookConsumerWidget { ), const SizedBox(width: 16), _AnimatedSendOrVoiceButton( + conversationId: conversationId, textEditingController: textEditingController, textEditingValueStream: textEditingValueStream, ), @@ -239,10 +253,12 @@ class _InputContainer extends HookConsumerWidget { class _AnimatedSendOrVoiceButton extends HookConsumerWidget { const _AnimatedSendOrVoiceButton({ + required this.conversationId, required this.textEditingValueStream, required this.textEditingController, }); + final String? conversationId; final Stream textEditingValueStream; final TextEditingController textEditingController; @@ -310,6 +326,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { callback: () => _sendMessage( context, textEditingController, + conversationId: conversationId, silent: true, ), ), @@ -318,7 +335,11 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { child: ActionButton( name: Resources.assetsImagesIcSendSvg, color: context.theme.icon, - onTap: () => _sendMessage(context, textEditingController), + onTap: () => _sendMessage( + context, + textEditingController, + conversationId: conversationId, + ), ), ), ), @@ -367,10 +388,12 @@ void _sendPostMessage( void _sendMessage( BuildContext context, TextEditingController textEditingController, { + required String? conversationId, bool silent = false, }) { final text = textEditingController.value.text.trim(); if (text.isEmpty) return; + if (conversationId == null) return; final conversationItem = context.providerContainer.read(conversationProvider); if (conversationItem == null) return; @@ -379,6 +402,62 @@ void _sendMessage( return; } + final aiModeController = context.providerContainer.read( + aiInputModeProvider(conversationId).notifier, + ); + final aiModeState = context.providerContainer.read( + aiInputModeProvider(conversationId), + ); + + if (text == '/ai') { + final provider = context.database.settingProperties.selectedAiProvider; + if (provider == null) { + showToastFailed(ToastError('Please add an AI provider first')); + return; + } + aiModeController.enter(providerId: provider.id); + textEditingController.text = ''; + return; + } + + final inlineAiInput = text.startsWith('/ai ') + ? text.substring(4).trim() + : null; + if (inlineAiInput != null && inlineAiInput.isNotEmpty) { + final provider = context.database.settingProperties.selectedAiProvider; + if (provider == null) { + showToastFailed(ToastError('Please add an AI provider first')); + return; + } + aiModeController.enter(providerId: provider.id); + textEditingController.text = ''; + unawaited( + AiChatController(context.database) + .send( + conversationId: conversationId, + input: inlineAiInput, + provider: provider, + ) + .catchError((Object error, StackTrace _) => showToastFailed(error)), + ); + return; + } + + if (aiModeState.enabled) { + final provider = context.database.settingProperties.selectedAiProvider; + if (provider == null) { + showToastFailed(ToastError('Please add an AI provider first')); + return; + } + textEditingController.text = ''; + unawaited( + AiChatController(context.database) + .send(conversationId: conversationId, input: text, provider: provider) + .catchError((Object error, StackTrace _) => showToastFailed(error)), + ); + return; + } + context.accountServer.sendTextMessage( text, conversationItem.encryptCategory, @@ -479,7 +558,11 @@ class _SendTextField extends HookConsumerWidget { }, actions: { _SendMessageIntent: CallbackAction( - onInvoke: (intent) => _sendMessage(context, textEditingController), + onInvoke: (intent) => _sendMessage( + context, + textEditingController, + conversationId: ref.read(currentConversationIdProvider), + ), ), PasteTextIntent: _PasteContextAction(context), _SendPostMessageIntent: CallbackAction( @@ -521,7 +604,7 @@ class _SendTextField extends HookConsumerWidget { child: Text( isEncryptConversation ? context.l10n.chatHintE2e - : context.l10n.typeMessage, + : 'Type message or /ai', style: TextStyle( color: context.theme.secondaryText, fontSize: 14, @@ -539,6 +622,52 @@ class _SendTextField extends HookConsumerWidget { } } +class _AiModeBar extends HookConsumerWidget { + const _AiModeBar({ + required this.conversationId, + required this.aiModeState, + required this.providerName, + }); + + final String conversationId; + final AiModeState aiModeState; + final String? providerName; + + @override + Widget build(BuildContext context, WidgetRef ref) { + if (!aiModeState.enabled) return const SizedBox(); + return Container( + width: double.infinity, + color: context.theme.primary, + padding: const EdgeInsets.fromLTRB(16, 10, 16, 0), + child: Container( + padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 10), + decoration: BoxDecoration( + color: context.messageBubbleColor(false), + borderRadius: const BorderRadius.all(Radius.circular(10)), + ), + child: Row( + children: [ + Expanded( + child: Text( + 'AI Mode · ${providerName ?? 'No Provider'}', + style: TextStyle(color: context.theme.text, fontSize: 14), + ), + ), + ActionButton( + name: Resources.assetsImagesIcCloseSvg, + color: context.theme.icon, + size: 18, + onTap: () => + ref.read(aiInputModeProvider(conversationId).notifier).exit(), + ), + ], + ), + ), + ); + } +} + class _QuoteMessage extends HookConsumerWidget { const _QuoteMessage(); diff --git a/lib/ui/provider/ai_input_mode_provider.dart b/lib/ui/provider/ai_input_mode_provider.dart new file mode 100644 index 0000000000..44cf77980d --- /dev/null +++ b/lib/ui/provider/ai_input_mode_provider.dart @@ -0,0 +1,20 @@ +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../ai/model/ai_mode_state.dart'; + +class AiInputModeNotifier extends StateNotifier { + AiInputModeNotifier() : super(const AiModeState()); + + void enter({String? providerId}) { + state = AiModeState(enabled: true, providerId: providerId); + } + + void exit() { + state = const AiModeState(); + } +} + +final aiInputModeProvider = StateNotifierProvider.autoDispose + .family( + (ref, _) => AiInputModeNotifier(), + ); diff --git a/lib/ui/provider/responsive_navigator_provider.dart b/lib/ui/provider/responsive_navigator_provider.dart index 94e1c0c156..5fd1a32712 100644 --- a/lib/ui/provider/responsive_navigator_provider.dart +++ b/lib/ui/provider/responsive_navigator_provider.dart @@ -5,6 +5,7 @@ import '../home/chat/chat_page.dart'; import '../setting/about_page.dart'; import '../setting/account_delete_page.dart'; import '../setting/account_page.dart'; +import '../setting/ai_settings_page.dart'; import '../setting/appearance_page.dart'; import '../setting/backup_page.dart'; import '../setting/edit_profile_page.dart'; @@ -31,6 +32,7 @@ class ResponsiveNavigatorStateNotifier static const chatBackupPage = 'chatBackupPage'; static const dataAndStorageUsagePage = 'dataAndStorageUsagePage'; static const appearancePage = 'appearancePage'; + static const aiSettingsPage = 'aiSettingsPage'; static const aboutPage = 'aboutPage'; static const storageUsage = 'storageUsage'; static const storageUsageDetail = 'storageUsageDetail'; @@ -43,6 +45,7 @@ class ResponsiveNavigatorStateNotifier chatBackupPage, dataAndStorageUsagePage, appearancePage, + aiSettingsPage, aboutPage, storageUsage, storageUsageDetail, @@ -117,6 +120,12 @@ class ResponsiveNavigatorStateNotifier name: appearancePage, child: AppearancePage(key: ValueKey(appearancePage)), ); + case aiSettingsPage: + return const MaterialPage( + key: ValueKey(aiSettingsPage), + name: aiSettingsPage, + child: AiSettingsPage(key: ValueKey(aiSettingsPage)), + ); case accountPage: return const MaterialPage( key: ValueKey(accountPage), diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart new file mode 100644 index 0000000000..4618d74089 --- /dev/null +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -0,0 +1,173 @@ +import 'package:flutter/material.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; +import 'package:uuid/uuid.dart'; + +import '../../ai/model/ai_provider_config.dart'; +import '../../ai/model/ai_provider_type.dart'; +import '../../utils/extension/extension.dart'; +import '../../widgets/app_bar.dart'; +import '../../widgets/cell.dart'; +import '../../widgets/toast.dart'; +import '../provider/database_provider.dart'; + +class AiProviderEditPage extends HookConsumerWidget { + const AiProviderEditPage({super.key, this.initial}); + + final AiProviderConfig? initial; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + final nameController = useTextEditingController(text: initial?.name ?? ''); + final baseUrlController = useTextEditingController( + text: initial?.baseUrl ?? '', + ); + final apiKeyController = useTextEditingController( + text: initial?.apiKey ?? '', + ); + final modelController = useTextEditingController( + text: initial?.model ?? '', + ); + final providerType = useState( + initial?.type ?? AiProviderType.openaiCompatible, + ); + + return Scaffold( + backgroundColor: context.theme.background, + appBar: MixinAppBar( + title: Text(initial == null ? 'Add AI Provider' : 'Edit AI Provider'), + actions: [ + TextButton( + onPressed: () { + final name = nameController.text.trim(); + final baseUrl = baseUrlController.text.trim(); + final apiKey = apiKeyController.text.trim(); + final model = modelController.text.trim(); + if (name.isEmpty || + baseUrl.isEmpty || + apiKey.isEmpty || + model.isEmpty) { + showToastFailed(ToastError('Please complete all fields')); + return; + } + + final provider = + (initial ?? + AiProviderConfig( + id: const Uuid().v4(), + name: name, + type: providerType.value, + baseUrl: baseUrl, + apiKey: apiKey, + model: model, + )) + .copyWith( + name: name, + type: providerType.value, + baseUrl: baseUrl, + apiKey: apiKey, + model: model, + ); + database.settingProperties.saveAiProvider(provider); + Navigator.of(context).pop(); + }, + child: Text( + 'Save', + style: TextStyle(color: context.theme.accent, fontSize: 16), + ), + ), + ], + ), + body: Align( + alignment: Alignment.topCenter, + child: SingleChildScrollView( + child: CellGroup( + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: Column( + children: [ + _TextFieldCell( + title: 'Display Name', + controller: nameController, + ), + _ProviderTypeCell( + value: providerType.value, + onChanged: (value) => providerType.value = value, + ), + _TextFieldCell( + title: 'Base URL', + controller: baseUrlController, + ), + _TextFieldCell( + title: 'API Key', + controller: apiKeyController, + obscureText: true, + ), + _TextFieldCell(title: 'Model', controller: modelController), + ], + ), + ), + ), + ), + ); + } +} + +class _ProviderTypeCell extends StatelessWidget { + const _ProviderTypeCell({required this.value, required this.onChanged}); + + final AiProviderType value; + final ValueChanged onChanged; + + @override + Widget build(BuildContext context) => CellItem( + title: const Text('Provider Type'), + trailing: DropdownButtonHideUnderline( + child: DropdownButton( + value: value, + onChanged: (value) { + if (value != null) onChanged(value); + }, + items: AiProviderType.values + .map( + (type) => DropdownMenuItem( + value: type, + child: Text( + type == AiProviderType.anthropic + ? 'Anthropic' + : 'OpenAI Compatible', + ), + ), + ) + .toList(), + ), + ), + ); +} + +class _TextFieldCell extends StatelessWidget { + const _TextFieldCell({ + required this.title, + required this.controller, + this.obscureText = false, + }); + + final String title; + final TextEditingController controller; + final bool obscureText; + + @override + Widget build(BuildContext context) => CellItem( + title: TextField( + controller: controller, + obscureText: obscureText, + style: TextStyle(color: context.theme.text, fontSize: 16), + decoration: InputDecoration( + border: InputBorder.none, + hintText: title, + hintStyle: TextStyle(color: context.theme.secondaryText), + ), + ), + trailing: const SizedBox.shrink(), + ); +} diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart new file mode 100644 index 0000000000..f7c3b088ee --- /dev/null +++ b/lib/ui/setting/ai_settings_page.dart @@ -0,0 +1,113 @@ +import 'package:flutter/material.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../utils/extension/extension.dart'; +import '../../widgets/app_bar.dart'; +import '../../widgets/cell.dart'; +import '../../widgets/toast.dart'; +import '../provider/database_provider.dart'; +import 'ai_provider_edit_page.dart'; + +class AiSettingsPage extends HookConsumerWidget { + const AiSettingsPage({super.key}); + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + useListenable(database.settingProperties); + final providers = database.settingProperties.aiProviders; + final selectedId = database.settingProperties.selectedAiProviderId; + + return Scaffold( + backgroundColor: context.theme.background, + appBar: MixinAppBar( + title: const Text('AI Settings'), + actions: [ + TextButton( + onPressed: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => const AiProviderEditPage(), + ), + ), + child: Text( + 'Add', + style: TextStyle(color: context.theme.accent, fontSize: 16), + ), + ), + ], + ), + body: Align( + alignment: Alignment.topCenter, + child: SingleChildScrollView( + child: Column( + children: [ + if (providers.isEmpty) + Padding( + padding: const EdgeInsets.all(24), + child: Text( + 'No AI provider configured yet.', + style: TextStyle(color: context.theme.secondaryText), + ), + ) + else + CellGroup( + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: Column( + children: providers.map((provider) { + final selected = provider.id == selectedId; + return CellItem( + title: Text(provider.name), + description: Text(provider.model), + selected: selected, + onTap: () => + database.settingProperties.selectedAiProviderId = + provider.id, + trailing: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Switch( + value: provider.enabled, + onChanged: (value) { + database.settingProperties.saveAiProvider( + provider.copyWith(enabled: value), + ); + }, + ), + IconButton( + onPressed: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => + AiProviderEditPage(initial: provider), + ), + ), + icon: Icon( + Icons.edit_outlined, + color: context.theme.icon, + ), + ), + IconButton( + onPressed: () { + database.settingProperties.removeAiProvider( + provider.id, + ); + showToastSuccessful(); + }, + icon: Icon( + Icons.delete_outline, + color: context.theme.red, + ), + ), + ], + ), + ); + }).toList(), + ), + ), + ], + ), + ), + ), + ); + } +} diff --git a/lib/ui/setting/setting_page.dart b/lib/ui/setting/setting_page.dart index 3e66506620..9481c60ee3 100644 --- a/lib/ui/setting/setting_page.dart +++ b/lib/ui/setting/setting_page.dart @@ -138,6 +138,13 @@ class SettingPage extends HookConsumerWidget { ResponsiveNavigatorStateNotifier.appearancePage, title: context.l10n.appearance, ), + _Item( + leadingAssetName: + Resources.assetsImagesIcAppearanceSvg, + pageName: + ResponsiveNavigatorStateNotifier.aiSettingsPage, + title: 'AI Settings', + ), _Item( leadingAssetName: Resources.assetsImagesIcAboutSvg, pageName: diff --git a/lib/utils/property/setting_property.dart b/lib/utils/property/setting_property.dart index 5b6c512fe2..fa746a2e50 100644 --- a/lib/utils/property/setting_property.dart +++ b/lib/utils/property/setting_property.dart @@ -2,6 +2,7 @@ import 'dart:convert'; import 'package:mixin_logger/mixin_logger.dart'; +import '../../ai/model/ai_provider_config.dart'; import '../../db/dao/property_dao.dart'; import '../../db/util/property_storage.dart'; import '../../enum/property_group.dart'; @@ -11,6 +12,8 @@ import '../proxy.dart'; const _kEnableProxyKey = 'enable_proxy'; const _kSelectedProxyKey = 'selected_proxy'; const _kProxyListKey = 'proxy_list'; +const _kAiProviderListKey = 'ai_provider_list'; +const _kSelectedAiProviderKey = 'selected_ai_provider'; class SettingPropertyStorage extends PropertyStorage { SettingPropertyStorage(PropertyDao dao) : super(PropertyGroup.setting, dao); @@ -63,4 +66,65 @@ class SettingPropertyStorage extends PropertyStorage { final list = proxyList.where((element) => element.id != id).toList(); set(_kProxyListKey, jsonEncode(list)); } + + List get aiProviders { + final json = get(_kAiProviderListKey); + if (json == null || json.isEmpty) { + return []; + } + try { + final list = jsonDecode(json) as List; + return list + .cast>() + .map(AiProviderConfig.fromJson) + .toList(); + } catch (error, stacktrace) { + e('load aiProviders error: $error, $stacktrace'); + } + return []; + } + + String? get selectedAiProviderId => get(_kSelectedAiProviderKey); + + set selectedAiProviderId(String? value) => + set(_kSelectedAiProviderKey, value); + + AiProviderConfig? get selectedAiProvider { + final providers = aiProviders.where((element) => element.enabled).toList(); + if (providers.isEmpty) { + return null; + } + final selectedId = selectedAiProviderId; + if (selectedId == null) { + return providers.first; + } + return providers.firstWhereOrNull((element) => element.id == selectedId) ?? + providers.first; + } + + void saveAiProvider(AiProviderConfig config) { + final providers = aiProviders; + final index = providers.indexWhere((element) => element.id == config.id); + if (index >= 0) { + providers[index] = config; + } else { + providers.add(config); + } + set( + _kAiProviderListKey, + jsonEncode(providers.map((element) => element.toJson()).toList()), + ); + selectedAiProviderId ??= config.id; + } + + void removeAiProvider(String id) { + final providers = aiProviders.where((element) => element.id != id).toList(); + set( + _kAiProviderListKey, + jsonEncode(providers.map((element) => element.toJson()).toList()), + ); + if (selectedAiProviderId == id) { + selectedAiProviderId = providers.firstOrNull?.id; + } + } } diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart new file mode 100644 index 0000000000..f1a0b5b3a7 --- /dev/null +++ b/lib/widgets/ai/ai_message_card.dart @@ -0,0 +1,75 @@ +import 'package:flutter/material.dart'; + +import '../../db/mixin_database.dart' hide Offset; +import '../../utils/extension/extension.dart'; +import '../markdown.dart'; +import '../message/message_bubble.dart'; + +class AiMessageCard extends StatelessWidget { + const AiMessageCard({required this.message, super.key}); + + final AiChatMessage message; + + @override + Widget build(BuildContext context) { + final isUser = message.role == 'user'; + final title = isUser ? 'You -> AI' : 'AI Assistant'; + final time = message.createdAt.format; + final cardColor = isUser + ? context.messageBubbleColor(true) + : context.messageBubbleColor(false); + + return Align( + alignment: isUser ? Alignment.centerRight : Alignment.centerLeft, + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 420), + child: Container( + margin: const EdgeInsets.symmetric(horizontal: 10, vertical: 6), + padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 10), + decoration: BoxDecoration( + color: cardColor, + borderRadius: const BorderRadius.all(Radius.circular(12)), + boxShadow: const [ + BoxShadow( + color: Color.fromRGBO(0, 0, 0, 0.08), + blurRadius: 8, + offset: Offset(0, 2), + ), + ], + ), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + '$title · $time', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + const SizedBox(height: 8), + if (message.content.trim().isEmpty) + SelectableText( + message.status == 'error' + ? (message.errorText ?? 'Request failed') + : 'Thinking...', + style: TextStyle(color: context.theme.text, fontSize: 14), + ) + else if (isUser) + SelectableText( + message.content, + style: TextStyle(color: context.theme.text, fontSize: 14), + ) + else + MarkdownColumn( + data: message.content, + selectable: true, + ), + ], + ), + ), + ), + ); + } +} diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index 01936fc36b..5d09130241 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -10,9 +10,14 @@ import '../utils/uri_utils.dart'; import 'mixin_image.dart'; class MarkdownColumn extends ConsumerWidget { - const MarkdownColumn({required this.data, super.key}); + const MarkdownColumn({ + required this.data, + super.key, + this.selectable = false, + }); final String data; + final bool selectable; @override Widget build(BuildContext context, WidgetRef ref) { @@ -24,7 +29,7 @@ class MarkdownColumn extends ConsumerWidget { child: MarkdownWidget( data: data, useColumn: true, - selectable: false, + selectable: selectable, padding: EdgeInsets.zero, theme: _createMarkdownTheme(context, chatFontSizeDelta), imageBuilder: _buildMarkdownImage, diff --git a/lib/widgets/message/message_day_time.dart b/lib/widgets/message/message_day_time.dart index c7727e55dc..7991f27081 100644 --- a/lib/widgets/message/message_day_time.dart +++ b/lib/widgets/message/message_day_time.dart @@ -80,14 +80,17 @@ class _CurrentShowingMessages { final List dayTimeElements = []; void dumpKeyedSubtree(Element element, {bool reverse = false}) { - final item = element.descendantFirstOf( + final item = element.descendantFirstWhere( (e) => e.widget is MessageItemWidget, ); + if (item == null) { + return; + } final widget = item.widget as MessageItemWidget; final dayTimeElement = !isSameDay(widget.message.createdAt, widget.prev?.createdAt) - ? element.descendantFirstOf( + ? element.descendantFirstWhere( (e) => e.widget is _MessageDayTimeWidget, ) : null; @@ -153,11 +156,11 @@ class MessageDayTimeViewportWidget extends HookConsumerWidget { }) => MessageDayTimeViewportWidget._create( () { final result = _CurrentShowingMessages(); - (listKey.currentContext! as Element) - .descendantFirstOf((e) => e.widget is SliverList) - .visitChildElements((e) { - result.dumpKeyedSubtree(e, reverse: reverse); - }); + final listElement = (listKey.currentContext! as Element) + .descendantFirstWhere((e) => e.widget is SliverList); + listElement?.visitChildElements((e) { + result.dumpKeyedSubtree(e, reverse: reverse); + }); return result; }, key: key, @@ -350,7 +353,7 @@ class MessageDayTimeViewportWidget extends HookConsumerWidget { } extension _ElementExt on Element { - Element descendantFirstOf(bool Function(Element e) predicate) { + Element? descendantFirstWhere(bool Function(Element e) predicate) { Element? dump(Element element) { if (predicate(element)) { return element; @@ -365,6 +368,6 @@ extension _ElementExt on Element { return child; } - return dump(this)!; + return dump(this); } } From 03b8bd3d4a8567fd0275c0e15295bcef074ef20d Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 20 Apr 2026 11:55:13 +0800 Subject: [PATCH 03/52] feat: enhance AI message handling with role comparison, improved UI, and context menu updates --- lib/ui/home/bloc/message_bloc.dart | 12 + lib/ui/home/chat/chat_page.dart | 28 + lib/widgets/ai/ai_message_card.dart | 631 ++++++++++++++++-- lib/widgets/markdown.dart | 1 + .../message/message_datetime_and_status.dart | 110 +-- 5 files changed, 689 insertions(+), 93 deletions(-) diff --git a/lib/ui/home/bloc/message_bloc.dart b/lib/ui/home/bloc/message_bloc.dart index 961884e437..dacd95dd42 100644 --- a/lib/ui/home/bloc/message_bloc.dart +++ b/lib/ui/home/bloc/message_bloc.dart @@ -308,6 +308,9 @@ int _compareAiMessages(AiChatMessage a, AiChatMessage b) { final createdAtCompare = a.createdAt.compareTo(b.createdAt); if (createdAtCompare != 0) return createdAtCompare; + final roleCompare = _compareAiRoles(a.role, b.role); + if (roleCompare != 0) return roleCompare; + return a.id.compareTo(b.id); } @@ -318,6 +321,15 @@ int _compareNullableDateTime(DateTime? a, DateTime? b) { return a.compareTo(b); } +int _compareAiRoles(String a, String b) { + const order = { + 'user': 0, + 'assistant': 1, + }; + + return (order[a] ?? order.length).compareTo(order[b] ?? order.length); +} + class MessageBloc extends Bloc<_MessageEvent, MessageState> with SubscribeMixin { MessageBloc({ diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index 2b441c6377..f93f9fb0ca 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -636,11 +636,39 @@ class _List extends HookConsumerWidget { return null; } + AiChatMessage? prevAiOf( + ChatTimelineItem item, + List items, + ) { + final index = items.indexOf(item); + if (index <= 0) return null; + for (var i = index - 1; i >= 0; i--) { + final aiMessage = items[i].aiMessage; + if (aiMessage != null) return aiMessage; + } + return null; + } + + AiChatMessage? nextAiOf( + ChatTimelineItem item, + List items, + ) { + final index = items.indexOf(item); + if (index == -1 || index >= items.length - 1) return null; + for (var i = index + 1; i < items.length; i++) { + final aiMessage = items[i].aiMessage; + if (aiMessage != null) return aiMessage; + } + return null; + } + Widget buildTimelineChild(ChatTimelineItem item) { if (item.isAiMessage) { return AiMessageCard( key: ValueKey('ai-${item.id}'), message: item.aiMessage!, + prev: prevAiOf(item, timeline), + next: nextAiOf(item, timeline), ); } final message = item.message!; diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index f1a0b5b3a7..46d75c9785 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -1,75 +1,604 @@ -import 'package:flutter/material.dart'; +import 'package:flutter/material.dart' + hide SelectableRegion, SelectableRegionState; +import 'package:flutter/rendering.dart' show SelectedContent, SelectionStatus; +import 'package:flutter/services.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:flutter_svg/svg.dart'; +import 'package:super_context_menu/super_context_menu.dart'; +import '../../constants/resources.dart'; import '../../db/mixin_database.dart' hide Offset; +import '../../utils/datetime_format_utils.dart'; import '../../utils/extension/extension.dart'; +import '../../utils/platform.dart'; import '../markdown.dart'; +import '../menu.dart'; +import '../message/item/text/selectable.dart'; import '../message/message_bubble.dart'; +import '../message/message_datetime_and_status.dart'; +import '../message/message_layout.dart'; +import '../message/message_style.dart'; +import '../qr_code.dart'; class AiMessageCard extends StatelessWidget { - const AiMessageCard({required this.message, super.key}); + const AiMessageCard({ + required this.message, + super.key, + this.prev, + this.next, + }); final AiChatMessage message; + final AiChatMessage? prev; + final AiChatMessage? next; @override Widget build(BuildContext context) { final isUser = message.role == 'user'; - final title = isUser ? 'You -> AI' : 'AI Assistant'; - final time = message.createdAt.format; - final cardColor = isUser - ? context.messageBubbleColor(true) - : context.messageBubbleColor(false); - - return Align( - alignment: isUser ? Alignment.centerRight : Alignment.centerLeft, + final sameDayPrev = isSameDay(prev?.createdAt, message.createdAt); + final sameRolePrev = prev?.role == message.role; + final sameDayNext = isSameDay(next?.createdAt, message.createdAt); + final sameRoleNext = next?.role == message.role; + final mergedWithPrev = sameDayPrev && sameRolePrev; + final mergedWithNext = sameDayNext && sameRoleNext; + final showAssistantMeta = !isUser && !mergedWithPrev; + final bubbleColor = _bubbleColor( + context, + isUser: isUser, + status: message.status, + ); + final body = _AiBubble( + isCurrentUser: isUser, + showNip: !mergedWithNext && !showAssistantMeta, + color: bubbleColor, child: ConstrainedBox( constraints: const BoxConstraints(maxWidth: 420), - child: Container( - margin: const EdgeInsets.symmetric(horizontal: 10, vertical: 6), - padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 10), - decoration: BoxDecoration( - color: cardColor, - borderRadius: const BorderRadius.all(Radius.circular(12)), - boxShadow: const [ - BoxShadow( - color: Color.fromRGBO(0, 0, 0, 0.08), - blurRadius: 8, - offset: Offset(0, 2), + child: _AiMessageBody(message: message), + ), + ); + final content = Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: isUser + ? CrossAxisAlignment.end + : CrossAxisAlignment.start, + children: [ + if (showAssistantMeta) + Padding( + padding: const EdgeInsets.only(left: 2, bottom: 4), + child: Text( + 'AI Assistant', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: context.messageStyle.statusFontSize, + fontWeight: FontWeight.w500, ), - ], + ), ), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Text( - '$title · $time', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 12, - fontWeight: FontWeight.w500, - ), + body, + ], + ); + + if (isUser) { + return Padding( + padding: EdgeInsets.only( + left: 65, + right: 16, + top: mergedWithPrev ? 0 : 8, + bottom: 2, + ), + child: Align( + alignment: Alignment.centerRight, + child: _AiMessageMenu( + message: message, + child: content, + ), + ), + ); + } + + return Padding( + padding: EdgeInsets.only(top: mergedWithPrev ? 0 : 8, bottom: 2), + child: Row( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const SizedBox(width: 8), + SizedBox( + width: 32, + child: showAssistantMeta + ? _AiAvatar(thinking: message.status == 'pending') + : null, + ), + Flexible( + child: Padding( + padding: const EdgeInsets.only(top: 2), + child: _AiMessageMenu( + message: message, + child: content, ), - const SizedBox(height: 8), - if (message.content.trim().isEmpty) - SelectableText( - message.status == 'error' - ? (message.errorText ?? 'Request failed') - : 'Thinking...', - style: TextStyle(color: context.theme.text, fontSize: 14), - ) - else if (isUser) - SelectableText( - message.content, - style: TextStyle(color: context.theme.text, fontSize: 14), - ) - else - MarkdownColumn( - data: message.content, - selectable: true, - ), - ], + ), ), + const SizedBox(width: 65), + ], + ), + ); + } +} + +class _AiMessageBody extends StatelessWidget { + const _AiMessageBody({required this.message}); + + final AiChatMessage message; + + @override + Widget build(BuildContext context) { + final isUser = message.role == 'user'; + final content = message.content.trim(); + final text = content.isNotEmpty + ? content + : message.status == 'error' + ? (message.errorText ?? 'Request failed') + : 'Thinking...'; + final statusColor = _statusColor( + context, + isUser: isUser, + status: message.status, + ); + + Widget body; + final textStyle = TextStyle( + color: context.theme.text, + fontSize: context.messageStyle.primaryFontSize, + height: 1.45, + ); + + if (isUser || message.status == 'error') { + body = _AiSelectableText(text: text, style: textStyle); + } else { + body = DefaultTextStyle.merge( + style: textStyle, + child: MarkdownColumn(data: text, selectable: true), + ); + } + + return MessageLayout( + spacing: 6, + content: body, + dateAndStatus: _AiFooter( + isUser: isUser, + status: message.status, + color: statusColor, + dateTime: message.createdAt, + ), + ); + } +} + +class _AiSelectableText extends StatefulWidget { + const _AiSelectableText({ + required this.text, + required this.style, + }); + + final String text; + final TextStyle style; + + @override + State<_AiSelectableText> createState() => _AiSelectableTextState(); +} + +class _AiSelectableTextState extends State<_AiSelectableText> { + late final FocusNode _focusNode = FocusNode( + debugLabel: 'ai_message_selection_focus', + ); + + @override + void dispose() { + _focusNode.dispose(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + final child = Text(widget.text, style: widget.style); + if (!kPlatformIsDesktop) { + return child; + } + return SelectableRegion( + focusNode: _focusNode, + contextMenuBuilder: (context, state) => const SizedBox.shrink(), + selectionControls: desktopTextSelectionHandleControls, + child: child, + ); + } +} + +class _AiBubble extends StatelessWidget { + const _AiBubble({ + required this.child, + required this.isCurrentUser, + required this.color, + required this.showNip, + }); + + final Widget child; + final bool isCurrentUser; + final Color color; + final bool showNip; + + @override + Widget build(BuildContext context) { + final clipper = BubbleClipper( + currentUser: isCurrentUser, + showNip: showNip, + ); + + return CustomPaint( + painter: BubblePainter(color: color, clipper: clipper), + child: Padding( + padding: const EdgeInsets.all(8), + child: MessageBubbleNipPadding( + currentUser: isCurrentUser, + child: child, ), ), ); } } + +class _AiAvatar extends HookWidget { + const _AiAvatar({required this.thinking}); + + final bool thinking; + + @override + Widget build(BuildContext context) { + final background = context.dynamicColor( + const Color.fromRGBO(227, 237, 213, 1), + darkColor: const Color.fromRGBO(64, 78, 56, 1), + ); + final foreground = context.dynamicColor( + const Color.fromRGBO(54, 87, 35, 1), + darkColor: const Color.fromRGBO(214, 235, 204, 1), + ); + final disableAnimations = + MediaQuery.maybeOf(context)?.disableAnimations ?? false; + final controller = useAnimationController( + duration: const Duration(milliseconds: 1800), + ); + useEffect(() { + if (!thinking || disableAnimations) { + controller + ..stop() + ..value = 0; + return null; + } + controller.repeat(); + return null; + }, [thinking, disableAnimations, controller]); + + final progress = useAnimation( + CurvedAnimation(parent: controller, curve: Curves.easeInOut), + ); + final scale = !thinking || disableAnimations + ? 1.0 + : 1 + (0.03 * (0.5 - (progress - 0.5).abs()) * 2); + final glowAlpha = !thinking || disableAnimations ? 0.0 : 0.16 * progress; + + return Transform.scale( + scale: scale, + child: Container( + width: 32, + height: 32, + decoration: BoxDecoration( + color: background, + shape: BoxShape.circle, + boxShadow: glowAlpha == 0 + ? null + : [ + BoxShadow( + color: foreground.withValues(alpha: glowAlpha), + blurRadius: 10, + spreadRadius: 0.5, + ), + ], + ), + alignment: Alignment.center, + child: SvgPicture.asset( + Resources.assetsImagesBotFillSvg, + width: 18, + height: 18, + colorFilter: ColorFilter.mode(foreground, BlendMode.srcIn), + ), + ), + ); + } +} + +class _AiMessageMenu extends StatelessWidget { + const _AiMessageMenu({ + required this.message, + required this.child, + }); + + final AiChatMessage message; + final Widget child; + + @override + Widget build(BuildContext context) { + final content = _menuCopyText(message); + + return Builder( + builder: (childContext) => CustomContextMenuWidget( + hitTestBehavior: HitTestBehavior.translucent, + desktopMenuWidgetBuilder: CustomDesktopMenuWidgetBuilder(), + menuProvider: (_) { + final selectedContent = _findSelectedContent(childContext); + return MenusWithSeparator( + childrens: [ + [ + MenuAction( + image: MenuImage.icon(Icons.copy), + title: context.l10n.copy, + callback: () { + Clipboard.setData(ClipboardData(text: content)); + }, + ), + if (selectedContent != null) + MenuAction( + image: MenuImage.icon(Icons.copy), + title: context.l10n.copySelectedText, + callback: () { + Clipboard.setData( + ClipboardData(text: selectedContent.plainText), + ); + }, + ), + if (content.isNotEmpty) + MenuAction( + image: MenuImage.icon(Icons.qr_code), + title: context.l10n.generateQrcode, + callback: () => showQrCodeDialog(context, content), + ), + ], + [ + MenuAction( + image: MenuImage.icon(Icons.data_object), + title: 'Copy AI message', + callback: () { + Clipboard.setData(ClipboardData(text: message.toString())); + }, + ), + ], + ], + ); + }, + child: child, + ), + ); + } +} + +SelectedContent? _findSelectedContent(BuildContext context) { + SelectableRegionState? findSelectableRegionState(BuildContext context) { + if (context is! Element) { + return null; + } + if (context.widget is SelectableRegion) { + return (context as StatefulElement).state as SelectableRegionState; + } + + SelectableRegionState? found; + context.visitChildren((element) { + if (found != null) return; + final result = findSelectableRegionState(element); + if (result != null) { + found = result; + } + }); + return found; + } + + final selectableRegion = findSelectableRegionState(context); + final status = selectableRegion?.selectable?.value.status; + final content = selectableRegion?.selectable?.getSelectedContent(); + if (status == SelectionStatus.uncollapsed && content != null) { + return content; + } + return null; +} + +class _AiStatusBadge extends HookWidget { + const _AiStatusBadge({ + required this.isUser, + required this.status, + required this.color, + }); + + final bool isUser; + final String status; + final Color color; + + @override + Widget build(BuildContext context) { + if (status == 'pending') { + return _AiThinkingIndicator(color: color); + } + + return Icon( + _statusIcon(messageRoleIsUser: isUser, status: status), + size: 12, + color: color, + ); + } +} + +class _AiFooter extends StatelessWidget { + const _AiFooter({ + required this.isUser, + required this.status, + required this.color, + required this.dateTime, + }); + + final bool isUser; + final String status; + final Color color; + final DateTime dateTime; + + @override + Widget build(BuildContext context) => MessageMetaRow( + dateTime: dateTime, + trailingSpacing: 4, + trailing: _AiStatusBadge( + isUser: isUser, + status: status, + color: color, + ), + ); +} + +class _AiThinkingIndicator extends HookWidget { + const _AiThinkingIndicator({required this.color}); + + final Color color; + + @override + Widget build(BuildContext context) { + final disableAnimations = + MediaQuery.maybeOf(context)?.disableAnimations ?? false; + + if (disableAnimations) { + return Icon(Icons.more_horiz_rounded, size: 12, color: color); + } + + final controller = useAnimationController( + duration: const Duration(milliseconds: 1200), + ); + useEffect(() { + controller.repeat(); + return null; + }, [controller]); + + return RotationTransition( + turns: controller, + child: CustomPaint( + size: const Size.square(12), + painter: _AiThinkingIndicatorPainter(color: color), + ), + ); + } +} + +class _AiThinkingIndicatorPainter extends CustomPainter { + const _AiThinkingIndicatorPainter({required this.color}); + + final Color color; + + @override + void paint(Canvas canvas, Size size) { + final center = size.center(Offset.zero); + final radius = (size.width / 2) - 1; + + final track = Paint() + ..color = color.withValues(alpha: 0.22) + ..style = PaintingStyle.stroke + ..strokeWidth = 1.2 + ..strokeCap = StrokeCap.round; + + final arc = Paint() + ..color = color + ..style = PaintingStyle.stroke + ..strokeWidth = 1.4 + ..strokeCap = StrokeCap.round; + + canvas + ..drawCircle(center, radius, track) + ..drawArc( + Rect.fromCircle(center: center, radius: radius), + -1.2, + 1.95, + false, + arc, + ); + } + + @override + bool shouldRepaint(covariant _AiThinkingIndicatorPainter oldDelegate) => + oldDelegate.color != color; +} + +IconData _statusIcon({ + required bool messageRoleIsUser, + required String status, +}) { + if (status == 'error') return Icons.error_outline_rounded; + if (messageRoleIsUser) return Icons.auto_awesome_rounded; + return Icons.smart_toy_rounded; +} + +Color _bubbleColor( + BuildContext context, { + required bool isUser, + required String status, +}) { + if (status == 'error') { + return context.dynamicColor( + const Color.fromRGBO(255, 235, 235, 1), + darkColor: const Color.fromRGBO(88, 46, 46, 1), + ); + } + + if (isUser) { + return context.dynamicColor( + const Color.fromRGBO(255, 241, 214, 1), + darkColor: const Color.fromRGBO(96, 76, 34, 1), + ); + } + + return context.dynamicColor( + const Color.fromRGBO(228, 245, 239, 1), + darkColor: const Color.fromRGBO(43, 77, 65, 1), + ); +} + +Color _statusColor( + BuildContext context, { + required bool isUser, + required String status, +}) { + if (status == 'error') { + return context.dynamicColor( + const Color.fromRGBO(193, 63, 63, 1), + darkColor: const Color.fromRGBO(255, 173, 173, 1), + ); + } + + if (isUser) { + return context.dynamicColor( + const Color.fromRGBO(176, 107, 18, 1), + darkColor: const Color.fromRGBO(255, 214, 143, 1), + ); + } + + if (status == 'pending') { + return context.dynamicColor( + const Color.fromRGBO(46, 123, 110, 1), + darkColor: const Color.fromRGBO(159, 230, 217, 1), + ); + } + + return context.dynamicColor( + const Color.fromRGBO(33, 126, 96, 1), + darkColor: const Color.fromRGBO(150, 238, 210, 1), + ); +} + +String _menuCopyText(AiChatMessage message) { + final content = message.content.trim(); + if (content.isNotEmpty) return content; + if (message.status == 'error') { + return message.errorText ?? 'Request failed'; + } + return 'Thinking...'; +} diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index 5d09130241..cc4bce0829 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -30,6 +30,7 @@ class MarkdownColumn extends ConsumerWidget { data: data, useColumn: true, selectable: selectable, + contextMenuBuilder: (_, _, _, _) => const SizedBox.shrink(), padding: EdgeInsets.zero, theme: _createMarkdownTheme(context, chatFontSizeDelta), imageBuilder: _buildMarkdownImage, diff --git a/lib/widgets/message/message_datetime_and_status.dart b/lib/widgets/message/message_datetime_and_status.dart index ee72401e5b..228220f07e 100644 --- a/lib/widgets/message/message_datetime_and_status.dart +++ b/lib/widgets/message/message_datetime_and_status.dart @@ -51,53 +51,79 @@ class MessageDatetimeAndStatus extends HookConsumerWidget { converter: (state) => state.createdAt, ); + return MessageMetaRow( + color: color, + dateTime: createdAt, + leading: [ + if (pinned) + _ChatIcon( + color: color, + assetName: Resources.assetsImagesMessagePinSvg, + ), + if (isSecret) + _ChatIcon( + color: color, + assetName: Resources.assetsImagesMessageSecretSvg, + ), + if (isRepresentative) + _ChatIcon( + color: color, + assetName: Resources.assetsImagesMessageRepresentativeSvg, + ), + ], + trailing: + isCurrentUser && !isTranscriptPage && !isPinnedPage && !hideStatus + ? HookBuilder( + builder: (context) { + final status = useMessageConverter( + converter: (state) => state.status, + ); + return MessageStatusIcon(status: status, color: color); + }, + ) + : null, + ); + } +} + +class MessageMetaRow extends StatelessWidget { + const MessageMetaRow({ + required this.dateTime, + super.key, + this.color, + this.leading = const [], + this.trailing, + this.trailingSpacing = 8, + }); + + final DateTime dateTime; + final Color? color; + final List leading; + final Widget? trailing; + final double trailingSpacing; + + @override + Widget build(BuildContext context) { + final children = [ + for (final widget in leading) + Padding( + padding: const EdgeInsets.only(right: 4), + child: widget, + ), + _MessageDatetime(dateTime: dateTime, color: color), + if (trailing != null) + Padding( + padding: EdgeInsets.only(left: trailingSpacing), + child: trailing, + ), + ]; + return SelectionContainer.disabled( child: SizedBox( height: 12, child: Row( mainAxisSize: MainAxisSize.min, - children: [ - if (pinned) - Padding( - padding: const EdgeInsets.only(right: 4), - child: _ChatIcon( - color: color, - assetName: Resources.assetsImagesMessagePinSvg, - ), - ), - if (isSecret) - Padding( - padding: const EdgeInsets.only(right: 4), - child: _ChatIcon( - color: color, - assetName: Resources.assetsImagesMessageSecretSvg, - ), - ), - if (isRepresentative) - Padding( - padding: const EdgeInsets.only(right: 4), - child: _ChatIcon( - color: color, - assetName: Resources.assetsImagesMessageRepresentativeSvg, - ), - ), - _MessageDatetime(dateTime: createdAt, color: color), - if (isCurrentUser && - !isTranscriptPage && - !isPinnedPage && - !hideStatus) - HookBuilder( - builder: (context) { - final status = useMessageConverter( - converter: (state) => state.status, - ); - return Padding( - padding: const EdgeInsets.only(left: 8), - child: MessageStatusIcon(status: status, color: color), - ); - }, - ), - ], + children: children, ), ), ); From 20355c4abb7ea0bb6c83fb6ed5d9a12807d77a1b Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:48:33 +0800 Subject: [PATCH 04/52] fix: adjust padding for AI message card layout --- lib/widgets/ai/ai_message_card.dart | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 46d75c9785..015991d66d 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -64,7 +64,7 @@ class AiMessageCard extends StatelessWidget { children: [ if (showAssistantMeta) Padding( - padding: const EdgeInsets.only(left: 2, bottom: 4), + padding: const EdgeInsets.only(left: 10, bottom: 2), child: Text( 'AI Assistant', style: TextStyle( @@ -111,7 +111,7 @@ class AiMessageCard extends StatelessWidget { ), Flexible( child: Padding( - padding: const EdgeInsets.only(top: 2), + padding: const EdgeInsets.symmetric(vertical: 2), child: _AiMessageMenu( message: message, child: content, From 22350753fa6dfd516b51eafea287503e223753aa Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 20 Apr 2026 15:53:58 +0800 Subject: [PATCH 05/52] feat(ai): display model name in status badge if available --- lib/widgets/ai/ai_message_card.dart | 36 ++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 015991d66d..8720e77dc0 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -166,6 +166,7 @@ class _AiMessageBody extends StatelessWidget { content: body, dateAndStatus: _AiFooter( isUser: isUser, + model: message.model, status: message.status, color: statusColor, dateTime: message.createdAt, @@ -411,24 +412,44 @@ SelectedContent? _findSelectedContent(BuildContext context) { class _AiStatusBadge extends HookWidget { const _AiStatusBadge({ required this.isUser, + required this.model, required this.status, required this.color, }); final bool isUser; + final String? model; final String status; final Color color; @override Widget build(BuildContext context) { - if (status == 'pending') { - return _AiThinkingIndicator(color: color); + final trimmedModel = isUser ? null : model?.trim(); + final icon = status == 'pending' + ? _AiThinkingIndicator(color: color) + : Icon( + _statusIcon(messageRoleIsUser: isUser, status: status), + size: 12, + color: color, + ); + + if (trimmedModel == null || trimmedModel.isEmpty) { + return icon; } - return Icon( - _statusIcon(messageRoleIsUser: isUser, status: status), - size: 12, - color: color, + return Row( + mainAxisSize: MainAxisSize.min, + children: [ + icon, + const SizedBox(width: 4), + Text( + trimmedModel, + style: TextStyle( + fontSize: context.messageStyle.statusFontSize, + color: color, + ), + ), + ], ); } } @@ -436,12 +457,14 @@ class _AiStatusBadge extends HookWidget { class _AiFooter extends StatelessWidget { const _AiFooter({ required this.isUser, + required this.model, required this.status, required this.color, required this.dateTime, }); final bool isUser; + final String? model; final String status; final Color color; final DateTime dateTime; @@ -452,6 +475,7 @@ class _AiFooter extends StatelessWidget { trailingSpacing: 4, trailing: _AiStatusBadge( isUser: isUser, + model: model, status: status, color: color, ), From 83d825ca83707f56838485064b747cb0a9e13221 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 20 Apr 2026 18:10:23 +0800 Subject: [PATCH 06/52] feat: enhance AI provider configuration with multi-model support and improved UI --- lib/ai/model/ai_provider_config.dart | 82 +++- lib/constants/brightness_theme_data.dart | 24 + lib/ui/home/chat/input_container.dart | 295 +++++++++--- lib/ui/setting/ai_provider_edit_page.dart | 526 +++++++++++++++++++--- lib/ui/setting/ai_settings_page.dart | 254 +++++++---- lib/widgets/ai/ai_message_card.dart | 48 +- lib/widgets/brightness_observer.dart | 90 +++- 7 files changed, 1048 insertions(+), 271 deletions(-) diff --git a/lib/ai/model/ai_provider_config.dart b/lib/ai/model/ai_provider_config.dart index c7a7159b40..4602dc0702 100644 --- a/lib/ai/model/ai_provider_config.dart +++ b/lib/ai/model/ai_provider_config.dart @@ -7,29 +7,44 @@ class AiProviderConfig { required this.type, required this.baseUrl, required this.apiKey, - required this.model, + required String model, + List? models, + String? defaultModel, this.enabled = true, - }); - - factory AiProviderConfig.fromJson(Map json) => - AiProviderConfig( - id: json['id'] as String, - name: json['name'] as String, - type: AiProviderType.fromValue(json['type'] as String? ?? ''), - baseUrl: json['baseUrl'] as String? ?? '', - apiKey: json['apiKey'] as String? ?? '', - model: json['model'] as String? ?? '', - enabled: json['enabled'] as bool? ?? true, - ); + }) : models = _normalizeModels(models, model, defaultModel), + defaultModel = _resolveDefaultModel(models, model, defaultModel); + + factory AiProviderConfig.fromJson(Map json) => () { + final legacyModel = json['model'] as String? ?? ''; + final models = (json['models'] as List?) + ?.whereType() + .map((model) => model.trim()) + .where((model) => model.isNotEmpty) + .toList(); + return AiProviderConfig( + id: json['id'] as String, + name: json['name'] as String, + type: AiProviderType.fromValue(json['type'] as String? ?? ''), + baseUrl: json['baseUrl'] as String? ?? '', + apiKey: json['apiKey'] as String? ?? '', + model: legacyModel, + models: models, + defaultModel: json['defaultModel'] as String?, + enabled: json['enabled'] as bool? ?? true, + ); + }(); final String id; final String name; final AiProviderType type; final String baseUrl; final String apiKey; - final String model; + final List models; + final String defaultModel; final bool enabled; + String get model => defaultModel; + Map toJson() => { 'id': id, 'name': name, @@ -37,6 +52,8 @@ class AiProviderConfig { 'baseUrl': baseUrl, 'apiKey': apiKey, 'model': model, + 'models': models, + 'defaultModel': defaultModel, 'enabled': enabled, }; @@ -47,6 +64,8 @@ class AiProviderConfig { String? baseUrl, String? apiKey, String? model, + List? models, + String? defaultModel, bool? enabled, }) => AiProviderConfig( id: id ?? this.id, @@ -55,6 +74,41 @@ class AiProviderConfig { baseUrl: baseUrl ?? this.baseUrl, apiKey: apiKey ?? this.apiKey, model: model ?? this.model, + models: models ?? this.models, + defaultModel: defaultModel ?? this.defaultModel, enabled: enabled ?? this.enabled, ); + + static List _normalizeModels( + List? models, + String model, + String? defaultModel, + ) { + final values = + [ + ...?models, + model, + ...(switch (defaultModel) { + final String value => [value], + null => const [], + }), + ] + .whereType() + .map((item) => item.trim()) + .where((item) => item.isNotEmpty); + return values.toSet().toList(growable: false); + } + + static String _resolveDefaultModel( + List? models, + String model, + String? defaultModel, + ) { + final normalizedModels = _normalizeModels(models, model, defaultModel); + final candidate = defaultModel?.trim() ?? model.trim(); + if (candidate.isNotEmpty && normalizedModels.contains(candidate)) { + return candidate; + } + return normalizedModels.isNotEmpty ? normalizedModels.first : ''; + } } diff --git a/lib/constants/brightness_theme_data.dart b/lib/constants/brightness_theme_data.dart index 027df56af9..78501a5126 100644 --- a/lib/constants/brightness_theme_data.dart +++ b/lib/constants/brightness_theme_data.dart @@ -25,6 +25,18 @@ const lightBrightnessThemeData = BrightnessThemeData( waveformBackground: Color.fromRGBO(221, 221, 221, 1), waveformForeground: Color.fromRGBO(155, 155, 155, 1), settingCellBackgroundColor: Colors.white, + ai: AiColorScheme( + avatarBackground: Color.fromRGBO(227, 237, 213, 1), + accent: Color.fromRGBO(54, 87, 35, 1), + onAccent: Colors.white, + surface: Color.fromRGBO(241, 248, 243, 1), + surfaceBorder: Color.fromRGBO(200, 223, 208, 1), + surfaceVariant: Color.fromRGBO(223, 236, 214, 1), + userBubble: Color.fromRGBO(255, 241, 214, 1), + assistantBubble: Color.fromRGBO(228, 245, 239, 1), + errorBubble: Color.fromRGBO(255, 235, 235, 1), + error: Color.fromRGBO(193, 63, 63, 1), + ), ); const darkBrightnessThemeData = BrightnessThemeData( @@ -50,6 +62,18 @@ const darkBrightnessThemeData = BrightnessThemeData( waveformBackground: Color.fromRGBO(255, 255, 255, 0.4), waveformForeground: Color.fromRGBO(255, 255, 255, 1), settingCellBackgroundColor: Color.fromRGBO(255, 255, 255, 0.06), + ai: AiColorScheme( + avatarBackground: Color.fromRGBO(64, 78, 56, 1), + accent: Color.fromRGBO(214, 235, 204, 1), + onAccent: Color.fromRGBO(26, 42, 31, 1), + surface: Color.fromRGBO(35, 52, 44, 1), + surfaceBorder: Color.fromRGBO(72, 101, 88, 1), + surfaceVariant: Color.fromRGBO(58, 77, 66, 1), + userBubble: Color.fromRGBO(96, 76, 34, 1), + assistantBubble: Color.fromRGBO(43, 77, 65, 1), + errorBubble: Color.fromRGBO(88, 46, 46, 1), + error: Color.fromRGBO(255, 173, 173, 1), + ), ); final circleColors = [ diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index ae5b73f1a9..6ec2709469 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -19,7 +19,7 @@ import 'package:simple_animations/simple_animations.dart'; import 'package:super_context_menu/super_context_menu.dart'; import '../../../ai/ai_chat_controller.dart'; -import '../../../ai/model/ai_mode_state.dart'; +import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; import '../../../constants/icon_fonts.dart'; import '../../../constants/resources.dart'; @@ -40,7 +40,6 @@ import '../../../widgets/hover_overlay.dart'; import '../../../widgets/mention_panel.dart'; import '../../../widgets/menu.dart'; import '../../../widgets/message/item/quote_message.dart'; -import '../../../widgets/message/message_bubble.dart'; import '../../../widgets/sticker_page/bloc/cubit/sticker_albums_cubit.dart'; import '../../../widgets/sticker_page/sticker_page.dart'; import '../../../widgets/toast.dart'; @@ -107,6 +106,21 @@ class _InputContainer extends HookConsumerWidget { final aiModeState = ref.watch(aiInputModeProvider(conversationId ?? '')); final selectedAiProvider = context.database.settingProperties.selectedAiProvider; + final enabledAiProviders = context.database.settingProperties.aiProviders + .whereType() + .where((element) => element.enabled) + .toList(); + final aiProviderId = aiModeState.providerId; + var aiProvider = selectedAiProvider; + if (aiProviderId != null) { + for (final provider in enabledAiProviders) { + if (provider.id == aiProviderId) { + aiProvider = provider; + break; + } + } + } + final aiModeEnabled = aiModeState.enabled; final quoteMessageId = ref.watch(quoteMessageIdProvider); @@ -204,43 +218,78 @@ class _InputContainer extends HookConsumerWidget { mainAxisAlignment: MainAxisAlignment.end, children: [ const _QuoteMessage(), - if (conversationId != null) - _AiModeBar( - conversationId: conversationId, - aiModeState: aiModeState, - providerName: selectedAiProvider?.name, - ), ConstrainedBox( - constraints: const BoxConstraints(minHeight: 56), + constraints: BoxConstraints(minHeight: aiModeEnabled ? 108 : 56), child: Container( decoration: BoxDecoration(color: context.theme.primary), - padding: const EdgeInsets.symmetric( - horizontal: 16, - vertical: 8, + padding: EdgeInsets.fromLTRB( + 16, + aiModeEnabled ? 10 : 8, + 16, + 8, ), - child: Row( - crossAxisAlignment: CrossAxisAlignment.end, - children: [ - const _SendActionTypeButton(), - const SizedBox(width: 6), - _StickerButton( - textEditingController: textEditingController, - ), - const SizedBox(width: 16), - Expanded( - child: _SendTextField( - focusNode: focusNode, - textEditingController: textEditingController, - mentionProviderInstance: mentionProviderInstance, + child: AnimatedContainer( + duration: const Duration(milliseconds: 220), + curve: Curves.easeOutCubic, + padding: aiModeEnabled + ? const EdgeInsets.fromLTRB(12, 12, 12, 12) + : EdgeInsets.zero, + decoration: BoxDecoration( + color: aiModeEnabled + ? context.theme.ai.surface + : Colors.transparent, + borderRadius: const BorderRadius.all(Radius.circular(18)), + border: aiModeEnabled + ? Border.all( + color: context.theme.ai.surfaceBorder, + ) + : null, + ), + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + if (conversationId != null && aiModeEnabled) ...[ + _AiModeBar( + conversationId: conversationId, + provider: aiProvider, + ), + Container( + height: 1, + margin: const EdgeInsets.only(top: 10, bottom: 10), + color: context.theme.ai.surfaceBorder, + ), + ], + Row( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + if (!aiModeEnabled) ...[ + const _SendActionTypeButton(), + const SizedBox(width: 6), + _StickerButton( + textEditingController: textEditingController, + ), + const SizedBox(width: 16), + ], + Expanded( + child: _SendTextField( + focusNode: focusNode, + textEditingController: textEditingController, + mentionProviderInstance: mentionProviderInstance, + aiModeEnabled: aiModeEnabled, + providerName: aiProvider?.name, + ), + ), + SizedBox(width: aiModeEnabled ? 10 : 16), + _AnimatedSendOrVoiceButton( + conversationId: conversationId, + textEditingController: textEditingController, + textEditingValueStream: textEditingValueStream, + aiModeEnabled: aiModeEnabled, + ), + ], ), - ), - const SizedBox(width: 16), - _AnimatedSendOrVoiceButton( - conversationId: conversationId, - textEditingController: textEditingController, - textEditingValueStream: textEditingValueStream, - ), - ], + ], + ), ), ), ), @@ -256,11 +305,13 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { required this.conversationId, required this.textEditingValueStream, required this.textEditingController, + required this.aiModeEnabled, }); final String? conversationId; final Stream textEditingValueStream; final TextEditingController textEditingController; + final bool aiModeEnabled; @override Widget build(BuildContext context, WidgetRef ref) { @@ -274,6 +325,39 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { ).data ?? false; + if (aiModeEnabled) { + final backgroundColor = context.theme.ai.accent; + final foregroundColor = context.theme.ai.onAccent; + + return AnimatedOpacity( + duration: const Duration(milliseconds: 180), + opacity: hasInputText ? 1 : 0.45, + child: IgnorePointer( + ignoring: !hasInputText, + child: DecoratedBox( + decoration: BoxDecoration( + color: backgroundColor, + shape: BoxShape.circle, + ), + child: ActionButton( + size: 20, + color: foregroundColor, + child: Icon( + Icons.arrow_upward_rounded, + size: 20, + color: foregroundColor, + ), + onTap: () => _sendMessage( + context, + textEditingController, + conversationId: conversationId, + ), + ), + ), + ), + ); + } + // start -> show voice button // end -> show send button final animationController = useAnimationController( @@ -476,12 +560,16 @@ class _SendTextField extends HookConsumerWidget { required this.focusNode, required this.textEditingController, required this.mentionProviderInstance, + required this.aiModeEnabled, + required this.providerName, }); final FocusNode focusNode; final TextEditingController textEditingController; final AutoDisposeStateNotifierProvider mentionProviderInstance; + final bool aiModeEnabled; + final String? providerName; @override Widget build(BuildContext context, WidgetRef ref) { @@ -531,14 +619,29 @@ class _SendTextField extends HookConsumerWidget { ).data ?? false; - return Container( + final placeholder = aiModeEnabled + ? 'Ask ${providerName?.trim().isNotEmpty == true ? providerName : 'AI'} anything' + : isEncryptConversation + ? context.l10n.chatHintE2e + : 'Type message or /ai'; + + return AnimatedContainer( + duration: const Duration(milliseconds: 220), + curve: Curves.easeOutCubic, constraints: const BoxConstraints(minHeight: 40), decoration: BoxDecoration( - borderRadius: const BorderRadius.all(Radius.circular(4)), + borderRadius: BorderRadius.all(Radius.circular(aiModeEnabled ? 14 : 4)), color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), + aiModeEnabled + ? const Color.fromRGBO(255, 255, 255, 0.78) + : const Color.fromRGBO(245, 247, 250, 1), darkColor: const Color.fromRGBO(255, 255, 255, 0.08), ), + border: aiModeEnabled + ? Border.all( + color: context.theme.ai.surfaceBorder, + ) + : null, ), alignment: Alignment.center, child: FocusableActionDetector( @@ -569,8 +672,17 @@ class _SendTextField extends HookConsumerWidget { onInvoke: (_) => _sendPostMessage(context, textEditingController), ), EscapeIntent: CallbackAction( - onInvoke: (_) => - ref.read(quoteMessageProvider.notifier).state = null, + onInvoke: (_) { + if (aiModeEnabled) { + final conversationId = ref.read(currentConversationIdProvider); + if (conversationId != null) { + ref.read(aiInputModeProvider(conversationId).notifier).exit(); + return null; + } + } + ref.read(quoteMessageProvider.notifier).state = null; + return null; + }, ), }, child: Stack( @@ -589,7 +701,12 @@ class _SendTextField extends HookConsumerWidget { isDense: true, enabledBorder: InputBorder.none, focusedBorder: InputBorder.none, - contentPadding: EdgeInsets.only(left: 8, top: 8, bottom: 8), + contentPadding: EdgeInsets.only( + left: 10, + right: 10, + top: 10, + bottom: 10, + ), ), selectionHeightStyle: ui.BoxHeightStyle.includeLineSpacingMiddle, contextMenuBuilder: (context, state) => @@ -602,9 +719,7 @@ class _SendTextField extends HookConsumerWidget { alignment: Alignment.centerLeft, child: IgnorePointer( child: Text( - isEncryptConversation - ? context.l10n.chatHintE2e - : 'Type message or /ai', + placeholder, style: TextStyle( color: context.theme.secondaryText, fontSize: 14, @@ -625,45 +740,79 @@ class _SendTextField extends HookConsumerWidget { class _AiModeBar extends HookConsumerWidget { const _AiModeBar({ required this.conversationId, - required this.aiModeState, - required this.providerName, + required this.provider, }); final String conversationId; - final AiModeState aiModeState; - final String? providerName; + final AiProviderConfig? provider; @override Widget build(BuildContext context, WidgetRef ref) { - if (!aiModeState.enabled) return const SizedBox(); - return Container( - width: double.infinity, - color: context.theme.primary, - padding: const EdgeInsets.fromLTRB(16, 10, 16, 0), - child: Container( - padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 10), - decoration: BoxDecoration( - color: context.messageBubbleColor(false), - borderRadius: const BorderRadius.all(Radius.circular(10)), + final providerName = provider?.name ?? 'No Provider'; + final model = provider?.model.trim(); + final aiColors = context.theme.ai; + final accentColor = aiColors.accent; + + return Row( + children: [ + Container( + padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 6), + decoration: BoxDecoration( + color: aiColors.surfaceVariant, + borderRadius: const BorderRadius.all(Radius.circular(999)), + ), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Icon(Icons.auto_awesome_rounded, size: 14, color: accentColor), + const SizedBox(width: 6), + Text( + 'AI Mode', + style: TextStyle( + color: accentColor, + fontSize: 12, + fontWeight: FontWeight.w600, + ), + ), + ], + ), ), - child: Row( - children: [ - Expanded( - child: Text( - 'AI Mode · ${providerName ?? 'No Provider'}', - style: TextStyle(color: context.theme.text, fontSize: 14), + const SizedBox(width: 10), + Expanded( + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisSize: MainAxisSize.min, + children: [ + Text( + providerName, + style: TextStyle( + color: context.theme.text, + fontSize: 14, + fontWeight: FontWeight.w600, + ), ), - ), - ActionButton( - name: Resources.assetsImagesIcCloseSvg, - color: context.theme.icon, - size: 18, - onTap: () => - ref.read(aiInputModeProvider(conversationId).notifier).exit(), - ), - ], + const SizedBox(height: 2), + Text( + model?.isNotEmpty == true ? model! : 'Next message goes to AI', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + ), + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ], + ), ), - ), + const SizedBox(width: 8), + ActionButton( + name: Resources.assetsImagesIcCloseSvg, + color: context.theme.icon, + size: 18, + onTap: () => + ref.read(aiInputModeProvider(conversationId).notifier).exit(), + ), + ], ); } } diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart index 4618d74089..e7d6b6e772 100644 --- a/lib/ui/setting/ai_provider_edit_page.dart +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -8,6 +8,7 @@ import '../../ai/model/ai_provider_type.dart'; import '../../utils/extension/extension.dart'; import '../../widgets/app_bar.dart'; import '../../widgets/cell.dart'; +import '../../widgets/dialog.dart'; import '../../widgets/toast.dart'; import '../provider/database_provider.dart'; @@ -19,6 +20,19 @@ class AiProviderEditPage extends HookConsumerWidget { @override Widget build(BuildContext context, WidgetRef ref) { final database = ref.watch(databaseProvider).requireValue; + final theme = context.theme; + final inputBackgroundColor = context.dynamicColor( + Colors.white, + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + final inputBorderColor = context.dynamicColor( + theme.divider, + darkColor: const Color.fromRGBO(255, 255, 255, 0.10), + ); + final inputIconColor = context.dynamicColor( + theme.secondaryText, + darkColor: const Color.fromRGBO(255, 255, 255, 0.52), + ); final nameController = useTextEditingController(text: initial?.name ?? ''); final baseUrlController = useTextEditingController( text: initial?.baseUrl ?? '', @@ -26,15 +40,68 @@ class AiProviderEditPage extends HookConsumerWidget { final apiKeyController = useTextEditingController( text: initial?.apiKey ?? '', ); - final modelController = useTextEditingController( - text: initial?.model ?? '', - ); final providerType = useState( initial?.type ?? AiProviderType.openaiCompatible, ); + final models = useState( + _normalizeModels(initial?.models ?? [initial?.model ?? '']), + ); + final defaultModel = useState( + _resolveDefaultModel( + models.value, + initial?.defaultModel ?? initial?.model, + ), + ); + final obscureApiKey = useState(true); + + useEffect(() { + final resolved = _resolveDefaultModel(models.value, defaultModel.value); + if (resolved != defaultModel.value) { + defaultModel.value = resolved; + } + return null; + }, [models.value, defaultModel.value]); + + Future showModelDialog({String? initialValue, int? index}) async { + final result = await showMixinDialog( + context: context, + child: EditDialog( + title: Text(index == null ? 'Add Model' : 'Edit Model'), + editText: initialValue ?? '', + hintText: 'gpt-4.1-mini', + positiveAction: index == null ? 'Add' : 'Save', + ), + ); + final model = result?.trim(); + if (model == null || model.isEmpty) return; + + final nextModels = [...models.value]; + if (index != null && index >= 0 && index < nextModels.length) { + nextModels[index] = model; + } else { + nextModels.add(model); + } + models.value = _normalizeModels(nextModels); + defaultModel.value = _resolveDefaultModel( + models.value, + index != null && initialValue == defaultModel.value + ? model + : defaultModel.value, + ); + } + + void removeModelAt(int index) { + final nextModels = [...models.value]..removeAt(index); + final removed = models.value[index]; + models.value = nextModels; + defaultModel.value = _resolveDefaultModel( + nextModels, + removed == defaultModel.value ? null : defaultModel.value, + ); + } return Scaffold( - backgroundColor: context.theme.background, + backgroundColor: theme.background, appBar: MixinAppBar( title: Text(initial == null ? 'Add AI Provider' : 'Edit AI Provider'), actions: [ @@ -43,11 +110,16 @@ class AiProviderEditPage extends HookConsumerWidget { final name = nameController.text.trim(); final baseUrl = baseUrlController.text.trim(); final apiKey = apiKeyController.text.trim(); - final model = modelController.text.trim(); + final normalizedModels = _normalizeModels(models.value); + final resolvedDefaultModel = _resolveDefaultModel( + normalizedModels, + defaultModel.value, + ); if (name.isEmpty || baseUrl.isEmpty || apiKey.isEmpty || - model.isEmpty) { + normalizedModels.isEmpty || + resolvedDefaultModel.isEmpty) { showToastFailed(ToastError('Please complete all fields')); return; } @@ -60,21 +132,25 @@ class AiProviderEditPage extends HookConsumerWidget { type: providerType.value, baseUrl: baseUrl, apiKey: apiKey, - model: model, + model: resolvedDefaultModel, + models: normalizedModels, + defaultModel: resolvedDefaultModel, )) .copyWith( name: name, type: providerType.value, baseUrl: baseUrl, apiKey: apiKey, - model: model, + models: normalizedModels, + defaultModel: resolvedDefaultModel, + model: resolvedDefaultModel, ); database.settingProperties.saveAiProvider(provider); Navigator.of(context).pop(); }, child: Text( 'Save', - style: TextStyle(color: context.theme.accent, fontSize: 16), + style: TextStyle(color: theme.accent, fontSize: 16), ), ), ], @@ -82,28 +158,211 @@ class AiProviderEditPage extends HookConsumerWidget { body: Align( alignment: Alignment.topCenter, child: SingleChildScrollView( - child: CellGroup( - cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), child: Column( + crossAxisAlignment: CrossAxisAlignment.start, children: [ - _TextFieldCell( - title: 'Display Name', - controller: nameController, + const _SectionLabel( + title: 'Provider', + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: Column( + children: [ + _FormFieldCell( + label: 'Display Name', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: TextField( + controller: nameController, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: 'OpenAI / Anthropic / Self-hosted', + hintStyle: TextStyle(color: theme.secondaryText), + ), + ), + ), + _CellDivider(color: theme.divider), + _FormFieldCell( + label: 'Provider Type', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: DropdownButtonHideUnderline( + child: DropdownButton( + value: providerType.value, + isExpanded: true, + dropdownColor: theme.popUp, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + iconEnabledColor: inputIconColor, + onChanged: (value) { + if (value != null) providerType.value = value; + }, + items: AiProviderType.values + .map( + (type) => DropdownMenuItem( + value: type, + child: Text( + type == AiProviderType.anthropic + ? 'Anthropic' + : 'OpenAI Compatible', + ), + ), + ) + .toList(), + ), + ), + ), + ], + ), + ), + const _SectionLabel( + title: 'Endpoint', + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: _FormFieldCell( + label: 'Base URL', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: TextField( + controller: baseUrlController, + keyboardType: TextInputType.url, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: 'https://api.example.com/v1', + hintStyle: TextStyle(color: theme.secondaryText), + ), + ), + ), + ), + const _SectionLabel( + title: 'Authorization', ), - _ProviderTypeCell( - value: providerType.value, - onChanged: (value) => providerType.value = value, + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: Column( + children: [ + _FormFieldCell( + label: 'API Key', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + trailing: IconButton( + onPressed: () => + obscureApiKey.value = !obscureApiKey.value, + icon: Icon( + obscureApiKey.value + ? Icons.visibility_outlined + : Icons.visibility_off_outlined, + size: 20, + color: inputIconColor, + ), + ), + child: TextField( + controller: apiKeyController, + obscureText: obscureApiKey.value, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: 'sk-...', + hintStyle: TextStyle(color: theme.secondaryText), + ), + ), + ), + ], + ), ), - _TextFieldCell( - title: 'Base URL', - controller: baseUrlController, + const _SectionLabel( + title: 'Models', ), - _TextFieldCell( - title: 'API Key', - controller: apiKeyController, - obscureText: true, + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: Column( + children: [ + CellItem( + title: const Text('Default Model'), + description: Text( + defaultModel.value.isEmpty + ? 'No default model yet' + : defaultModel.value, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + trailing: null, + ), + _CellDivider(color: context.theme.divider), + CellItem( + title: const Text('Add Model'), + leading: Icon(Icons.add, color: context.theme.icon), + trailing: null, + onTap: showModelDialog, + ), + if (models.value.isEmpty) ...[ + _CellDivider(color: context.theme.divider), + Padding( + padding: const EdgeInsets.symmetric( + horizontal: 16, + vertical: 20, + ), + child: Row( + children: [ + Icon( + Icons.view_list_outlined, + size: 18, + color: theme.secondaryText, + ), + const SizedBox(width: 10), + Expanded( + child: Text( + 'No models yet. Add at least one model before saving.', + style: TextStyle( + color: theme.secondaryText, + fontSize: 14, + ), + ), + ), + ], + ), + ), + ] else ...[ + for (var i = 0; i < models.value.length; i++) ...[ + _CellDivider(color: context.theme.divider), + _ModelItem( + model: models.value[i], + selected: models.value[i] == defaultModel.value, + onTap: () => defaultModel.value = models.value[i], + onEdit: () => showModelDialog( + initialValue: models.value[i], + index: i, + ), + onDelete: () => removeModelAt(i), + ), + ], + ], + ], + ), ), - _TextFieldCell(title: 'Model', controller: modelController), ], ), ), @@ -111,63 +370,196 @@ class AiProviderEditPage extends HookConsumerWidget { ), ); } + + static List _normalizeModels(List models) => models + .map((item) => item.trim()) + .where((item) => item.isNotEmpty) + .toSet() + .toList(growable: false); + + static String _resolveDefaultModel(List models, String? candidate) { + if (models.isEmpty) return ''; + final normalized = candidate?.trim(); + if (normalized != null && + normalized.isNotEmpty && + models.contains(normalized)) { + return normalized; + } + return models.first; + } } -class _ProviderTypeCell extends StatelessWidget { - const _ProviderTypeCell({required this.value, required this.onChanged}); +class _SectionLabel extends StatelessWidget { + const _SectionLabel({required this.title}); - final AiProviderType value; - final ValueChanged onChanged; + final String title; @override - Widget build(BuildContext context) => CellItem( - title: const Text('Provider Type'), - trailing: DropdownButtonHideUnderline( - child: DropdownButton( - value: value, - onChanged: (value) { - if (value != null) onChanged(value); - }, - items: AiProviderType.values - .map( - (type) => DropdownMenuItem( - value: type, - child: Text( - type == AiProviderType.anthropic - ? 'Anthropic' - : 'OpenAI Compatible', - ), - ), - ) - .toList(), + Widget build(BuildContext context) => Padding( + padding: const EdgeInsets.only(left: 20, right: 20, bottom: 6, top: 6), + child: Text( + title, + style: TextStyle( + color: context.theme.text, + fontSize: 14, + fontWeight: FontWeight.w600, ), ), ); } -class _TextFieldCell extends StatelessWidget { - const _TextFieldCell({ - required this.title, - required this.controller, - this.obscureText = false, +class _FormFieldCell extends StatelessWidget { + const _FormFieldCell({ + required this.label, + required this.child, + required this.backgroundColor, + required this.borderColor, + this.trailing, }); - final String title; - final TextEditingController controller; - final bool obscureText; + final String label; + final Widget child; + final Color backgroundColor; + final Color borderColor; + final Widget? trailing; + + @override + Widget build(BuildContext context) => Padding( + padding: const EdgeInsets.symmetric(horizontal: 14, vertical: 10), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + label, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + const SizedBox(height: 6), + _InputSurface( + backgroundColor: backgroundColor, + borderColor: borderColor, + trailing: trailing, + child: child, + ), + ], + ), + ); +} + +class _InputSurface extends StatelessWidget { + const _InputSurface({ + required this.child, + required this.backgroundColor, + required this.borderColor, + this.trailing, + }); + + final Widget child; + final Color backgroundColor; + final Color borderColor; + final Widget? trailing; + + @override + Widget build(BuildContext context) => Container( + padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 9), + decoration: BoxDecoration( + color: backgroundColor, + borderRadius: const BorderRadius.all(Radius.circular(10)), + border: Border.all(color: borderColor), + ), + child: Row( + children: [ + Expanded(child: child), + if (trailing != null) ...[ + const SizedBox(width: 8), + trailing!, + ], + ], + ), + ); +} + +class _ModelItem extends StatelessWidget { + const _ModelItem({ + required this.model, + required this.selected, + required this.onTap, + required this.onEdit, + required this.onDelete, + }); + + final String model; + final bool selected; + final VoidCallback onTap; + final VoidCallback onEdit; + final VoidCallback onDelete; @override Widget build(BuildContext context) => CellItem( - title: TextField( - controller: controller, - obscureText: obscureText, - style: TextStyle(color: context.theme.text, fontSize: 16), - decoration: InputDecoration( - border: InputBorder.none, - hintText: title, - hintStyle: TextStyle(color: context.theme.secondaryText), - ), + selected: selected, + onTap: onTap, + title: Row( + children: [ + Expanded( + child: Text( + model, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + if (selected) + Container( + margin: const EdgeInsets.only(left: 8), + padding: const EdgeInsets.symmetric(horizontal: 8, vertical: 3), + decoration: BoxDecoration( + color: context.theme.accent.withValues(alpha: 0.12), + borderRadius: const BorderRadius.all(Radius.circular(999)), + ), + child: Text( + 'Default', + style: TextStyle( + color: context.theme.accent, + fontSize: 11, + fontWeight: FontWeight.w600, + ), + ), + ), + ], + ), + description: Text( + selected ? 'Used for new AI requests' : 'Tap to set as default', + maxLines: 1, + overflow: TextOverflow.ellipsis, ), - trailing: const SizedBox.shrink(), + trailing: Row( + mainAxisSize: MainAxisSize.min, + children: [ + IconButton( + onPressed: onEdit, + icon: Icon(Icons.edit_outlined, color: context.theme.icon), + ), + IconButton( + onPressed: onDelete, + icon: Icon(Icons.delete_outline, color: context.theme.red), + ), + ], + ), + ); +} + +class _CellDivider extends StatelessWidget { + const _CellDivider({required this.color}); + + final Color color; + + @override + Widget build(BuildContext context) => Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: color, ); } diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index f7c3b088ee..0d4bf59e76 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -1,7 +1,9 @@ +import 'package:flutter/cupertino.dart'; import 'package:flutter/material.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; +import '../../ai/model/ai_provider_config.dart'; import '../../utils/extension/extension.dart'; import '../../widgets/app_bar.dart'; import '../../widgets/cell.dart'; @@ -18,96 +20,196 @@ class AiSettingsPage extends HookConsumerWidget { useListenable(database.settingProperties); final providers = database.settingProperties.aiProviders; final selectedId = database.settingProperties.selectedAiProviderId; + final selectedProvider = database.settingProperties.selectedAiProvider; return Scaffold( backgroundColor: context.theme.background, - appBar: MixinAppBar( - title: const Text('AI Settings'), - actions: [ - TextButton( - onPressed: () => Navigator.of(context).push( - MaterialPageRoute( - builder: (_) => const AiProviderEditPage(), - ), - ), - child: Text( - 'Add', - style: TextStyle(color: context.theme.accent, fontSize: 16), - ), - ), - ], - ), + appBar: const MixinAppBar(title: Text('AI Settings')), body: Align( alignment: Alignment.topCenter, child: SingleChildScrollView( - child: Column( - children: [ - if (providers.isEmpty) + child: Padding( + padding: const EdgeInsets.only(top: 40), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: CellItem( + title: const Text('Add Provider'), + leading: Icon(Icons.add, color: context.theme.icon), + trailing: null, + onTap: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => const AiProviderEditPage(), + ), + ), + ), + ), Padding( - padding: const EdgeInsets.all(24), + padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), child: Text( - 'No AI provider configured yet.', - style: TextStyle(color: context.theme.secondaryText), + providers.isEmpty + ? 'Add an AI provider to enable AI mode in chat.' + : 'The selected provider is used by default in AI mode.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), ), - ) - else - CellGroup( - cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Column( - children: providers.map((provider) { - final selected = provider.id == selectedId; - return CellItem( - title: Text(provider.name), - description: Text(provider.model), - selected: selected, - onTap: () => - database.settingProperties.selectedAiProviderId = - provider.id, - trailing: Row( - mainAxisSize: MainAxisSize.min, - children: [ - Switch( - value: provider.enabled, - onChanged: (value) { - database.settingProperties.saveAiProvider( - provider.copyWith(enabled: value), - ); - }, - ), - IconButton( - onPressed: () => Navigator.of(context).push( - MaterialPageRoute( - builder: (_) => - AiProviderEditPage(initial: provider), - ), - ), - icon: Icon( - Icons.edit_outlined, - color: context.theme.icon, - ), - ), - IconButton( - onPressed: () { - database.settingProperties.removeAiProvider( - provider.id, - ); - showToastSuccessful(); - }, - icon: Icon( - Icons.delete_outline, - color: context.theme.red, - ), + ), + if (providers.isNotEmpty) ...[ + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: CellItem( + title: const Text('Default Provider'), + description: Text( + _providerSummary(selectedProvider), + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + trailing: null, + ), + ), + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + 'Each API endpoint can contain multiple models. One default model is used for new AI requests.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + for (var i = 0; i < providers.length; i++) ...[ + _ProviderCell( + provider: providers[i], + selected: selectedId == providers[i].id, + ), + if (i != providers.length - 1) + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, ), - ], - ), - ); - }).toList(), + ], + ], + ), ), - ), - ], + ], + ], + ), ), ), ), ); } + + static String _providerSummary(AiProviderConfig? provider) { + if (provider == null) return 'No enabled provider'; + final modelCount = provider.models.length; + if (modelCount <= 1) { + return provider.model; + } + return '${provider.model} · $modelCount models'; + } +} + +class _ProviderCell extends HookConsumerWidget { + const _ProviderCell({required this.provider, required this.selected}); + + final AiProviderConfig provider; + final bool selected; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + final subtitle = [ + provider.baseUrl, + if (provider.models.isNotEmpty) + provider.models.length == 1 + ? provider.model + : '${provider.model} · ${provider.models.length} models', + ].join('\n'); + + return CellItem( + selected: selected, + onTap: () => + database.settingProperties.selectedAiProviderId = provider.id, + title: Row( + children: [ + Expanded( + child: Text( + provider.name, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + if (selected) + Padding( + padding: const EdgeInsets.only(left: 8), + child: Icon( + Icons.check_circle_rounded, + size: 18, + color: context.theme.accent, + ), + ), + ], + ), + description: Expanded( + child: Text( + subtitle, + textAlign: TextAlign.end, + maxLines: 2, + overflow: TextOverflow.ellipsis, + ), + ), + trailing: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: provider.enabled, + onChanged: (value) { + database.settingProperties.saveAiProvider( + provider.copyWith(enabled: value), + ); + }, + ), + ), + IconButton( + onPressed: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => AiProviderEditPage(initial: provider), + ), + ), + icon: Icon(Icons.edit_outlined, color: context.theme.icon), + ), + IconButton( + onPressed: () { + database.settingProperties.removeAiProvider(provider.id); + showToastSuccessful(); + }, + icon: Icon(Icons.delete_outline, color: context.theme.red), + ), + ], + ), + ); + } } diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 8720e77dc0..c1a047d351 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -254,14 +254,9 @@ class _AiAvatar extends HookWidget { @override Widget build(BuildContext context) { - final background = context.dynamicColor( - const Color.fromRGBO(227, 237, 213, 1), - darkColor: const Color.fromRGBO(64, 78, 56, 1), - ); - final foreground = context.dynamicColor( - const Color.fromRGBO(54, 87, 35, 1), - darkColor: const Color.fromRGBO(214, 235, 204, 1), - ); + final aiColors = context.theme.ai; + final background = aiColors.avatarBackground; + final foreground = aiColors.accent; final disableAnimations = MediaQuery.maybeOf(context)?.disableAnimations ?? false; final controller = useAnimationController( @@ -567,23 +562,14 @@ Color _bubbleColor( required String status, }) { if (status == 'error') { - return context.dynamicColor( - const Color.fromRGBO(255, 235, 235, 1), - darkColor: const Color.fromRGBO(88, 46, 46, 1), - ); + return context.theme.ai.errorBubble; } if (isUser) { - return context.dynamicColor( - const Color.fromRGBO(255, 241, 214, 1), - darkColor: const Color.fromRGBO(96, 76, 34, 1), - ); + return context.theme.ai.userBubble; } - return context.dynamicColor( - const Color.fromRGBO(228, 245, 239, 1), - darkColor: const Color.fromRGBO(43, 77, 65, 1), - ); + return context.theme.ai.assistantBubble; } Color _statusColor( @@ -592,30 +578,14 @@ Color _statusColor( required String status, }) { if (status == 'error') { - return context.dynamicColor( - const Color.fromRGBO(193, 63, 63, 1), - darkColor: const Color.fromRGBO(255, 173, 173, 1), - ); + return context.theme.ai.error; } if (isUser) { - return context.dynamicColor( - const Color.fromRGBO(176, 107, 18, 1), - darkColor: const Color.fromRGBO(255, 214, 143, 1), - ); + return context.theme.green; } - if (status == 'pending') { - return context.dynamicColor( - const Color.fromRGBO(46, 123, 110, 1), - darkColor: const Color.fromRGBO(159, 230, 217, 1), - ); - } - - return context.dynamicColor( - const Color.fromRGBO(33, 126, 96, 1), - darkColor: const Color.fromRGBO(150, 238, 210, 1), - ); + return context.theme.ai.accent; } String _menuCopyText(AiChatMessage message) { diff --git a/lib/widgets/brightness_observer.dart b/lib/widgets/brightness_observer.dart index dba48f033b..e4f352c6e8 100644 --- a/lib/widgets/brightness_observer.dart +++ b/lib/widgets/brightness_observer.dart @@ -122,6 +122,87 @@ class BrightnessData extends InheritedWidget { } } +@immutable +class AiColorScheme { + const AiColorScheme({ + required this.avatarBackground, + required this.accent, + required this.onAccent, + required this.surface, + required this.surfaceBorder, + required this.surfaceVariant, + required this.userBubble, + required this.assistantBubble, + required this.errorBubble, + required this.error, + }); + + final Color avatarBackground; + final Color accent; + final Color onAccent; + final Color surface; + final Color surfaceBorder; + final Color surfaceVariant; + final Color userBubble; + final Color assistantBubble; + final Color errorBubble; + final Color error; + + static AiColorScheme lerp( + AiColorScheme begin, + AiColorScheme end, + double t, + ) => AiColorScheme( + avatarBackground: Color.lerp( + begin.avatarBackground, + end.avatarBackground, + t, + )!, + accent: Color.lerp(begin.accent, end.accent, t)!, + onAccent: Color.lerp(begin.onAccent, end.onAccent, t)!, + surface: Color.lerp(begin.surface, end.surface, t)!, + surfaceBorder: Color.lerp(begin.surfaceBorder, end.surfaceBorder, t)!, + surfaceVariant: Color.lerp( + begin.surfaceVariant, + end.surfaceVariant, + t, + )!, + userBubble: Color.lerp(begin.userBubble, end.userBubble, t)!, + assistantBubble: Color.lerp(begin.assistantBubble, end.assistantBubble, t)!, + errorBubble: Color.lerp(begin.errorBubble, end.errorBubble, t)!, + error: Color.lerp(begin.error, end.error, t)!, + ); + + @override + bool operator ==(Object other) => + identical(this, other) || + other is AiColorScheme && + runtimeType == other.runtimeType && + avatarBackground == other.avatarBackground && + accent == other.accent && + onAccent == other.onAccent && + surface == other.surface && + surfaceBorder == other.surfaceBorder && + surfaceVariant == other.surfaceVariant && + userBubble == other.userBubble && + assistantBubble == other.assistantBubble && + errorBubble == other.errorBubble && + error == other.error; + + @override + int get hashCode => + avatarBackground.hashCode ^ + accent.hashCode ^ + onAccent.hashCode ^ + surface.hashCode ^ + surfaceBorder.hashCode ^ + surfaceVariant.hashCode ^ + userBubble.hashCode ^ + assistantBubble.hashCode ^ + errorBubble.hashCode ^ + error.hashCode; +} + @immutable class BrightnessThemeData { const BrightnessThemeData({ @@ -147,6 +228,7 @@ class BrightnessThemeData { required this.waveformBackground, required this.waveformForeground, required this.settingCellBackgroundColor, + required this.ai, }); final Color primary; @@ -171,6 +253,7 @@ class BrightnessThemeData { final Color waveformBackground; final Color waveformForeground; final Color settingCellBackgroundColor; + final AiColorScheme ai; static BrightnessThemeData lerp( BrightnessThemeData begin, @@ -219,6 +302,7 @@ class BrightnessThemeData { end.settingCellBackgroundColor, t, )!, + ai: AiColorScheme.lerp(begin.ai, end.ai, t), ); @override @@ -246,7 +330,8 @@ class BrightnessThemeData { stickerPlaceholderColor == other.stickerPlaceholderColor && waveformBackground == other.waveformBackground && waveformForeground == other.waveformForeground && - settingCellBackgroundColor == other.settingCellBackgroundColor; + settingCellBackgroundColor == other.settingCellBackgroundColor && + ai == other.ai; @override int get hashCode => @@ -270,5 +355,6 @@ class BrightnessThemeData { stickerPlaceholderColor.hashCode ^ waveformBackground.hashCode ^ waveformForeground.hashCode ^ - waveformForeground.hashCode; + settingCellBackgroundColor.hashCode ^ + ai.hashCode; } From ce53eb44094a530a9309d7b0f81f8bafe6dd7749 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 20 Apr 2026 21:14:39 +0800 Subject: [PATCH 07/52] feat(ai): enhance AI message handling with caching, warmup, and improved UI --- lib/ai/ai_chat_controller.dart | 19 + lib/ai/model/ai_mode_state.dart | 7 +- lib/db/dao/ai_chat_message_dao.dart | 51 +++ lib/ui/home/chat/chat_page.dart | 59 ++- lib/ui/home/chat/input_container.dart | 416 +++++++++++++++----- lib/ui/provider/ai_input_mode_provider.dart | 15 +- lib/widgets/ai/ai_message_card.dart | 11 +- lib/widgets/markdown.dart | 283 ++++++++++++- lib/widgets/message/item/post_message.dart | 19 +- 9 files changed, 763 insertions(+), 117 deletions(-) diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 6c970a9842..75669c823f 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -23,6 +23,12 @@ const _kAiContextMessageLimit = 30; const _kAiHistoryLimit = 12; const _kAiStreamFlushChars = 32; const _kAiStreamFlushInterval = Duration(milliseconds: 80); +final kAiRuntimeStartedAt = DateTime.now(); + +bool isActivePendingAiMessage(AiChatMessage message) => + message.role == _kAiRoleAssistant && + message.status == _kAiStatusPending && + !message.updatedAt.isBefore(kAiRuntimeStartedAt); class AiChatController { AiChatController(this.database); @@ -37,6 +43,19 @@ class AiChatController { required String input, AiProviderConfig? provider, }) async { + await database.aiChatMessageDao.resolveStalePendingAssistantMessages( + updatedBefore: kAiRuntimeStartedAt, + conversationId: conversationId, + ); + final hasPendingAssistant = await database.aiChatMessageDao + .hasPendingAssistantMessage( + conversationId, + updatedAfter: kAiRuntimeStartedAt, + ); + if (hasPendingAssistant) { + throw Exception('AI is still responding'); + } + final config = provider ?? database.settingProperties.selectedAiProvider; if (config == null) { throw Exception('No AI provider configured'); diff --git a/lib/ai/model/ai_mode_state.dart b/lib/ai/model/ai_mode_state.dart index b938cba063..a835127912 100644 --- a/lib/ai/model/ai_mode_state.dart +++ b/lib/ai/model/ai_mode_state.dart @@ -4,20 +4,25 @@ class AiModeState extends Equatable { const AiModeState({ this.enabled = false, this.providerId, + this.model, }); final bool enabled; final String? providerId; + final String? model; @override - List get props => [enabled, providerId]; + List get props => [enabled, providerId, model]; AiModeState copyWith({ bool? enabled, String? providerId, + String? model, bool clearProviderId = false, + bool clearModel = false, }) => AiModeState( enabled: enabled ?? this.enabled, providerId: clearProviderId ? null : (providerId ?? this.providerId), + model: clearModel ? null : (model ?? this.model), ); } diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index 41ac9d5d64..12ddfe414a 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -9,6 +9,10 @@ class AiChatMessageDao extends DatabaseAccessor with _$AiChatMessageDaoMixin { AiChatMessageDao(super.db); + static const assistantRole = 'assistant'; + static const pendingStatus = 'pending'; + static const errorStatus = 'error'; + Stream> watchConversationMessages( String conversationId, ) => @@ -63,4 +67,51 @@ class AiChatMessageDao extends DatabaseAccessor Future deleteConversationMessages(String conversationId) => (delete( db.aiChatMessages, )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); + + Future hasPendingAssistantMessage( + String conversationId, { + DateTime? updatedAfter, + }) async { + final query = selectOnly(db.aiChatMessages) + ..addColumns([db.aiChatMessages.id.count()]) + ..where( + db.aiChatMessages.conversationId.equals(conversationId) & + db.aiChatMessages.role.equals(assistantRole) & + db.aiChatMessages.status.equals(pendingStatus) & + (updatedAfter == null + ? const Constant(true) + : db.aiChatMessages.updatedAt.isBiggerOrEqualValue( + updatedAfter.millisecondsSinceEpoch, + )), + ); + final row = await query.getSingleOrNull(); + final count = row?.read(db.aiChatMessages.id.count()) ?? 0; + return count > 0; + } + + Future resolveStalePendingAssistantMessages({ + required DateTime updatedBefore, + String? conversationId, + String errorText = 'Interrupted by app restart', + }) { + final query = update(db.aiChatMessages) + ..where( + (tbl) => + tbl.role.equals(assistantRole) & + tbl.status.equals(pendingStatus) & + tbl.updatedAt.isSmallerThanValue( + updatedBefore.millisecondsSinceEpoch, + ) & + (conversationId == null + ? const Constant(true) + : tbl.conversationId.equals(conversationId)), + ); + return query.write( + AiChatMessagesCompanion( + status: const Value(errorStatus), + errorText: Value(errorText), + updatedAt: Value(DateTime.now()), + ), + ); + } } diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index f93f9fb0ca..849ef55088 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -1,4 +1,5 @@ import 'dart:io'; +import 'dart:math' as math; import 'package:desktop_drop/desktop_drop.dart'; import 'package:flutter/material.dart'; @@ -29,6 +30,7 @@ import '../../../widgets/dash_path_border.dart'; import '../../../widgets/dialog.dart'; import '../../../widgets/high_light_text.dart'; import '../../../widgets/interactive_decorated_box.dart'; +import '../../../widgets/markdown.dart'; import '../../../widgets/menu.dart'; import '../../../widgets/message/message.dart'; import '../../../widgets/message/message_bubble.dart'; @@ -662,7 +664,53 @@ class _List extends HookConsumerWidget { return null; } - Widget buildTimelineChild(ChatTimelineItem item) { + ({String key, String data})? markdownWarmupEntryOf(ChatTimelineItem item) { + final aiMessage = item.aiMessage; + if (aiMessage != null) { + if (aiMessage.role == 'user' || aiMessage.status == 'error') { + return null; + } + final content = aiMessage.content.trim(); + if (content.isEmpty) return null; + return ( + key: buildMarkdownCacheKey( + namespace: 'ai', + id: aiMessage.id, + data: content, + ), + data: content, + ); + } + + final message = item.message; + if (message == null || !message.type.isPost) return null; + final content = (message.content ?? '').postOptimize(); + if (content.isEmpty) return null; + return ( + key: buildMarkdownCacheKey( + namespace: 'post', + id: message.messageId, + data: content, + ), + data: content, + ); + } + + void warmupMarkdownAround(int index) { + final start = math.max(0, index - 6); + final end = math.min(timeline.length, index + 7); + final entries = <({String key, String data})>[]; + for (var i = start; i < end; i++) { + final entry = markdownWarmupEntryOf(timeline[i]); + if (entry != null) { + entries.add(entry); + } + } + markdownControllerCache.warmupAll(entries); + } + + Widget buildTimelineChild(ChatTimelineItem item, int index) { + warmupMarkdownAround(index); if (item.isAiMessage) { return AiMessageCard( key: ValueKey('ai-${item.id}'), @@ -704,7 +752,7 @@ class _List extends HookConsumerWidget { index, ) { final actualIndex = topTimeline.length - index - 1; - return buildTimelineChild(topTimeline[actualIndex]); + return buildTimelineChild(topTimeline[actualIndex], actualIndex); }, childCount: topTimeline.length), ), SliverToBoxAdapter( @@ -712,7 +760,7 @@ class _List extends HookConsumerWidget { child: Builder( builder: (context) { if (centerTimeline == null) return const SizedBox(); - return buildTimelineChild(centerTimeline); + return buildTimelineChild(centerTimeline, centerTimelineIndex!); }, ), ), @@ -722,7 +770,10 @@ class _List extends HookConsumerWidget { ( context, index, - ) => buildTimelineChild(bottomTimeline[index]), + ) => buildTimelineChild( + bottomTimeline[index], + (centerTimelineIndex ?? -1) + index + 1, + ), childCount: bottomTimeline.length, ), ), diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 6ec2709469..b8aa78810e 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -110,17 +110,25 @@ class _InputContainer extends HookConsumerWidget { .whereType() .where((element) => element.enabled) .toList(); - final aiProviderId = aiModeState.providerId; - var aiProvider = selectedAiProvider; - if (aiProviderId != null) { - for (final provider in enabledAiProviders) { - if (provider.id == aiProviderId) { - aiProvider = provider; - break; - } - } - } + final aiProvider = _resolveAiModeProvider( + selectedAiProvider: selectedAiProvider, + enabledAiProviders: enabledAiProviders, + providerId: aiModeState.providerId, + selectedModel: aiModeState.model, + ); final aiModeEnabled = aiModeState.enabled; + final aiMessages = + useMemoizedStream( + () => conversationId == null + ? Stream.value(const []) + : context.database.aiChatMessageDao.watchConversationMessages( + conversationId, + ), + keys: [conversationId], + initialData: const [], + ).data ?? + const []; + final aiRequestInFlight = aiMessages.any(isActivePendingAiMessage); final quoteMessageId = ref.watch(quoteMessageIdProvider); @@ -144,6 +152,17 @@ class _InputContainer extends HookConsumerWidget { final mentionProviderInstance = mentionProvider(textEditingValueStream); + useEffect(() { + if (conversationId == null) return null; + unawaited( + context.database.aiChatMessageDao.resolveStalePendingAssistantMessages( + updatedBefore: kAiRuntimeStartedAt, + conversationId: conversationId, + ), + ); + return null; + }, [conversationId]); + useEffect(() { final updateDraft = context.database.conversationDao.updateDraft; return () { @@ -224,7 +243,7 @@ class _InputContainer extends HookConsumerWidget { decoration: BoxDecoration(color: context.theme.primary), padding: EdgeInsets.fromLTRB( 16, - aiModeEnabled ? 10 : 8, + aiModeEnabled ? 8 : 8, 16, 8, ), @@ -232,7 +251,7 @@ class _InputContainer extends HookConsumerWidget { duration: const Duration(milliseconds: 220), curve: Curves.easeOutCubic, padding: aiModeEnabled - ? const EdgeInsets.fromLTRB(12, 12, 12, 12) + ? const EdgeInsets.fromLTRB(10, 8, 10, 8) : EdgeInsets.zero, decoration: BoxDecoration( color: aiModeEnabled @@ -255,7 +274,7 @@ class _InputContainer extends HookConsumerWidget { ), Container( height: 1, - margin: const EdgeInsets.only(top: 10, bottom: 10), + margin: const EdgeInsets.only(top: 8, bottom: 8), color: context.theme.ai.surfaceBorder, ), ], @@ -277,6 +296,8 @@ class _InputContainer extends HookConsumerWidget { mentionProviderInstance: mentionProviderInstance, aiModeEnabled: aiModeEnabled, providerName: aiProvider?.name, + modelName: aiProvider?.model, + aiRequestInFlight: aiRequestInFlight, ), ), SizedBox(width: aiModeEnabled ? 10 : 16), @@ -285,6 +306,7 @@ class _InputContainer extends HookConsumerWidget { textEditingController: textEditingController, textEditingValueStream: textEditingValueStream, aiModeEnabled: aiModeEnabled, + aiRequestInFlight: aiRequestInFlight, ), ], ), @@ -306,12 +328,14 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { required this.textEditingValueStream, required this.textEditingController, required this.aiModeEnabled, + required this.aiRequestInFlight, }); final String? conversationId; final Stream textEditingValueStream; final TextEditingController textEditingController; final bool aiModeEnabled; + final bool aiRequestInFlight; @override Widget build(BuildContext context, WidgetRef ref) { @@ -328,12 +352,13 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { if (aiModeEnabled) { final backgroundColor = context.theme.ai.accent; final foregroundColor = context.theme.ai.onAccent; + final canSend = hasInputText && !aiRequestInFlight; return AnimatedOpacity( duration: const Duration(milliseconds: 180), - opacity: hasInputText ? 1 : 0.45, + opacity: canSend ? 1 : 0.45, child: IgnorePointer( - ignoring: !hasInputText, + ignoring: !canSend, child: DecoratedBox( decoration: BoxDecoration( color: backgroundColor, @@ -347,10 +372,12 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { size: 20, color: foregroundColor, ), - onTap: () => _sendMessage( - context, - textEditingController, - conversationId: conversationId, + onTap: () => unawaited( + _sendMessage( + context, + textEditingController, + conversationId: conversationId, + ), ), ), ), @@ -407,11 +434,13 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { MenuAction( image: MenuImage.icon(IconFonts.mute), title: context.l10n.sendWithoutSound, - callback: () => _sendMessage( - context, - textEditingController, - conversationId: conversationId, - silent: true, + callback: () => unawaited( + _sendMessage( + context, + textEditingController, + conversationId: conversationId, + silent: true, + ), ), ), ], @@ -419,10 +448,12 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { child: ActionButton( name: Resources.assetsImagesIcSendSvg, color: context.theme.icon, - onTap: () => _sendMessage( - context, - textEditingController, - conversationId: conversationId, + onTap: () => unawaited( + _sendMessage( + context, + textEditingController, + conversationId: conversationId, + ), ), ), ), @@ -469,12 +500,12 @@ void _sendPostMessage( context.providerContainer.read(quoteMessageProvider.notifier).state = null; } -void _sendMessage( +Future _sendMessage( BuildContext context, TextEditingController textEditingController, { required String? conversationId, bool silent = false, -}) { +}) async { final text = textEditingController.value.text.trim(); if (text.isEmpty) return; if (conversationId == null) return; @@ -499,7 +530,7 @@ void _sendMessage( showToastFailed(ToastError('Please add an AI provider first')); return; } - aiModeController.enter(providerId: provider.id); + aiModeController.enter(providerId: provider.id, model: provider.model); textEditingController.text = ''; return; } @@ -513,36 +544,46 @@ void _sendMessage( showToastFailed(ToastError('Please add an AI provider first')); return; } - aiModeController.enter(providerId: provider.id); - textEditingController.text = ''; - unawaited( - AiChatController(context.database) - .send( - conversationId: conversationId, - input: inlineAiInput, - provider: provider, - ) - .catchError((Object error, StackTrace _) => showToastFailed(error)), - ); + aiModeController.enter(providerId: provider.id, model: provider.model); + try { + await AiChatController(context.database).send( + conversationId: conversationId, + input: inlineAiInput, + provider: provider, + ); + textEditingController.text = ''; + } catch (error, _) { + showToastFailed(error); + } return; } if (aiModeState.enabled) { - final provider = context.database.settingProperties.selectedAiProvider; + final provider = _resolveAiModeProvider( + selectedAiProvider: context.database.settingProperties.selectedAiProvider, + enabledAiProviders: context.database.settingProperties.aiProviders + .whereType() + .where((element) => element.enabled) + .toList(), + providerId: aiModeState.providerId, + selectedModel: aiModeState.model, + ); if (provider == null) { showToastFailed(ToastError('Please add an AI provider first')); return; } - textEditingController.text = ''; - unawaited( - AiChatController(context.database) - .send(conversationId: conversationId, input: text, provider: provider) - .catchError((Object error, StackTrace _) => showToastFailed(error)), - ); + try { + await AiChatController( + context.database, + ).send(conversationId: conversationId, input: text, provider: provider); + textEditingController.text = ''; + } catch (error, _) { + showToastFailed(error); + } return; } - context.accountServer.sendTextMessage( + await context.accountServer.sendTextMessage( text, conversationItem.encryptCategory, conversationId: conversationItem.conversationId, @@ -555,6 +596,32 @@ void _sendMessage( context.providerContainer.read(quoteMessageProvider.notifier).state = null; } +AiProviderConfig? _resolveAiModeProvider({ + required AiProviderConfig? selectedAiProvider, + required List enabledAiProviders, + required String? providerId, + required String? selectedModel, +}) { + var provider = selectedAiProvider; + if (providerId != null) { + for (final item in enabledAiProviders) { + if (item.id == providerId) { + provider = item; + break; + } + } + } + if (provider == null) return null; + + final trimmedModel = selectedModel?.trim(); + if (trimmedModel == null || trimmedModel.isEmpty) return provider; + if (provider.models.isNotEmpty && !provider.models.contains(trimmedModel)) { + return provider; + } + if (provider.model == trimmedModel) return provider; + return provider.copyWith(defaultModel: trimmedModel, model: trimmedModel); +} + class _SendTextField extends HookConsumerWidget { const _SendTextField({ required this.focusNode, @@ -562,6 +629,8 @@ class _SendTextField extends HookConsumerWidget { required this.mentionProviderInstance, required this.aiModeEnabled, required this.providerName, + required this.modelName, + required this.aiRequestInFlight, }); final FocusNode focusNode; @@ -570,6 +639,8 @@ class _SendTextField extends HookConsumerWidget { mentionProviderInstance; final bool aiModeEnabled; final String? providerName; + final String? modelName; + final bool aiRequestInFlight; @override Widget build(BuildContext context, WidgetRef ref) { @@ -620,20 +691,39 @@ class _SendTextField extends HookConsumerWidget { false; final placeholder = aiModeEnabled - ? 'Ask ${providerName?.trim().isNotEmpty == true ? providerName : 'AI'} anything' + ? aiRequestInFlight + ? [ + if (providerName?.trim().isNotEmpty == true) + providerName!.trim() + else + 'AI', + if (modelName?.trim().isNotEmpty == true) + '(${modelName!.trim()})', + 'is responding...', + ].join(' ') + : [ + 'Ask', + if (providerName?.trim().isNotEmpty == true) + providerName!.trim() + else + 'AI', + if (modelName?.trim().isNotEmpty == true) + '(${modelName!.trim()})', + ].join(' ') : isEncryptConversation ? context.l10n.chatHintE2e : 'Type message or /ai'; + final canSubmit = sendable && (!aiModeEnabled || !aiRequestInFlight); return AnimatedContainer( duration: const Duration(milliseconds: 220), curve: Curves.easeOutCubic, constraints: const BoxConstraints(minHeight: 40), decoration: BoxDecoration( - borderRadius: BorderRadius.all(Radius.circular(aiModeEnabled ? 14 : 4)), + borderRadius: BorderRadius.all(Radius.circular(aiModeEnabled ? 12 : 4)), color: context.dynamicColor( aiModeEnabled - ? const Color.fromRGBO(255, 255, 255, 0.78) + ? const Color.fromRGBO(255, 255, 255, 0.82) : const Color.fromRGBO(245, 247, 250, 1), darkColor: const Color.fromRGBO(255, 255, 255, 0.08), ), @@ -647,7 +737,7 @@ class _SendTextField extends HookConsumerWidget { child: FocusableActionDetector( autofocus: true, shortcuts: { - if (sendable) + if (canSubmit) const SingleActivator(LogicalKeyboardKey.enter): const _SendMessageIntent(), SingleActivator( @@ -661,10 +751,12 @@ class _SendTextField extends HookConsumerWidget { }, actions: { _SendMessageIntent: CallbackAction( - onInvoke: (intent) => _sendMessage( - context, - textEditingController, - conversationId: ref.read(currentConversationIdProvider), + onInvoke: (intent) => unawaited( + _sendMessage( + context, + textEditingController, + conversationId: ref.read(currentConversationIdProvider), + ), ), ), PasteTextIntent: _PasteContextAction(context), @@ -704,8 +796,8 @@ class _SendTextField extends HookConsumerWidget { contentPadding: EdgeInsets.only( left: 10, right: 10, - top: 10, - bottom: 10, + top: 8, + bottom: 8, ), ), selectionHeightStyle: ui.BoxHeightStyle.includeLineSpacingMiddle, @@ -752,59 +844,67 @@ class _AiModeBar extends HookConsumerWidget { final model = provider?.model.trim(); final aiColors = context.theme.ai; final accentColor = aiColors.accent; + final hasProvider = provider != null; + final notifier = ref.read(aiInputModeProvider(conversationId).notifier); + final enabledAiProviders = context.database.settingProperties.aiProviders + .whereType() + .where((element) => element.enabled) + .toList(); + final providerOptions = enabledAiProviders + .map( + (item) => CustomPopupMenuItem( + title: item.name, + value: item, + ), + ) + .toList(growable: false); + final modelOptions = + provider?.models + .where((item) => item.trim().isNotEmpty) + .map( + (item) => CustomPopupMenuItem(title: item, value: item), + ) + .toList(growable: false) ?? + >[]; return Row( children: [ - Container( - padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 6), - decoration: BoxDecoration( - color: aiColors.surfaceVariant, - borderRadius: const BorderRadius.all(Radius.circular(999)), - ), - child: Row( - mainAxisSize: MainAxisSize.min, - children: [ - Icon(Icons.auto_awesome_rounded, size: 14, color: accentColor), - const SizedBox(width: 6), - Text( - 'AI Mode', - style: TextStyle( - color: accentColor, - fontSize: 12, - fontWeight: FontWeight.w600, - ), - ), - ], - ), + _AiModeChip( + icon: Icons.auto_awesome_rounded, + label: 'AI', + foregroundColor: accentColor, + backgroundColor: aiColors.surfaceVariant, ), - const SizedBox(width: 10), Expanded( - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - mainAxisSize: MainAxisSize.min, - children: [ - Text( - providerName, - style: TextStyle( - color: context.theme.text, - fontSize: 14, - fontWeight: FontWeight.w600, + child: SingleChildScrollView( + scrollDirection: Axis.horizontal, + padding: const EdgeInsets.symmetric(horizontal: 8), + child: Row( + children: [ + _AiModeMenuChip( + icon: Icons.hub_rounded, + label: providerName, + items: providerOptions, + enabled: providerOptions.length > 1, + onSelected: (value) => notifier.updateProvider( + providerId: value.id, + model: value.model, + ), ), - ), - const SizedBox(height: 2), - Text( - model?.isNotEmpty == true ? model! : 'Next message goes to AI', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 12, + const SizedBox(width: 6), + _AiModeMenuChip( + icon: Icons.tune_rounded, + label: model?.isNotEmpty == true + ? model! + : (hasProvider ? 'Select Model' : 'No Model'), + items: modelOptions, + enabled: modelOptions.length > 1, + onSelected: notifier.updateModel, ), - maxLines: 1, - overflow: TextOverflow.ellipsis, - ), - ], + ], + ), ), ), - const SizedBox(width: 8), ActionButton( name: Resources.assetsImagesIcCloseSvg, color: context.theme.icon, @@ -817,6 +917,114 @@ class _AiModeBar extends HookConsumerWidget { } } +class _AiModeChip extends StatelessWidget { + const _AiModeChip({ + required this.icon, + required this.label, + required this.foregroundColor, + required this.backgroundColor, + }); + + final IconData icon; + final String label; + final Color foregroundColor; + final Color backgroundColor; + + @override + Widget build(BuildContext context) => Container( + height: 28, + padding: const EdgeInsets.symmetric(horizontal: 10), + decoration: BoxDecoration( + color: backgroundColor, + borderRadius: const BorderRadius.all(Radius.circular(999)), + ), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Icon(icon, size: 14, color: foregroundColor), + const SizedBox(width: 6), + Text( + label, + style: TextStyle( + color: foregroundColor, + fontSize: 12, + fontWeight: FontWeight.w700, + ), + ), + ], + ), + ); +} + +class _AiModeMenuChip extends StatelessWidget { + const _AiModeMenuChip({ + required this.icon, + required this.label, + required this.items, + required this.onSelected, + this.enabled = true, + }); + + final IconData icon; + final String label; + final List> items; + final ValueChanged onSelected; + final bool enabled; + + @override + Widget build(BuildContext context) { + final child = Container( + height: 28, + padding: const EdgeInsets.symmetric(horizontal: 10), + decoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(255, 255, 255, 0.74), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ), + border: Border.all(color: context.theme.ai.surfaceBorder), + borderRadius: const BorderRadius.all(Radius.circular(999)), + ), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Icon(icon, size: 13, color: context.theme.secondaryText), + const SizedBox(width: 6), + ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 160), + child: Text( + label, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.text, + fontSize: 12, + fontWeight: FontWeight.w600, + ), + ), + ), + if (enabled) ...[ + const SizedBox(width: 4), + Icon( + Icons.keyboard_arrow_down_rounded, + size: 16, + color: context.theme.secondaryText, + ), + ], + ], + ), + ); + + if (!enabled || items.isEmpty) return child; + + return CustomPopupMenuButton( + itemBuilder: (_) => items, + onSelected: onSelected, + color: Colors.transparent, + child: child, + ); + } +} + class _QuoteMessage extends HookConsumerWidget { const _QuoteMessage(); diff --git a/lib/ui/provider/ai_input_mode_provider.dart b/lib/ui/provider/ai_input_mode_provider.dart index 44cf77980d..538ba8d8ec 100644 --- a/lib/ui/provider/ai_input_mode_provider.dart +++ b/lib/ui/provider/ai_input_mode_provider.dart @@ -5,8 +5,19 @@ import '../../ai/model/ai_mode_state.dart'; class AiInputModeNotifier extends StateNotifier { AiInputModeNotifier() : super(const AiModeState()); - void enter({String? providerId}) { - state = AiModeState(enabled: true, providerId: providerId); + void enter({String? providerId, String? model}) { + state = AiModeState(enabled: true, providerId: providerId, model: model); + } + + void updateProvider({ + required String providerId, + String? model, + }) { + state = state.copyWith(providerId: providerId, model: model); + } + + void updateModel(String? model) { + state = state.copyWith(model: model); } void exit() { diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index c1a047d351..2129c95fab 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -155,9 +155,18 @@ class _AiMessageBody extends StatelessWidget { if (isUser || message.status == 'error') { body = _AiSelectableText(text: text, style: textStyle); } else { + final cacheKey = buildMarkdownCacheKey( + namespace: 'ai', + id: message.id, + data: text, + ); body = DefaultTextStyle.merge( style: textStyle, - child: MarkdownColumn(data: text, selectable: true), + child: MarkdownColumn( + data: text, + selectable: true, + cacheKey: cacheKey, + ), ); } diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index cc4bce0829..7edd3f91ae 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -1,6 +1,10 @@ +import 'dart:async'; +import 'dart:collection'; import 'dart:io'; import 'package:flutter/material.dart'; +import 'package:flutter/scheduler.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:mixin_markdown_widget/mixin_markdown_widget.dart'; @@ -9,15 +13,166 @@ import '../utils/extension/extension.dart'; import '../utils/uri_utils.dart'; import 'mixin_image.dart'; -class MarkdownColumn extends ConsumerWidget { +const _kMarkdownControllerCacheLimit = 120; +const _kMarkdownWarmupPerFrame = 2; + +String buildMarkdownCacheKey({ + required String namespace, + required String id, + required String data, +}) => '$namespace:$id:${data.hashCode}'; + +final markdownControllerCache = MarkdownControllerCache(); + +class MarkdownControllerCache { + final _entries = {}; + final _pending = >{}; + final _queuedKeys = {}; + final _warmupQueue = ListQueue<({String key, String data})>(); + + bool _warmupScheduled = false; + + MarkdownController? acquire(String key, String data) { + final entry = _entries[key]; + if (entry == null) return null; + if (entry.data != data) { + _removeEntry(key, entry); + return null; + } + _touch(key, entry); + entry.retainCount += 1; + return entry.controller; + } + + void release(String key, MarkdownController controller) { + final entry = _entries[key]; + if (entry == null || !identical(entry.controller, controller)) return; + if (entry.retainCount > 0) { + entry.retainCount -= 1; + } + } + + Future warmup(String key, String data) { + final entry = _entries[key]; + if (entry != null) { + if (entry.data == data) { + _touch(key, entry); + return Future.value(); + } + _removeEntry(key, entry); + } + + final pending = _pending[key]; + if (pending != null) return pending.future; + + final completer = Completer(); + _pending[key] = completer; + if (_queuedKeys.add(key)) { + _warmupQueue.add((key: key, data: data)); + _scheduleWarmup(); + } + return completer.future; + } + + void warmupAll(Iterable<({String key, String data})> entries) { + for (final entry in entries) { + unawaited(warmup(entry.key, entry.data)); + } + } + + void _scheduleWarmup() { + if (_warmupScheduled) return; + _warmupScheduled = true; + SchedulerBinding.instance.addPostFrameCallback((_) { + _warmupScheduled = false; + _drainWarmupQueue(); + }); + } + + void _drainWarmupQueue() { + var count = 0; + while (_warmupQueue.isNotEmpty && count < _kMarkdownWarmupPerFrame) { + final task = _warmupQueue.removeFirst(); + _queuedKeys.remove(task.key); + final completer = _pending.remove(task.key); + + try { + final existing = _entries[task.key]; + if (existing != null && existing.data == task.data) { + _touch(task.key, existing); + } else { + if (existing != null) { + _removeEntry(task.key, existing); + } + _entries[task.key] = _MarkdownCacheEntry( + data: task.data, + controller: MarkdownController(data: task.data), + ); + _evictIfNeeded(); + } + completer?.complete(); + } catch (error, stackTrace) { + completer?.completeError(error, stackTrace); + } + count += 1; + } + + if (_warmupQueue.isNotEmpty) { + _scheduleWarmup(); + } + } + + void _touch(String key, _MarkdownCacheEntry entry) { + _entries.remove(key); + _entries[key] = entry; + } + + void _evictIfNeeded() { + while (_entries.length > _kMarkdownControllerCacheLimit) { + String? keyToRemove; + _MarkdownCacheEntry? entryToRemove; + for (final entry in _entries.entries) { + if (entry.value.retainCount == 0) { + keyToRemove = entry.key; + entryToRemove = entry.value; + break; + } + } + if (keyToRemove == null || entryToRemove == null) { + return; + } + _removeEntry(keyToRemove, entryToRemove); + } + } + + void _removeEntry(String key, _MarkdownCacheEntry entry) { + _entries.remove(key); + entry.controller.dispose(); + } +} + +class _MarkdownCacheEntry { + _MarkdownCacheEntry({ + required this.data, + required this.controller, + }); + + final String data; + final MarkdownController controller; + int retainCount = 0; +} + +class MarkdownColumn extends HookConsumerWidget { const MarkdownColumn({ required this.data, super.key, this.selectable = false, + this.cacheKey, }); final String data; final bool selectable; + final String? cacheKey; @override Widget build(BuildContext context, WidgetRef ref) { @@ -26,8 +181,9 @@ class MarkdownColumn extends ConsumerWidget { ); return ClipRect( - child: MarkdownWidget( + child: _MarkdownView( data: data, + cacheKey: cacheKey, useColumn: true, selectable: selectable, contextMenuBuilder: (_, _, _, _) => const SizedBox.shrink(), @@ -43,17 +199,19 @@ class MarkdownColumn extends ConsumerWidget { } } -class Markdown extends ConsumerWidget { +class Markdown extends HookConsumerWidget { const Markdown({ required this.data, super.key, this.padding = EdgeInsets.zero, this.physics, + this.cacheKey, }); final String data; final EdgeInsetsGeometry? padding; final ScrollPhysics? physics; + final String? cacheKey; @override Widget build(BuildContext context, WidgetRef ref) { @@ -61,8 +219,9 @@ class Markdown extends ConsumerWidget { settingProvider.select((value) => value.chatFontSizeDelta), ); - return MarkdownWidget( + return _MarkdownView( data: data, + cacheKey: cacheKey, padding: padding, physics: physics, theme: _createMarkdownTheme(context, chatFontSizeDelta), @@ -75,6 +234,122 @@ class Markdown extends ConsumerWidget { } } +class _MarkdownView extends HookWidget { + const _MarkdownView({ + required this.data, + required this.theme, + required this.imageBuilder, + required this.onTapLink, + this.cacheKey, + this.padding, + this.physics, + this.useColumn = false, + this.selectable = true, + this.contextMenuBuilder, + }); + + final String data; + final String? cacheKey; + final MarkdownThemeData theme; + final EdgeInsetsGeometry? padding; + final ScrollPhysics? physics; + final bool useColumn; + final bool selectable; + final MarkdownImageBuilder imageBuilder; + final MarkdownTapLinkCallback onTapLink; + final MarkdownContextMenuBuilder? contextMenuBuilder; + + @override + Widget build(BuildContext context) { + if (cacheKey == null) { + return _buildMarkdownWidget(data: data); + } + + final controller = useState(null); + + useEffect(() { + var disposed = false; + MarkdownController? retained; + + void bindCachedController() { + final cached = markdownControllerCache.acquire(cacheKey!, data); + if (cached == null || disposed) return; + retained = cached; + controller.value = cached; + } + + bindCachedController(); + if (controller.value == null) { + unawaited( + markdownControllerCache.warmup(cacheKey!, data).then((_) { + if (disposed) return; + bindCachedController(); + }), + ); + } + + return () { + disposed = true; + final current = retained; + if (current != null) { + markdownControllerCache.release(cacheKey!, current); + } + }; + }, [cacheKey, data]); + + final cachedController = controller.value; + if (cachedController != null) { + return _buildMarkdownWidget(controller: cachedController); + } + + return _MarkdownFallback( + data: data, + theme: theme, + padding: padding, + ); + } + + Widget _buildMarkdownWidget({ + String? data, + MarkdownController? controller, + }) => MarkdownWidget( + data: data, + controller: controller, + padding: padding, + physics: physics, + useColumn: useColumn, + selectable: selectable, + contextMenuBuilder: contextMenuBuilder, + theme: theme, + imageBuilder: imageBuilder, + onTapLink: onTapLink, + ); +} + +class _MarkdownFallback extends StatelessWidget { + const _MarkdownFallback({ + required this.data, + required this.theme, + this.padding, + }); + + final String data; + final MarkdownThemeData theme; + final EdgeInsetsGeometry? padding; + + @override + Widget build(BuildContext context) { + final effectivePadding = padding ?? EdgeInsets.zero; + return Padding( + padding: effectivePadding, + child: Text( + data, + style: theme.bodyStyle, + ), + ); + } +} + Widget _buildMarkdownImage( BuildContext context, ImageBlock block, diff --git a/lib/widgets/message/item/post_message.dart b/lib/widgets/message/item/post_message.dart index 6f5d3970dc..bada94e10e 100644 --- a/lib/widgets/message/item/post_message.dart +++ b/lib/widgets/message/item/post_message.dart @@ -70,7 +70,16 @@ class MessagePost extends StatelessWidget { children: [ HookBuilder( builder: (context) { + final messageId = context.message.messageId; final postContent = useMemoized(content.postOptimize, [content]); + final cacheKey = useMemoized( + () => buildMarkdownCacheKey( + namespace: 'post', + id: messageId, + data: postContent, + ), + [messageId, postContent], + ); return ConstrainedBox( constraints: BoxConstraints( @@ -84,7 +93,10 @@ class MessagePost extends StatelessWidget { ).copyWith(scrollbars: false), child: SingleChildScrollView( physics: const NeverScrollableScrollPhysics(), - child: MarkdownColumn(data: postContent), + child: MarkdownColumn( + data: postContent, + cacheKey: cacheKey, + ), ), ), ); @@ -162,6 +174,11 @@ class PostPreview extends StatelessWidget { constraints: const BoxConstraints(maxWidth: 600), child: Markdown( data: message.content ?? '', + cacheKey: buildMarkdownCacheKey( + namespace: 'post-preview', + id: message.messageId, + data: message.content ?? '', + ), padding: const EdgeInsets.symmetric(vertical: 8, horizontal: 32), ), ), From c83081dc2d79ae8cbab0ff538f4708feea107f3a Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:27:07 +0800 Subject: [PATCH 08/52] feat: add streaming support to Markdown rendering --- lib/ui/home/chat/chat_page.dart | 2 - lib/widgets/ai/ai_message_card.dart | 2 +- lib/widgets/markdown.dart | 85 ++++++++++++++++------ lib/widgets/message/item/post_message.dart | 4 +- 4 files changed, 66 insertions(+), 27 deletions(-) diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index 849ef55088..af234f2e02 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -676,7 +676,6 @@ class _List extends HookConsumerWidget { key: buildMarkdownCacheKey( namespace: 'ai', id: aiMessage.id, - data: content, ), data: content, ); @@ -690,7 +689,6 @@ class _List extends HookConsumerWidget { key: buildMarkdownCacheKey( namespace: 'post', id: message.messageId, - data: content, ), data: content, ); diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 2129c95fab..66b6c2d93d 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -158,7 +158,6 @@ class _AiMessageBody extends StatelessWidget { final cacheKey = buildMarkdownCacheKey( namespace: 'ai', id: message.id, - data: text, ); body = DefaultTextStyle.merge( style: textStyle, @@ -166,6 +165,7 @@ class _AiMessageBody extends StatelessWidget { data: text, selectable: true, cacheKey: cacheKey, + streaming: message.status == 'pending', ), ); } diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index 7edd3f91ae..df45f25375 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -19,8 +19,7 @@ const _kMarkdownWarmupPerFrame = 2; String buildMarkdownCacheKey({ required String namespace, required String id, - required String data, -}) => '$namespace:$id:${data.hashCode}'; +}) => '$namespace:$id'; final markdownControllerCache = MarkdownControllerCache(); @@ -32,12 +31,17 @@ class MarkdownControllerCache { bool _warmupScheduled = false; - MarkdownController? acquire(String key, String data) { + MarkdownController? acquire( + String key, + String data, { + bool streaming = false, + }) { final entry = _entries[key]; if (entry == null) return null; if (entry.data != data) { - _removeEntry(key, entry); - return null; + _updateEntryData(entry, data, streaming: streaming); + } else if (!streaming) { + entry.controller.commitStream(); } _touch(key, entry); entry.retainCount += 1; @@ -57,9 +61,12 @@ class MarkdownControllerCache { if (entry != null) { if (entry.data == data) { _touch(key, entry); + entry.controller.commitStream(); return Future.value(); } - _removeEntry(key, entry); + _updateEntryData(entry, data, streaming: false); + _touch(key, entry); + return Future.value(); } final pending = _pending[key]; @@ -98,12 +105,14 @@ class MarkdownControllerCache { try { final existing = _entries[task.key]; - if (existing != null && existing.data == task.data) { + if (existing != null) { + if (existing.data != task.data) { + _updateEntryData(existing, task.data, streaming: false); + } else { + existing.controller.commitStream(); + } _touch(task.key, existing); } else { - if (existing != null) { - _removeEntry(task.key, existing); - } _entries[task.key] = _MarkdownCacheEntry( data: task.data, controller: MarkdownController(data: task.data), @@ -149,6 +158,23 @@ class MarkdownControllerCache { _entries.remove(key); entry.controller.dispose(); } + + void _updateEntryData( + _MarkdownCacheEntry entry, + String data, { + required bool streaming, + }) { + final previousData = entry.data; + entry.data = data; + if (streaming && data.startsWith(previousData)) { + entry.controller.appendChunk(data.substring(previousData.length)); + return; + } + entry.controller.setData(data); + if (!streaming) { + entry.controller.commitStream(); + } + } } class _MarkdownCacheEntry { @@ -157,7 +183,7 @@ class _MarkdownCacheEntry { required this.controller, }); - final String data; + String data; final MarkdownController controller; int retainCount = 0; } @@ -168,11 +194,13 @@ class MarkdownColumn extends HookConsumerWidget { super.key, this.selectable = false, this.cacheKey, + this.streaming = false, }); final String data; final bool selectable; final String? cacheKey; + final bool streaming; @override Widget build(BuildContext context, WidgetRef ref) { @@ -184,6 +212,7 @@ class MarkdownColumn extends HookConsumerWidget { child: _MarkdownView( data: data, cacheKey: cacheKey, + streaming: streaming, useColumn: true, selectable: selectable, contextMenuBuilder: (_, _, _, _) => const SizedBox.shrink(), @@ -206,12 +235,14 @@ class Markdown extends HookConsumerWidget { this.padding = EdgeInsets.zero, this.physics, this.cacheKey, + this.streaming = false, }); final String data; final EdgeInsetsGeometry? padding; final ScrollPhysics? physics; final String? cacheKey; + final bool streaming; @override Widget build(BuildContext context, WidgetRef ref) { @@ -222,6 +253,7 @@ class Markdown extends HookConsumerWidget { return _MarkdownView( data: data, cacheKey: cacheKey, + streaming: streaming, padding: padding, physics: physics, theme: _createMarkdownTheme(context, chatFontSizeDelta), @@ -243,6 +275,7 @@ class _MarkdownView extends HookWidget { this.cacheKey, this.padding, this.physics, + this.streaming = false, this.useColumn = false, this.selectable = true, this.contextMenuBuilder, @@ -250,6 +283,7 @@ class _MarkdownView extends HookWidget { final String data; final String? cacheKey; + final bool streaming; final MarkdownThemeData theme; final EdgeInsetsGeometry? padding; final ScrollPhysics? physics; @@ -265,21 +299,28 @@ class _MarkdownView extends HookWidget { return _buildMarkdownWidget(data: data); } - final controller = useState(null); + final controller = + useState<({String key, String data, MarkdownController controller})?>( + null, + ); useEffect(() { var disposed = false; MarkdownController? retained; - void bindCachedController() { - final cached = markdownControllerCache.acquire(cacheKey!, data); - if (cached == null || disposed) return; + bool bindCachedController() { + final cached = markdownControllerCache.acquire( + cacheKey!, + data, + streaming: streaming, + ); + if (cached == null || disposed) return false; retained = cached; - controller.value = cached; + controller.value = (key: cacheKey!, data: data, controller: cached); + return true; } - bindCachedController(); - if (controller.value == null) { + if (!bindCachedController()) { unawaited( markdownControllerCache.warmup(cacheKey!, data).then((_) { if (disposed) return; @@ -295,11 +336,13 @@ class _MarkdownView extends HookWidget { markdownControllerCache.release(cacheKey!, current); } }; - }, [cacheKey, data]); + }, [cacheKey, data, streaming]); final cachedController = controller.value; - if (cachedController != null) { - return _buildMarkdownWidget(controller: cachedController); + if (cachedController != null && + cachedController.key == cacheKey && + cachedController.data == data) { + return _buildMarkdownWidget(controller: cachedController.controller); } return _MarkdownFallback( diff --git a/lib/widgets/message/item/post_message.dart b/lib/widgets/message/item/post_message.dart index bada94e10e..ee42ae1a63 100644 --- a/lib/widgets/message/item/post_message.dart +++ b/lib/widgets/message/item/post_message.dart @@ -76,9 +76,8 @@ class MessagePost extends StatelessWidget { () => buildMarkdownCacheKey( namespace: 'post', id: messageId, - data: postContent, ), - [messageId, postContent], + [messageId], ); return ConstrainedBox( @@ -177,7 +176,6 @@ class PostPreview extends StatelessWidget { cacheKey: buildMarkdownCacheKey( namespace: 'post-preview', id: message.messageId, - data: message.content ?? '', ), padding: const EdgeInsets.symmetric(vertical: 8, horizontal: 32), ), From e527fca3d2f2a7240e4f48693e12a1799311b648 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:01:34 +0800 Subject: [PATCH 09/52] feat: update AI mode UI and theme colors --- lib/constants/brightness_theme_data.dart | 38 ++-- lib/ui/home/chat/input_container.dart | 270 ++++++++++------------- lib/widgets/menu.dart | 81 ++++--- 3 files changed, 184 insertions(+), 205 deletions(-) diff --git a/lib/constants/brightness_theme_data.dart b/lib/constants/brightness_theme_data.dart index 78501a5126..0f746748f8 100644 --- a/lib/constants/brightness_theme_data.dart +++ b/lib/constants/brightness_theme_data.dart @@ -26,16 +26,16 @@ const lightBrightnessThemeData = BrightnessThemeData( waveformForeground: Color.fromRGBO(155, 155, 155, 1), settingCellBackgroundColor: Colors.white, ai: AiColorScheme( - avatarBackground: Color.fromRGBO(227, 237, 213, 1), - accent: Color.fromRGBO(54, 87, 35, 1), + avatarBackground: Color(0xFFE0E7FF), + accent: Color(0xFF4F46E5), onAccent: Colors.white, - surface: Color.fromRGBO(241, 248, 243, 1), - surfaceBorder: Color.fromRGBO(200, 223, 208, 1), - surfaceVariant: Color.fromRGBO(223, 236, 214, 1), - userBubble: Color.fromRGBO(255, 241, 214, 1), - assistantBubble: Color.fromRGBO(228, 245, 239, 1), - errorBubble: Color.fromRGBO(255, 235, 235, 1), - error: Color.fromRGBO(193, 63, 63, 1), + surface: Color(0xFFF5F3FF), + surfaceBorder: Color(0xFFE0E7FF), + surfaceVariant: Color(0xFFEDE9FE), + userBubble: Color(0xFFF1F5F9), + assistantBubble: Color(0xFFEEF2FF), + errorBubble: Color(0xFFFEF2F2), + error: Color(0xFFEF4444), ), ); @@ -63,16 +63,16 @@ const darkBrightnessThemeData = BrightnessThemeData( waveformForeground: Color.fromRGBO(255, 255, 255, 1), settingCellBackgroundColor: Color.fromRGBO(255, 255, 255, 0.06), ai: AiColorScheme( - avatarBackground: Color.fromRGBO(64, 78, 56, 1), - accent: Color.fromRGBO(214, 235, 204, 1), - onAccent: Color.fromRGBO(26, 42, 31, 1), - surface: Color.fromRGBO(35, 52, 44, 1), - surfaceBorder: Color.fromRGBO(72, 101, 88, 1), - surfaceVariant: Color.fromRGBO(58, 77, 66, 1), - userBubble: Color.fromRGBO(96, 76, 34, 1), - assistantBubble: Color.fromRGBO(43, 77, 65, 1), - errorBubble: Color.fromRGBO(88, 46, 46, 1), - error: Color.fromRGBO(255, 173, 173, 1), + avatarBackground: Color(0xFF312E81), + accent: Color(0xFFA5B4FC), + onAccent: Color(0xFF1E1B4B), + surface: Color(0xFF282836), + surfaceBorder: Color(0xFF3F3F5A), + surfaceVariant: Color(0xFF312E4A), + userBubble: Color(0xFF334155), + assistantBubble: Color(0xFF2D2B4A), + errorBubble: Color(0xFF450A0A), + error: Color(0xFFFCA5A5), ), ); diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index b8aa78810e..259deb2dd7 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -238,7 +238,7 @@ class _InputContainer extends HookConsumerWidget { children: [ const _QuoteMessage(), ConstrainedBox( - constraints: BoxConstraints(minHeight: aiModeEnabled ? 108 : 56), + constraints: BoxConstraints(minHeight: aiModeEnabled ? 92 : 56), child: Container( decoration: BoxDecoration(color: context.theme.primary), padding: EdgeInsets.fromLTRB( @@ -250,19 +250,10 @@ class _InputContainer extends HookConsumerWidget { child: AnimatedContainer( duration: const Duration(milliseconds: 220), curve: Curves.easeOutCubic, - padding: aiModeEnabled - ? const EdgeInsets.fromLTRB(10, 8, 10, 8) - : EdgeInsets.zero, - decoration: BoxDecoration( - color: aiModeEnabled - ? context.theme.ai.surface - : Colors.transparent, - borderRadius: const BorderRadius.all(Radius.circular(18)), - border: aiModeEnabled - ? Border.all( - color: context.theme.ai.surfaceBorder, - ) - : null, + padding: EdgeInsets.zero, + decoration: const BoxDecoration( + color: Colors.transparent, + borderRadius: BorderRadius.all(Radius.circular(4)), ), child: Column( mainAxisSize: MainAxisSize.min, @@ -272,16 +263,15 @@ class _InputContainer extends HookConsumerWidget { conversationId: conversationId, provider: aiProvider, ), - Container( - height: 1, - margin: const EdgeInsets.only(top: 8, bottom: 8), - color: context.theme.ai.surfaceBorder, - ), + const SizedBox(height: 8), ], Row( crossAxisAlignment: CrossAxisAlignment.end, children: [ - if (!aiModeEnabled) ...[ + if (aiModeEnabled) ...[ + _AiModeBadge(color: context.theme.accent), + const SizedBox(width: 16), + ] else ...[ const _SendActionTypeButton(), const SizedBox(width: 6), _StickerButton( @@ -350,8 +340,6 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { false; if (aiModeEnabled) { - final backgroundColor = context.theme.ai.accent; - final foregroundColor = context.theme.ai.onAccent; final canSend = hasInputText && !aiRequestInFlight; return AnimatedOpacity( @@ -359,25 +347,14 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { opacity: canSend ? 1 : 0.45, child: IgnorePointer( ignoring: !canSend, - child: DecoratedBox( - decoration: BoxDecoration( - color: backgroundColor, - shape: BoxShape.circle, - ), - child: ActionButton( - size: 20, - color: foregroundColor, - child: Icon( - Icons.arrow_upward_rounded, - size: 20, - color: foregroundColor, - ), - onTap: () => unawaited( - _sendMessage( - context, - textEditingController, - conversationId: conversationId, - ), + child: ActionButton( + name: Resources.assetsImagesIcSendSvg, + color: context.theme.accent, + onTap: () => unawaited( + _sendMessage( + context, + textEditingController, + conversationId: conversationId, ), ), ), @@ -720,18 +697,11 @@ class _SendTextField extends HookConsumerWidget { curve: Curves.easeOutCubic, constraints: const BoxConstraints(minHeight: 40), decoration: BoxDecoration( - borderRadius: BorderRadius.all(Radius.circular(aiModeEnabled ? 12 : 4)), + borderRadius: const BorderRadius.all(Radius.circular(4)), color: context.dynamicColor( - aiModeEnabled - ? const Color.fromRGBO(255, 255, 255, 0.82) - : const Color.fromRGBO(245, 247, 250, 1), + const Color.fromRGBO(245, 247, 250, 1), darkColor: const Color.fromRGBO(255, 255, 255, 0.08), ), - border: aiModeEnabled - ? Border.all( - color: context.theme.ai.surfaceBorder, - ) - : null, ), alignment: Alignment.center, child: FocusableActionDetector( @@ -840,10 +810,10 @@ class _AiModeBar extends HookConsumerWidget { @override Widget build(BuildContext context, WidgetRef ref) { - final providerName = provider?.name ?? 'No Provider'; + final providerName = provider?.name.trim().isNotEmpty == true + ? provider!.name.trim() + : 'No Provider'; final model = provider?.model.trim(); - final aiColors = context.theme.ai; - final accentColor = aiColors.accent; final hasProvider = provider != null; final notifier = ref.read(aiInputModeProvider(conversationId).notifier); final enabledAiProviders = context.database.settingProperties.aiProviders @@ -862,94 +832,92 @@ class _AiModeBar extends HookConsumerWidget { provider?.models .where((item) => item.trim().isNotEmpty) .map( - (item) => CustomPopupMenuItem(title: item, value: item), + (item) => CustomPopupMenuItem( + title: item.trim(), + value: item.trim(), + ), ) .toList(growable: false) ?? >[]; - return Row( - children: [ - _AiModeChip( - icon: Icons.auto_awesome_rounded, - label: 'AI', - foregroundColor: accentColor, - backgroundColor: aiColors.surfaceVariant, - ), - Expanded( - child: SingleChildScrollView( - scrollDirection: Axis.horizontal, - padding: const EdgeInsets.symmetric(horizontal: 8), + return SizedBox( + width: double.infinity, + height: 30, + child: Row( + children: [ + Expanded( child: Row( children: [ - _AiModeMenuChip( - icon: Icons.hub_rounded, - label: providerName, - items: providerOptions, - enabled: providerOptions.length > 1, - onSelected: (value) => notifier.updateProvider( - providerId: value.id, - model: value.model, + Flexible( + child: _AiModeMenuChip( + icon: Icons.hub_rounded, + label: providerName, + items: providerOptions, + enabled: providerOptions.length > 1, + onSelected: (value) => notifier.updateProvider( + providerId: value.id, + model: value.model, + ), ), ), - const SizedBox(width: 6), - _AiModeMenuChip( - icon: Icons.tune_rounded, - label: model?.isNotEmpty == true - ? model! - : (hasProvider ? 'Select Model' : 'No Model'), - items: modelOptions, - enabled: modelOptions.length > 1, - onSelected: notifier.updateModel, + const SizedBox(width: 10), + _AiModeDivider(), + const SizedBox(width: 10), + Flexible( + child: _AiModeMenuChip( + icon: Icons.tune_rounded, + label: model?.isNotEmpty == true + ? model! + : (hasProvider ? 'Select Model' : 'No Model'), + items: modelOptions, + enabled: modelOptions.length > 1, + onSelected: notifier.updateModel, + ), ), ], ), ), - ), - ActionButton( - name: Resources.assetsImagesIcCloseSvg, - color: context.theme.icon, - size: 18, - onTap: () => - ref.read(aiInputModeProvider(conversationId).notifier).exit(), - ), - ], + const SizedBox(width: 8), + ActionButton( + name: Resources.assetsImagesIcCloseSvg, + color: context.theme.icon, + size: 20, + onTap: () => + ref.read(aiInputModeProvider(conversationId).notifier).exit(), + ), + ], + ), ); } } -class _AiModeChip extends StatelessWidget { - const _AiModeChip({ - required this.icon, - required this.label, - required this.foregroundColor, - required this.backgroundColor, - }); - - final IconData icon; - final String label; - final Color foregroundColor; - final Color backgroundColor; - +class _AiModeDivider extends StatelessWidget { @override Widget build(BuildContext context) => Container( - height: 28, - padding: const EdgeInsets.symmetric(horizontal: 10), - decoration: BoxDecoration( - color: backgroundColor, - borderRadius: const BorderRadius.all(Radius.circular(999)), + width: 1, + height: 14, + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.08), + darkColor: const Color.fromRGBO(255, 255, 255, 0.1), ), + ); +} + +class _AiModeBadge extends StatelessWidget { + const _AiModeBadge({required this.color}); + + final Color color; + + @override + Widget build(BuildContext context) => SizedBox( + height: 40, child: Row( mainAxisSize: MainAxisSize.min, children: [ - Icon(icon, size: 14, color: foregroundColor), - const SizedBox(width: 6), - Text( - label, - style: TextStyle( - color: foregroundColor, - fontSize: 12, - fontWeight: FontWeight.w700, - ), + Icon( + Icons.auto_awesome_rounded, + size: 14, + color: color, ), ], ), @@ -962,6 +930,7 @@ class _AiModeMenuChip extends StatelessWidget { required this.label, required this.items, required this.onSelected, + this.maxWidth = 200, this.enabled = true, }); @@ -969,48 +938,40 @@ class _AiModeMenuChip extends StatelessWidget { final String label; final List> items; final ValueChanged onSelected; + final double maxWidth; final bool enabled; @override Widget build(BuildContext context) { - final child = Container( - height: 28, - padding: const EdgeInsets.symmetric(horizontal: 10), - decoration: BoxDecoration( - color: context.dynamicColor( - const Color.fromRGBO(255, 255, 255, 0.74), - darkColor: const Color.fromRGBO(255, 255, 255, 0.06), - ), - border: Border.all(color: context.theme.ai.surfaceBorder), - borderRadius: const BorderRadius.all(Radius.circular(999)), - ), - child: Row( - mainAxisSize: MainAxisSize.min, - children: [ - Icon(icon, size: 13, color: context.theme.secondaryText), - const SizedBox(width: 6), - ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 160), - child: Text( - label, - maxLines: 1, - overflow: TextOverflow.ellipsis, - style: TextStyle( - color: context.theme.text, - fontSize: 12, - fontWeight: FontWeight.w600, + final child = IntrinsicWidth( + child: ConstrainedBox( + constraints: BoxConstraints(maxWidth: maxWidth), + child: Row( + children: [ + Icon(icon, size: 13, color: context.theme.secondaryText), + const SizedBox(width: 6), + Expanded( + child: Text( + label, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), ), ), - ), - if (enabled) ...[ - const SizedBox(width: 4), - Icon( - Icons.keyboard_arrow_down_rounded, - size: 16, - color: context.theme.secondaryText, - ), + if (enabled) ...[ + const SizedBox(width: 2), + Icon( + Icons.keyboard_arrow_down_rounded, + size: 14, + color: context.theme.secondaryText, + ), + ], ], - ], + ), ), ); @@ -1020,6 +981,7 @@ class _AiModeMenuChip extends StatelessWidget { itemBuilder: (_) => items, onSelected: onSelected, color: Colors.transparent, + useActionButton: false, child: child, ); } diff --git a/lib/widgets/menu.dart b/lib/widgets/menu.dart index a26e4a6b69..08fefaf06c 100644 --- a/lib/widgets/menu.dart +++ b/lib/widgets/menu.dart @@ -65,6 +65,7 @@ class CustomPopupMenuButton extends HookConsumerWidget { this.icon, this.color, this.alignment, + this.useActionButton = true, }); final CustomPopupMenuItemBuilder itemBuilder; @@ -73,43 +74,59 @@ class CustomPopupMenuButton extends HookConsumerWidget { final Widget? child; final Color? color; final Alignment? alignment; + final bool useActionButton; @override - Widget build(BuildContext context, WidgetRef ref) => ContextMenuPortalEntry( - interactive: false, - buildMenus: () => itemBuilder(context) - .map( - (e) => ContextMenu( - title: e.title, - onTap: () => onSelected?.call(e.value), - isDestructiveAction: e.isDestructiveAction, - icon: e.icon, - ), - ) - .toList(), - child: Builder( - builder: (context) => ActionButton( - name: icon, - color: color ?? context.theme.icon, - onTapUp: (details) { - d('onTapUp: $alignment'); - if (alignment == null) { - context.sendMenuPosition(details.globalPosition); - return; - } - final renderBox = context.findRenderObject() as RenderBox?; - if (renderBox != null) { - var position = alignment!.withinRect(renderBox.paintBounds); - position = renderBox.localToGlobal(position); - context.sendMenuPosition(position); - } else { - context.sendMenuPosition(details.globalPosition); + Widget build(BuildContext context, WidgetRef ref) { + void showMenu(TapUpDetails details, BuildContext buildContext) { + d('onTapUp: $alignment'); + final targetAlignment = alignment; + if (targetAlignment == null) { + buildContext.sendMenuPosition(details.globalPosition); + return; + } + final renderBox = buildContext.findRenderObject() as RenderBox?; + if (renderBox != null) { + var position = targetAlignment.withinRect(renderBox.paintBounds); + position = renderBox.localToGlobal(position); + buildContext.sendMenuPosition(position); + } else { + buildContext.sendMenuPosition(details.globalPosition); + } + } + + return ContextMenuPortalEntry( + interactive: false, + buildMenus: () => itemBuilder(context) + .map( + (e) => ContextMenu( + title: e.title, + onTap: () => onSelected?.call(e.value), + isDestructiveAction: e.isDestructiveAction, + icon: e.icon, + ), + ) + .toList(), + child: Builder( + builder: (context) { + final triggerChild = child; + if (!useActionButton && triggerChild != null) { + return GestureDetector( + behavior: HitTestBehavior.opaque, + onTapUp: (details) => showMenu(details, context), + child: triggerChild, + ); } + return ActionButton( + name: icon, + color: color ?? context.theme.icon, + onTapUp: (details) => showMenu(details, context), + child: child, + ); }, - child: child, ), - ), - ); + ); + } } class CustomPopupMenuItem { From 1488982d3d22f6ab024f50c3d057154e599aab5f Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 23 Apr 2026 19:15:29 +0800 Subject: [PATCH 10/52] feat(ai): add cancelation support for AI requests and improve input handling --- lib/ai/ai_chat_controller.dart | 39 ++++++++++++++++++++++++--- lib/ui/home/chat/input_container.dart | 24 ++++++++++++++--- 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 75669c823f..034dd67883 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -24,6 +24,7 @@ const _kAiHistoryLimit = 12; const _kAiStreamFlushChars = 32; const _kAiStreamFlushInterval = Duration(milliseconds: 80); final kAiRuntimeStartedAt = DateTime.now(); +final _activeAiRequests = {}; bool isActivePendingAiMessage(AiChatMessage message) => message.role == _kAiRoleAssistant && @@ -42,6 +43,7 @@ class AiChatController { required String conversationId, required String input, AiProviderConfig? provider, + void Function()? onInputAccepted, }) async { await database.aiChatMessageDao.resolveStalePendingAssistantMessages( updatedBefore: kAiRuntimeStartedAt, @@ -64,6 +66,7 @@ class AiChatController { final now = DateTime.now(); final userMessageId = _uuid.v4(); final assistantMessageId = _uuid.v4(); + final cancelToken = CancelToken(); final anchorMessage = await database.messageDao .messagesByConversationId(conversationId, 1) .getSingleOrNull(); @@ -100,15 +103,19 @@ class AiChatController { ), ); + onInputAccepted?.call(); + + final updater = _StreamingMessageUpdater( + dao: database.aiChatMessageDao, + messageId: assistantMessageId, + ); + _activeAiRequests[conversationId] = cancelToken; try { final messages = await _buildPromptMessages(conversationId, input); - final updater = _StreamingMessageUpdater( - dao: database.aiChatMessageDao, - messageId: assistantMessageId, - ); final result = await _streamRequest( config, messages, + cancelToken: cancelToken, onContent: updater.append, ); await updater.flush(contentOverride: result, force: true); @@ -118,6 +125,15 @@ class AiChatController { updatedAt: DateTime.now(), ); } catch (error, stacktrace) { + if (cancelToken.isCancelled) { + await updater.flush(force: true); + await database.aiChatMessageDao.updateMessageStatus( + assistantMessageId, + _kAiStatusDone, + updatedAt: DateTime.now(), + ); + return; + } e('AI chat error: $error, $stacktrace'); await database.aiChatMessageDao.updateMessageStatus( assistantMessageId, @@ -126,9 +142,17 @@ class AiChatController { errorText: error.toString(), ); rethrow; + } finally { + if (_activeAiRequests[conversationId] == cancelToken) { + _activeAiRequests.remove(conversationId); + } } } + void stop(String conversationId) { + _activeAiRequests[conversationId]?.cancel('AI generation stopped'); + } + Future> _buildPromptMessages( String conversationId, String input, @@ -193,6 +217,7 @@ class AiChatController { Future _streamRequest( AiProviderConfig config, List messages, { + required CancelToken cancelToken, required Future Function(String chunk) onContent, }) async { final dio = Dio( @@ -209,6 +234,7 @@ class AiChatController { dio: dio, config: config, messages: messages, + cancelToken: cancelToken, onContent: onContent, ); } @@ -239,6 +265,7 @@ abstract interface class _AiProviderStrategy { required Dio dio, required AiProviderConfig config, required List messages, + required CancelToken cancelToken, required Future Function(String chunk) onContent, }); } @@ -257,6 +284,7 @@ class _OpenAiCompatibleStrategy implements _AiProviderStrategy { required Dio dio, required AiProviderConfig config, required List messages, + required CancelToken cancelToken, required Future Function(String chunk) onContent, }) async { final response = await dio.post( @@ -274,6 +302,7 @@ class _OpenAiCompatibleStrategy implements _AiProviderStrategy { .toList(), }, options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, ); final body = response.data; @@ -337,6 +366,7 @@ class _AnthropicStrategy implements _AiProviderStrategy { required Dio dio, required AiProviderConfig config, required List messages, + required CancelToken cancelToken, required Future Function(String chunk) onContent, }) async { final response = await dio.post( @@ -360,6 +390,7 @@ class _AnthropicStrategy implements _AiProviderStrategy { .join('\n\n'), }, options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, ); final body = response.data; diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 259deb2dd7..9d41315191 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -339,8 +339,20 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { ).data ?? false; + if (aiRequestInFlight) { + return ActionButton( + name: Resources.assetsImagesRecordStopSvg, + color: context.theme.accent, + onTap: () { + final currentConversationId = conversationId; + if (currentConversationId == null) return; + AiChatController(context.database).stop(currentConversationId); + }, + ); + } + if (aiModeEnabled) { - final canSend = hasInputText && !aiRequestInFlight; + final canSend = hasInputText; return AnimatedOpacity( duration: const Duration(milliseconds: 180), @@ -527,8 +539,8 @@ Future _sendMessage( conversationId: conversationId, input: inlineAiInput, provider: provider, + onInputAccepted: () => textEditingController.text = '', ); - textEditingController.text = ''; } catch (error, _) { showToastFailed(error); } @@ -552,8 +564,12 @@ Future _sendMessage( try { await AiChatController( context.database, - ).send(conversationId: conversationId, input: text, provider: provider); - textEditingController.text = ''; + ).send( + conversationId: conversationId, + input: text, + provider: provider, + onInputAccepted: () => textEditingController.text = '', + ); } catch (error, _) { showToastFailed(error); } From 69b959c164ba242c8eba7557fd65f52e146dc769 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Fri, 24 Apr 2026 11:03:41 +0800 Subject: [PATCH 11/52] feat: add AI-assisted text generation and inline message actions --- lib/ai/ai_chat_controller.dart | 95 +++++- lib/ui/home/chat/input_container.dart | 205 ++++++++++-- lib/widgets/ai/ai_message_card.dart | 14 +- lib/widgets/ai/ai_text_result_dialog.dart | 124 +++++++ lib/widgets/message/message.dart | 389 ++++++++++++++++++++++ pubspec.lock | 7 +- pubspec.yaml | 3 +- 7 files changed, 775 insertions(+), 62 deletions(-) create mode 100644 lib/widgets/ai/ai_text_result_dialog.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 034dd67883..ce6b2510f4 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -39,6 +39,31 @@ class AiChatController { static const _openAiStrategy = _OpenAiCompatibleStrategy(); static const _anthropicStrategy = _AnthropicStrategy(); + Future assistText({ + required String instruction, + String? input, + String? conversationId, + AiProviderConfig? provider, + }) async { + final config = provider ?? database.settingProperties.selectedAiProvider; + if (config == null) { + throw Exception('No AI provider configured'); + } + + final messages = await _buildAssistPromptMessages( + instruction: instruction, + input: input, + conversationId: conversationId, + ); + + return _streamRequest( + config, + messages, + cancelToken: CancelToken(), + onContent: (_) async {}, + ); + } + Future send({ required String conversationId, required String input, @@ -131,6 +156,7 @@ class AiChatController { assistantMessageId, _kAiStatusDone, updatedAt: DateTime.now(), + errorText: 'Stopped', ); return; } @@ -204,6 +230,55 @@ class AiChatController { return promptMessages; } + Future> _buildAssistPromptMessages({ + required String instruction, + required String? input, + required String? conversationId, + }) async { + final promptMessages = [ + AiPromptMessage( + role: 'system', + content: + 'You are an invisible writing assistant inside a chat app. ' + 'Return only the requested text. Do not add explanations, labels, ' + 'markdown fences, or greetings unless explicitly requested.', + ), + ]; + + if (conversationId != null) { + final recentMessages = await database.messageDao + .messagesByConversationId(conversationId, _kAiContextMessageLimit) + .get(); + if (recentMessages.isNotEmpty) { + final lines = recentMessages.reversed + .map((message) { + final sender = message.userFullName ?? message.userId; + final content = _messagePlainText(message); + return '[${message.createdAt.toIso8601String()}] $sender: $content'; + }) + .join('\n'); + promptMessages.add( + AiPromptMessage( + role: 'system', + content: 'Current conversation recent messages:\n$lines', + ), + ); + } + } + + final inputText = input?.trim(); + promptMessages.add( + AiPromptMessage( + role: _kAiRoleUser, + content: [ + instruction.trim(), + if (inputText != null && inputText.isNotEmpty) '\nText:\n$inputText', + ].join('\n'), + ), + ); + return promptMessages; + } + String _messagePlainText(MessageItem message) { if (message.content?.trim().isNotEmpty == true) { return message.content!.trim(); @@ -294,10 +369,7 @@ class _OpenAiCompatibleStrategy implements _AiProviderStrategy { 'stream': true, 'messages': messages .map( - (message) => { - 'role': message.role, - 'content': message.content, - }, + (message) => {'role': message.role, 'content': message.content}, ) .toList(), }, @@ -378,10 +450,7 @@ class _AnthropicStrategy implements _AiProviderStrategy { 'messages': messages .where((message) => message.role != 'system') .map( - (message) => { - 'role': message.role, - 'content': message.content, - }, + (message) => {'role': message.role, 'content': message.content}, ) .toList(), 'system': messages @@ -443,10 +512,7 @@ class _AnthropicStrategy implements _AiProviderStrategy { } class _StreamingMessageUpdater { - _StreamingMessageUpdater({ - required this.dao, - required this.messageId, - }); + _StreamingMessageUpdater({required this.dao, required this.messageId}); final AiChatMessageDao dao; final String messageId; @@ -469,10 +535,7 @@ class _StreamingMessageUpdater { await flush(); } - Future flush({ - String? contentOverride, - bool force = false, - }) async { + Future flush({String? contentOverride, bool force = false}) async { final content = contentOverride ?? _buffer.toString(); if (!force && content == _persistedContent) { return; diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 9d41315191..bd435c8b39 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -13,6 +13,7 @@ import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:flutter_svg/svg.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart' hide ChangeNotifierProvider; import 'package:image_picker/image_picker.dart'; +import 'package:mixin_logger/mixin_logger.dart'; import 'package:provider/provider.dart' hide Consumer; import 'package:rxdart/rxdart.dart'; import 'package:simple_animations/simple_animations.dart'; @@ -35,6 +36,7 @@ import '../../../utils/reg_exp_utils.dart'; import '../../../utils/system/clipboard.dart'; import '../../../widgets/action_button.dart'; import '../../../widgets/actions/actions.dart'; +import '../../../widgets/ai/ai_text_result_dialog.dart'; import '../../../widgets/high_light_text.dart'; import '../../../widgets/hover_overlay.dart'; import '../../../widgets/mention_panel.dart'; @@ -241,12 +243,7 @@ class _InputContainer extends HookConsumerWidget { constraints: BoxConstraints(minHeight: aiModeEnabled ? 92 : 56), child: Container( decoration: BoxDecoration(color: context.theme.primary), - padding: EdgeInsets.fromLTRB( - 16, - aiModeEnabled ? 8 : 8, - 16, - 8, - ), + padding: EdgeInsets.fromLTRB(16, aiModeEnabled ? 8 : 8, 16, 8), child: AnimatedContainer( duration: const Duration(milliseconds: 220), curve: Curves.easeOutCubic, @@ -290,7 +287,14 @@ class _InputContainer extends HookConsumerWidget { aiRequestInFlight: aiRequestInFlight, ), ), - SizedBox(width: aiModeEnabled ? 10 : 16), + if (!aiModeEnabled) ...[ + const SizedBox(width: 8), + _AiDraftAssistButton( + conversationId: conversationId, + textEditingController: textEditingController, + ), + ], + SizedBox(width: aiModeEnabled ? 10 : 8), _AnimatedSendOrVoiceButton( conversationId: conversationId, textEditingController: textEditingController, @@ -461,6 +465,160 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { } } +enum _AiDraftAction { polish, shorten, polite, translate, replyWithContext } + +class _AiDraftAssistButton extends StatelessWidget { + const _AiDraftAssistButton({ + required this.conversationId, + required this.textEditingController, + }); + + final String? conversationId; + final TextEditingController textEditingController; + + @override + Widget build(BuildContext context) { + final enabled = + context.database.settingProperties.selectedAiProvider != null; + return AnimatedOpacity( + duration: const Duration(milliseconds: 180), + opacity: enabled ? 1 : 0.45, + child: IgnorePointer( + ignoring: !enabled, + child: CustomPopupMenuButton<_AiDraftAction>( + itemBuilder: (_) => [ + CustomPopupMenuItem( + title: 'Polish', + value: _AiDraftAction.polish, + icon: Resources.assetsImagesBotSvg, + ), + CustomPopupMenuItem( + title: 'Make shorter', + value: _AiDraftAction.shorten, + icon: Resources.assetsImagesBotSvg, + ), + CustomPopupMenuItem( + title: 'Make polite', + value: _AiDraftAction.polite, + icon: Resources.assetsImagesBotSvg, + ), + CustomPopupMenuItem( + title: 'Translate draft', + value: _AiDraftAction.translate, + icon: Resources.assetsImagesBotSvg, + ), + CustomPopupMenuItem( + title: 'Reply with context', + value: _AiDraftAction.replyWithContext, + icon: Resources.assetsImagesBotSvg, + ), + ], + onSelected: (action) => unawaited( + _runAiDraftAction( + context, + action: action, + conversationId: conversationId, + textEditingController: textEditingController, + ), + ), + child: Icon( + Icons.auto_awesome_rounded, + size: 20, + color: context.theme.icon, + ), + ), + ), + ); + } +} + +Future _runAiDraftAction( + BuildContext context, { + required _AiDraftAction action, + required String? conversationId, + required TextEditingController textEditingController, +}) async { + final original = textEditingController.text.trim(); + if (conversationId == null) return; + if (action != _AiDraftAction.replyWithContext && original.isEmpty) { + showToastFailed(ToastError('Please type a message first')); + return; + } + + final language = _currentLanguageTag(context); + final instruction = switch (action) { + _AiDraftAction.polish => + 'Polish this draft for a chat message. Keep the original meaning, language, and approximate length.', + _AiDraftAction.shorten => + 'Rewrite this chat draft to be shorter and clearer. Keep the original language and intent.', + _AiDraftAction.polite => + 'Rewrite this chat draft to sound polite, natural, and still concise. Keep the original language.', + _AiDraftAction.translate => + 'Translate this chat draft into $language. Return only the translation.', + _AiDraftAction.replyWithContext => + 'Draft a concise, natural reply to the latest conversation message using the recent context. Return only the reply text.', + }; + final title = switch (action) { + _AiDraftAction.polish => 'Polish', + _AiDraftAction.shorten => 'Make shorter', + _AiDraftAction.polite => 'Make polite', + _AiDraftAction.translate => 'Translate draft', + _AiDraftAction.replyWithContext => 'Reply with context', + }; + + showToastLoading(context: context); + try { + final result = await AiChatController(context.database).assistText( + instruction: instruction, + input: action == _AiDraftAction.replyWithContext ? null : original, + conversationId: conversationId, + ); + Toast.dismiss(); + if (!context.mounted) return; + final selectedAction = await showAiTextResultDialog( + context: context, + title: title, + original: original, + result: result, + allowReplace: original.isNotEmpty, + ); + if (selectedAction == null) return; + switch (selectedAction) { + case AiTextResultAction.replace: + _replaceDraft(textEditingController, result); + case AiTextResultAction.insert: + _insertDraft(textEditingController, result); + } + } catch (error, stackTrace) { + e('AI draft assist failed: $error, $stackTrace'); + showToastFailed(error, context: context); + } +} + +String _currentLanguageTag(BuildContext context) { + final locale = Localizations.localeOf(context); + final countryCode = locale.countryCode; + if (countryCode == null || countryCode.isEmpty) return locale.languageCode; + return '${locale.languageCode}-$countryCode'; +} + +void _replaceDraft(TextEditingController controller, String text) { + controller.value = TextEditingValue( + text: text, + selection: TextSelection.collapsed(offset: text.length), + ); +} + +void _insertDraft(TextEditingController controller, String text) { + final current = controller.text; + final separator = current.trim().isEmpty ? '' : '\n'; + final next = '$current$separator$text'; + controller.value = TextEditingValue( + text: next, + selection: TextSelection.collapsed(offset: next.length), + ); +} + void showMaxLengthReachedToast(BuildContext context) => showToastFailed(ToastError(context.l10n.contentTooLong)); @@ -562,9 +720,7 @@ Future _sendMessage( return; } try { - await AiChatController( - context.database, - ).send( + await AiChatController(context.database).send( conversationId: conversationId, input: text, provider: provider, @@ -816,10 +972,7 @@ class _SendTextField extends HookConsumerWidget { } class _AiModeBar extends HookConsumerWidget { - const _AiModeBar({ - required this.conversationId, - required this.provider, - }); + const _AiModeBar({required this.conversationId, required this.provider}); final String conversationId; final AiProviderConfig? provider; @@ -929,13 +1082,7 @@ class _AiModeBadge extends StatelessWidget { height: 40, child: Row( mainAxisSize: MainAxisSize.min, - children: [ - Icon( - Icons.auto_awesome_rounded, - size: 14, - color: color, - ), - ], + children: [Icon(Icons.auto_awesome_rounded, size: 14, color: color)], ), ); } @@ -1106,9 +1253,7 @@ class _SendActionTypeButton extends HookConsumerWidget { .getSingleOrNull(); if (user == null) throw Exception('User not found'); - final quoteMessage = ref.read( - quoteMessageProvider.notifier, - ); + final quoteMessage = ref.read(quoteMessageProvider.notifier); await context.accountServer.sendContactMessage( userId, @@ -1152,9 +1297,7 @@ class _SendActionTypeButton extends HookConsumerWidget { source: ImageSource.gallery, ); if (image == null) return; - await showFilesPreviewDialog(context, [ - image.withMineType(), - ]); + await showFilesPreviewDialog(context, [image.withMineType()]); }, ), if (!isDesktop) @@ -1166,9 +1309,7 @@ class _SendActionTypeButton extends HookConsumerWidget { source: ImageSource.gallery, ); if (video == null) return; - await showFilesPreviewDialog(context, [ - video.withMineType(), - ]); + await showFilesPreviewDialog(context, [video.withMineType()]); }, ), ], @@ -1361,9 +1502,7 @@ class MentionTextMatcher extends TextMatcher implements EquatableMixin { return TextSpan( text: displayString, style: valid - ? (span.style ?? const TextStyle()).merge( - highlightTextStyle, - ) + ? (span.style ?? const TextStyle()).merge(highlightTextStyle) : span.style, ); }, diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 66b6c2d93d..f55061ec90 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -133,12 +133,7 @@ class _AiMessageBody extends StatelessWidget { @override Widget build(BuildContext context) { final isUser = message.role == 'user'; - final content = message.content.trim(); - final text = content.isNotEmpty - ? content - : message.status == 'error' - ? (message.errorText ?? 'Request failed') - : 'Thinking...'; + final text = _displayText(message); final statusColor = _statusColor( context, isUser: isUser, @@ -597,11 +592,14 @@ Color _statusColor( return context.theme.ai.accent; } -String _menuCopyText(AiChatMessage message) { +String _menuCopyText(AiChatMessage message) => _displayText(message); + +String _displayText(AiChatMessage message) { final content = message.content.trim(); if (content.isNotEmpty) return content; if (message.status == 'error') { return message.errorText ?? 'Request failed'; } - return 'Thinking...'; + if (message.status == 'pending') return 'Thinking...'; + return message.errorText ?? 'No response'; } diff --git a/lib/widgets/ai/ai_text_result_dialog.dart b/lib/widgets/ai/ai_text_result_dialog.dart new file mode 100644 index 0000000000..d4dab9589b --- /dev/null +++ b/lib/widgets/ai/ai_text_result_dialog.dart @@ -0,0 +1,124 @@ +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; + +import '../../utils/extension/extension.dart'; +import '../dialog.dart'; +import '../toast.dart'; + +enum AiTextResultAction { replace, insert } + +Future showAiTextResultDialog({ + required BuildContext context, + required String title, + required String result, + String? original, + bool allowReplace = true, +}) => showMixinDialog( + context: context, + constraints: const BoxConstraints(maxWidth: 560), + child: AlertDialogLayout( + minWidth: 420, + minHeight: 0, + titleMarginBottom: 20, + title: Text(title), + content: _AiTextResultContent(original: original, result: result), + actions: [ + MixinButton( + backgroundTransparent: true, + child: const Text('Copy'), + onTap: () { + Clipboard.setData(ClipboardData(text: result)); + showToastSuccessful(context: context); + Navigator.pop(context); + }, + ), + const MixinButton( + backgroundTransparent: true, + value: AiTextResultAction.insert, + child: Text('Insert'), + ), + if (allowReplace) + const MixinButton( + value: AiTextResultAction.replace, + child: Text('Replace'), + ), + ], + ), +); + +class _AiTextResultContent extends StatelessWidget { + const _AiTextResultContent({required this.result, this.original}); + + final String? original; + final String result; + + @override + Widget build(BuildContext context) { + final original = this.original?.trim(); + return ConstrainedBox( + constraints: const BoxConstraints(maxHeight: 420), + child: SingleChildScrollView( + child: DefaultTextStyle.merge( + style: TextStyle( + color: context.theme.text, + fontSize: 14, + fontWeight: FontWeight.normal, + height: 1.45, + ), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisSize: MainAxisSize.min, + children: [ + if (original != null && original.isNotEmpty) ...[ + const _SectionLabel('Original'), + _TextBlock(original), + const SizedBox(height: 16), + ], + const _SectionLabel('AI'), + _TextBlock(result), + ], + ), + ), + ), + ); + } +} + +class _SectionLabel extends StatelessWidget { + const _SectionLabel(this.text); + + final String text; + + @override + Widget build(BuildContext context) => Padding( + padding: const EdgeInsets.only(bottom: 6), + child: Text( + text, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + ); +} + +class _TextBlock extends StatelessWidget { + const _TextBlock(this.text); + + final String text; + + @override + Widget build(BuildContext context) => Container( + width: double.infinity, + padding: const EdgeInsets.all(12), + decoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ), + borderRadius: const BorderRadius.all(Radius.circular(6)), + ), + child: SelectableText(text), + ); +} diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index b2a89a44e3..5077b1975d 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -20,6 +20,7 @@ import 'package:super_context_menu/super_context_menu.dart'; import 'package:visibility_detector/visibility_detector.dart'; import '../../account/account_server.dart'; +import '../../ai/ai_chat_controller.dart'; import '../../blaze/vo/pin_message_minimal.dart'; import '../../bloc/simple_cubit.dart'; import '../../constants/icon_fonts.dart'; @@ -190,6 +191,124 @@ SelectedContent? _findSelectedContent(BuildContext context) { return null; } +enum _MessageAiAction { translate, explain, suggestReplies } + +class _InlineMessageAiState with EquatableMixin { + const _InlineMessageAiState({this.entries = const {}}); + + final Map<_MessageAiAction, _InlineMessageAiEntry> entries; + + _InlineMessageAiState put( + _MessageAiAction action, + _InlineMessageAiEntry entry, + ) => _InlineMessageAiState( + entries: Map<_MessageAiAction, _InlineMessageAiEntry>.from(entries) + ..[action] = entry, + ); + + _InlineMessageAiEntry? operator [](_MessageAiAction action) => + entries[action]; + + bool get hasVisibleEntry => + entries.values.any((entry) => entry.loading || entry.hasContent); + + @override + List get props => [entries]; +} + +class _InlineMessageAiEntry with EquatableMixin { + const _InlineMessageAiEntry({ + this.loading = false, + this.result, + this.error, + }); + + final bool loading; + final String? result; + final String? error; + + bool get hasContent => + (result != null && result!.trim().isNotEmpty) || + (error != null && error!.trim().isNotEmpty); + + @override + List get props => [loading, result, error]; +} + +String? _messageAiText(MessageItem message) { + final content = message.content?.trim(); + if ((message.type.isText || message.type.isPost) && + content != null && + content.isNotEmpty) { + return content; + } + + final caption = message.caption?.trim(); + if (caption != null && caption.isNotEmpty) { + return caption; + } + return null; +} + +Future _runMessageAiAction( + BuildContext context, { + required MessageItem message, + required String input, + required _MessageAiAction action, + required void Function(_MessageAiAction, _InlineMessageAiEntry) + onStateChanged, +}) async { + final language = _currentLanguageTag(context); + final instruction = switch (action) { + _MessageAiAction.translate => + 'Translate this chat message into $language. Return only the translation.', + _MessageAiAction.explain => + 'Explain this chat message clearly and concisely. Clarify slang, abbreviations, technical terms, and implied meaning when useful.', + _MessageAiAction.suggestReplies => + 'Suggest three concise, natural replies to this chat message using the recent conversation context. Return one reply per line, without numbering.', + }; + final title = switch (action) { + _MessageAiAction.translate => 'Translate', + _MessageAiAction.explain => 'Explain', + _MessageAiAction.suggestReplies => 'Suggest replies', + }; + + onStateChanged(action, const _InlineMessageAiEntry(loading: true)); + try { + final result = await AiChatController(context.database).assistText( + instruction: instruction, + input: input, + conversationId: message.conversationId, + ); + if (!context.mounted) return; + onStateChanged( + action, + _InlineMessageAiEntry(result: result.trim()), + ); + } catch (error, stackTrace) { + e('AI message assist failed: $error, $stackTrace'); + if (!context.mounted) return; + onStateChanged( + action, + _InlineMessageAiEntry(error: '$title failed: $error'), + ); + } +} + +String _currentLanguageTag(BuildContext context) { + final locale = Localizations.localeOf(context); + final countryCode = locale.countryCode; + if (countryCode == null || countryCode.isEmpty) return locale.languageCode; + return '${locale.languageCode}-$countryCode'; +} + +List _parseAiReplySuggestions(String result) => result + .split('\n') + .map((line) => line.trim().replaceFirst(RegExp(r'^[-*\d.)\s]+'), '')) + .where((line) => line.isNotEmpty) + .take(3) + .toList(growable: false); + class MessageItemWidget extends HookConsumerWidget { const MessageItemWidget({ required this.message, @@ -277,6 +396,7 @@ class MessageItemWidget extends HookConsumerWidget { keys: [message.messageId], ).data ?? Colors.transparent; + final inlineAiState = useState(const _InlineMessageAiState()); Widget child = Column( mainAxisSize: MainAxisSize.min, @@ -314,6 +434,9 @@ class MessageItemWidget extends HookConsumerWidget { pinArrowWidth: isPinnedPage ? _pinArrowWidth : 0, isBot: message.isBot, isVerified: message.isVerified, + aiSection: _InlineMessageAiSection( + state: inlineAiState.value, + ), buildMenus: (request) { request.onShowMenu.addListener(() { showedMenuCubit.emit(true); @@ -627,6 +750,67 @@ class MessageItemWidget extends HookConsumerWidget { ), ]; + final aiText = _messageAiText(message); + final aiActions = [ + if (aiText != null) + MenuAction( + image: MenuImage.icon(Icons.translate), + title: 'Translate', + callback: () => unawaited( + _runMessageAiAction( + context, + message: message, + input: aiText, + action: _MessageAiAction.translate, + onStateChanged: (action, entry) { + inlineAiState.value = inlineAiState.value.put( + action, + entry, + ); + }, + ), + ), + ), + if (aiText != null) + MenuAction( + image: MenuImage.icon(Icons.psychology_alt), + title: 'Explain', + callback: () => unawaited( + _runMessageAiAction( + context, + message: message, + input: aiText, + action: _MessageAiAction.explain, + onStateChanged: (action, entry) { + inlineAiState.value = inlineAiState.value.put( + action, + entry, + ); + }, + ), + ), + ), + if (aiText != null && !isTranscriptPage) + MenuAction( + image: MenuImage.icon(Icons.auto_awesome), + title: 'Suggest replies', + callback: () => unawaited( + _runMessageAiAction( + context, + message: message, + input: aiText, + action: _MessageAiAction.suggestReplies, + onStateChanged: (action, entry) { + inlineAiState.value = inlineAiState.value.put( + action, + entry, + ); + }, + ), + ), + ), + ]; + final devActions = [ if (!kReleaseMode) MenuAction( @@ -642,6 +826,7 @@ class MessageItemWidget extends HookConsumerWidget { childrens: [ replayAction, copyActions, + aiActions, messageActions, saveActions, addStickerMenuAction, @@ -920,6 +1105,7 @@ class _MessageBubbleMargin extends HookConsumerWidget { required this.showAvatar, required this.isBot, required this.isVerified, + required this.aiSection, }); final bool isCurrentUser; @@ -932,6 +1118,7 @@ class _MessageBubbleMargin extends HookConsumerWidget { final bool showAvatar; final bool isBot; final bool isVerified; + final Widget aiSection; @override Widget build(BuildContext context, WidgetRef ref) { @@ -969,6 +1156,7 @@ class _MessageBubbleMargin extends HookConsumerWidget { child: Builder(builder: builder), ), ), + aiSection, ], ); @@ -1011,6 +1199,207 @@ class _MessageBubbleMargin extends HookConsumerWidget { } } +class _InlineMessageAiSection extends StatelessWidget { + const _InlineMessageAiSection({required this.state}); + + final _InlineMessageAiState state; + + @override + Widget build(BuildContext context) { + if (!state.hasVisibleEntry) { + return const SizedBox.shrink(); + } + + final children = [ + for (final action in _MessageAiAction.values) + if (state[action]?.loading == true || state[action]?.hasContent == true) + Padding( + padding: const EdgeInsets.only(top: 8), + child: _InlineMessageAiCard( + action: action, + entry: state[action]!, + ), + ), + ]; + + if (children.isEmpty) { + return const SizedBox.shrink(); + } + + return Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: children, + ); + } +} + +class _InlineMessageAiCard extends StatelessWidget { + const _InlineMessageAiCard({ + required this.action, + required this.entry, + }); + + final _MessageAiAction action; + final _InlineMessageAiEntry entry; + + @override + Widget build(BuildContext context) { + final title = switch (action) { + _MessageAiAction.translate => 'Translation', + _MessageAiAction.explain => 'Explanation', + _MessageAiAction.suggestReplies => 'Suggested replies', + }; + final loadingText = switch (action) { + _MessageAiAction.translate => 'Translating...', + _MessageAiAction.explain => 'Explaining...', + _MessageAiAction.suggestReplies => 'Generating replies...', + }; + + Widget content; + if (entry.loading) { + content = Row( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox( + width: 14, + height: 14, + child: CircularProgressIndicator( + strokeWidth: 1.8, + color: context.theme.secondaryText, + ), + ), + const SizedBox(width: 8), + Text( + loadingText, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + height: 1.4, + ), + ), + ], + ); + } else if (entry.error?.isNotEmpty == true) { + content = Text( + entry.error!, + style: TextStyle( + color: context.theme.red, + fontSize: 13, + height: 1.45, + ), + ); + } else if (action == _MessageAiAction.suggestReplies) { + content = _InlineReplySuggestions(result: entry.result ?? ''); + } else { + content = SelectableText( + entry.result ?? '', + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.45, + ), + ); + } + + return Container( + constraints: const BoxConstraints(maxWidth: 420), + padding: const EdgeInsets.all(10), + decoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ), + borderRadius: const BorderRadius.all(Radius.circular(8)), + ), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + title, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + const SizedBox(height: 6), + content, + ], + ), + ); + } +} + +class _InlineReplySuggestions extends StatelessWidget { + const _InlineReplySuggestions({required this.result}); + + final String result; + + @override + Widget build(BuildContext context) { + final replies = _parseAiReplySuggestions(result); + if (replies.isEmpty) { + return SelectableText( + result, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.45, + ), + ); + } + + return Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: [ + for (var i = 0; i < replies.length; i++) + Padding( + padding: EdgeInsets.only(bottom: i == replies.length - 1 ? 0 : 6), + child: _InlineReplyButton(reply: replies[i]), + ), + ], + ); + } +} + +class _InlineReplyButton extends StatelessWidget { + const _InlineReplyButton({required this.reply}); + + final String reply; + + @override + Widget build(BuildContext context) => InteractiveDecoratedBox( + onTap: () => context.providerContainer + .read(recallMessageNotifierProvider) + .onReedit(reply), + decoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(255, 255, 255, 0.92), + darkColor: const Color.fromRGBO(255, 255, 255, 0.04), + ), + borderRadius: const BorderRadius.all(Radius.circular(6)), + ), + hoveringDecoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.03), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ), + borderRadius: const BorderRadius.all(Radius.circular(6)), + ), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 8), + child: Text( + reply, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.35, + ), + ), + ), + ); +} + class _UnreadMessageBar extends StatelessWidget { const _UnreadMessageBar(); diff --git a/pubspec.lock b/pubspec.lock index dd5b1cd2ef..039d36056a 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -1270,9 +1270,10 @@ packages: mixin_markdown_widget: dependency: "direct main" description: - path: "/Users/yangbin/workspace/mixin/flutter-plugins/packages/mixin_markdown_widget" - relative: false - source: path + name: mixin_markdown_widget + sha256: c7e6134e5e98a2c390e0cc7f56245336152cc6bfddda14f77a6d920021e82186 + url: "https://pub.dev" + source: hosted version: "0.1.0" msix: dependency: "direct dev" diff --git a/pubspec.yaml b/pubspec.yaml index 6456500b46..5d2d2a765e 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -96,8 +96,7 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - mixin_markdown_widget: - path: /Users/yangbin/workspace/mixin/flutter-plugins/packages/mixin_markdown_widget + mixin_markdown_widget: ^0.1.0 mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 From 1d7363a80313d813c7b2a4c88cc4299927b8a779 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Fri, 24 Apr 2026 11:33:52 +0800 Subject: [PATCH 12/52] refactor: modularize and optimize AI message assist functionality --- lib/widgets/message/message.dart | 376 ++------------------ lib/widgets/message/message_ai_assist.dart | 380 +++++++++++++++++++++ 2 files changed, 409 insertions(+), 347 deletions(-) create mode 100644 lib/widgets/message/message_ai_assist.dart diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index 5077b1975d..cf31c38e2c 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -20,7 +20,6 @@ import 'package:super_context_menu/super_context_menu.dart'; import 'package:visibility_detector/visibility_detector.dart'; import '../../account/account_server.dart'; -import '../../ai/ai_chat_controller.dart'; import '../../blaze/vo/pin_message_minimal.dart'; import '../../bloc/simple_cubit.dart'; import '../../constants/icon_fonts.dart'; @@ -76,6 +75,7 @@ import 'item/transfer/transfer_message.dart'; import 'item/unknown_message.dart'; import 'item/video/video_message.dart'; import 'item/waiting_message.dart'; +import 'message_ai_assist.dart'; import 'message_day_time.dart'; import 'message_name.dart'; import 'message_style.dart'; @@ -191,124 +191,6 @@ SelectedContent? _findSelectedContent(BuildContext context) { return null; } -enum _MessageAiAction { translate, explain, suggestReplies } - -class _InlineMessageAiState with EquatableMixin { - const _InlineMessageAiState({this.entries = const {}}); - - final Map<_MessageAiAction, _InlineMessageAiEntry> entries; - - _InlineMessageAiState put( - _MessageAiAction action, - _InlineMessageAiEntry entry, - ) => _InlineMessageAiState( - entries: Map<_MessageAiAction, _InlineMessageAiEntry>.from(entries) - ..[action] = entry, - ); - - _InlineMessageAiEntry? operator [](_MessageAiAction action) => - entries[action]; - - bool get hasVisibleEntry => - entries.values.any((entry) => entry.loading || entry.hasContent); - - @override - List get props => [entries]; -} - -class _InlineMessageAiEntry with EquatableMixin { - const _InlineMessageAiEntry({ - this.loading = false, - this.result, - this.error, - }); - - final bool loading; - final String? result; - final String? error; - - bool get hasContent => - (result != null && result!.trim().isNotEmpty) || - (error != null && error!.trim().isNotEmpty); - - @override - List get props => [loading, result, error]; -} - -String? _messageAiText(MessageItem message) { - final content = message.content?.trim(); - if ((message.type.isText || message.type.isPost) && - content != null && - content.isNotEmpty) { - return content; - } - - final caption = message.caption?.trim(); - if (caption != null && caption.isNotEmpty) { - return caption; - } - return null; -} - -Future _runMessageAiAction( - BuildContext context, { - required MessageItem message, - required String input, - required _MessageAiAction action, - required void Function(_MessageAiAction, _InlineMessageAiEntry) - onStateChanged, -}) async { - final language = _currentLanguageTag(context); - final instruction = switch (action) { - _MessageAiAction.translate => - 'Translate this chat message into $language. Return only the translation.', - _MessageAiAction.explain => - 'Explain this chat message clearly and concisely. Clarify slang, abbreviations, technical terms, and implied meaning when useful.', - _MessageAiAction.suggestReplies => - 'Suggest three concise, natural replies to this chat message using the recent conversation context. Return one reply per line, without numbering.', - }; - final title = switch (action) { - _MessageAiAction.translate => 'Translate', - _MessageAiAction.explain => 'Explain', - _MessageAiAction.suggestReplies => 'Suggest replies', - }; - - onStateChanged(action, const _InlineMessageAiEntry(loading: true)); - try { - final result = await AiChatController(context.database).assistText( - instruction: instruction, - input: input, - conversationId: message.conversationId, - ); - if (!context.mounted) return; - onStateChanged( - action, - _InlineMessageAiEntry(result: result.trim()), - ); - } catch (error, stackTrace) { - e('AI message assist failed: $error, $stackTrace'); - if (!context.mounted) return; - onStateChanged( - action, - _InlineMessageAiEntry(error: '$title failed: $error'), - ); - } -} - -String _currentLanguageTag(BuildContext context) { - final locale = Localizations.localeOf(context); - final countryCode = locale.countryCode; - if (countryCode == null || countryCode.isEmpty) return locale.languageCode; - return '${locale.languageCode}-$countryCode'; -} - -List _parseAiReplySuggestions(String result) => result - .split('\n') - .map((line) => line.trim().replaceFirst(RegExp(r'^[-*\d.)\s]+'), '')) - .where((line) => line.isNotEmpty) - .take(3) - .toList(growable: false); - class MessageItemWidget extends HookConsumerWidget { const MessageItemWidget({ required this.message, @@ -396,7 +278,14 @@ class MessageItemWidget extends HookConsumerWidget { keys: [message.messageId], ).data ?? Colors.transparent; - final inlineAiState = useState(const _InlineMessageAiState()); + final inlineAiState = useState( + readInlineMessageAiState(message.messageId), + ); + + useEffect(() { + inlineAiState.value = readInlineMessageAiState(message.messageId); + return null; + }, [message.messageId]); Widget child = Column( mainAxisSize: MainAxisSize.min, @@ -434,7 +323,7 @@ class MessageItemWidget extends HookConsumerWidget { pinArrowWidth: isPinnedPage ? _pinArrowWidth : 0, isBot: message.isBot, isVerified: message.isVerified, - aiSection: _InlineMessageAiSection( + aiSection: MessageInlineAiSection( state: inlineAiState.value, ), buildMenus: (request) { @@ -750,24 +639,28 @@ class MessageItemWidget extends HookConsumerWidget { ), ]; - final aiText = _messageAiText(message); + final aiText = messageAiText(message); + void updateInlineAiState( + MessageAiAction action, + InlineMessageAiEntry entry, + ) { + final nextState = inlineAiState.value.put(action, entry); + inlineAiState.value = nextState; + writeInlineMessageAiState(message.messageId, nextState); + } + final aiActions = [ if (aiText != null) MenuAction( image: MenuImage.icon(Icons.translate), title: 'Translate', callback: () => unawaited( - _runMessageAiAction( + runMessageAiAction( context, message: message, input: aiText, - action: _MessageAiAction.translate, - onStateChanged: (action, entry) { - inlineAiState.value = inlineAiState.value.put( - action, - entry, - ); - }, + action: MessageAiAction.translate, + onStateChanged: updateInlineAiState, ), ), ), @@ -776,17 +669,12 @@ class MessageItemWidget extends HookConsumerWidget { image: MenuImage.icon(Icons.psychology_alt), title: 'Explain', callback: () => unawaited( - _runMessageAiAction( + runMessageAiAction( context, message: message, input: aiText, - action: _MessageAiAction.explain, - onStateChanged: (action, entry) { - inlineAiState.value = inlineAiState.value.put( - action, - entry, - ); - }, + action: MessageAiAction.explain, + onStateChanged: updateInlineAiState, ), ), ), @@ -795,17 +683,12 @@ class MessageItemWidget extends HookConsumerWidget { image: MenuImage.icon(Icons.auto_awesome), title: 'Suggest replies', callback: () => unawaited( - _runMessageAiAction( + runMessageAiAction( context, message: message, input: aiText, - action: _MessageAiAction.suggestReplies, - onStateChanged: (action, entry) { - inlineAiState.value = inlineAiState.value.put( - action, - entry, - ); - }, + action: MessageAiAction.suggestReplies, + onStateChanged: updateInlineAiState, ), ), ), @@ -1199,207 +1082,6 @@ class _MessageBubbleMargin extends HookConsumerWidget { } } -class _InlineMessageAiSection extends StatelessWidget { - const _InlineMessageAiSection({required this.state}); - - final _InlineMessageAiState state; - - @override - Widget build(BuildContext context) { - if (!state.hasVisibleEntry) { - return const SizedBox.shrink(); - } - - final children = [ - for (final action in _MessageAiAction.values) - if (state[action]?.loading == true || state[action]?.hasContent == true) - Padding( - padding: const EdgeInsets.only(top: 8), - child: _InlineMessageAiCard( - action: action, - entry: state[action]!, - ), - ), - ]; - - if (children.isEmpty) { - return const SizedBox.shrink(); - } - - return Column( - crossAxisAlignment: CrossAxisAlignment.stretch, - children: children, - ); - } -} - -class _InlineMessageAiCard extends StatelessWidget { - const _InlineMessageAiCard({ - required this.action, - required this.entry, - }); - - final _MessageAiAction action; - final _InlineMessageAiEntry entry; - - @override - Widget build(BuildContext context) { - final title = switch (action) { - _MessageAiAction.translate => 'Translation', - _MessageAiAction.explain => 'Explanation', - _MessageAiAction.suggestReplies => 'Suggested replies', - }; - final loadingText = switch (action) { - _MessageAiAction.translate => 'Translating...', - _MessageAiAction.explain => 'Explaining...', - _MessageAiAction.suggestReplies => 'Generating replies...', - }; - - Widget content; - if (entry.loading) { - content = Row( - mainAxisSize: MainAxisSize.min, - children: [ - SizedBox( - width: 14, - height: 14, - child: CircularProgressIndicator( - strokeWidth: 1.8, - color: context.theme.secondaryText, - ), - ), - const SizedBox(width: 8), - Text( - loadingText, - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 13, - height: 1.4, - ), - ), - ], - ); - } else if (entry.error?.isNotEmpty == true) { - content = Text( - entry.error!, - style: TextStyle( - color: context.theme.red, - fontSize: 13, - height: 1.45, - ), - ); - } else if (action == _MessageAiAction.suggestReplies) { - content = _InlineReplySuggestions(result: entry.result ?? ''); - } else { - content = SelectableText( - entry.result ?? '', - style: TextStyle( - color: context.theme.text, - fontSize: 13, - height: 1.45, - ), - ); - } - - return Container( - constraints: const BoxConstraints(maxWidth: 420), - padding: const EdgeInsets.all(10), - decoration: BoxDecoration( - color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), - darkColor: const Color.fromRGBO(255, 255, 255, 0.06), - ), - borderRadius: const BorderRadius.all(Radius.circular(8)), - ), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Text( - title, - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 12, - fontWeight: FontWeight.w500, - ), - ), - const SizedBox(height: 6), - content, - ], - ), - ); - } -} - -class _InlineReplySuggestions extends StatelessWidget { - const _InlineReplySuggestions({required this.result}); - - final String result; - - @override - Widget build(BuildContext context) { - final replies = _parseAiReplySuggestions(result); - if (replies.isEmpty) { - return SelectableText( - result, - style: TextStyle( - color: context.theme.text, - fontSize: 13, - height: 1.45, - ), - ); - } - - return Column( - crossAxisAlignment: CrossAxisAlignment.stretch, - children: [ - for (var i = 0; i < replies.length; i++) - Padding( - padding: EdgeInsets.only(bottom: i == replies.length - 1 ? 0 : 6), - child: _InlineReplyButton(reply: replies[i]), - ), - ], - ); - } -} - -class _InlineReplyButton extends StatelessWidget { - const _InlineReplyButton({required this.reply}); - - final String reply; - - @override - Widget build(BuildContext context) => InteractiveDecoratedBox( - onTap: () => context.providerContainer - .read(recallMessageNotifierProvider) - .onReedit(reply), - decoration: BoxDecoration( - color: context.dynamicColor( - const Color.fromRGBO(255, 255, 255, 0.92), - darkColor: const Color.fromRGBO(255, 255, 255, 0.04), - ), - borderRadius: const BorderRadius.all(Radius.circular(6)), - ), - hoveringDecoration: BoxDecoration( - color: context.dynamicColor( - const Color.fromRGBO(0, 0, 0, 0.03), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), - ), - borderRadius: const BorderRadius.all(Radius.circular(6)), - ), - child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 8), - child: Text( - reply, - style: TextStyle( - color: context.theme.text, - fontSize: 13, - height: 1.35, - ), - ), - ), - ); -} - class _UnreadMessageBar extends StatelessWidget { const _UnreadMessageBar(); diff --git a/lib/widgets/message/message_ai_assist.dart b/lib/widgets/message/message_ai_assist.dart new file mode 100644 index 0000000000..52ff0bc448 --- /dev/null +++ b/lib/widgets/message/message_ai_assist.dart @@ -0,0 +1,380 @@ +import 'package:equatable/equatable.dart'; +import 'package:flutter/material.dart'; + +import '../../ai/ai_chat_controller.dart'; +import '../../db/mixin_database.dart'; +import '../../ui/provider/recall_message_reedit_provider.dart'; +import '../../utils/extension/extension.dart'; +import '../../utils/logger.dart'; +import '../markdown.dart'; + +enum MessageAiAction { translate, explain, suggestReplies } + +final _inlineMessageAiStateCache = {}; + +class InlineMessageAiState with EquatableMixin { + const InlineMessageAiState({this.entries = const {}}); + + final Map entries; + + InlineMessageAiState put( + MessageAiAction action, + InlineMessageAiEntry entry, + ) => InlineMessageAiState( + entries: Map.from(entries) + ..[action] = entry, + ); + + InlineMessageAiEntry? operator [](MessageAiAction action) => entries[action]; + + bool get hasVisibleEntry => + entries.values.any((entry) => entry.loading || entry.hasContent); + + @override + List get props => [entries]; +} + +InlineMessageAiState readInlineMessageAiState(String messageId) => + _inlineMessageAiStateCache[messageId] ?? const InlineMessageAiState(); + +void writeInlineMessageAiState( + String messageId, + InlineMessageAiState state, +) { + if (!state.hasVisibleEntry) { + _inlineMessageAiStateCache.remove(messageId); + return; + } + _inlineMessageAiStateCache[messageId] = state; +} + +class InlineMessageAiEntry with EquatableMixin { + const InlineMessageAiEntry({ + this.loading = false, + this.result, + this.error, + this.model, + }); + + final bool loading; + final String? result; + final String? error; + final String? model; + + bool get hasContent => + (result != null && result!.trim().isNotEmpty) || + (error != null && error!.trim().isNotEmpty); + + @override + List get props => [loading, result, error, model]; +} + +String? messageAiText(MessageItem message) { + final content = message.content?.trim(); + if ((message.type.isText || message.type.isPost) && + content != null && + content.isNotEmpty) { + return content; + } + + final caption = message.caption?.trim(); + if (caption != null && caption.isNotEmpty) { + return caption; + } + return null; +} + +Future runMessageAiAction( + BuildContext context, { + required MessageItem message, + required String input, + required MessageAiAction action, + required void Function(MessageAiAction, InlineMessageAiEntry) onStateChanged, +}) async { + final language = _currentLanguageTag(context); + final provider = context.database.settingProperties.selectedAiProvider; + final model = provider?.model; + final instruction = switch (action) { + MessageAiAction.translate => + 'Translate this chat message into $language. Return only the translation.', + MessageAiAction.explain => + 'Explain this chat message clearly and concisely in $language. ' + 'Clarify slang, abbreviations, technical terms, and implied meaning when useful. ' + 'Return only the explanation.', + MessageAiAction.suggestReplies => + 'Suggest three concise, natural replies in $language to this chat message ' + 'using the recent conversation context. Return one reply per line, without numbering.', + }; + final title = switch (action) { + MessageAiAction.translate => 'Translate', + MessageAiAction.explain => 'Explain', + MessageAiAction.suggestReplies => 'Suggest replies', + }; + + onStateChanged( + action, + InlineMessageAiEntry(loading: true, model: model), + ); + try { + final result = await AiChatController(context.database).assistText( + instruction: instruction, + input: input, + conversationId: message.conversationId, + provider: provider, + ); + if (!context.mounted) return; + onStateChanged( + action, + InlineMessageAiEntry(result: result.trim(), model: model), + ); + } catch (error, stackTrace) { + e('AI message assist failed: $error, $stackTrace'); + if (!context.mounted) return; + onStateChanged( + action, + InlineMessageAiEntry(error: '$title failed: $error', model: model), + ); + } +} + +String _currentLanguageTag(BuildContext context) { + final locale = Localizations.localeOf(context); + final countryCode = locale.countryCode; + if (countryCode == null || countryCode.isEmpty) return locale.languageCode; + return '${locale.languageCode}-$countryCode'; +} + +List _parseAiReplySuggestions(String result) => result + .split('\n') + .map((line) => line.trim().replaceFirst(RegExp(r'^[-*\d.)\s]+'), '')) + .where((line) => line.isNotEmpty) + .take(3) + .toList(growable: false); + +class MessageInlineAiSection extends StatelessWidget { + const MessageInlineAiSection({required this.state, super.key}); + + final InlineMessageAiState state; + + @override + Widget build(BuildContext context) { + if (!state.hasVisibleEntry) { + return const SizedBox.shrink(); + } + + final children = [ + for (final action in MessageAiAction.values) + if (state[action]?.loading == true || state[action]?.hasContent == true) + Padding( + padding: const EdgeInsets.only(top: 8), + child: _InlineMessageAiCard( + action: action, + entry: state[action]!, + ), + ), + ]; + + if (children.isEmpty) { + return const SizedBox.shrink(); + } + + return Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: children, + ); + } +} + +class _InlineMessageAiCard extends StatelessWidget { + const _InlineMessageAiCard({ + required this.action, + required this.entry, + }); + + final MessageAiAction action; + final InlineMessageAiEntry entry; + + @override + Widget build(BuildContext context) { + final title = switch (action) { + MessageAiAction.translate => 'Translation', + MessageAiAction.explain => 'Explanation', + MessageAiAction.suggestReplies => 'Suggested replies', + }; + final loadingText = switch (action) { + MessageAiAction.translate => 'Translating...', + MessageAiAction.explain => 'Explaining...', + MessageAiAction.suggestReplies => 'Generating replies...', + }; + + Widget content; + if (entry.loading) { + content = Row( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox( + width: 14, + height: 14, + child: CircularProgressIndicator( + strokeWidth: 1.8, + color: context.theme.secondaryText, + ), + ), + const SizedBox(width: 8), + Text( + loadingText, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + height: 1.4, + ), + ), + ], + ); + } else if (entry.error?.isNotEmpty == true) { + content = Text( + entry.error!, + style: TextStyle( + color: context.theme.red, + fontSize: 13, + height: 1.45, + ), + ); + } else if (action == MessageAiAction.suggestReplies) { + content = _InlineReplySuggestions(result: entry.result ?? ''); + } else if (action == MessageAiAction.explain) { + final data = entry.result ?? ''; + content = MarkdownColumn( + data: data, + selectable: true, + cacheKey: buildMarkdownCacheKey( + namespace: 'inline-message-ai-explain', + id: '${entry.model ?? 'unknown'}:${data.hashCode}', + ), + ); + } else { + content = SelectableText( + entry.result ?? '', + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.45, + ), + ); + } + + return Container( + constraints: const BoxConstraints(maxWidth: 420), + padding: const EdgeInsets.all(10), + decoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ), + borderRadius: const BorderRadius.all(Radius.circular(8)), + ), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Row( + children: [ + Expanded( + child: Text( + title, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + ), + if (entry.model?.isNotEmpty == true) + Text( + entry.model!, + textAlign: TextAlign.right, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 11, + height: 1.2, + ), + ), + ], + ), + if (entry.model?.isNotEmpty == true) const SizedBox(height: 2), + const SizedBox(height: 6), + DefaultTextStyle.merge( + style: const TextStyle(height: 1.45), + child: content, + ), + ], + ), + ); + } +} + +class _InlineReplySuggestions extends StatelessWidget { + const _InlineReplySuggestions({required this.result}); + + final String result; + + @override + Widget build(BuildContext context) { + final replies = _parseAiReplySuggestions(result); + if (replies.isEmpty) { + return SelectableText( + result, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.45, + ), + ); + } + + return Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: [ + for (var i = 0; i < replies.length; i++) + Padding( + padding: EdgeInsets.only(bottom: i == replies.length - 1 ? 0 : 6), + child: _InlineReplyButton(reply: replies[i]), + ), + ], + ); + } +} + +class _InlineReplyButton extends StatelessWidget { + const _InlineReplyButton({required this.reply}); + + final String reply; + + @override + Widget build(BuildContext context) { + const borderRadius = BorderRadius.all(Radius.circular(6)); + return Material( + color: context.dynamicColor( + const Color.fromRGBO(255, 255, 255, 0.92), + darkColor: const Color.fromRGBO(255, 255, 255, 0.04), + ), + borderRadius: borderRadius, + child: InkWell( + borderRadius: borderRadius, + onTap: () => context.providerContainer + .read(recallMessageNotifierProvider) + .onReedit(reply), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 10, vertical: 8), + child: Text( + reply, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.35, + ), + ), + ), + ), + ); + } +} From ad8b36663329389bb5022dcc66e632b9acd6cb3f Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:08:37 +0800 Subject: [PATCH 13/52] feat: add AI draft assist panel and inline message AI actions --- lib/ai/ai_chat_controller.dart | 23 +- lib/ui/home/chat/ai_draft_assist_panel.dart | 685 ++++++++++++++++++++ lib/ui/home/chat/input_container.dart | 284 ++++---- lib/widgets/message/message.dart | 8 + lib/widgets/message/message_ai_assist.dart | 42 +- 5 files changed, 901 insertions(+), 141 deletions(-) create mode 100644 lib/ui/home/chat/ai_draft_assist_panel.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index ce6b2510f4..d03bfc442e 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -56,12 +56,23 @@ class AiChatController { conversationId: conversationId, ); - return _streamRequest( - config, - messages, - cancelToken: CancelToken(), - onContent: (_) async {}, - ); + final cancelToken = CancelToken(); + if (conversationId != null) { + _activeAiRequests[conversationId] = cancelToken; + } + try { + return await _streamRequest( + config, + messages, + cancelToken: cancelToken, + onContent: (_) async {}, + ); + } finally { + if (conversationId != null && + _activeAiRequests[conversationId] == cancelToken) { + _activeAiRequests.remove(conversationId); + } + } } Future send({ diff --git a/lib/ui/home/chat/ai_draft_assist_panel.dart b/lib/ui/home/chat/ai_draft_assist_panel.dart new file mode 100644 index 0000000000..0a024f5749 --- /dev/null +++ b/lib/ui/home/chat/ai_draft_assist_panel.dart @@ -0,0 +1,685 @@ +import 'package:flutter/material.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:flutter_portal/flutter_portal.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../../utils/extension/extension.dart'; +import '../../../widgets/action_button.dart'; +import '../../../widgets/interactive_decorated_box.dart'; +import '../../../widgets/menu.dart'; + +enum AiDraftAction { polish, shorten, polite, translate, replyWithContext } + +enum AiDraftAssistPhase { idle, loading, result, error } + +class AiDraftAssistViewState { + const AiDraftAssistViewState({ + this.phase = AiDraftAssistPhase.idle, + this.action, + this.original = '', + this.result, + this.error, + }); + + final AiDraftAssistPhase phase; + final AiDraftAction? action; + final String original; + final String? result; + final String? error; + + bool get isIdle => phase == AiDraftAssistPhase.idle; + bool get isLoading => phase == AiDraftAssistPhase.loading; + static const idle = AiDraftAssistViewState(); +} + +class AiDraftAssistButton extends HookConsumerWidget { + const AiDraftAssistButton({ + required this.enabled, + required this.textEditingController, + required this.viewState, + required this.onSelected, + required this.onStop, + super.key, + }); + + final bool enabled; + final TextEditingController textEditingController; + final AiDraftAssistViewState viewState; + final ValueChanged onSelected; + final VoidCallback onStop; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final draftValue = useValueListenable(textEditingController); + final hasDraft = draftValue.text.trim().isNotEmpty; + final visible = useState(false); + final hovering = useState(false); + + void closePanel() { + visible.value = false; + } + + useEffect(() { + if (viewState.phase != AiDraftAssistPhase.idle) { + visible.value = false; + } + return null; + }, [viewState.phase]); + + return AnimatedOpacity( + duration: const Duration(milliseconds: 180), + opacity: enabled ? 1 : 0.45, + child: Barrier( + visible: enabled && visible.value, + onClose: closePanel, + duration: const Duration(milliseconds: 160), + child: PortalTarget( + visible: enabled && visible.value, + closeDuration: const Duration(milliseconds: 160), + anchor: const Aligned( + follower: Alignment.bottomRight, + target: Alignment.topRight, + ), + portalFollower: TweenAnimationBuilder( + duration: const Duration(milliseconds: 160), + curve: Curves.easeOutCubic, + tween: Tween(begin: 0, end: visible.value ? 1 : 0), + child: Padding( + padding: const EdgeInsets.only(bottom: 8), + child: _AiDraftAssistActionPanel( + hasDraft: hasDraft, + onSelected: (action) { + closePanel(); + onSelected(action); + }, + ), + ), + builder: (context, progress, child) => Opacity( + opacity: progress, + child: Transform.translate( + offset: Offset(0, 8 * (1 - progress)), + child: child, + ), + ), + ), + child: IgnorePointer( + ignoring: !enabled, + child: MouseRegion( + onEnter: (_) => hovering.value = true, + onExit: (_) => hovering.value = false, + child: ActionButton( + onTap: () { + if (viewState.isLoading) { + onStop(); + return; + } + if (visible.value) { + closePanel(); + return; + } + visible.value = true; + }, + child: AnimatedSwitcher( + duration: const Duration(milliseconds: 160), + transitionBuilder: (child, animation) => FadeTransition( + opacity: animation, + child: ScaleTransition(scale: animation, child: child), + ), + child: _AiDraftAssistButtonIcon( + key: ValueKey('${viewState.phase}-${hovering.value}'), + viewState: viewState, + hovering: hovering.value, + ), + ), + ), + ), + ), + ), + ), + ); + } +} + +class _AiDraftAssistButtonIcon extends StatelessWidget { + const _AiDraftAssistButtonIcon({ + required this.viewState, + required this.hovering, + super.key, + }); + + final AiDraftAssistViewState viewState; + final bool hovering; + + @override + Widget build(BuildContext context) { + if (viewState.isLoading) { + if (hovering) { + return Icon( + Icons.stop_rounded, + size: 18, + color: context.theme.red, + ); + } + return SizedBox( + width: 18, + height: 18, + child: CircularProgressIndicator( + strokeWidth: 2, + color: context.theme.accent, + ), + ); + } + if (viewState.phase == AiDraftAssistPhase.result) { + return Icon( + Icons.auto_awesome_rounded, + size: 20, + color: context.theme.accent, + ); + } + if (viewState.phase == AiDraftAssistPhase.error) { + return Icon( + Icons.error_outline_rounded, + size: 20, + color: context.theme.red, + ); + } + return Icon( + Icons.auto_awesome_rounded, + size: 20, + color: context.theme.icon, + ); + } +} + +class _AiDraftAssistActionPanel extends StatelessWidget { + const _AiDraftAssistActionPanel({ + required this.hasDraft, + required this.onSelected, + }); + + final bool hasDraft; + final ValueChanged onSelected; + + @override + Widget build(BuildContext context) => ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 320), + child: DecoratedBox( + decoration: BoxDecoration( + color: context.theme.popUp, + borderRadius: const BorderRadius.all(Radius.circular(12)), + boxShadow: const [ + BoxShadow( + color: Color.fromRGBO(0, 0, 0, 0.12), + offset: Offset(0, 8), + blurRadius: 28, + ), + ], + border: Border.all( + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.05), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ), + ), + ), + child: Padding( + padding: const EdgeInsets.all(12), + child: Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + _AiDraftAssistGroup( + title: 'Draft', + children: [ + _AiDraftAssistActionTile( + title: 'Polish', + subtitle: 'Clearer and more natural', + icon: Icons.auto_fix_high_rounded, + enabled: hasDraft, + onTap: () => onSelected(AiDraftAction.polish), + ), + _AiDraftAssistActionTile( + title: 'Make shorter', + subtitle: 'Cut extra words', + icon: Icons.short_text_rounded, + enabled: hasDraft, + onTap: () => onSelected(AiDraftAction.shorten), + ), + _AiDraftAssistActionTile( + title: 'Make polite', + subtitle: 'Softer tone', + icon: Icons.favorite_border_rounded, + enabled: hasDraft, + onTap: () => onSelected(AiDraftAction.polite), + ), + _AiDraftAssistActionTile( + title: 'Translate draft', + subtitle: 'Translate current input', + icon: Icons.translate_rounded, + enabled: hasDraft, + onTap: () => onSelected(AiDraftAction.translate), + ), + ], + ), + const SizedBox(height: 14), + _AiDraftAssistGroup( + title: 'Conversation', + children: [ + _AiDraftAssistActionTile( + title: 'Reply with context', + subtitle: 'Generate from recent messages', + icon: Icons.reply_rounded, + enabled: true, + onTap: () => onSelected(AiDraftAction.replyWithContext), + ), + ], + ), + ], + ), + ), + ), + ); +} + +class AiDraftAssistInlineCandidate extends StatelessWidget { + const AiDraftAssistInlineCandidate({ + required this.viewState, + required this.onDismiss, + required this.onCopy, + required this.onAppend, + required this.onReplace, + super.key, + }); + + final AiDraftAssistViewState viewState; + final VoidCallback onDismiss; + final VoidCallback onCopy; + final VoidCallback onAppend; + final VoidCallback onReplace; + + @override + Widget build(BuildContext context) { + if (viewState.isIdle || viewState.phase == AiDraftAssistPhase.loading) { + return const SizedBox.shrink(); + } + + final accent = context.theme.accent; + final background = context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + final border = context.dynamicColor( + accent.withValues(alpha: 0.22), + darkColor: accent.withValues(alpha: 0.28), + ); + + return AnimatedSwitcher( + duration: const Duration(milliseconds: 180), + child: Container( + key: ValueKey( + '${viewState.phase}-${viewState.result}-${viewState.error}', + ), + margin: const EdgeInsets.fromLTRB(16, 0, 16, 8), + padding: const EdgeInsets.all(10), + decoration: BoxDecoration( + color: background, + borderRadius: const BorderRadius.all(Radius.circular(12)), + border: Border.all(color: border), + ), + child: switch (viewState.phase) { + AiDraftAssistPhase.result => _AiDraftAssistInlineResult( + action: viewState.action, + result: viewState.result ?? '', + onDismiss: onDismiss, + onCopy: onCopy, + onAppend: onAppend, + onReplace: onReplace, + ), + AiDraftAssistPhase.error => _AiDraftAssistInlineError( + error: viewState.error ?? 'Unknown error', + onDismiss: onDismiss, + ), + AiDraftAssistPhase.loading || + AiDraftAssistPhase.idle => const SizedBox.shrink(), + }, + ), + ); + } +} + +class _AiDraftAssistInlineResult extends StatelessWidget { + const _AiDraftAssistInlineResult({ + required this.action, + required this.result, + required this.onDismiss, + required this.onCopy, + required this.onAppend, + required this.onReplace, + }); + + final AiDraftAction? action; + final String result; + final VoidCallback onDismiss; + final VoidCallback onCopy; + final VoidCallback onAppend; + final VoidCallback onReplace; + + @override + Widget build(BuildContext context) => Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Row( + children: [ + Expanded( + child: Text( + aiDraftActionTitle(action ?? AiDraftAction.polish), + style: TextStyle( + color: context.theme.text, + fontSize: 13, + fontWeight: FontWeight.w600, + ), + ), + ), + _AiDraftInlineIconButton( + icon: Icons.close_rounded, + color: context.theme.secondaryText, + onTap: onDismiss, + ), + ], + ), + const SizedBox(height: 8), + ConstrainedBox( + constraints: const BoxConstraints(maxHeight: 160), + child: ScrollConfiguration( + behavior: ScrollConfiguration.of(context).copyWith(scrollbars: true), + child: SingleChildScrollView( + child: SelectableText( + result, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + height: 1.4, + ), + ), + ), + ), + ), + const SizedBox(height: 10), + Wrap( + spacing: 8, + runSpacing: 8, + children: [ + _AiDraftInlineTextButton( + title: 'Copy', + onTap: onCopy, + secondary: true, + ), + _AiDraftInlineTextButton( + title: 'Append', + onTap: onAppend, + secondary: true, + ), + _AiDraftInlineTextButton( + title: 'Replace Draft', + onTap: onReplace, + ), + ], + ), + ], + ); +} + +class _AiDraftAssistInlineError extends StatelessWidget { + const _AiDraftAssistInlineError({ + required this.error, + required this.onDismiss, + }); + + final String error; + final VoidCallback onDismiss; + + @override + Widget build(BuildContext context) => Row( + children: [ + Icon( + Icons.error_outline_rounded, + size: 16, + color: context.theme.red, + ), + const SizedBox(width: 8), + Expanded( + child: Text( + error, + style: TextStyle( + color: context.theme.red, + fontSize: 12, + height: 1.35, + ), + ), + ), + _AiDraftInlineIconButton( + icon: Icons.close_rounded, + color: context.theme.secondaryText, + onTap: onDismiss, + ), + ], + ); +} + +class _AiDraftAssistGroup extends StatelessWidget { + const _AiDraftAssistGroup({ + required this.title, + required this.children, + }); + + final String title; + final List children; + + @override + Widget build(BuildContext context) => Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + title, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 8), + ...children.expand((child) => [child, const SizedBox(height: 8)]).toList() + ..removeLast(), + ], + ); +} + +class _AiDraftAssistActionTile extends StatelessWidget { + const _AiDraftAssistActionTile({ + required this.title, + required this.subtitle, + required this.icon, + required this.enabled, + required this.onTap, + }); + + final String title; + final String subtitle; + final IconData icon; + final bool enabled; + final VoidCallback onTap; + + @override + Widget build(BuildContext context) { + final backgroundColor = context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ); + final accentColor = context.theme.accent; + final iconColor = enabled ? accentColor : context.theme.secondaryText; + + return Opacity( + opacity: enabled ? 1 : 0.5, + child: IgnorePointer( + ignoring: !enabled, + child: InteractiveDecoratedBox.color( + decoration: BoxDecoration( + color: backgroundColor, + borderRadius: const BorderRadius.all(Radius.circular(8)), + border: Border.all( + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.04), + darkColor: const Color.fromRGBO(255, 255, 255, 0.05), + ), + ), + ), + hoveringColor: accentColor.withValues(alpha: 0.08), + tapDowningColor: accentColor.withValues(alpha: 0.12), + onTap: onTap, + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 11), + child: Row( + children: [ + Container( + width: 28, + height: 28, + decoration: BoxDecoration( + color: iconColor.withValues(alpha: 0.12), + borderRadius: const BorderRadius.all(Radius.circular(8)), + ), + alignment: Alignment.center, + child: Icon(icon, size: 16, color: iconColor), + ), + const SizedBox(width: 10), + Expanded( + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + title, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 3), + Text( + subtitle, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + height: 1.3, + ), + ), + ], + ), + ), + ], + ), + ), + ), + ), + ); + } +} + +class _AiDraftInlineIconButton extends StatelessWidget { + const _AiDraftInlineIconButton({ + required this.icon, + required this.color, + required this.onTap, + }); + + final IconData icon; + final Color color; + final VoidCallback onTap; + + @override + Widget build(BuildContext context) => ActionButton( + size: 16, + padding: const EdgeInsets.all(4), + onTap: onTap, + child: Icon(icon, size: 16, color: color), + ); +} + +class _AiDraftInlineTextButton extends StatelessWidget { + const _AiDraftInlineTextButton({ + required this.title, + required this.onTap, + this.secondary = false, + }); + + final String title; + final VoidCallback onTap; + final bool secondary; + + @override + Widget build(BuildContext context) => InteractiveDecoratedBox.color( + decoration: BoxDecoration( + color: secondary + ? context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ) + : context.theme.accent, + borderRadius: const BorderRadius.all(Radius.circular(8)), + ), + hoveringColor: secondary + ? context.dynamicColor( + const Color.fromRGBO(235, 238, 242, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.1), + ) + : context.theme.accent.withValues(alpha: 0.88), + tapDowningColor: secondary + ? context.dynamicColor( + const Color.fromRGBO(225, 229, 235, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.12), + ) + : context.theme.accent.withValues(alpha: 0.8), + onTap: onTap, + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 8), + child: Text( + title, + style: TextStyle( + color: secondary + ? context.theme.text + : context.dynamicColor(const Color.fromRGBO(255, 255, 255, 1)), + fontSize: 12, + fontWeight: FontWeight.w600, + ), + ), + ), + ); +} + +void applyAiDraftAssistResult( + TextEditingController controller, + String text, { + required bool replace, +}) { + if (replace) { + controller.value = TextEditingValue( + text: text, + selection: TextSelection.collapsed(offset: text.length), + ); + return; + } + + final current = controller.text; + final separator = current.trim().isEmpty ? '' : '\n'; + final next = '$current$separator$text'; + controller.value = TextEditingValue( + text: next, + selection: TextSelection.collapsed(offset: next.length), + ); +} + +String aiDraftActionTitle(AiDraftAction action) => switch (action) { + AiDraftAction.polish => 'Polish', + AiDraftAction.shorten => 'Make shorter', + AiDraftAction.polite => 'Make polite', + AiDraftAction.translate => 'Translate draft', + AiDraftAction.replyWithContext => 'Reply with context', +}; diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index bd435c8b39..351f0b494c 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -36,7 +36,6 @@ import '../../../utils/reg_exp_utils.dart'; import '../../../utils/system/clipboard.dart'; import '../../../widgets/action_button.dart'; import '../../../widgets/actions/actions.dart'; -import '../../../widgets/ai/ai_text_result_dialog.dart'; import '../../../widgets/high_light_text.dart'; import '../../../widgets/hover_overlay.dart'; import '../../../widgets/mention_panel.dart'; @@ -53,6 +52,7 @@ import '../../provider/mention_cache_provider.dart'; import '../../provider/mention_provider.dart'; import '../../provider/quote_message_provider.dart'; import '../../provider/recall_message_reedit_provider.dart'; +import 'ai_draft_assist_panel.dart'; import 'chat_page.dart'; import 'files_preview.dart'; import 'voice_recorder_bottom_bar.dart'; @@ -153,6 +153,59 @@ class _InputContainer extends HookConsumerWidget { ); final mentionProviderInstance = mentionProvider(textEditingValueStream); + final aiDraftAssistState = useState(AiDraftAssistViewState.idle); + final aiDraftAssistRequestVersion = useState(0); + + Future handleAiDraftRequest( + AiDraftAction action, + String original, + ) async { + final currentConversationId = conversationId; + if (currentConversationId == null) { + throw ToastError('Conversation unavailable'); + } + + final requestId = aiDraftAssistRequestVersion.value + 1; + aiDraftAssistRequestVersion.value = requestId; + aiDraftAssistState.value = AiDraftAssistViewState( + phase: AiDraftAssistPhase.loading, + action: action, + original: original, + ); + + try { + final result = await _requestAiDraftAction( + context, + action: action, + conversationId: currentConversationId, + original: original, + ); + if (aiDraftAssistRequestVersion.value == requestId) { + aiDraftAssistState.value = AiDraftAssistViewState( + phase: AiDraftAssistPhase.result, + action: action, + original: original, + result: result, + ); + } + return result; + } catch (error) { + if (aiDraftAssistRequestVersion.value == requestId) { + aiDraftAssistState.value = AiDraftAssistViewState( + phase: AiDraftAssistPhase.error, + action: action, + original: original, + error: '$error', + ); + } + rethrow; + } + } + + void dismissAiDraftAssist() { + aiDraftAssistRequestVersion.value += 1; + aiDraftAssistState.value = AiDraftAssistViewState.idle; + } useEffect(() { if (conversationId == null) return null; @@ -165,6 +218,11 @@ class _InputContainer extends HookConsumerWidget { return null; }, [conversationId]); + useEffect(() { + dismissAiDraftAssist(); + return null; + }, [conversationId]); + useEffect(() { final updateDraft = context.database.conversationDao.updateDraft; return () { @@ -262,6 +320,39 @@ class _InputContainer extends HookConsumerWidget { ), const SizedBox(height: 8), ], + if (!aiModeEnabled && + !aiDraftAssistState.value.isIdle) ...[ + AiDraftAssistInlineCandidate( + viewState: aiDraftAssistState.value, + onDismiss: dismissAiDraftAssist, + onCopy: () { + final result = aiDraftAssistState.value.result; + if (result == null) return; + Clipboard.setData(ClipboardData(text: result)); + showToastSuccessful(context: context); + }, + onAppend: () { + final result = aiDraftAssistState.value.result; + if (result == null) return; + applyAiDraftAssistResult( + textEditingController, + result, + replace: false, + ); + dismissAiDraftAssist(); + }, + onReplace: () { + final result = aiDraftAssistState.value.result; + if (result == null) return; + applyAiDraftAssistResult( + textEditingController, + result, + replace: true, + ); + dismissAiDraftAssist(); + }, + ), + ], Row( crossAxisAlignment: CrossAxisAlignment.end, children: [ @@ -285,13 +376,35 @@ class _InputContainer extends HookConsumerWidget { providerName: aiProvider?.name, modelName: aiProvider?.model, aiRequestInFlight: aiRequestInFlight, + aiDraftAssistState: aiDraftAssistState.value, ), ), if (!aiModeEnabled) ...[ const SizedBox(width: 8), - _AiDraftAssistButton( - conversationId: conversationId, + AiDraftAssistButton( + enabled: + context + .database + .settingProperties + .selectedAiProvider != + null, textEditingController: textEditingController, + viewState: aiDraftAssistState.value, + onSelected: (action) => unawaited( + handleAiDraftRequest( + action, + textEditingController.text.trim(), + ), + ), + onStop: () { + final currentConversationId = conversationId; + if (currentConversationId != null) { + AiChatController( + context.database, + ).stop(currentConversationId); + } + dismissAiDraftAssist(); + }, ), ], SizedBox(width: aiModeEnabled ? 10 : 8), @@ -465,133 +578,47 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { } } -enum _AiDraftAction { polish, shorten, polite, translate, replyWithContext } - -class _AiDraftAssistButton extends StatelessWidget { - const _AiDraftAssistButton({ - required this.conversationId, - required this.textEditingController, - }); - - final String? conversationId; - final TextEditingController textEditingController; - - @override - Widget build(BuildContext context) { - final enabled = - context.database.settingProperties.selectedAiProvider != null; - return AnimatedOpacity( - duration: const Duration(milliseconds: 180), - opacity: enabled ? 1 : 0.45, - child: IgnorePointer( - ignoring: !enabled, - child: CustomPopupMenuButton<_AiDraftAction>( - itemBuilder: (_) => [ - CustomPopupMenuItem( - title: 'Polish', - value: _AiDraftAction.polish, - icon: Resources.assetsImagesBotSvg, - ), - CustomPopupMenuItem( - title: 'Make shorter', - value: _AiDraftAction.shorten, - icon: Resources.assetsImagesBotSvg, - ), - CustomPopupMenuItem( - title: 'Make polite', - value: _AiDraftAction.polite, - icon: Resources.assetsImagesBotSvg, - ), - CustomPopupMenuItem( - title: 'Translate draft', - value: _AiDraftAction.translate, - icon: Resources.assetsImagesBotSvg, - ), - CustomPopupMenuItem( - title: 'Reply with context', - value: _AiDraftAction.replyWithContext, - icon: Resources.assetsImagesBotSvg, - ), - ], - onSelected: (action) => unawaited( - _runAiDraftAction( - context, - action: action, - conversationId: conversationId, - textEditingController: textEditingController, - ), - ), - child: Icon( - Icons.auto_awesome_rounded, - size: 20, - color: context.theme.icon, - ), - ), - ), - ); - } -} - -Future _runAiDraftAction( +Future _requestAiDraftAction( BuildContext context, { - required _AiDraftAction action, - required String? conversationId, - required TextEditingController textEditingController, + required AiDraftAction action, + required String conversationId, + required String original, }) async { - final original = textEditingController.text.trim(); - if (conversationId == null) return; - if (action != _AiDraftAction.replyWithContext && original.isEmpty) { - showToastFailed(ToastError('Please type a message first')); - return; + if (action != AiDraftAction.replyWithContext && original.isEmpty) { + throw ToastError('Please type a message first'); } final language = _currentLanguageTag(context); final instruction = switch (action) { - _AiDraftAction.polish => + AiDraftAction.polish => 'Polish this draft for a chat message. Keep the original meaning, language, and approximate length.', - _AiDraftAction.shorten => + AiDraftAction.shorten => 'Rewrite this chat draft to be shorter and clearer. Keep the original language and intent.', - _AiDraftAction.polite => + AiDraftAction.polite => 'Rewrite this chat draft to sound polite, natural, and still concise. Keep the original language.', - _AiDraftAction.translate => + AiDraftAction.translate => 'Translate this chat draft into $language. Return only the translation.', - _AiDraftAction.replyWithContext => + AiDraftAction.replyWithContext => 'Draft a concise, natural reply to the latest conversation message using the recent context. Return only the reply text.', }; final title = switch (action) { - _AiDraftAction.polish => 'Polish', - _AiDraftAction.shorten => 'Make shorter', - _AiDraftAction.polite => 'Make polite', - _AiDraftAction.translate => 'Translate draft', - _AiDraftAction.replyWithContext => 'Reply with context', + AiDraftAction.polish => 'Polish', + AiDraftAction.shorten => 'Make shorter', + AiDraftAction.polite => 'Make polite', + AiDraftAction.translate => 'Translate draft', + AiDraftAction.replyWithContext => 'Reply with context', }; - showToastLoading(context: context); try { final result = await AiChatController(context.database).assistText( instruction: instruction, - input: action == _AiDraftAction.replyWithContext ? null : original, + input: action == AiDraftAction.replyWithContext ? null : original, conversationId: conversationId, ); - Toast.dismiss(); - if (!context.mounted) return; - final selectedAction = await showAiTextResultDialog( - context: context, - title: title, - original: original, - result: result, - allowReplace: original.isNotEmpty, - ); - if (selectedAction == null) return; - switch (selectedAction) { - case AiTextResultAction.replace: - _replaceDraft(textEditingController, result); - case AiTextResultAction.insert: - _insertDraft(textEditingController, result); - } + return result.trim(); } catch (error, stackTrace) { - e('AI draft assist failed: $error, $stackTrace'); - showToastFailed(error, context: context); + e('AI draft assist failed: $title: $error, $stackTrace'); + rethrow; } } @@ -602,23 +629,6 @@ String _currentLanguageTag(BuildContext context) { return '${locale.languageCode}-$countryCode'; } -void _replaceDraft(TextEditingController controller, String text) { - controller.value = TextEditingValue( - text: text, - selection: TextSelection.collapsed(offset: text.length), - ); -} - -void _insertDraft(TextEditingController controller, String text) { - final current = controller.text; - final separator = current.trim().isEmpty ? '' : '\n'; - final next = '$current$separator$text'; - controller.value = TextEditingValue( - text: next, - selection: TextSelection.collapsed(offset: next.length), - ); -} - void showMaxLengthReachedToast(BuildContext context) => showToastFailed(ToastError(context.l10n.contentTooLong)); @@ -780,6 +790,7 @@ class _SendTextField extends HookConsumerWidget { required this.providerName, required this.modelName, required this.aiRequestInFlight, + required this.aiDraftAssistState, }); final FocusNode focusNode; @@ -790,6 +801,7 @@ class _SendTextField extends HookConsumerWidget { final String? providerName; final String? modelName; final bool aiRequestInFlight; + final AiDraftAssistViewState aiDraftAssistState; @override Widget build(BuildContext context, WidgetRef ref) { @@ -863,17 +875,27 @@ class _SendTextField extends HookConsumerWidget { ? context.l10n.chatHintE2e : 'Type message or /ai'; final canSubmit = sendable && (!aiModeEnabled || !aiRequestInFlight); + final aiDraftAssistActive = !aiDraftAssistState.isIdle; + final aiDraftAssistHasResult = + aiDraftAssistState.phase == AiDraftAssistPhase.result; + final fieldColor = context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + final borderColor = aiDraftAssistActive + ? context.theme.accent.withValues( + alpha: aiDraftAssistHasResult ? 0.26 : 0.16, + ) + : Colors.transparent; return AnimatedContainer( duration: const Duration(milliseconds: 220), curve: Curves.easeOutCubic, constraints: const BoxConstraints(minHeight: 40), decoration: BoxDecoration( - borderRadius: const BorderRadius.all(Radius.circular(4)), - color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), - ), + borderRadius: const BorderRadius.all(Radius.circular(10)), + color: fieldColor, + border: Border.all(color: borderColor), ), alignment: Alignment.center, child: FocusableActionDetector( diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index cf31c38e2c..b4caf3753c 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -325,6 +325,14 @@ class MessageItemWidget extends HookConsumerWidget { isVerified: message.isVerified, aiSection: MessageInlineAiSection( state: inlineAiState.value, + leadingPadding: !isCurrentUser + ? kInlineMessageAiLeadingPadding + : 0, + onClose: (action) { + final nextState = inlineAiState.value.remove(action); + inlineAiState.value = nextState; + writeInlineMessageAiState(message.messageId, nextState); + }, ), buildMenus: (request) { request.onShowMenu.addListener(() { diff --git a/lib/widgets/message/message_ai_assist.dart b/lib/widgets/message/message_ai_assist.dart index 52ff0bc448..e24dbba024 100644 --- a/lib/widgets/message/message_ai_assist.dart +++ b/lib/widgets/message/message_ai_assist.dart @@ -6,10 +6,13 @@ import '../../db/mixin_database.dart'; import '../../ui/provider/recall_message_reedit_provider.dart'; import '../../utils/extension/extension.dart'; import '../../utils/logger.dart'; +import '../action_button.dart'; import '../markdown.dart'; enum MessageAiAction { translate, explain, suggestReplies } +const kInlineMessageAiLeadingPadding = 9.0; + final _inlineMessageAiStateCache = {}; class InlineMessageAiState with EquatableMixin { @@ -25,6 +28,13 @@ class InlineMessageAiState with EquatableMixin { ..[action] = entry, ); + InlineMessageAiState remove(MessageAiAction action) { + if (!entries.containsKey(action)) return this; + final nextEntries = Map.from(entries) + ..remove(action); + return InlineMessageAiState(entries: nextEntries); + } + InlineMessageAiEntry? operator [](MessageAiAction action) => entries[action]; bool get hasVisibleEntry => @@ -152,9 +162,16 @@ List _parseAiReplySuggestions(String result) => result .toList(growable: false); class MessageInlineAiSection extends StatelessWidget { - const MessageInlineAiSection({required this.state, super.key}); + const MessageInlineAiSection({ + required this.state, + required this.onClose, + this.leadingPadding = 0, + super.key, + }); final InlineMessageAiState state; + final void Function(MessageAiAction action) onClose; + final double leadingPadding; @override Widget build(BuildContext context) { @@ -170,6 +187,7 @@ class MessageInlineAiSection extends StatelessWidget { child: _InlineMessageAiCard( action: action, entry: state[action]!, + onClose: () => onClose(action), ), ), ]; @@ -178,9 +196,12 @@ class MessageInlineAiSection extends StatelessWidget { return const SizedBox.shrink(); } - return Column( - crossAxisAlignment: CrossAxisAlignment.stretch, - children: children, + return Padding( + padding: EdgeInsets.only(left: leadingPadding), + child: Column( + crossAxisAlignment: CrossAxisAlignment.stretch, + children: children, + ), ); } } @@ -189,10 +210,12 @@ class _InlineMessageAiCard extends StatelessWidget { const _InlineMessageAiCard({ required this.action, required this.entry, + required this.onClose, }); final MessageAiAction action; final InlineMessageAiEntry entry; + final VoidCallback onClose; @override Widget build(BuildContext context) { @@ -298,6 +321,17 @@ class _InlineMessageAiCard extends StatelessWidget { height: 1.2, ), ), + const SizedBox(width: 4), + ActionButton( + size: 14, + padding: const EdgeInsets.all(2), + onTap: onClose, + child: Icon( + Icons.close, + size: 14, + color: context.theme.secondaryText, + ), + ), ], ), if (entry.model?.isNotEmpty == true) const SizedBox(height: 2), From e9ed85b712b9bd9e46f4ffc674d9b19e030e6bb7 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Fri, 24 Apr 2026 19:37:26 +0800 Subject: [PATCH 14/52] fix: ensure AI features respect enabled providers configuration --- lib/ui/home/chat/input_container.dart | 3 ++- lib/widgets/message/message.dart | 9 ++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 351f0b494c..d18c0822c4 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -379,7 +379,8 @@ class _InputContainer extends HookConsumerWidget { aiDraftAssistState: aiDraftAssistState.value, ), ), - if (!aiModeEnabled) ...[ + if (!aiModeEnabled && + enabledAiProviders.isNotEmpty) ...[ const SizedBox(width: 8), AiDraftAssistButton( enabled: diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index b4caf3753c..7177c7bad8 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -647,7 +647,14 @@ class MessageItemWidget extends HookConsumerWidget { ), ]; - final aiText = messageAiText(message); + final hasEnabledAiProvider = context + .database + .settingProperties + .aiProviders + .any((p) => p.enabled); + final aiText = hasEnabledAiProvider + ? messageAiText(message) + : null; void updateInlineAiState( MessageAiAction action, InlineMessageAiEntry entry, From 6cfce8bf49f477be4442f3fa62ec4a3956ecc840 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 27 Apr 2026 18:36:08 +0800 Subject: [PATCH 15/52] refactor: simplify placeholder logic and update markdown theme implementation --- lib/ui/home/chat/input_container.dart | 24 ++---------------------- lib/widgets/markdown.dart | 14 ++++++++------ pubspec.lock | 7 +++---- pubspec.yaml | 3 ++- 4 files changed, 15 insertions(+), 33 deletions(-) diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index d18c0822c4..75667edcc5 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -852,29 +852,9 @@ class _SendTextField extends HookConsumerWidget { ).data ?? false; - final placeholder = aiModeEnabled - ? aiRequestInFlight - ? [ - if (providerName?.trim().isNotEmpty == true) - providerName!.trim() - else - 'AI', - if (modelName?.trim().isNotEmpty == true) - '(${modelName!.trim()})', - 'is responding...', - ].join(' ') - : [ - 'Ask', - if (providerName?.trim().isNotEmpty == true) - providerName!.trim() - else - 'AI', - if (modelName?.trim().isNotEmpty == true) - '(${modelName!.trim()})', - ].join(' ') - : isEncryptConversation + final placeholder = isEncryptConversation ? context.l10n.chatHintE2e - : 'Type message or /ai'; + : context.l10n.typeMessage; final canSubmit = sendable && (!aiModeEnabled || !aiRequestInFlight); final aiDraftAssistActive = !aiDraftAssistState.isIdle; final aiDraftAssistHasResult = diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index df45f25375..4a80721c91 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -504,10 +504,15 @@ MarkdownThemeData _createMarkdownTheme( BuildContext context, double chatFontSizeDelta, ) { - final base = MarkdownThemeData.fallback(context); + final foreground = context.brightness == Brightness.dark + ? MarkdownThemeForeground.dark + : MarkdownThemeForeground.light; + final base = MarkdownThemeData.themed( + context, + foreground: foreground, + ); final textColor = context.theme.text; final accentColor = context.theme.accent; - final codeBlockBackgroundColor = context.theme.chatBackground; TextStyle applyTextColor(TextStyle style) => style.copyWith(color: textColor); TextStyle applyFontSizeDelta(TextStyle style) { @@ -523,7 +528,7 @@ MarkdownThemeData _createMarkdownTheme( bodyStyle: applyTextStyle(base.bodyStyle), quoteStyle: applyFontSizeDelta( base.quoteStyle.copyWith( - color: textColor.withValues(alpha: 0.82), + color: base.quoteStyle.color ?? textColor.withValues(alpha: 0.82), ), ), linkStyle: base.linkStyle.copyWith( @@ -535,9 +540,6 @@ MarkdownThemeData _createMarkdownTheme( ), inlineCodeStyle: applyTextStyle(base.inlineCodeStyle), codeBlockStyle: applyTextStyle(base.codeBlockStyle), - codeBlockBackgroundColor: codeBlockBackgroundColor, - inlineCodeBackgroundColor: codeBlockBackgroundColor, - quoteBackgroundColor: codeBlockBackgroundColor, tableHeaderStyle: applyTextStyle(base.tableHeaderStyle), heading1Style: applyTextStyle( applyFontSizeDelta( diff --git a/pubspec.lock b/pubspec.lock index 039d36056a..24db9422e3 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -1270,10 +1270,9 @@ packages: mixin_markdown_widget: dependency: "direct main" description: - name: mixin_markdown_widget - sha256: c7e6134e5e98a2c390e0cc7f56245336152cc6bfddda14f77a6d920021e82186 - url: "https://pub.dev" - source: hosted + path: "../flutter-plugins/packages/mixin_markdown_widget" + relative: true + source: path version: "0.1.0" msix: dependency: "direct dev" diff --git a/pubspec.yaml b/pubspec.yaml index 5d2d2a765e..2d231a8cd3 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -96,7 +96,8 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - mixin_markdown_widget: ^0.1.0 + mixin_markdown_widget: + path: ../flutter-plugins/packages/mixin_markdown_widget mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 From 7e793452c227f6db7d5ce921ddce30863fc23420 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Mon, 27 Apr 2026 18:43:27 +0800 Subject: [PATCH 16/52] feat: add support for Gemini AI provider --- lib/ai/ai_chat_controller.dart | 124 ++++++++++++++++++++++ lib/ai/model/ai_provider_type.dart | 3 +- lib/ui/setting/ai_provider_edit_page.dart | 79 ++++++++++++-- 3 files changed, 198 insertions(+), 8 deletions(-) diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index d03bfc442e..987cc3e084 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -38,6 +38,7 @@ class AiChatController { final _uuid = const Uuid(); static const _openAiStrategy = _OpenAiCompatibleStrategy(); static const _anthropicStrategy = _AnthropicStrategy(); + static const _geminiStrategy = _GeminiStrategy(); Future assistText({ required String instruction, @@ -328,6 +329,7 @@ class AiChatController { _AiProviderStrategy _strategyFor(AiProviderType type) => switch (type) { AiProviderType.openaiCompatible => _openAiStrategy, AiProviderType.anthropic => _anthropicStrategy, + AiProviderType.gemini => _geminiStrategy, }; } @@ -522,6 +524,128 @@ class _AnthropicStrategy implements _AiProviderStrategy { } } +class _GeminiStrategy implements _AiProviderStrategy { + const _GeminiStrategy(); + + @override + Map headers(AiProviderConfig config) => { + 'x-goog-api-key': config.apiKey, + 'content-type': 'application/json', + }; + + @override + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + }) async { + final systemInstruction = messages + .where((message) => message.role == 'system') + .map((message) => message.content.trim()) + .where((content) => content.isNotEmpty) + .join('\n\n'); + + final contents = messages + .where((message) => message.role != 'system') + .map( + (message) => { + 'role': message.role == _kAiRoleAssistant ? 'model' : 'user', + 'parts': [ + {'text': message.content}, + ], + }, + ) + .toList(); + + final response = await dio.post( + '/models/${Uri.encodeComponent(config.model)}:streamGenerateContent', + queryParameters: const {'alt': 'sse'}, + data: { + 'contents': contents, + if (systemInstruction.isNotEmpty) + 'system_instruction': { + 'parts': [ + {'text': systemInstruction}, + ], + }, + 'generationConfig': { + 'candidateCount': 1, + }, + }, + options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final buffer = StringBuffer(); + await for (final data in _decodeSse(body.stream)) { + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final promptFeedback = json['promptFeedback']; + if (promptFeedback is Map) { + final blockReason = promptFeedback['blockReason']; + if (blockReason is String && blockReason.isNotEmpty) { + throw Exception('Gemini request blocked: $blockReason'); + } + } + + final candidates = json['candidates'] as List?; + if (candidates == null || candidates.isEmpty) { + continue; + } + + final first = candidates.first; + if (first is! Map) { + continue; + } + + final finishReason = first['finishReason']; + if (finishReason is String && + finishReason.isNotEmpty && + finishReason != 'STOP' && + finishReason != 'FINISH_REASON_UNSPECIFIED') { + throw Exception('Gemini request finished with reason: $finishReason'); + } + + final content = first['content']; + if (content is! Map) { + continue; + } + + final parts = content['parts'] as List?; + if (parts == null || parts.isEmpty) { + continue; + } + + for (final part in parts) { + if (part is! Map) { + continue; + } + final text = part['text']; + if (text is String && text.isNotEmpty) { + buffer.write(text); + await onContent(text); + } + } + } + + final text = buffer.toString().trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + return text; + } +} + class _StreamingMessageUpdater { _StreamingMessageUpdater({required this.dao, required this.messageId}); diff --git a/lib/ai/model/ai_provider_type.dart b/lib/ai/model/ai_provider_type.dart index 1b28278824..4c8c31e4ec 100644 --- a/lib/ai/model/ai_provider_type.dart +++ b/lib/ai/model/ai_provider_type.dart @@ -1,6 +1,7 @@ enum AiProviderType { openaiCompatible('openai_compatible'), - anthropic('anthropic') + anthropic('anthropic'), + gemini('gemini') ; const AiProviderType(this.value); diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart index e7d6b6e772..574301d7a3 100644 --- a/lib/ui/setting/ai_provider_edit_page.dart +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -54,6 +54,15 @@ class AiProviderEditPage extends HookConsumerWidget { ); final obscureApiKey = useState(true); + useEffect(() { + if (initial != null) return null; + final suggestion = _defaultBaseUrlFor(providerType.value); + if (baseUrlController.text.trim().isEmpty && suggestion.isNotEmpty) { + baseUrlController.text = suggestion; + } + return null; + }, [initial, providerType.value]); + useEffect(() { final resolved = _resolveDefaultModel(models.value, defaultModel.value); if (resolved != defaultModel.value) { @@ -184,7 +193,8 @@ class AiProviderEditPage extends HookConsumerWidget { decoration: InputDecoration( isDense: true, border: InputBorder.none, - hintText: 'OpenAI / Anthropic / Self-hosted', + hintText: + 'OpenAI / Anthropic / Gemini / Self-hosted', hintStyle: TextStyle(color: theme.secondaryText), ), ), @@ -205,16 +215,34 @@ class AiProviderEditPage extends HookConsumerWidget { ), iconEnabledColor: inputIconColor, onChanged: (value) { - if (value != null) providerType.value = value; + if (value == null || + value == providerType.value) { + return; + } + final previousType = providerType.value; + providerType.value = value; + if (initial == null) { + final suggestion = _defaultBaseUrlFor(value); + final current = baseUrlController.text.trim(); + final replaceCurrent = + current.isEmpty || + current == _defaultBaseUrlFor(previousType); + if (replaceCurrent && suggestion.isNotEmpty) { + baseUrlController.text = suggestion; + } + } }, items: AiProviderType.values .map( (type) => DropdownMenuItem( value: type, child: Text( - type == AiProviderType.anthropic - ? 'Anthropic' - : 'OpenAI Compatible', + switch (type) { + AiProviderType.anthropic => 'Anthropic', + AiProviderType.gemini => 'Gemini', + AiProviderType.openaiCompatible => + 'OpenAI Compatible', + }, ), ), ) @@ -245,12 +273,22 @@ class AiProviderEditPage extends HookConsumerWidget { decoration: InputDecoration( isDense: true, border: InputBorder.none, - hintText: 'https://api.example.com/v1', + hintText: _baseUrlHintFor(providerType.value), hintStyle: TextStyle(color: theme.secondaryText), ), ), ), ), + Padding( + padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), + child: Text( + _baseUrlHelperTextFor(providerType.value), + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), const _SectionLabel( title: 'Authorization', ), @@ -284,7 +322,7 @@ class AiProviderEditPage extends HookConsumerWidget { decoration: InputDecoration( isDense: true, border: InputBorder.none, - hintText: 'sk-...', + hintText: _apiKeyHintFor(providerType.value), hintStyle: TextStyle(color: theme.secondaryText), ), ), @@ -387,6 +425,33 @@ class AiProviderEditPage extends HookConsumerWidget { } return models.first; } + + static String _defaultBaseUrlFor(AiProviderType type) => switch (type) { + AiProviderType.openaiCompatible => '', + AiProviderType.anthropic => 'https://api.anthropic.com/v1', + AiProviderType.gemini => 'https://generativelanguage.googleapis.com/v1beta', + }; + + static String _baseUrlHintFor(AiProviderType type) => switch (type) { + AiProviderType.openaiCompatible => 'https://api.example.com/v1', + AiProviderType.anthropic => 'https://api.anthropic.com/v1', + AiProviderType.gemini => 'https://generativelanguage.googleapis.com/v1beta', + }; + + static String _baseUrlHelperTextFor(AiProviderType type) => switch (type) { + AiProviderType.openaiCompatible => + 'For OpenAI-compatible APIs, use the server root that exposes /chat/completions.', + AiProviderType.anthropic => + 'Anthropic uses the Messages API under /v1/messages.', + AiProviderType.gemini => + 'Gemini uses the Google Generative Language API and appends /models/{model}:streamGenerateContent automatically.', + }; + + static String _apiKeyHintFor(AiProviderType type) => switch (type) { + AiProviderType.openaiCompatible => 'sk-...', + AiProviderType.anthropic => 'sk-ant-...', + AiProviderType.gemini => 'AIza...', + }; } class _SectionLabel extends StatelessWidget { From 69ce809214173bf6527d7c2ce924e1fb29c0030d Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:28:52 +0800 Subject: [PATCH 17/52] feat: add AI conversation tools and enhance message handling --- lib/ai/ai_chat_controller.dart | 1069 +++++++++++++++-- lib/ai/model/ai_prompt_message.dart | 19 +- lib/ai/model/ai_tool.dart | 39 + .../tools/ai_conversation_tool_service.dart | 642 ++++++++++ lib/db/dao/message_dao.dart | 53 + lib/ui/home/chat/ai_draft_assist_panel.dart | 8 +- lib/ui/home/chat/input_container.dart | 3 +- 7 files changed, 1760 insertions(+), 73 deletions(-) create mode 100644 lib/ai/model/ai_tool.dart create mode 100644 lib/ai/tools/ai_conversation_tool_service.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 987cc3e084..91d741ee3a 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -7,12 +7,15 @@ import 'package:mixin_logger/mixin_logger.dart'; import 'package:uuid/uuid.dart'; import '../db/dao/ai_chat_message_dao.dart'; +import '../db/dao/message_dao.dart'; import '../db/database.dart'; import '../db/mixin_database.dart'; import '../utils/proxy.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_provider_config.dart'; import 'model/ai_provider_type.dart'; +import 'model/ai_tool.dart'; +import 'tools/ai_conversation_tool_service.dart'; const _kAiRoleUser = 'user'; const _kAiRoleAssistant = 'assistant'; @@ -20,9 +23,14 @@ const _kAiStatusPending = 'pending'; const _kAiStatusDone = 'done'; const _kAiStatusError = 'error'; const _kAiContextMessageLimit = 30; +const _kAiRetrievedMessageLimit = 6; const _kAiHistoryLimit = 12; const _kAiStreamFlushChars = 32; const _kAiStreamFlushInterval = Duration(milliseconds: 80); +const _kAiRetrievalQueryMaxLength = 120; +const _kAiToolMaxRounds = 8; +const _kAiLogPreviewLength = 240; +const _kAiLogJsonPreviewLength = 480; final kAiRuntimeStartedAt = DateTime.now(); final _activeAiRequests = {}; @@ -39,6 +47,11 @@ class AiChatController { static const _openAiStrategy = _OpenAiCompatibleStrategy(); static const _anthropicStrategy = _AnthropicStrategy(); static const _geminiStrategy = _GeminiStrategy(); + late final DatabaseAiConversationToolService _conversationToolService = + DatabaseAiConversationToolService(database); + late final AiConversationToolKit _conversationTools = AiConversationToolKit( + _conversationToolService, + ); Future assistText({ required String instruction, @@ -51,6 +64,12 @@ class AiChatController { throw Exception('No AI provider configured'); } + d( + 'AI assist start: provider=${config.type.name} model=${config.model} ' + 'conversationId=$conversationId instruction=${_previewText(instruction)} ' + 'input=${_previewText(input)}', + ); + final messages = await _buildAssistPromptMessages( instruction: instruction, input: input, @@ -62,12 +81,19 @@ class AiChatController { _activeAiRequests[conversationId] = cancelToken; } try { - return await _streamRequest( + final result = await _requestText( config, messages, cancelToken: cancelToken, onContent: (_) async {}, + conversationId: conversationId, + streamFinalResponse: false, ); + d( + 'AI assist done: provider=${config.type.name} model=${config.model} ' + 'conversationId=$conversationId output=${_previewText(result)}', + ); + return result; } finally { if (conversationId != null && _activeAiRequests[conversationId] == cancelToken) { @@ -100,6 +126,12 @@ class AiChatController { throw Exception('No AI provider configured'); } + d( + 'AI send start: conversationId=$conversationId ' + 'provider=${config.type.name} model=${config.model} ' + 'input=${_previewText(input)}', + ); + final now = DateTime.now(); final userMessageId = _uuid.v4(); final assistantMessageId = _uuid.v4(); @@ -149,11 +181,13 @@ class AiChatController { _activeAiRequests[conversationId] = cancelToken; try { final messages = await _buildPromptMessages(conversationId, input); - final result = await _streamRequest( + final result = await _requestText( config, messages, cancelToken: cancelToken, onContent: updater.append, + conversationId: conversationId, + streamFinalResponse: true, ); await updater.flush(contentOverride: result, force: true); await database.aiChatMessageDao.updateMessageStatus( @@ -161,8 +195,16 @@ class AiChatController { _kAiStatusDone, updatedAt: DateTime.now(), ); + d( + 'AI send done: conversationId=$conversationId ' + 'assistantMessageId=$assistantMessageId output=${_previewText(result)}', + ); } catch (error, stacktrace) { if (cancelToken.isCancelled) { + d( + 'AI send cancelled: conversationId=$conversationId ' + 'assistantMessageId=$assistantMessageId', + ); await updater.flush(force: true); await database.aiChatMessageDao.updateMessageStatus( assistantMessageId, @@ -188,9 +230,74 @@ class AiChatController { } void stop(String conversationId) { + d('AI stop requested: conversationId=$conversationId'); _activeAiRequests[conversationId]?.cancel('AI generation stopped'); } + Future summarizeConversationRange({ + required String conversationId, + required DateTime startInclusive, + required DateTime endExclusive, + String? languageTag, + AiProviderConfig? provider, + }) async { + final config = provider ?? database.settingProperties.selectedAiProvider; + if (config == null) { + throw Exception('No AI provider configured'); + } + + final stats = await _conversationToolService.getConversationStats( + conversationId: conversationId, + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + if (stats.messageCount <= 0) { + return 'No messages found in the selected time range.'; + } + + final messages = _buildConversationSummaryPromptMessages( + stats: stats, + languageTag: languageTag, + ); + final cancelToken = CancelToken(); + _activeAiRequests[conversationId] = cancelToken; + try { + return await _requestText( + config, + messages, + cancelToken: cancelToken, + onContent: (_) async {}, + conversationId: conversationId, + streamFinalResponse: false, + ); + } finally { + if (_activeAiRequests[conversationId] == cancelToken) { + _activeAiRequests.remove(conversationId); + } + } + } + + Future summarizeConversationToday({ + required String conversationId, + String? languageTag, + AiProviderConfig? provider, + DateTime? now, + }) { + final localNow = now ?? DateTime.now(); + final startInclusive = DateTime( + localNow.year, + localNow.month, + localNow.day, + ); + return summarizeConversationRange( + conversationId: conversationId, + startInclusive: startInclusive, + endExclusive: startInclusive.add(const Duration(days: 1)), + languageTag: languageTag, + provider: provider, + ); + } + Future> _buildPromptMessages( String conversationId, String input, @@ -198,6 +305,11 @@ class AiChatController { final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _kAiContextMessageLimit) .get(); + final retrievedMessages = await _retrieveConversationMessages( + conversationId: conversationId, + recentMessages: recentMessages, + query: input, + ); final aiMessages = await database.aiChatMessageDao.conversationMessages( conversationId, ); @@ -213,21 +325,12 @@ class AiChatController { ), ]; - if (recentMessages.isNotEmpty) { - final lines = recentMessages.reversed - .map((message) { - final sender = message.userFullName ?? message.userId; - final content = _messagePlainText(message); - return '[${message.createdAt.toIso8601String()}] $sender: $content'; - }) - .join('\n'); - promptMessages.add( - AiPromptMessage( - role: 'system', - content: 'Current conversation recent messages:\n$lines', - ), - ); - } + _appendConversationToolInstruction(promptMessages, enabled: true); + _appendConversationContext( + promptMessages, + recentMessages: recentMessages, + retrievedMessages: retrievedMessages, + ); final history = aiMessages .where((element) => element.status != _kAiStatusPending) @@ -239,6 +342,11 @@ class AiChatController { } promptMessages.add(AiPromptMessage(role: _kAiRoleUser, content: input)); + d( + 'AI prompt built: conversationId=$conversationId ' + 'recent=${recentMessages.length} retrieved=${retrievedMessages.length} ' + 'history=${history.length} promptMessages=${promptMessages.length}', + ); return promptMessages; } @@ -258,24 +366,20 @@ class AiChatController { ]; if (conversationId != null) { + _appendConversationToolInstruction(promptMessages, enabled: true); final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _kAiContextMessageLimit) .get(); - if (recentMessages.isNotEmpty) { - final lines = recentMessages.reversed - .map((message) { - final sender = message.userFullName ?? message.userId; - final content = _messagePlainText(message); - return '[${message.createdAt.toIso8601String()}] $sender: $content'; - }) - .join('\n'); - promptMessages.add( - AiPromptMessage( - role: 'system', - content: 'Current conversation recent messages:\n$lines', - ), - ); - } + final retrievedMessages = await _retrieveConversationMessages( + conversationId: conversationId, + recentMessages: recentMessages, + query: input ?? _latestRetrievalSeed(recentMessages), + ); + _appendConversationContext( + promptMessages, + recentMessages: recentMessages, + retrievedMessages: retrievedMessages, + ); } final inputText = input?.trim(); @@ -288,44 +392,456 @@ class AiChatController { ].join('\n'), ), ); + d( + 'AI assist prompt built: conversationId=$conversationId ' + 'messages=${promptMessages.length}', + ); return promptMessages; } - String _messagePlainText(MessageItem message) { - if (message.content?.trim().isNotEmpty == true) { - return message.content!.trim(); + List _buildConversationSummaryPromptMessages({ + required AiConversationToolStats stats, + required String? languageTag, + }) { + final outputLanguage = languageTag?.trim(); + return [ + AiPromptMessage( + role: 'system', + content: + 'You are a conversation summarizer inside a chat application. ' + 'For this task, you must use the available read-only conversation tools ' + 'to inspect the requested time range before writing the final answer. ' + 'Start by calling list_conversation_chunks for the exact range, then ' + 'read_conversation_chunk until you have covered the full range. ' + 'Do not rely only on recent context or search for this task.', + ), + AiPromptMessage( + role: 'system', + content: + 'Summaries must cover the requested range completely and should include ' + 'main topics, key decisions, action items, unresolved questions, and ' + 'notable follow-ups. Keep the final answer concise but comprehensive.', + ), + AiPromptMessage( + role: 'user', + content: [ + 'Summarize the conversation messages in this time range.', + 'Conversation ID: ${stats.conversationId}', + 'Start time: ${stats.startInclusive?.toIso8601String() ?? 'unspecified'}', + 'End time: ${stats.endExclusive?.toIso8601String() ?? 'unspecified'}', + 'Messages in range: ${stats.messageCount}', + if (stats.firstMessageAt != null) + 'First message at: ${stats.firstMessageAt!.toIso8601String()}', + if (stats.lastMessageAt != null) + 'Last message at: ${stats.lastMessageAt!.toIso8601String()}', + if (outputLanguage != null && outputLanguage.isNotEmpty) + 'Write the final summary in $outputLanguage.', + 'Before finalizing, make sure you have covered every chunk in the range.', + 'Return only the summary text.', + ].join('\n'), + ), + ]; + } + + void _appendConversationToolInstruction( + List promptMessages, { + required bool enabled, + }) { + if (!enabled) { + return; } - if (message.mediaName?.isNotEmpty == true) { - return '[${message.type}] ${message.mediaName}'; + promptMessages.add( + AiPromptMessage( + role: 'system', + content: + 'Read-only conversation tools are available for the current conversation. ' + 'Use them when you need exhaustive coverage, date-scoped summaries, ' + 'statistics, older messages, or more context than the provided messages. ' + 'Do not call tools when the provided context is already sufficient.', + ), + ); + } + + void _appendConversationContext( + List promptMessages, { + required List recentMessages, + required List retrievedMessages, + }) { + if (recentMessages.isNotEmpty) { + final lines = recentMessages.reversed + .map( + (message) => _conversationContextLine( + createdAt: message.createdAt, + sender: message.userFullName ?? message.userId, + content: _messagePlainText(message), + ), + ) + .join('\n'); + promptMessages.add( + AiPromptMessage( + role: 'system', + content: 'Current conversation recent messages:\n$lines', + ), + ); + } + + if (retrievedMessages.isEmpty) { + return; + } + + final lines = retrievedMessages + .map( + (message) => _conversationContextLine( + createdAt: message.createdAt, + sender: message.senderFullName ?? message.senderId, + content: _searchMessagePlainText(message), + ), + ) + .join('\n'); + promptMessages.add( + AiPromptMessage( + role: 'system', + content: + 'Relevant older conversation messages matched by search ' + '(use only if they help answer the current request):\n$lines', + ), + ); + } + + Future> _retrieveConversationMessages({ + required String conversationId, + required List recentMessages, + required String? query, + }) async { + final normalizedQuery = _normalizeRetrievalQuery(query); + if (normalizedQuery == null) { + d('AI retrieval skipped: conversationId=$conversationId empty query'); + return const []; + } + + final recentIds = recentMessages + .map((message) => message.messageId) + .toSet(); + final matchedIds = await database.ftsDatabase.fuzzySearchMessage( + query: normalizedQuery, + limit: _kAiRetrievedMessageLimit + recentIds.length, + conversationIds: [conversationId], + ); + final candidateIds = matchedIds + .where((messageId) => !recentIds.contains(messageId)) + .take(_kAiRetrievedMessageLimit) + .toList(growable: false); + if (candidateIds.isEmpty) { + d( + 'AI retrieval no match: conversationId=$conversationId ' + 'query=${_previewText(normalizedQuery)}', + ); + return const []; + } + + final matchedMessages = await database.messageDao + .searchMessageByIds(candidateIds) + .get(); + final messagesById = { + for (final message in matchedMessages) message.messageId: message, + }; + final ordered = []; + for (final messageId in candidateIds) { + final message = messagesById[messageId]; + if (message != null) { + ordered.add(message); + } } - return '[${message.type}]'; + ordered.sort((left, right) => left.createdAt.compareTo(right.createdAt)); + d( + 'AI retrieval matched: conversationId=$conversationId ' + 'query=${_previewText(normalizedQuery)} matches=${ordered.length}', + ); + return ordered; } - Future _streamRequest( + String? _latestRetrievalSeed(List recentMessages) { + for (final message in recentMessages) { + final content = _messagePlainText(message); + final normalized = _normalizeRetrievalQuery(content); + if (normalized != null) { + return normalized; + } + } + return null; + } + + String? _normalizeRetrievalQuery(String? query) { + final compact = query?.replaceAll(RegExp(r'\s+'), ' ').trim(); + if (compact == null || compact.isEmpty) { + return null; + } + if (compact.length <= _kAiRetrievalQueryMaxLength) { + return compact; + } + return compact.substring(0, _kAiRetrievalQueryMaxLength); + } + + String _conversationContextLine({ + required DateTime createdAt, + required String sender, + required String content, + }) => '[${createdAt.toIso8601String()}] $sender: $content'; + + String _messagePlainText(MessageItem message) => _messagePlainTextFromFields( + content: message.content, + mediaName: message.mediaName, + type: message.type, + ); + + String _searchMessagePlainText(SearchMessageDetailItem message) => + _messagePlainTextFromFields( + content: message.content, + mediaName: message.mediaName, + type: message.type, + ); + + String _messagePlainTextFromFields({ + required String? content, + required String? mediaName, + required String type, + }) { + if (content?.trim().isNotEmpty == true) { + return content!.trim(); + } + if (mediaName?.isNotEmpty == true) { + return '[$type] $mediaName'; + } + return '[$type]'; + } + + Future _requestText( AiProviderConfig config, List messages, { required CancelToken cancelToken, required Future Function(String chunk) onContent, + required bool streamFinalResponse, + String? conversationId, }) async { - final dio = Dio( - BaseOptions( - baseUrl: config.baseUrl, - connectTimeout: const Duration(seconds: 20), - receiveTimeout: const Duration(minutes: 5), - sendTimeout: const Duration(seconds: 20), - headers: _strategyFor(config.type).headers(config), - ), - )..applyProxy(database.settingProperties.activatedProxy); + d( + 'AI request start: provider=${config.type.name} model=${config.model} ' + 'conversationId=$conversationId streamFinal=$streamFinalResponse ' + 'messages=${messages.length} tools=${conversationId != null}', + ); + final dio = + Dio( + BaseOptions( + baseUrl: config.baseUrl, + connectTimeout: const Duration(seconds: 20), + receiveTimeout: const Duration(minutes: 5), + sendTimeout: const Duration(seconds: 20), + headers: _strategyFor(config.type).headers(config), + ), + ) + ..interceptors.add( + InterceptorsWrapper( + onRequest: (options, handler) { + options.extra['ai_request_started_at'] = DateTime.now(); + d( + 'AI HTTP request: ${options.method} ${options.uri} ' + 'provider=${config.type.name} model=${config.model}', + ); + handler.next(options); + }, + onResponse: (response, handler) { + final startedAt = + response.requestOptions.extra['ai_request_started_at'] + as DateTime?; + d( + 'AI HTTP response: ${response.requestOptions.method} ' + '${response.requestOptions.uri} status=${response.statusCode} ' + 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds}', + ); + handler.next(response); + }, + onError: (error, handler) { + final startedAt = + error.requestOptions.extra['ai_request_started_at'] + as DateTime?; + e( + 'AI HTTP error: ${error.requestOptions.method} ' + '${error.requestOptions.uri} ' + 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds} ' + 'error=${error.message}', + error, + error.stackTrace, + ); + handler.next(error); + }, + ), + ) + ..applyProxy(database.settingProperties.activatedProxy); + + if (conversationId == null) { + return _strategyFor(config.type).streamResponse( + dio: dio, + config: config, + messages: messages, + cancelToken: cancelToken, + onContent: onContent, + ); + } - return _strategyFor(config.type).streamResponse( - dio: dio, - config: config, - messages: messages, + return _requestWithTools( + dio, + config, + [...messages], + conversationId: conversationId, cancelToken: cancelToken, onContent: onContent, + streamFinalResponse: streamFinalResponse, ); } + Future _requestWithTools( + Dio dio, + AiProviderConfig config, + List messages, { + required String conversationId, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + required bool streamFinalResponse, + }) async { + for (var round = 0; round < _kAiToolMaxRounds; round++) { + d( + 'AI tool round start: conversationId=$conversationId ' + 'round=${round + 1}/$_kAiToolMaxRounds messages=${messages.length}', + ); + final response = await _strategyFor(config.type).completeResponse( + dio: dio, + config: config, + messages: messages, + tools: AiConversationToolKit.definitions, + cancelToken: cancelToken, + ); + d( + 'AI tool round response: conversationId=$conversationId ' + 'round=${round + 1} text=${_previewText(response.text)} ' + 'toolCalls=${_previewToolCalls(response.toolCalls)}', + ); + + if (!response.hasToolCalls) { + final text = response.text.trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + if (streamFinalResponse) { + try { + d( + 'AI final stream start: conversationId=$conversationId ' + 'round=${round + 1}', + ); + return await _strategyFor(config.type).streamResponse( + dio: dio, + config: config, + messages: messages, + cancelToken: cancelToken, + onContent: onContent, + ); + } catch (error, stacktrace) { + e('AI final streaming fallback: $error, $stacktrace'); + await _emitBufferedText(text, onContent); + d( + 'AI final stream fallback: conversationId=$conversationId ' + 'round=${round + 1} text=${_previewText(text)}', + ); + return text; + } + } + await onContent(text); + d( + 'AI tool request done without stream: conversationId=$conversationId ' + 'round=${round + 1} text=${_previewText(text)}', + ); + return text; + } + + messages.add( + AiPromptMessage( + role: _kAiRoleAssistant, + content: response.text, + toolCalls: response.toolCalls, + ), + ); + for (final toolCall in response.toolCalls) { + final result = await _executeConversationTool( + conversationId: conversationId, + toolCall: toolCall, + ); + messages.add( + AiPromptMessage( + role: 'tool', + content: result.content, + toolCallId: result.toolCallId, + toolName: result.toolName, + toolPayload: result.payload, + ), + ); + } + } + + e( + 'AI exceeded tool call limit: conversationId=$conversationId ' + 'maxRounds=$_kAiToolMaxRounds', + ); + throw Exception('AI exceeded tool call limit'); + } + + Future _emitBufferedText( + String text, + Future Function(String chunk) onContent, + ) async { + final trimmed = text.trim(); + if (trimmed.isEmpty) { + return; + } + if (trimmed.length <= _kAiStreamFlushChars) { + await onContent(trimmed); + return; + } + for (var start = 0; start < trimmed.length; start += _kAiStreamFlushChars) { + final end = (start + _kAiStreamFlushChars).clamp(0, trimmed.length); + await onContent(trimmed.substring(start, end)); + } + } + + Future _executeConversationTool({ + required String conversationId, + required AiToolCall toolCall, + }) async { + final stopwatch = Stopwatch()..start(); + d( + 'AI tool execute start: conversationId=$conversationId ' + 'tool=${toolCall.name} id=${toolCall.id} ' + 'arguments=${_previewJson(toolCall.arguments)}', + ); + try { + final result = await _conversationTools.execute( + conversationId: conversationId, + call: toolCall, + ); + d( + 'AI tool execute done: conversationId=$conversationId ' + 'tool=${toolCall.name} id=${toolCall.id} ' + 'elapsedMs=${stopwatch.elapsedMilliseconds} ' + 'result=${_previewJson(result.payload)}', + ); + return result; + } catch (error, stacktrace) { + e('AI tool execution error: $error, $stacktrace'); + return AiToolExecutionResult( + toolCallId: toolCall.id, + toolName: toolCall.name, + payload: {'error': '$error'}, + ); + } + } + _AiProviderStrategy _strategyFor(AiProviderType type) => switch (type) { AiProviderType.openaiCompatible => _openAiStrategy, AiProviderType.anthropic => _anthropicStrategy, @@ -333,6 +849,41 @@ class AiChatController { }; } +String _previewText(String? text, {int maxLength = _kAiLogPreviewLength}) { + final compact = text?.replaceAll(RegExp(r'\s+'), ' ').trim() ?? ''; + if (compact.isEmpty) { + return '""'; + } + if (compact.length <= maxLength) { + return compact; + } + return '${compact.substring(0, maxLength)}...(${compact.length} chars)'; +} + +String _previewJson(Object? value, {int maxLength = _kAiLogJsonPreviewLength}) { + try { + final encoded = jsonEncode(value); + if (encoded.length <= maxLength) { + return encoded; + } + return '${encoded.substring(0, maxLength)}...(${encoded.length} chars)'; + } catch (_) { + return '$value'; + } +} + +String _previewToolCalls(List toolCalls) { + if (toolCalls.isEmpty) { + return '[]'; + } + return toolCalls + .map( + (toolCall) => + '${toolCall.name}#${toolCall.id}(${_previewJson(toolCall.arguments, maxLength: 120)})', + ) + .join(', '); +} + extension on Iterable { Iterable takeLast(int count) { if (count <= 0) return const []; @@ -349,6 +900,14 @@ abstract interface class _AiProviderStrategy { Map headers(AiProviderConfig config); + Future<_AiCompletionResponse> completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }); + Future streamResponse({ required Dio dio, required AiProviderConfig config, @@ -367,6 +926,54 @@ class _OpenAiCompatibleStrategy implements _AiProviderStrategy { 'Content-Type': 'application/json', }; + @override + Future<_AiCompletionResponse> completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }) async { + final response = await dio.post( + '/chat/completions', + data: { + 'model': config.model, + 'messages': messages.map(_openAiMessagePayload).toList(growable: false), + if (tools.isNotEmpty) + 'tools': tools + .map( + (tool) => { + 'type': 'function', + 'function': { + 'name': tool.name, + 'description': tool.description, + 'parameters': tool.inputSchema, + }, + }, + ) + .toList(growable: false), + if (tools.isNotEmpty) 'tool_choice': 'auto', + }, + cancelToken: cancelToken, + ); + + final body = _jsonMap(response.data); + final choices = body['choices'] as List?; + if (choices == null || choices.isEmpty) { + throw Exception('Empty AI response'); + } + final first = _jsonMap(choices.first); + final message = _jsonMap(first['message']); + final text = _stringContent(message['content']); + final toolCalls = (message['tool_calls'] as List? ?? const []) + .map((item) => _openAiToolCall(_jsonMap(item))) + .toList(growable: false); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return _AiCompletionResponse(text: text, toolCalls: toolCalls); + } + @override Future streamResponse({ required Dio dio, @@ -380,11 +987,7 @@ class _OpenAiCompatibleStrategy implements _AiProviderStrategy { data: { 'model': config.model, 'stream': true, - 'messages': messages - .map( - (message) => {'role': message.role, 'content': message.content}, - ) - .toList(), + 'messages': messages.map(_openAiMessagePayload).toList(growable: false), }, options: Options(responseType: ResponseType.stream), cancelToken: cancelToken, @@ -434,6 +1037,38 @@ class _OpenAiCompatibleStrategy implements _AiProviderStrategy { } return text; } + + Map _openAiMessagePayload(AiPromptMessage message) => { + 'role': message.role, + 'content': message.content, + if (message.hasToolCalls) + 'tool_calls': message.toolCalls + .map( + (toolCall) => { + 'id': toolCall.id, + 'type': 'function', + 'function': { + 'name': toolCall.name, + 'arguments': jsonEncode(toolCall.arguments), + }, + }, + ) + .toList(growable: false), + if (message.isToolResult) 'tool_call_id': message.toolCallId, + }; + + AiToolCall _openAiToolCall(Map value) { + final function = _jsonMap(value['function']); + final name = function['name'] as String?; + if (name == null || name.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + return AiToolCall( + id: value['id'] as String? ?? '${name}_${value.hashCode}', + name: name, + arguments: _toolArguments(function['arguments']), + ); + } } class _AnthropicStrategy implements _AiProviderStrategy { @@ -446,6 +1081,85 @@ class _AnthropicStrategy implements _AiProviderStrategy { 'content-type': 'application/json', }; + @override + Future<_AiCompletionResponse> completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }) async { + final response = await dio.post( + '/messages', + data: { + 'model': config.model, + 'max_tokens': 1024, + 'messages': messages + .where((message) => message.role != 'system') + .map(_anthropicMessagePayload) + .toList(growable: false), + 'system': messages + .where((message) => message.role == 'system') + .map((message) => message.content) + .where((content) => content.isNotEmpty) + .join('\n\n'), + if (tools.isNotEmpty) + 'tools': tools + .map( + (tool) => { + 'name': tool.name, + 'description': tool.description, + 'input_schema': tool.inputSchema, + }, + ) + .toList(growable: false), + }, + cancelToken: cancelToken, + ); + + final body = _jsonMap(response.data); + if (body['type'] == 'error') { + final error = _jsonMap(body['error']); + throw Exception(error['message'] ?? 'Anthropic request failed'); + } + + final content = body['content'] as List?; + if (content == null || content.isEmpty) { + throw Exception('Empty AI response'); + } + + final textBuffer = StringBuffer(); + final toolCalls = []; + for (final item in content) { + final block = _jsonMap(item); + switch (block['type']) { + case 'text': + final text = block['text']; + if (text is String && text.isNotEmpty) { + textBuffer.write(text); + } + case 'tool_use': + final name = block['name'] as String?; + if (name == null || name.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + toolCalls.add( + AiToolCall( + id: block['id'] as String? ?? '${name}_${block.hashCode}', + name: name, + arguments: _toolArguments(block['input']), + ), + ); + } + } + + final text = textBuffer.toString(); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return _AiCompletionResponse(text: text, toolCalls: toolCalls); + } + @override Future streamResponse({ required Dio dio, @@ -462,10 +1176,8 @@ class _AnthropicStrategy implements _AiProviderStrategy { 'stream': true, 'messages': messages .where((message) => message.role != 'system') - .map( - (message) => {'role': message.role, 'content': message.content}, - ) - .toList(), + .map(_anthropicMessagePayload) + .toList(growable: false), 'system': messages .where((message) => message.role == 'system') .map((message) => message.content) @@ -522,6 +1234,37 @@ class _AnthropicStrategy implements _AiProviderStrategy { } return text; } + + Map _anthropicMessagePayload(AiPromptMessage message) => { + 'role': message.isToolResult ? 'user' : message.role, + 'content': _anthropicContentBlocks(message), + }; + + List> _anthropicContentBlocks(AiPromptMessage message) { + if (message.isToolResult) { + return [ + { + 'type': 'tool_result', + 'tool_use_id': message.toolCallId, + 'content': message.content, + }, + ]; + } + + final blocks = >[]; + if (message.content.isNotEmpty) { + blocks.add({'type': 'text', 'text': message.content}); + } + for (final toolCall in message.toolCalls) { + blocks.add({ + 'type': 'tool_use', + 'id': toolCall.id, + 'name': toolCall.name, + 'input': toolCall.arguments, + }); + } + return blocks; + } } class _GeminiStrategy implements _AiProviderStrategy { @@ -533,6 +1276,116 @@ class _GeminiStrategy implements _AiProviderStrategy { 'content-type': 'application/json', }; + @override + Future<_AiCompletionResponse> completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }) async { + final systemInstruction = messages + .where((message) => message.role == 'system') + .map((message) => message.content.trim()) + .where((content) => content.isNotEmpty) + .join('\n\n'); + final response = await dio.post( + '/models/${Uri.encodeComponent(config.model)}:generateContent', + data: { + 'contents': messages + .where((message) => message.role != 'system') + .map(_geminiMessagePayload) + .toList(growable: false), + if (systemInstruction.isNotEmpty) + 'system_instruction': { + 'parts': [ + {'text': systemInstruction}, + ], + }, + if (tools.isNotEmpty) + 'tools': [ + { + 'functionDeclarations': tools + .map( + (tool) => { + 'name': tool.name, + 'description': tool.description, + 'parameters': tool.inputSchema, + }, + ) + .toList(growable: false), + }, + ], + if (tools.isNotEmpty) + 'toolConfig': { + 'functionCallingConfig': {'mode': 'AUTO'}, + }, + 'generationConfig': { + 'candidateCount': 1, + }, + }, + cancelToken: cancelToken, + ); + + final body = _jsonMap(response.data); + final promptFeedback = body['promptFeedback']; + if (promptFeedback is Map) { + final blockReason = promptFeedback['blockReason']; + if (blockReason is String && blockReason.isNotEmpty) { + throw Exception('Gemini request blocked: $blockReason'); + } + } + + final candidates = body['candidates'] as List?; + if (candidates == null || candidates.isEmpty) { + throw Exception('Empty AI response'); + } + final first = _jsonMap(candidates.first); + final finishReason = first['finishReason']; + if (finishReason is String && + finishReason.isNotEmpty && + finishReason != 'STOP' && + finishReason != 'FINISH_REASON_UNSPECIFIED') { + throw Exception('Gemini request finished with reason: $finishReason'); + } + + final content = _jsonMap(first['content']); + final parts = content['parts'] as List?; + if (parts == null || parts.isEmpty) { + throw Exception('Empty AI response'); + } + + final textBuffer = StringBuffer(); + final toolCalls = []; + for (final item in parts) { + final part = _jsonMap(item); + final text = part['text']; + if (text is String && text.isNotEmpty) { + textBuffer.write(text); + } + final functionCall = part['functionCall']; + if (functionCall is Map) { + final name = functionCall['name'] as String?; + if (name == null || name.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + toolCalls.add( + AiToolCall( + id: '${name}_${functionCall.hashCode}', + name: name, + arguments: _toolArguments(functionCall['args']), + ), + ); + } + } + + final text = textBuffer.toString(); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return _AiCompletionResponse(text: text, toolCalls: toolCalls); + } + @override Future streamResponse({ required Dio dio, @@ -549,15 +1402,8 @@ class _GeminiStrategy implements _AiProviderStrategy { final contents = messages .where((message) => message.role != 'system') - .map( - (message) => { - 'role': message.role == _kAiRoleAssistant ? 'model' : 'user', - 'parts': [ - {'text': message.content}, - ], - }, - ) - .toList(); + .map(_geminiMessagePayload) + .toList(growable: false); final response = await dio.post( '/models/${Uri.encodeComponent(config.model)}:streamGenerateContent', @@ -644,6 +1490,89 @@ class _GeminiStrategy implements _AiProviderStrategy { } return text; } + + Map _geminiMessagePayload(AiPromptMessage message) => { + 'role': message.role == _kAiRoleAssistant ? 'model' : 'user', + 'parts': _geminiMessageParts(message), + }; + + List> _geminiMessageParts(AiPromptMessage message) { + if (message.isToolResult) { + return [ + { + 'functionResponse': { + 'name': message.toolName, + 'response': message.toolPayload ?? {'content': message.content}, + }, + }, + ]; + } + + final parts = >[]; + if (message.content.isNotEmpty) { + parts.add({'text': message.content}); + } + for (final toolCall in message.toolCalls) { + parts.add({ + 'functionCall': { + 'name': toolCall.name, + 'args': toolCall.arguments, + }, + }); + } + return parts; + } +} + +class _AiCompletionResponse { + const _AiCompletionResponse({ + this.text = '', + this.toolCalls = const [], + }); + + final String text; + final List toolCalls; + + bool get hasToolCalls => toolCalls.isNotEmpty; +} + +Map _jsonMap(dynamic value) { + if (value is Map) { + return value; + } + if (value is Map) { + return value.map((key, value) => MapEntry('$key', value)); + } + throw Exception('Invalid AI response payload'); +} + +Map _toolArguments(dynamic value) { + if (value == null) { + return const {}; + } + if (value is String) { + final trimmed = value.trim(); + if (trimmed.isEmpty) { + return const {}; + } + final decoded = jsonDecode(trimmed); + return _jsonMap(decoded); + } + return _jsonMap(value); +} + +String _stringContent(dynamic value) { + if (value is String) { + return value; + } + if (value is List) { + return value + .whereType() + .map((item) => item['text']) + .whereType() + .join('\n'); + } + return ''; } class _StreamingMessageUpdater { diff --git a/lib/ai/model/ai_prompt_message.dart b/lib/ai/model/ai_prompt_message.dart index c7ab2f94a4..0574dccdeb 100644 --- a/lib/ai/model/ai_prompt_message.dart +++ b/lib/ai/model/ai_prompt_message.dart @@ -1,6 +1,23 @@ +import 'ai_tool.dart'; + class AiPromptMessage { - AiPromptMessage({required this.role, required this.content}); + AiPromptMessage({ + required this.role, + required this.content, + List? toolCalls, + this.toolCallId, + this.toolName, + this.toolPayload, + }) : toolCalls = toolCalls ?? const []; final String role; final String content; + final List toolCalls; + final String? toolCallId; + final String? toolName; + final Map? toolPayload; + + bool get hasToolCalls => toolCalls.isNotEmpty; + + bool get isToolResult => role == 'tool'; } diff --git a/lib/ai/model/ai_tool.dart b/lib/ai/model/ai_tool.dart new file mode 100644 index 0000000000..76e2fbff5b --- /dev/null +++ b/lib/ai/model/ai_tool.dart @@ -0,0 +1,39 @@ +import 'dart:convert'; + +class AiToolDefinition { + const AiToolDefinition({ + required this.name, + required this.description, + required this.inputSchema, + }); + + final String name; + final String description; + final Map inputSchema; +} + +class AiToolCall { + const AiToolCall({ + required this.id, + required this.name, + required this.arguments, + }); + + final String id; + final String name; + final Map arguments; +} + +class AiToolExecutionResult { + const AiToolExecutionResult({ + required this.toolCallId, + required this.toolName, + required this.payload, + }); + + final String toolCallId; + final String toolName; + final Map payload; + + String get content => jsonEncode(payload); +} diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart new file mode 100644 index 0000000000..8ba90468ea --- /dev/null +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -0,0 +1,642 @@ +import 'dart:math' as math; + +import '../../db/dao/message_dao.dart'; +import '../../db/database.dart'; +import '../../db/mixin_database.dart'; +import '../model/ai_tool.dart'; + +const _kDefaultConversationChunkSize = 100; +const _kMaxConversationChunkSize = 200; +const _kDefaultConversationSearchLimit = 8; +const _kMaxConversationSearchLimit = 20; + +class AiConversationToolMessage { + const AiConversationToolMessage({ + required this.messageId, + required this.createdAt, + required this.senderId, + required this.senderName, + required this.type, + required this.text, + }); + + final String messageId; + final DateTime createdAt; + final String senderId; + final String senderName; + final String type; + final String text; + + Map toJson() => { + 'message_id': messageId, + 'created_at': createdAt.toIso8601String(), + 'sender_id': senderId, + 'sender_name': senderName, + 'type': type, + 'text': text, + }; +} + +class AiConversationToolStats { + const AiConversationToolStats({ + required this.conversationId, + required this.messageCount, + required this.startInclusive, + required this.endExclusive, + this.firstMessageAt, + this.lastMessageAt, + }); + + final String conversationId; + final int messageCount; + final DateTime? startInclusive; + final DateTime? endExclusive; + final DateTime? firstMessageAt; + final DateTime? lastMessageAt; + + Map toJson() => { + 'conversation_id': conversationId, + 'message_count': messageCount, + 'start_time': startInclusive?.toIso8601String(), + 'end_time': endExclusive?.toIso8601String(), + 'first_message_at': firstMessageAt?.toIso8601String(), + 'last_message_at': lastMessageAt?.toIso8601String(), + }; +} + +class AiConversationToolChunk { + const AiConversationToolChunk({ + required this.index, + required this.offset, + required this.messageCount, + }); + + final int index; + final int offset; + final int messageCount; + + Map toJson() => { + 'index': index, + 'offset': offset, + 'message_count': messageCount, + }; +} + +class AiConversationToolChunkList { + const AiConversationToolChunkList({ + required this.conversationId, + required this.chunkSize, + required this.totalMessages, + required this.startInclusive, + required this.endExclusive, + required this.chunks, + }); + + final String conversationId; + final int chunkSize; + final int totalMessages; + final DateTime? startInclusive; + final DateTime? endExclusive; + final List chunks; + + Map toJson() => { + 'conversation_id': conversationId, + 'chunk_size': chunkSize, + 'total_messages': totalMessages, + 'total_chunks': chunks.length, + 'start_time': startInclusive?.toIso8601String(), + 'end_time': endExclusive?.toIso8601String(), + 'chunks': chunks.map((chunk) => chunk.toJson()).toList(growable: false), + }; +} + +class AiConversationToolChunkPage { + const AiConversationToolChunkPage({ + required this.conversationId, + required this.offset, + required this.limit, + required this.totalMessages, + required this.startInclusive, + required this.endExclusive, + required this.messages, + required this.nextOffset, + }); + + final String conversationId; + final int offset; + final int limit; + final int totalMessages; + final DateTime? startInclusive; + final DateTime? endExclusive; + final List messages; + final int? nextOffset; + + Map toJson() => { + 'conversation_id': conversationId, + 'offset': offset, + 'limit': limit, + 'total_messages': totalMessages, + 'returned_count': messages.length, + 'next_offset': nextOffset, + 'start_time': startInclusive?.toIso8601String(), + 'end_time': endExclusive?.toIso8601String(), + 'messages': messages + .map((message) => message.toJson()) + .toList(growable: false), + }; +} + +class AiConversationToolSearchResult { + const AiConversationToolSearchResult({ + required this.conversationId, + required this.query, + required this.limit, + required this.messages, + }); + + final String conversationId; + final String query; + final int limit; + final List messages; + + Map toJson() => { + 'conversation_id': conversationId, + 'query': query, + 'limit': limit, + 'returned_count': messages.length, + 'messages': messages + .map((message) => message.toJson()) + .toList(growable: false), + }; +} + +abstract interface class AiConversationToolService { + Future getConversationStats({ + required String conversationId, + DateTime? startInclusive, + DateTime? endExclusive, + }); + + Future listConversationChunks({ + required String conversationId, + required int chunkSize, + DateTime? startInclusive, + DateTime? endExclusive, + }); + + Future readConversationChunk({ + required String conversationId, + required int offset, + required int limit, + DateTime? startInclusive, + DateTime? endExclusive, + }); + + Future searchConversationMessages({ + required String conversationId, + required String query, + required int limit, + }); +} + +class DatabaseAiConversationToolService implements AiConversationToolService { + DatabaseAiConversationToolService(this.database); + + final Database database; + + @override + Future getConversationStats({ + required String conversationId, + DateTime? startInclusive, + DateTime? endExclusive, + }) async { + final messageCount = await database.messageDao + .messageCountByConversationIdAndCreatedAtRange( + conversationId, + startInclusive: startInclusive, + endExclusive: endExclusive, + ) + .getSingle(); + + DateTime? firstMessageAt; + DateTime? lastMessageAt; + if (messageCount > 0) { + final firstMessage = await database.messageDao + .messagesByConversationIdAndCreatedAtRange( + conversationId, + limit: 1, + startInclusive: startInclusive, + endExclusive: endExclusive, + ) + .getSingleOrNull(); + final lastMessage = await database.messageDao + .messagesByConversationIdAndCreatedAtRange( + conversationId, + limit: 1, + startInclusive: startInclusive, + endExclusive: endExclusive, + ascending: false, + ) + .getSingleOrNull(); + firstMessageAt = firstMessage?.createdAt; + lastMessageAt = lastMessage?.createdAt; + } + + return AiConversationToolStats( + conversationId: conversationId, + messageCount: messageCount, + startInclusive: startInclusive, + endExclusive: endExclusive, + firstMessageAt: firstMessageAt, + lastMessageAt: lastMessageAt, + ); + } + + @override + Future listConversationChunks({ + required String conversationId, + required int chunkSize, + DateTime? startInclusive, + DateTime? endExclusive, + }) async { + final totalMessages = await database.messageDao + .messageCountByConversationIdAndCreatedAtRange( + conversationId, + startInclusive: startInclusive, + endExclusive: endExclusive, + ) + .getSingle(); + final chunks = []; + for (var offset = 0; offset < totalMessages; offset += chunkSize) { + final index = offset ~/ chunkSize; + final messageCount = math.min(chunkSize, totalMessages - offset); + chunks.add( + AiConversationToolChunk( + index: index, + offset: offset, + messageCount: messageCount, + ), + ); + } + return AiConversationToolChunkList( + conversationId: conversationId, + chunkSize: chunkSize, + totalMessages: totalMessages, + startInclusive: startInclusive, + endExclusive: endExclusive, + chunks: chunks, + ); + } + + @override + Future readConversationChunk({ + required String conversationId, + required int offset, + required int limit, + DateTime? startInclusive, + DateTime? endExclusive, + }) async { + final totalMessages = await database.messageDao + .messageCountByConversationIdAndCreatedAtRange( + conversationId, + startInclusive: startInclusive, + endExclusive: endExclusive, + ) + .getSingle(); + final safeOffset = math.max(0, offset); + final messages = safeOffset >= totalMessages + ? const [] + : await database.messageDao + .messagesByConversationIdAndCreatedAtRange( + conversationId, + limit: limit, + offset: safeOffset, + startInclusive: startInclusive, + endExclusive: endExclusive, + ) + .get(); + final nextOffset = safeOffset + messages.length < totalMessages + ? safeOffset + messages.length + : null; + + return AiConversationToolChunkPage( + conversationId: conversationId, + offset: safeOffset, + limit: limit, + totalMessages: totalMessages, + startInclusive: startInclusive, + endExclusive: endExclusive, + messages: messages.map(_messageItemToToolMessage).toList(growable: false), + nextOffset: nextOffset, + ); + } + + @override + Future searchConversationMessages({ + required String conversationId, + required String query, + required int limit, + }) async { + final messages = await database.fuzzySearchMessage( + query: query, + limit: limit, + conversationIds: [conversationId], + ); + return AiConversationToolSearchResult( + conversationId: conversationId, + query: query, + limit: limit, + messages: messages + .map(_searchMessageToToolMessage) + .toList(growable: false), + ); + } + + AiConversationToolMessage _messageItemToToolMessage(MessageItem message) => + AiConversationToolMessage( + messageId: message.messageId, + createdAt: message.createdAt, + senderId: message.userId, + senderName: message.userFullName ?? message.userId, + type: message.type, + text: _messageText( + content: message.content, + mediaName: message.mediaName, + type: message.type, + ), + ); + + AiConversationToolMessage _searchMessageToToolMessage( + SearchMessageDetailItem message, + ) => AiConversationToolMessage( + messageId: message.messageId, + createdAt: message.createdAt, + senderId: message.senderId, + senderName: message.senderFullName ?? message.senderId, + type: message.type, + text: _messageText( + content: message.content, + mediaName: message.mediaName, + type: message.type, + ), + ); + + String _messageText({ + required String? content, + required String? mediaName, + required String type, + }) { + if (content?.trim().isNotEmpty == true) { + return content!.trim(); + } + if (mediaName?.isNotEmpty == true) { + return '[$type] $mediaName'; + } + return '[$type]'; + } +} + +class AiConversationToolKit { + const AiConversationToolKit(this.service); + + final AiConversationToolService service; + + static const definitions = [ + AiToolDefinition( + name: 'get_conversation_stats', + description: + 'Get message counts and boundary timestamps for the current conversation or a specific time range.', + inputSchema: { + 'type': 'object', + 'properties': { + 'start_time': { + 'type': 'string', + 'description': 'Optional inclusive ISO-8601 start time.', + }, + 'end_time': { + 'type': 'string', + 'description': 'Optional exclusive ISO-8601 end time.', + }, + }, + 'additionalProperties': false, + }, + ), + AiToolDefinition( + name: 'list_conversation_chunks', + description: + 'List chunk offsets that can be used to read the current conversation in fixed-size batches, optionally scoped to a time range.', + inputSchema: { + 'type': 'object', + 'properties': { + 'chunk_size': { + 'type': 'integer', + 'description': 'Optional chunk size between 1 and 200.', + }, + 'start_time': { + 'type': 'string', + 'description': 'Optional inclusive ISO-8601 start time.', + }, + 'end_time': { + 'type': 'string', + 'description': 'Optional exclusive ISO-8601 end time.', + }, + }, + 'additionalProperties': false, + }, + ), + AiToolDefinition( + name: 'read_conversation_chunk', + description: + 'Read a batch of messages from the current conversation by offset and limit, optionally scoped to a time range.', + inputSchema: { + 'type': 'object', + 'properties': { + 'offset': { + 'type': 'integer', + 'description': 'Zero-based offset into the matching message list.', + }, + 'limit': { + 'type': 'integer', + 'description': 'Number of messages to read, between 1 and 200.', + }, + 'start_time': { + 'type': 'string', + 'description': 'Optional inclusive ISO-8601 start time.', + }, + 'end_time': { + 'type': 'string', + 'description': 'Optional exclusive ISO-8601 end time.', + }, + }, + 'required': ['offset'], + 'additionalProperties': false, + }, + ), + AiToolDefinition( + name: 'search_conversation_messages', + description: + 'Search the current conversation for messages relevant to a query string.', + inputSchema: { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'Search query text.', + }, + 'limit': { + 'type': 'integer', + 'description': + 'Maximum number of matches to return, between 1 and 20.', + }, + }, + 'required': ['query'], + 'additionalProperties': false, + }, + ), + ]; + + Future execute({ + required String conversationId, + required AiToolCall call, + }) async { + final arguments = call.arguments; + switch (call.name) { + case 'get_conversation_stats': + final (startInclusive, endExclusive) = _parseRange(arguments); + final stats = await service.getConversationStats( + conversationId: conversationId, + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + return AiToolExecutionResult( + toolCallId: call.id, + toolName: call.name, + payload: stats.toJson(), + ); + case 'list_conversation_chunks': + final (startInclusive, endExclusive) = _parseRange(arguments); + final chunkSize = _parseInt( + arguments, + 'chunk_size', + defaultValue: _kDefaultConversationChunkSize, + min: 1, + max: _kMaxConversationChunkSize, + ); + final chunks = await service.listConversationChunks( + conversationId: conversationId, + chunkSize: chunkSize, + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + return AiToolExecutionResult( + toolCallId: call.id, + toolName: call.name, + payload: chunks.toJson(), + ); + case 'read_conversation_chunk': + final (startInclusive, endExclusive) = _parseRange(arguments); + final offset = _parseInt( + arguments, + 'offset', + defaultValue: 0, + min: 0, + max: 1 << 20, + ); + final limit = _parseInt( + arguments, + 'limit', + defaultValue: _kDefaultConversationChunkSize, + min: 1, + max: _kMaxConversationChunkSize, + ); + final page = await service.readConversationChunk( + conversationId: conversationId, + offset: offset, + limit: limit, + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + return AiToolExecutionResult( + toolCallId: call.id, + toolName: call.name, + payload: page.toJson(), + ); + case 'search_conversation_messages': + final query = _parseRequiredString(arguments, 'query'); + final limit = _parseInt( + arguments, + 'limit', + defaultValue: _kDefaultConversationSearchLimit, + min: 1, + max: _kMaxConversationSearchLimit, + ); + final result = await service.searchConversationMessages( + conversationId: conversationId, + query: query, + limit: limit, + ); + return AiToolExecutionResult( + toolCallId: call.id, + toolName: call.name, + payload: result.toJson(), + ); + default: + throw UnsupportedError('Unknown conversation tool: ${call.name}'); + } + } + + (DateTime?, DateTime?) _parseRange(Map arguments) { + final startInclusive = _parseDateTime(arguments, 'start_time'); + final endExclusive = _parseDateTime(arguments, 'end_time'); + if (startInclusive != null && + endExclusive != null && + !endExclusive.isAfter(startInclusive)) { + throw const FormatException('end_time must be later than start_time'); + } + return (startInclusive, endExclusive); + } + + DateTime? _parseDateTime(Map arguments, String key) { + final raw = arguments[key]; + if (raw == null) { + return null; + } + if (raw is! String || raw.trim().isEmpty) { + throw FormatException('$key must be an ISO-8601 string'); + } + final value = DateTime.tryParse(raw.trim()); + if (value == null) { + throw FormatException('$key must be a valid ISO-8601 string'); + } + return value; + } + + int _parseInt( + Map arguments, + String key, { + required int defaultValue, + required int min, + required int max, + }) { + final raw = arguments[key]; + if (raw == null) { + return defaultValue; + } + final value = switch (raw) { + final int value => value, + final String value => + int.tryParse(value.trim()) ?? + (throw FormatException('$key must be an integer')), + _ => throw FormatException('$key must be an integer'), + }; + return value.clamp(min, max); + } + + String _parseRequiredString(Map arguments, String key) { + final raw = arguments[key]; + if (raw is! String || raw.trim().isEmpty) { + throw FormatException('$key must be a non-empty string'); + } + return raw.trim(); + } +} diff --git a/lib/db/dao/message_dao.dart b/lib/db/dao/message_dao.dart index 67c348752c..4544d0229c 100644 --- a/lib/db/dao/message_dao.dart +++ b/lib/db/dao/message_dao.dart @@ -684,6 +684,59 @@ class MessageDao extends DatabaseAccessor .map((row) => row.read(countExp)!); } + Selectable messagesByConversationIdAndCreatedAtRange( + String conversationId, { + required int limit, + int offset = 0, + DateTime? startInclusive, + DateTime? endExclusive, + bool ascending = true, + }) { + final startMillis = startInclusive?.millisecondsSinceEpoch; + final endMillis = endExclusive?.millisecondsSinceEpoch; + return _baseMessageItems( + (message, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => + message.conversationId.equals(conversationId) & + (startMillis == null + ? const Constant(true) + : message.createdAt.isBiggerOrEqualValue(startMillis)) & + (endMillis == null + ? const Constant(true) + : message.createdAt.isSmallerThanValue(endMillis)), + (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, offset), + order: (message, _, _, _, _, _, _, _, _, _, _, _, _, em) => OrderBy([ + if (ascending) OrderingTerm.asc(message.createdAt), + if (ascending) OrderingTerm.asc(message.rowId), + if (!ascending) OrderingTerm.desc(message.createdAt), + if (!ascending) OrderingTerm.desc(message.rowId), + ]), + ); + } + + Selectable messageCountByConversationIdAndCreatedAtRange( + String conversationId, { + DateTime? startInclusive, + DateTime? endExclusive, + }) { + final startMillis = startInclusive?.millisecondsSinceEpoch; + final endMillis = endExclusive?.millisecondsSinceEpoch; + final countExp = countAll(); + return (db.selectOnly(db.messages) + ..addColumns([countExp]) + ..where( + db.messages.conversationId.equals(conversationId) & + (startMillis == null + ? const Constant(true) + : db.messages.createdAt.isBiggerOrEqualValue( + startMillis, + )) & + (endMillis == null + ? const Constant(true) + : db.messages.createdAt.isSmallerThanValue(endMillis)), + )) + .map((row) => row.read(countExp)!); + } + Future> getUnreadMessageIds( String conversationId, String userId, diff --git a/lib/ui/home/chat/ai_draft_assist_panel.dart b/lib/ui/home/chat/ai_draft_assist_panel.dart index 0a024f5749..8bf92b886c 100644 --- a/lib/ui/home/chat/ai_draft_assist_panel.dart +++ b/lib/ui/home/chat/ai_draft_assist_panel.dart @@ -8,7 +8,13 @@ import '../../../widgets/action_button.dart'; import '../../../widgets/interactive_decorated_box.dart'; import '../../../widgets/menu.dart'; -enum AiDraftAction { polish, shorten, polite, translate, replyWithContext } +enum AiDraftAction { + polish, + shorten, + polite, + translate, + replyWithContext, +} enum AiDraftAssistPhase { idle, loading, result, error } diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 75667edcc5..41d616cc5a 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -611,7 +611,8 @@ Future _requestAiDraftAction( }; try { - final result = await AiChatController(context.database).assistText( + final controller = AiChatController(context.database); + final result = await controller.assistText( instruction: instruction, input: action == AiDraftAction.replyWithContext ? null : original, conversationId: conversationId, From cafba8b9f2395acd6073303c32203dd0a32ea782 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 10:31:24 +0800 Subject: [PATCH 18/52] feat(chat): add support for displaying day/time separators in message timeline --- lib/ui/home/chat/chat_page.dart | 17 ++++-- lib/widgets/message/message.dart | 7 ++- lib/widgets/message/message_day_time.dart | 64 ++++++++++++++++++----- 3 files changed, 69 insertions(+), 19 deletions(-) diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index af234f2e02..e2f1e7fcc0 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -709,18 +709,25 @@ class _List extends HookConsumerWidget { Widget buildTimelineChild(ChatTimelineItem item, int index) { warmupMarkdownAround(index); + final prevDateTime = index > 0 ? timeline[index - 1].createdAt : null; if (item.isAiMessage) { - return AiMessageCard( - key: ValueKey('ai-${item.id}'), - message: item.aiMessage!, - prev: prevAiOf(item, timeline), - next: nextAiOf(item, timeline), + return MessageDayTimeItem( + key: ValueKey('ai-daytime-${item.id}'), + dateTime: item.createdAt, + prevDateTime: prevDateTime, + child: AiMessageCard( + key: ValueKey('ai-${item.id}'), + message: item.aiMessage!, + prev: prevAiOf(item, timeline), + next: nextAiOf(item, timeline), + ), ); } final message = item.message!; return MessageItemWidget( key: keyRef.value[message.messageId], prev: prevMessageOf(item, timeline), + prevDateTime: prevDateTime, message: message, next: nextMessageOf(item, timeline), lastReadMessageId: state.lastReadMessageId, diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index 7177c7bad8..cd85345341 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -196,6 +196,7 @@ class MessageItemWidget extends HookConsumerWidget { required this.message, super.key, this.prev, + this.prevDateTime, this.next, this.lastReadMessageId, this.isTranscriptPage = false, @@ -205,6 +206,7 @@ class MessageItemWidget extends HookConsumerWidget { final MessageItem message; final MessageItem? prev; + final DateTime? prevDateTime; final MessageItem? next; final String? lastReadMessageId; final bool isTranscriptPage; @@ -242,7 +244,10 @@ class MessageItemWidget extends HookConsumerWidget { final showNip = !(sameUserNext && sameDayNext) && (!showAvatar || isCurrentUser); - final datetime = sameDayPrev ? null : message.createdAt; + final datetime = + isSameDay(prevDateTime ?? prev?.createdAt, message.createdAt) + ? null + : message.createdAt; String? userName; String? userId; String? userAvatarUrl; diff --git a/lib/widgets/message/message_day_time.dart b/lib/widgets/message/message_day_time.dart index 7991f27081..5a4d6520a8 100644 --- a/lib/widgets/message/message_day_time.dart +++ b/lib/widgets/message/message_day_time.dart @@ -33,6 +33,31 @@ class MessageDayTime extends HookConsumerWidget { } } +class MessageDayTimeItem extends StatelessWidget { + const MessageDayTimeItem({ + required this.dateTime, + required this.child, + super.key, + this.prevDateTime, + }); + + final DateTime dateTime; + final DateTime? prevDateTime; + final Widget child; + + bool get showDayTime => !isSameDay(prevDateTime, dateTime); + + @override + Widget build(BuildContext context) => Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (showDayTime) MessageDayTime(dateTime: dateTime), + child, + ], + ); +} + class _MessageDayTimeWidget extends HookConsumerWidget { const _MessageDayTimeWidget({required this.dateTime}); @@ -75,31 +100,44 @@ class HiddenMessageDayTimeBloc extends Cubit { class _CurrentShowingMessages { _CurrentShowingMessages(); - final List items = []; + final List items = []; final List elements = []; final List dayTimeElements = []; void dumpKeyedSubtree(Element element, {bool reverse = false}) { - final item = element.descendantFirstWhere( - (e) => e.widget is MessageItemWidget, - ); + final item = element.descendantFirstWhere((e) { + final widget = e.widget; + return widget is MessageItemWidget || widget is MessageDayTimeItem; + }); if (item == null) { return; } - final widget = item.widget as MessageItemWidget; + final widget = item.widget; + + late final DateTime createdAt; + DateTime? prevCreatedAt; + + if (widget is MessageDayTimeItem) { + createdAt = widget.dateTime; + prevCreatedAt = widget.prevDateTime; + } else if (widget is MessageItemWidget) { + createdAt = widget.message.createdAt; + prevCreatedAt = widget.prev?.createdAt; + } else { + return; + } - final dayTimeElement = - !isSameDay(widget.message.createdAt, widget.prev?.createdAt) + final dayTimeElement = !isSameDay(createdAt, prevCreatedAt) ? element.descendantFirstWhere( (e) => e.widget is _MessageDayTimeWidget, ) : null; if (!reverse) { - items.add(widget.message); + items.add(createdAt); elements.add(item); dayTimeElements.add(dayTimeElement); } else { - items.insert(0, widget.message); + items.insert(0, createdAt); elements.insert(0, item); dayTimeElements.insert(0, dayTimeElement); } @@ -265,7 +303,7 @@ class MessageDayTimeViewportWidget extends HookConsumerWidget { if (offset.dy < render.size.height / 2) { // up firstInScreenIndex = closestToTopDayTimeIndex; - bloc.update(items[closestToTopDayTimeIndex].createdAt); + bloc.update(items[closestToTopDayTimeIndex]); dateTimeTopOffset.value = 0; } else { // down @@ -275,8 +313,8 @@ class MessageDayTimeViewportWidget extends HookConsumerWidget { e('firstInScreenIndex > closestToTopDayTimeIndex'); } if (isSameDay( - items[firstInScreenIndex].createdAt, - items[closestToTopDayTimeIndex].createdAt, + items[firstInScreenIndex], + items[closestToTopDayTimeIndex], )) { e( 'there is a day time item but is the same day.' @@ -298,7 +336,7 @@ class MessageDayTimeViewportWidget extends HookConsumerWidget { dateTimeTopOffset.value = 0; } - dateTime.value = items.getOrNull(firstInScreenIndex)?.createdAt; + dateTime.value = items.getOrNull(firstInScreenIndex); } useEffect(() { From 02bbd7858c49b08a01b12c640b828d68213b0472 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 11:29:54 +0800 Subject: [PATCH 19/52] feat(ai): refactor AI chat controller and introduce modular strategies for provider requests --- AGENTS.md | 157 ++ lib/ai/ai_chat_controller.dart | 1278 +---------------- lib/ai/ai_chat_prompt_builder.dart | 370 +++++ lib/ai/ai_provider_requester.dart | 277 ++++ lib/ai/provider/ai_provider_strategy.dart | 116 ++ lib/ai/provider/anthropic_strategy.dart | 208 +++ lib/ai/provider/gemini_strategy.dart | 267 ++++ .../provider/openai_compatible_strategy.dart | 162 +++ pubspec.lock | 9 +- pubspec.yaml | 3 +- 10 files changed, 1589 insertions(+), 1258 deletions(-) create mode 100644 AGENTS.md create mode 100644 lib/ai/ai_chat_prompt_builder.dart create mode 100644 lib/ai/ai_provider_requester.dart create mode 100644 lib/ai/provider/ai_provider_strategy.dart create mode 100644 lib/ai/provider/anthropic_strategy.dart create mode 100644 lib/ai/provider/gemini_strategy.dart create mode 100644 lib/ai/provider/openai_compatible_strategy.dart diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000000..009702fd4d --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,157 @@ +# AGENTS.md + +## Project + +Mixin Messenger desktop Flutter app. + +Tech stack: + +- Flutter/Dart app, `environment` in `pubspec.yaml`: Dart `^3.10.0`, Flutter `^3.38.0`. +- Desktop targets: macOS, Linux, Windows. Web/iOS/Android folders exist, but README/release flow focuses on desktop. +- State/UI: Flutter Hooks, Riverpod, Provider, Bloc/Hydrated Bloc. +- Storage: Drift/Moor over SQLite, Hive, Hydrated Bloc storage. +- Networking/runtime: Dio, rhttp, WebSocket, Mixin SDK, Signal protocol implementation. +- Code generation: `build_runner`, `drift_dev`, `json_serializable`, `envied_generator`, `flutter_intl`. + +Main directories: + +- `lib/main.dart`, `lib/app.dart`: app bootstrap, desktop window setup, localization and root providers. +- `lib/account`: account server, notification and key-value account state. +- `lib/ai`: AI chat controller, models and tools. +- `lib/api`, `lib/blaze`: HTTP/API and Blaze message models. +- `lib/db`: main Drift database, DAOs, converters, FTS database, open helpers. +- `lib/db/moor`: Drift SQL schema and DAO `.drift` files. +- `lib/crypto/signal`: Signal protocol database, DAOs and crypto storage. +- `lib/ui`, `lib/widgets`: screens and reusable UI. +- `lib/utils`, `lib/workers`: platform utilities, background work, transfer and job queues. +- `lib/l10n`: source ARB localization files. +- `lib/generated`: generated localization output; do not edit by hand. +- `assets`, `fonts`: bundled assets and fonts. +- `test`: Flutter/unit tests. +- `third_party/system_tray`: local path dependency. +- `dist`: packaging scripts and platform distribution metadata. + +## Commands + +Setup: + +```sh +flutter pub get +``` + +Generate code: + +```sh +dart run build_runner build --delete-conflicting-outputs +``` + +Short scripts: + +```sh +./generate.sh # dart run build_runner build +./db_generate.sh # dart run build_runner build --delete-conflicting-outputs +``` + +Format/lint: + +```sh +dart format --set-exit-if-changed . +dart analyze --fatal-infos +``` + +Tests: + +```sh +dart run webcrypto:setup +flutter test +flutter test test/path/to/file_test.dart +``` + +Run: + +```sh +flutter run -d macos +flutter run -d linux +flutter run -d windows +``` + +Build: + +```sh +flutter build macos --release +flutter build linux --release +flutter build windows --release +``` + +Packaging helpers: + +```sh +./dist/macos.sh +./dist/win.sh +./dist/linux_deb.sh amd64 +./dist/linux_deb.sh arm64 +``` + +Linux desktop build dependencies used by CI: + +```sh +sudo apt-get install -y ninja-build libgtk-3-dev libsdl2-dev \ + libwebkit2gtk-4.1-dev libopus-dev libogg-dev libcurl4-openssl-dev +``` + +## Environment + +- `lib/constants/env.dart` uses `envied` with `.env`. +- `.env` may contain: + +```env +SENTRY_DSN=... +``` + +- `SENTRY_DSN` is optional for local debug; release builds read it through generated env code and may also pass `--dart-define SENTRY_DSN=$SENTRY_DSN` in CI/release scripts. +- If `.env` changes or `EnviedField` changes, rerun build runner. + +## Database + +- Main DB: `lib/db/mixin_database.dart`, schema in `lib/db/moor/**`, current `schemaVersion` is `30`. +- FTS DB: `lib/db/fts_database.dart`, schema in `lib/db/moor/fts.drift`. +- Signal DB: `lib/crypto/signal/signal_database.dart`, schema in `lib/crypto/signal/moor/**`. +- Drift generation options are in `build.yaml`; FTS5 is enabled. +- When changing `.drift` schemas or DAOs: + - update `schemaVersion` when persistent schema changes; + - add an `onUpgrade` migration in `MigrationStrategy`; + - prefer idempotent helpers like `_addColumnIfNotExists` for additive migrations; + - keep existing data migration jobs in `lib/workers/job` in mind; + - rerun `dart run build_runner build --delete-conflicting-outputs`; + - add or update focused tests under `test/db` when behavior changes. + +## Code Generation + +- Do not hand-edit `*.g.dart`, `lib/generated/**`, or other files marked generated. +- Source annotations/models commonly use `part '*.g.dart'` with `json_serializable`, `drift_dev`, or `envied_generator`. +- Reserve `part` and `part of` for code generation only. For manual code organization, split code into separate libraries and connect them with imports. +- Localization source is `lib/l10n/*.arb`; generated class is `Localization` in `lib/generated/l10n.dart`. +- Asset constants in `lib/constants/resources.dart` are generated; update the generator flow instead of manual edits if assets change. + +## Coding Conventions + +- Follow `analysis_options.yaml` and `very_good_analysis` overrides. +- Prefer relative imports inside `lib`; avoid broad package imports for local files. +- Prefer `final` locals, expression-bodied members where already used, and concise null handling. +- Keep generated, third-party, and platform registrant files untouched unless the task explicitly requires them. +- Keep changes scoped to the requested behavior; do not refactor unrelated areas. +- Reuse existing UI components from `lib/widgets` and patterns from nearby screens. +- Reuse existing DB access through DAOs and providers instead of bypassing with ad hoc SQL unless Drift APIs cannot express the query. +- For user-facing text, use `Localization` and ARB files rather than hard-coded strings. +- For async work, preserve current error propagation style and do not swallow exceptions without a concrete recovery path. +- Before finalizing non-trivial changes, run the narrowest relevant tests plus `dart analyze --fatal-infos` when feasible. + +## Code Style + +- Follow the Dart style guide and `analysis_options.yaml` rules. +- Use consistent naming conventions for variables, functions, classes, and files. +- Keep lines within 80 characters where possible for readability. +- Prefer clear and descriptive names over abbreviations. +- Maintain consistent indentation and spacing throughout the codebase. +- Flow effective-dart guidelines, such as using `final` for variables that are not reassigned, and preferring composition over inheritance where appropriate. +- Use comments judiciously to explain complex logic, but avoid redundant comments that do not add value beyond the code itself. diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 91d741ee3a..43e0470855 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -7,13 +7,12 @@ import 'package:mixin_logger/mixin_logger.dart'; import 'package:uuid/uuid.dart'; import '../db/dao/ai_chat_message_dao.dart'; -import '../db/dao/message_dao.dart'; import '../db/database.dart'; import '../db/mixin_database.dart'; -import '../utils/proxy.dart'; +import 'ai_chat_prompt_builder.dart'; +import 'ai_provider_requester.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_provider_config.dart'; -import 'model/ai_provider_type.dart'; import 'model/ai_tool.dart'; import 'tools/ai_conversation_tool_service.dart'; @@ -22,13 +21,8 @@ const _kAiRoleAssistant = 'assistant'; const _kAiStatusPending = 'pending'; const _kAiStatusDone = 'done'; const _kAiStatusError = 'error'; -const _kAiContextMessageLimit = 30; -const _kAiRetrievedMessageLimit = 6; -const _kAiHistoryLimit = 12; const _kAiStreamFlushChars = 32; const _kAiStreamFlushInterval = Duration(milliseconds: 80); -const _kAiRetrievalQueryMaxLength = 120; -const _kAiToolMaxRounds = 8; const _kAiLogPreviewLength = 240; const _kAiLogJsonPreviewLength = 480; final kAiRuntimeStartedAt = DateTime.now(); @@ -44,14 +38,13 @@ class AiChatController { final Database database; final _uuid = const Uuid(); - static const _openAiStrategy = _OpenAiCompatibleStrategy(); - static const _anthropicStrategy = _AnthropicStrategy(); - static const _geminiStrategy = _GeminiStrategy(); + static const _providerRequester = AiProviderRequester(); late final DatabaseAiConversationToolService _conversationToolService = DatabaseAiConversationToolService(database); late final AiConversationToolKit _conversationTools = AiConversationToolKit( _conversationToolService, ); + late final AiChatPromptBuilder _promptBuilder = AiChatPromptBuilder(database); Future assistText({ required String instruction, @@ -70,7 +63,7 @@ class AiChatController { 'input=${_previewText(input)}', ); - final messages = await _buildAssistPromptMessages( + final messages = await _promptBuilder.buildAssistPromptMessages( instruction: instruction, input: input, conversationId: conversationId, @@ -180,7 +173,10 @@ class AiChatController { ); _activeAiRequests[conversationId] = cancelToken; try { - final messages = await _buildPromptMessages(conversationId, input); + final messages = await _promptBuilder.buildPromptMessages( + conversationId, + input, + ); final result = await _requestText( config, messages, @@ -255,7 +251,7 @@ class AiChatController { return 'No messages found in the selected time range.'; } - final messages = _buildConversationSummaryPromptMessages( + final messages = _promptBuilder.buildConversationSummaryPromptMessages( stats: stats, languageTag: languageTag, ); @@ -298,323 +294,6 @@ class AiChatController { ); } - Future> _buildPromptMessages( - String conversationId, - String input, - ) async { - final recentMessages = await database.messageDao - .messagesByConversationId(conversationId, _kAiContextMessageLimit) - .get(); - final retrievedMessages = await _retrieveConversationMessages( - conversationId: conversationId, - recentMessages: recentMessages, - query: input, - ); - final aiMessages = await database.aiChatMessageDao.conversationMessages( - conversationId, - ); - - final promptMessages = [ - AiPromptMessage( - role: 'system', - content: - 'You are a local AI assistant inside a chat application. ' - 'Only use the provided current conversation context. ' - 'Help summarize, answer questions about the conversation, and draft replies. ' - 'Be concise and practical.', - ), - ]; - - _appendConversationToolInstruction(promptMessages, enabled: true); - _appendConversationContext( - promptMessages, - recentMessages: recentMessages, - retrievedMessages: retrievedMessages, - ); - - final history = aiMessages - .where((element) => element.status != _kAiStatusPending) - .takeLast(_kAiHistoryLimit); - for (final item in history) { - promptMessages.add( - AiPromptMessage(role: item.role, content: item.content), - ); - } - - promptMessages.add(AiPromptMessage(role: _kAiRoleUser, content: input)); - d( - 'AI prompt built: conversationId=$conversationId ' - 'recent=${recentMessages.length} retrieved=${retrievedMessages.length} ' - 'history=${history.length} promptMessages=${promptMessages.length}', - ); - return promptMessages; - } - - Future> _buildAssistPromptMessages({ - required String instruction, - required String? input, - required String? conversationId, - }) async { - final promptMessages = [ - AiPromptMessage( - role: 'system', - content: - 'You are an invisible writing assistant inside a chat app. ' - 'Return only the requested text. Do not add explanations, labels, ' - 'markdown fences, or greetings unless explicitly requested.', - ), - ]; - - if (conversationId != null) { - _appendConversationToolInstruction(promptMessages, enabled: true); - final recentMessages = await database.messageDao - .messagesByConversationId(conversationId, _kAiContextMessageLimit) - .get(); - final retrievedMessages = await _retrieveConversationMessages( - conversationId: conversationId, - recentMessages: recentMessages, - query: input ?? _latestRetrievalSeed(recentMessages), - ); - _appendConversationContext( - promptMessages, - recentMessages: recentMessages, - retrievedMessages: retrievedMessages, - ); - } - - final inputText = input?.trim(); - promptMessages.add( - AiPromptMessage( - role: _kAiRoleUser, - content: [ - instruction.trim(), - if (inputText != null && inputText.isNotEmpty) '\nText:\n$inputText', - ].join('\n'), - ), - ); - d( - 'AI assist prompt built: conversationId=$conversationId ' - 'messages=${promptMessages.length}', - ); - return promptMessages; - } - - List _buildConversationSummaryPromptMessages({ - required AiConversationToolStats stats, - required String? languageTag, - }) { - final outputLanguage = languageTag?.trim(); - return [ - AiPromptMessage( - role: 'system', - content: - 'You are a conversation summarizer inside a chat application. ' - 'For this task, you must use the available read-only conversation tools ' - 'to inspect the requested time range before writing the final answer. ' - 'Start by calling list_conversation_chunks for the exact range, then ' - 'read_conversation_chunk until you have covered the full range. ' - 'Do not rely only on recent context or search for this task.', - ), - AiPromptMessage( - role: 'system', - content: - 'Summaries must cover the requested range completely and should include ' - 'main topics, key decisions, action items, unresolved questions, and ' - 'notable follow-ups. Keep the final answer concise but comprehensive.', - ), - AiPromptMessage( - role: 'user', - content: [ - 'Summarize the conversation messages in this time range.', - 'Conversation ID: ${stats.conversationId}', - 'Start time: ${stats.startInclusive?.toIso8601String() ?? 'unspecified'}', - 'End time: ${stats.endExclusive?.toIso8601String() ?? 'unspecified'}', - 'Messages in range: ${stats.messageCount}', - if (stats.firstMessageAt != null) - 'First message at: ${stats.firstMessageAt!.toIso8601String()}', - if (stats.lastMessageAt != null) - 'Last message at: ${stats.lastMessageAt!.toIso8601String()}', - if (outputLanguage != null && outputLanguage.isNotEmpty) - 'Write the final summary in $outputLanguage.', - 'Before finalizing, make sure you have covered every chunk in the range.', - 'Return only the summary text.', - ].join('\n'), - ), - ]; - } - - void _appendConversationToolInstruction( - List promptMessages, { - required bool enabled, - }) { - if (!enabled) { - return; - } - promptMessages.add( - AiPromptMessage( - role: 'system', - content: - 'Read-only conversation tools are available for the current conversation. ' - 'Use them when you need exhaustive coverage, date-scoped summaries, ' - 'statistics, older messages, or more context than the provided messages. ' - 'Do not call tools when the provided context is already sufficient.', - ), - ); - } - - void _appendConversationContext( - List promptMessages, { - required List recentMessages, - required List retrievedMessages, - }) { - if (recentMessages.isNotEmpty) { - final lines = recentMessages.reversed - .map( - (message) => _conversationContextLine( - createdAt: message.createdAt, - sender: message.userFullName ?? message.userId, - content: _messagePlainText(message), - ), - ) - .join('\n'); - promptMessages.add( - AiPromptMessage( - role: 'system', - content: 'Current conversation recent messages:\n$lines', - ), - ); - } - - if (retrievedMessages.isEmpty) { - return; - } - - final lines = retrievedMessages - .map( - (message) => _conversationContextLine( - createdAt: message.createdAt, - sender: message.senderFullName ?? message.senderId, - content: _searchMessagePlainText(message), - ), - ) - .join('\n'); - promptMessages.add( - AiPromptMessage( - role: 'system', - content: - 'Relevant older conversation messages matched by search ' - '(use only if they help answer the current request):\n$lines', - ), - ); - } - - Future> _retrieveConversationMessages({ - required String conversationId, - required List recentMessages, - required String? query, - }) async { - final normalizedQuery = _normalizeRetrievalQuery(query); - if (normalizedQuery == null) { - d('AI retrieval skipped: conversationId=$conversationId empty query'); - return const []; - } - - final recentIds = recentMessages - .map((message) => message.messageId) - .toSet(); - final matchedIds = await database.ftsDatabase.fuzzySearchMessage( - query: normalizedQuery, - limit: _kAiRetrievedMessageLimit + recentIds.length, - conversationIds: [conversationId], - ); - final candidateIds = matchedIds - .where((messageId) => !recentIds.contains(messageId)) - .take(_kAiRetrievedMessageLimit) - .toList(growable: false); - if (candidateIds.isEmpty) { - d( - 'AI retrieval no match: conversationId=$conversationId ' - 'query=${_previewText(normalizedQuery)}', - ); - return const []; - } - - final matchedMessages = await database.messageDao - .searchMessageByIds(candidateIds) - .get(); - final messagesById = { - for (final message in matchedMessages) message.messageId: message, - }; - final ordered = []; - for (final messageId in candidateIds) { - final message = messagesById[messageId]; - if (message != null) { - ordered.add(message); - } - } - ordered.sort((left, right) => left.createdAt.compareTo(right.createdAt)); - d( - 'AI retrieval matched: conversationId=$conversationId ' - 'query=${_previewText(normalizedQuery)} matches=${ordered.length}', - ); - return ordered; - } - - String? _latestRetrievalSeed(List recentMessages) { - for (final message in recentMessages) { - final content = _messagePlainText(message); - final normalized = _normalizeRetrievalQuery(content); - if (normalized != null) { - return normalized; - } - } - return null; - } - - String? _normalizeRetrievalQuery(String? query) { - final compact = query?.replaceAll(RegExp(r'\s+'), ' ').trim(); - if (compact == null || compact.isEmpty) { - return null; - } - if (compact.length <= _kAiRetrievalQueryMaxLength) { - return compact; - } - return compact.substring(0, _kAiRetrievalQueryMaxLength); - } - - String _conversationContextLine({ - required DateTime createdAt, - required String sender, - required String content, - }) => '[${createdAt.toIso8601String()}] $sender: $content'; - - String _messagePlainText(MessageItem message) => _messagePlainTextFromFields( - content: message.content, - mediaName: message.mediaName, - type: message.type, - ); - - String _searchMessagePlainText(SearchMessageDetailItem message) => - _messagePlainTextFromFields( - content: message.content, - mediaName: message.mediaName, - type: message.type, - ); - - String _messagePlainTextFromFields({ - required String? content, - required String? mediaName, - required String type, - }) { - if (content?.trim().isNotEmpty == true) { - return content!.trim(); - } - if (mediaName?.isNotEmpty == true) { - return '[$type] $mediaName'; - } - return '[$type]'; - } - Future _requestText( AiProviderConfig config, List messages, { @@ -622,192 +301,27 @@ class AiChatController { required Future Function(String chunk) onContent, required bool streamFinalResponse, String? conversationId, - }) async { - d( - 'AI request start: provider=${config.type.name} model=${config.model} ' - 'conversationId=$conversationId streamFinal=$streamFinalResponse ' - 'messages=${messages.length} tools=${conversationId != null}', - ); - final dio = - Dio( - BaseOptions( - baseUrl: config.baseUrl, - connectTimeout: const Duration(seconds: 20), - receiveTimeout: const Duration(minutes: 5), - sendTimeout: const Duration(seconds: 20), - headers: _strategyFor(config.type).headers(config), - ), - ) - ..interceptors.add( - InterceptorsWrapper( - onRequest: (options, handler) { - options.extra['ai_request_started_at'] = DateTime.now(); - d( - 'AI HTTP request: ${options.method} ${options.uri} ' - 'provider=${config.type.name} model=${config.model}', - ); - handler.next(options); - }, - onResponse: (response, handler) { - final startedAt = - response.requestOptions.extra['ai_request_started_at'] - as DateTime?; - d( - 'AI HTTP response: ${response.requestOptions.method} ' - '${response.requestOptions.uri} status=${response.statusCode} ' - 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds}', - ); - handler.next(response); - }, - onError: (error, handler) { - final startedAt = - error.requestOptions.extra['ai_request_started_at'] - as DateTime?; - e( - 'AI HTTP error: ${error.requestOptions.method} ' - '${error.requestOptions.uri} ' - 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds} ' - 'error=${error.message}', - error, - error.stackTrace, - ); - handler.next(error); - }, - ), - ) - ..applyProxy(database.settingProperties.activatedProxy); + }) => _providerRequester.requestText( + config, + messages, + proxy: database.settingProperties.activatedProxy, + cancelToken: cancelToken, + onContent: onContent, + streamFinalResponse: streamFinalResponse, + conversationId: conversationId, + onToolCall: _toolExecutorFor(conversationId), + ); + Future Function(AiToolCall toolCall)? _toolExecutorFor( + String? conversationId, + ) { if (conversationId == null) { - return _strategyFor(config.type).streamResponse( - dio: dio, - config: config, - messages: messages, - cancelToken: cancelToken, - onContent: onContent, - ); + return null; } - - return _requestWithTools( - dio, - config, - [...messages], + return (toolCall) => _executeConversationTool( conversationId: conversationId, - cancelToken: cancelToken, - onContent: onContent, - streamFinalResponse: streamFinalResponse, - ); - } - - Future _requestWithTools( - Dio dio, - AiProviderConfig config, - List messages, { - required String conversationId, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - required bool streamFinalResponse, - }) async { - for (var round = 0; round < _kAiToolMaxRounds; round++) { - d( - 'AI tool round start: conversationId=$conversationId ' - 'round=${round + 1}/$_kAiToolMaxRounds messages=${messages.length}', - ); - final response = await _strategyFor(config.type).completeResponse( - dio: dio, - config: config, - messages: messages, - tools: AiConversationToolKit.definitions, - cancelToken: cancelToken, - ); - d( - 'AI tool round response: conversationId=$conversationId ' - 'round=${round + 1} text=${_previewText(response.text)} ' - 'toolCalls=${_previewToolCalls(response.toolCalls)}', - ); - - if (!response.hasToolCalls) { - final text = response.text.trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - if (streamFinalResponse) { - try { - d( - 'AI final stream start: conversationId=$conversationId ' - 'round=${round + 1}', - ); - return await _strategyFor(config.type).streamResponse( - dio: dio, - config: config, - messages: messages, - cancelToken: cancelToken, - onContent: onContent, - ); - } catch (error, stacktrace) { - e('AI final streaming fallback: $error, $stacktrace'); - await _emitBufferedText(text, onContent); - d( - 'AI final stream fallback: conversationId=$conversationId ' - 'round=${round + 1} text=${_previewText(text)}', - ); - return text; - } - } - await onContent(text); - d( - 'AI tool request done without stream: conversationId=$conversationId ' - 'round=${round + 1} text=${_previewText(text)}', - ); - return text; - } - - messages.add( - AiPromptMessage( - role: _kAiRoleAssistant, - content: response.text, - toolCalls: response.toolCalls, - ), - ); - for (final toolCall in response.toolCalls) { - final result = await _executeConversationTool( - conversationId: conversationId, - toolCall: toolCall, - ); - messages.add( - AiPromptMessage( - role: 'tool', - content: result.content, - toolCallId: result.toolCallId, - toolName: result.toolName, - toolPayload: result.payload, - ), - ); - } - } - - e( - 'AI exceeded tool call limit: conversationId=$conversationId ' - 'maxRounds=$_kAiToolMaxRounds', + toolCall: toolCall, ); - throw Exception('AI exceeded tool call limit'); - } - - Future _emitBufferedText( - String text, - Future Function(String chunk) onContent, - ) async { - final trimmed = text.trim(); - if (trimmed.isEmpty) { - return; - } - if (trimmed.length <= _kAiStreamFlushChars) { - await onContent(trimmed); - return; - } - for (var start = 0; start < trimmed.length; start += _kAiStreamFlushChars) { - final end = (start + _kAiStreamFlushChars).clamp(0, trimmed.length); - await onContent(trimmed.substring(start, end)); - } } Future _executeConversationTool({ @@ -841,12 +355,6 @@ class AiChatController { ); } } - - _AiProviderStrategy _strategyFor(AiProviderType type) => switch (type) { - AiProviderType.openaiCompatible => _openAiStrategy, - AiProviderType.anthropic => _anthropicStrategy, - AiProviderType.gemini => _geminiStrategy, - }; } String _previewText(String? text, {int maxLength = _kAiLogPreviewLength}) { @@ -872,709 +380,6 @@ String _previewJson(Object? value, {int maxLength = _kAiLogJsonPreviewLength}) { } } -String _previewToolCalls(List toolCalls) { - if (toolCalls.isEmpty) { - return '[]'; - } - return toolCalls - .map( - (toolCall) => - '${toolCall.name}#${toolCall.id}(${_previewJson(toolCall.arguments, maxLength: 120)})', - ) - .join(', '); -} - -extension on Iterable { - Iterable takeLast(int count) { - if (count <= 0) return const []; - final list = toList(); - if (list.length <= count) { - return list; - } - return list.sublist(list.length - count); - } -} - -abstract interface class _AiProviderStrategy { - const _AiProviderStrategy(); - - Map headers(AiProviderConfig config); - - Future<_AiCompletionResponse> completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }); - - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }); -} - -class _OpenAiCompatibleStrategy implements _AiProviderStrategy { - const _OpenAiCompatibleStrategy(); - - @override - Map headers(AiProviderConfig config) => { - 'Authorization': 'Bearer ${config.apiKey}', - 'Content-Type': 'application/json', - }; - - @override - Future<_AiCompletionResponse> completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }) async { - final response = await dio.post( - '/chat/completions', - data: { - 'model': config.model, - 'messages': messages.map(_openAiMessagePayload).toList(growable: false), - if (tools.isNotEmpty) - 'tools': tools - .map( - (tool) => { - 'type': 'function', - 'function': { - 'name': tool.name, - 'description': tool.description, - 'parameters': tool.inputSchema, - }, - }, - ) - .toList(growable: false), - if (tools.isNotEmpty) 'tool_choice': 'auto', - }, - cancelToken: cancelToken, - ); - - final body = _jsonMap(response.data); - final choices = body['choices'] as List?; - if (choices == null || choices.isEmpty) { - throw Exception('Empty AI response'); - } - final first = _jsonMap(choices.first); - final message = _jsonMap(first['message']); - final text = _stringContent(message['content']); - final toolCalls = (message['tool_calls'] as List? ?? const []) - .map((item) => _openAiToolCall(_jsonMap(item))) - .toList(growable: false); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return _AiCompletionResponse(text: text, toolCalls: toolCalls); - } - - @override - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final response = await dio.post( - '/chat/completions', - data: { - 'model': config.model, - 'stream': true, - 'messages': messages.map(_openAiMessagePayload).toList(growable: false), - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final buffer = StringBuffer(); - await for (final data in _decodeSse(body.stream)) { - if (data == '[DONE]') { - continue; - } - - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final choices = json['choices'] as List?; - if (choices == null || choices.isEmpty) { - continue; - } - - final first = choices.first; - if (first is! Map) { - continue; - } - - final delta = first['delta']; - if (delta is! Map) { - continue; - } - - final content = delta['content']; - if (content is String && content.isNotEmpty) { - buffer.write(content); - await onContent(content); - } - } - - final text = buffer.toString().trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - return text; - } - - Map _openAiMessagePayload(AiPromptMessage message) => { - 'role': message.role, - 'content': message.content, - if (message.hasToolCalls) - 'tool_calls': message.toolCalls - .map( - (toolCall) => { - 'id': toolCall.id, - 'type': 'function', - 'function': { - 'name': toolCall.name, - 'arguments': jsonEncode(toolCall.arguments), - }, - }, - ) - .toList(growable: false), - if (message.isToolResult) 'tool_call_id': message.toolCallId, - }; - - AiToolCall _openAiToolCall(Map value) { - final function = _jsonMap(value['function']); - final name = function['name'] as String?; - if (name == null || name.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - return AiToolCall( - id: value['id'] as String? ?? '${name}_${value.hashCode}', - name: name, - arguments: _toolArguments(function['arguments']), - ); - } -} - -class _AnthropicStrategy implements _AiProviderStrategy { - const _AnthropicStrategy(); - - @override - Map headers(AiProviderConfig config) => { - 'x-api-key': config.apiKey, - 'anthropic-version': '2023-06-01', - 'content-type': 'application/json', - }; - - @override - Future<_AiCompletionResponse> completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }) async { - final response = await dio.post( - '/messages', - data: { - 'model': config.model, - 'max_tokens': 1024, - 'messages': messages - .where((message) => message.role != 'system') - .map(_anthropicMessagePayload) - .toList(growable: false), - 'system': messages - .where((message) => message.role == 'system') - .map((message) => message.content) - .where((content) => content.isNotEmpty) - .join('\n\n'), - if (tools.isNotEmpty) - 'tools': tools - .map( - (tool) => { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.inputSchema, - }, - ) - .toList(growable: false), - }, - cancelToken: cancelToken, - ); - - final body = _jsonMap(response.data); - if (body['type'] == 'error') { - final error = _jsonMap(body['error']); - throw Exception(error['message'] ?? 'Anthropic request failed'); - } - - final content = body['content'] as List?; - if (content == null || content.isEmpty) { - throw Exception('Empty AI response'); - } - - final textBuffer = StringBuffer(); - final toolCalls = []; - for (final item in content) { - final block = _jsonMap(item); - switch (block['type']) { - case 'text': - final text = block['text']; - if (text is String && text.isNotEmpty) { - textBuffer.write(text); - } - case 'tool_use': - final name = block['name'] as String?; - if (name == null || name.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - toolCalls.add( - AiToolCall( - id: block['id'] as String? ?? '${name}_${block.hashCode}', - name: name, - arguments: _toolArguments(block['input']), - ), - ); - } - } - - final text = textBuffer.toString(); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return _AiCompletionResponse(text: text, toolCalls: toolCalls); - } - - @override - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final response = await dio.post( - '/messages', - data: { - 'model': config.model, - 'max_tokens': 1024, - 'stream': true, - 'messages': messages - .where((message) => message.role != 'system') - .map(_anthropicMessagePayload) - .toList(growable: false), - 'system': messages - .where((message) => message.role == 'system') - .map((message) => message.content) - .join('\n\n'), - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final buffer = StringBuffer(); - await for (final data in _decodeSse(body.stream)) { - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final type = json['type'] as String?; - if (type == 'error') { - final error = json['error']; - if (error is Map) { - throw Exception(error['message'] ?? 'Anthropic request failed'); - } - throw Exception('Anthropic request failed'); - } - - if (type != 'content_block_delta') { - continue; - } - - final delta = json['delta']; - if (delta is! Map) { - continue; - } - - if (delta['type'] != 'text_delta') { - continue; - } - - final text = delta['text']; - if (text is String && text.isNotEmpty) { - buffer.write(text); - await onContent(text); - } - } - - final text = buffer.toString().trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - return text; - } - - Map _anthropicMessagePayload(AiPromptMessage message) => { - 'role': message.isToolResult ? 'user' : message.role, - 'content': _anthropicContentBlocks(message), - }; - - List> _anthropicContentBlocks(AiPromptMessage message) { - if (message.isToolResult) { - return [ - { - 'type': 'tool_result', - 'tool_use_id': message.toolCallId, - 'content': message.content, - }, - ]; - } - - final blocks = >[]; - if (message.content.isNotEmpty) { - blocks.add({'type': 'text', 'text': message.content}); - } - for (final toolCall in message.toolCalls) { - blocks.add({ - 'type': 'tool_use', - 'id': toolCall.id, - 'name': toolCall.name, - 'input': toolCall.arguments, - }); - } - return blocks; - } -} - -class _GeminiStrategy implements _AiProviderStrategy { - const _GeminiStrategy(); - - @override - Map headers(AiProviderConfig config) => { - 'x-goog-api-key': config.apiKey, - 'content-type': 'application/json', - }; - - @override - Future<_AiCompletionResponse> completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }) async { - final systemInstruction = messages - .where((message) => message.role == 'system') - .map((message) => message.content.trim()) - .where((content) => content.isNotEmpty) - .join('\n\n'); - final response = await dio.post( - '/models/${Uri.encodeComponent(config.model)}:generateContent', - data: { - 'contents': messages - .where((message) => message.role != 'system') - .map(_geminiMessagePayload) - .toList(growable: false), - if (systemInstruction.isNotEmpty) - 'system_instruction': { - 'parts': [ - {'text': systemInstruction}, - ], - }, - if (tools.isNotEmpty) - 'tools': [ - { - 'functionDeclarations': tools - .map( - (tool) => { - 'name': tool.name, - 'description': tool.description, - 'parameters': tool.inputSchema, - }, - ) - .toList(growable: false), - }, - ], - if (tools.isNotEmpty) - 'toolConfig': { - 'functionCallingConfig': {'mode': 'AUTO'}, - }, - 'generationConfig': { - 'candidateCount': 1, - }, - }, - cancelToken: cancelToken, - ); - - final body = _jsonMap(response.data); - final promptFeedback = body['promptFeedback']; - if (promptFeedback is Map) { - final blockReason = promptFeedback['blockReason']; - if (blockReason is String && blockReason.isNotEmpty) { - throw Exception('Gemini request blocked: $blockReason'); - } - } - - final candidates = body['candidates'] as List?; - if (candidates == null || candidates.isEmpty) { - throw Exception('Empty AI response'); - } - final first = _jsonMap(candidates.first); - final finishReason = first['finishReason']; - if (finishReason is String && - finishReason.isNotEmpty && - finishReason != 'STOP' && - finishReason != 'FINISH_REASON_UNSPECIFIED') { - throw Exception('Gemini request finished with reason: $finishReason'); - } - - final content = _jsonMap(first['content']); - final parts = content['parts'] as List?; - if (parts == null || parts.isEmpty) { - throw Exception('Empty AI response'); - } - - final textBuffer = StringBuffer(); - final toolCalls = []; - for (final item in parts) { - final part = _jsonMap(item); - final text = part['text']; - if (text is String && text.isNotEmpty) { - textBuffer.write(text); - } - final functionCall = part['functionCall']; - if (functionCall is Map) { - final name = functionCall['name'] as String?; - if (name == null || name.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - toolCalls.add( - AiToolCall( - id: '${name}_${functionCall.hashCode}', - name: name, - arguments: _toolArguments(functionCall['args']), - ), - ); - } - } - - final text = textBuffer.toString(); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return _AiCompletionResponse(text: text, toolCalls: toolCalls); - } - - @override - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final systemInstruction = messages - .where((message) => message.role == 'system') - .map((message) => message.content.trim()) - .where((content) => content.isNotEmpty) - .join('\n\n'); - - final contents = messages - .where((message) => message.role != 'system') - .map(_geminiMessagePayload) - .toList(growable: false); - - final response = await dio.post( - '/models/${Uri.encodeComponent(config.model)}:streamGenerateContent', - queryParameters: const {'alt': 'sse'}, - data: { - 'contents': contents, - if (systemInstruction.isNotEmpty) - 'system_instruction': { - 'parts': [ - {'text': systemInstruction}, - ], - }, - 'generationConfig': { - 'candidateCount': 1, - }, - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final buffer = StringBuffer(); - await for (final data in _decodeSse(body.stream)) { - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final promptFeedback = json['promptFeedback']; - if (promptFeedback is Map) { - final blockReason = promptFeedback['blockReason']; - if (blockReason is String && blockReason.isNotEmpty) { - throw Exception('Gemini request blocked: $blockReason'); - } - } - - final candidates = json['candidates'] as List?; - if (candidates == null || candidates.isEmpty) { - continue; - } - - final first = candidates.first; - if (first is! Map) { - continue; - } - - final finishReason = first['finishReason']; - if (finishReason is String && - finishReason.isNotEmpty && - finishReason != 'STOP' && - finishReason != 'FINISH_REASON_UNSPECIFIED') { - throw Exception('Gemini request finished with reason: $finishReason'); - } - - final content = first['content']; - if (content is! Map) { - continue; - } - - final parts = content['parts'] as List?; - if (parts == null || parts.isEmpty) { - continue; - } - - for (final part in parts) { - if (part is! Map) { - continue; - } - final text = part['text']; - if (text is String && text.isNotEmpty) { - buffer.write(text); - await onContent(text); - } - } - } - - final text = buffer.toString().trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - return text; - } - - Map _geminiMessagePayload(AiPromptMessage message) => { - 'role': message.role == _kAiRoleAssistant ? 'model' : 'user', - 'parts': _geminiMessageParts(message), - }; - - List> _geminiMessageParts(AiPromptMessage message) { - if (message.isToolResult) { - return [ - { - 'functionResponse': { - 'name': message.toolName, - 'response': message.toolPayload ?? {'content': message.content}, - }, - }, - ]; - } - - final parts = >[]; - if (message.content.isNotEmpty) { - parts.add({'text': message.content}); - } - for (final toolCall in message.toolCalls) { - parts.add({ - 'functionCall': { - 'name': toolCall.name, - 'args': toolCall.arguments, - }, - }); - } - return parts; - } -} - -class _AiCompletionResponse { - const _AiCompletionResponse({ - this.text = '', - this.toolCalls = const [], - }); - - final String text; - final List toolCalls; - - bool get hasToolCalls => toolCalls.isNotEmpty; -} - -Map _jsonMap(dynamic value) { - if (value is Map) { - return value; - } - if (value is Map) { - return value.map((key, value) => MapEntry('$key', value)); - } - throw Exception('Invalid AI response payload'); -} - -Map _toolArguments(dynamic value) { - if (value == null) { - return const {}; - } - if (value is String) { - final trimmed = value.trim(); - if (trimmed.isEmpty) { - return const {}; - } - final decoded = jsonDecode(trimmed); - return _jsonMap(decoded); - } - return _jsonMap(value); -} - -String _stringContent(dynamic value) { - if (value is String) { - return value; - } - if (value is List) { - return value - .whereType() - .map((item) => item['text']) - .whereType() - .join('\n'); - } - return ''; -} - class _StreamingMessageUpdater { _StreamingMessageUpdater({required this.dao, required this.messageId}); @@ -1614,34 +419,3 @@ class _StreamingMessageUpdater { ); } } - -Stream _decodeSse(Stream> stream) async* { - final buffer = StringBuffer(); - await for (final bytes in stream) { - final chunk = utf8.decode(bytes); - buffer.write(chunk.replaceAll('\r\n', '\n').replaceAll('\r', '\n')); - while (true) { - final current = buffer.toString(); - final separatorIndex = current.indexOf('\n\n'); - if (separatorIndex < 0) { - break; - } - - final rawEvent = current.substring(0, separatorIndex); - final remaining = current.substring(separatorIndex + 2); - buffer - ..clear() - ..write(remaining); - - final payload = rawEvent - .split('\n') - .where((line) => line.startsWith('data:')) - .map((line) => line.substring(5).trimLeft()) - .join('\n') - .trim(); - if (payload.isNotEmpty) { - yield payload; - } - } - } -} diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart new file mode 100644 index 0000000000..54e25d6c58 --- /dev/null +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -0,0 +1,370 @@ +import 'dart:async'; + +import 'package:mixin_logger/mixin_logger.dart'; + +import '../db/dao/message_dao.dart'; +import '../db/database.dart'; +import '../db/mixin_database.dart'; +import 'model/ai_prompt_message.dart'; +import 'tools/ai_conversation_tool_service.dart'; + +class AiChatPromptBuilder { + AiChatPromptBuilder(this.database); + + static const _aiRoleUser = 'user'; + static const _aiStatusPending = 'pending'; + static const _aiContextMessageLimit = 30; + static const _aiRetrievedMessageLimit = 6; + static const _aiHistoryLimit = 12; + static const _aiRetrievalQueryMaxLength = 120; + static const _aiLogPreviewLength = 240; + + final Database database; + + Future> buildPromptMessages( + String conversationId, + String input, + ) async { + final recentMessages = await database.messageDao + .messagesByConversationId(conversationId, _aiContextMessageLimit) + .get(); + final retrievedMessages = await _retrieveConversationMessages( + conversationId: conversationId, + recentMessages: recentMessages, + query: input, + ); + final aiMessages = await database.aiChatMessageDao.conversationMessages( + conversationId, + ); + + final promptMessages = [ + AiPromptMessage( + role: 'system', + content: + 'You are a local AI assistant inside a chat application. ' + 'Only use the provided current conversation context. ' + 'Help summarize, answer questions about the conversation, ' + 'and draft replies. Be concise and practical.', + ), + ]; + + _appendConversationToolInstruction(promptMessages, enabled: true); + _appendConversationContext( + promptMessages, + recentMessages: recentMessages, + retrievedMessages: retrievedMessages, + ); + + final history = aiMessages + .where((element) => element.status != _aiStatusPending) + .takeLast(_aiHistoryLimit); + for (final item in history) { + promptMessages.add( + AiPromptMessage(role: item.role, content: item.content), + ); + } + + promptMessages.add(AiPromptMessage(role: _aiRoleUser, content: input)); + d( + 'AI prompt built: conversationId=$conversationId ' + 'recent=${recentMessages.length} retrieved=${retrievedMessages.length} ' + 'history=${history.length} promptMessages=${promptMessages.length}', + ); + return promptMessages; + } + + Future> buildAssistPromptMessages({ + required String instruction, + required String? input, + required String? conversationId, + }) async { + final promptMessages = [ + AiPromptMessage( + role: 'system', + content: + 'You are an invisible writing assistant inside a chat app. ' + 'Return only the requested text. Do not add explanations, ' + 'labels, markdown fences, or greetings unless explicitly ' + 'requested.', + ), + ]; + + if (conversationId != null) { + _appendConversationToolInstruction(promptMessages, enabled: true); + final recentMessages = await database.messageDao + .messagesByConversationId(conversationId, _aiContextMessageLimit) + .get(); + final retrievedMessages = await _retrieveConversationMessages( + conversationId: conversationId, + recentMessages: recentMessages, + query: input ?? _latestRetrievalSeed(recentMessages), + ); + _appendConversationContext( + promptMessages, + recentMessages: recentMessages, + retrievedMessages: retrievedMessages, + ); + } + + final inputText = input?.trim(); + promptMessages.add( + AiPromptMessage( + role: _aiRoleUser, + content: [ + instruction.trim(), + if (inputText != null && inputText.isNotEmpty) '\nText:\n$inputText', + ].join('\n'), + ), + ); + d( + 'AI assist prompt built: conversationId=$conversationId ' + 'messages=${promptMessages.length}', + ); + return promptMessages; + } + + List buildConversationSummaryPromptMessages({ + required AiConversationToolStats stats, + required String? languageTag, + }) { + final outputLanguage = languageTag?.trim(); + return [ + AiPromptMessage( + role: 'system', + content: + 'You are a conversation summarizer inside a chat application. ' + 'For this task, you must use the available read-only ' + 'conversation tools to inspect the requested time range ' + 'before writing the final answer. Start by calling ' + 'list_conversation_chunks for the exact range, then ' + 'read_conversation_chunk until you have covered the full ' + 'range. Do not rely only on recent context or search for ' + 'this task.', + ), + AiPromptMessage( + role: 'system', + content: + 'Summaries must cover the requested range completely and ' + 'should include main topics, key decisions, action items, ' + 'unresolved questions, and notable follow-ups. Keep the ' + 'final answer concise but comprehensive.', + ), + AiPromptMessage( + role: 'user', + content: [ + 'Summarize the conversation messages in this time range.', + 'Conversation ID: ${stats.conversationId}', + 'Start time: ${stats.startInclusive?.toIso8601String() ?? 'unspecified'}', + 'End time: ${stats.endExclusive?.toIso8601String() ?? 'unspecified'}', + 'Messages in range: ${stats.messageCount}', + if (stats.firstMessageAt != null) + 'First message at: ${stats.firstMessageAt!.toIso8601String()}', + if (stats.lastMessageAt != null) + 'Last message at: ${stats.lastMessageAt!.toIso8601String()}', + if (outputLanguage != null && outputLanguage.isNotEmpty) + 'Write the final summary in $outputLanguage.', + 'Before finalizing, make sure you have covered every chunk in the range.', + 'Return only the summary text.', + ].join('\n'), + ), + ]; + } + + void _appendConversationToolInstruction( + List promptMessages, { + required bool enabled, + }) { + if (!enabled) { + return; + } + promptMessages.add( + AiPromptMessage( + role: 'system', + content: + 'Read-only conversation tools are available for the current ' + 'conversation. Use them when you need exhaustive coverage, ' + 'date-scoped summaries, statistics, older messages, or more ' + 'context than the provided messages. Do not call tools when ' + 'the provided context is already sufficient.', + ), + ); + } + + void _appendConversationContext( + List promptMessages, { + required List recentMessages, + required List retrievedMessages, + }) { + if (recentMessages.isNotEmpty) { + final lines = recentMessages.reversed + .map( + (message) => _conversationContextLine( + createdAt: message.createdAt, + sender: message.userFullName ?? message.userId, + content: _messagePlainText(message), + ), + ) + .join('\n'); + promptMessages.add( + AiPromptMessage( + role: 'system', + content: 'Current conversation recent messages:\n$lines', + ), + ); + } + + if (retrievedMessages.isEmpty) { + return; + } + + final lines = retrievedMessages + .map( + (message) => _conversationContextLine( + createdAt: message.createdAt, + sender: message.senderFullName ?? message.senderId, + content: _searchMessagePlainText(message), + ), + ) + .join('\n'); + promptMessages.add( + AiPromptMessage( + role: 'system', + content: + 'Relevant older conversation messages matched by search ' + '(use only if they help answer the current request):\n$lines', + ), + ); + } + + Future> _retrieveConversationMessages({ + required String conversationId, + required List recentMessages, + required String? query, + }) async { + final normalizedQuery = _normalizeRetrievalQuery(query); + if (normalizedQuery == null) { + d('AI retrieval skipped: conversationId=$conversationId empty query'); + return const []; + } + + final recentIds = recentMessages + .map((message) => message.messageId) + .toSet(); + final matchedIds = await database.ftsDatabase.fuzzySearchMessage( + query: normalizedQuery, + limit: _aiRetrievedMessageLimit + recentIds.length, + conversationIds: [conversationId], + ); + final candidateIds = matchedIds + .where((messageId) => !recentIds.contains(messageId)) + .take(_aiRetrievedMessageLimit) + .toList(growable: false); + if (candidateIds.isEmpty) { + d( + 'AI retrieval no match: conversationId=$conversationId ' + 'query=${_previewText(normalizedQuery)}', + ); + return const []; + } + + final matchedMessages = await database.messageDao + .searchMessageByIds(candidateIds) + .get(); + final messagesById = { + for (final message in matchedMessages) message.messageId: message, + }; + final ordered = []; + for (final messageId in candidateIds) { + final message = messagesById[messageId]; + if (message != null) { + ordered.add(message); + } + } + ordered.sort((left, right) => left.createdAt.compareTo(right.createdAt)); + d( + 'AI retrieval matched: conversationId=$conversationId ' + 'query=${_previewText(normalizedQuery)} matches=${ordered.length}', + ); + return ordered; + } + + String? _latestRetrievalSeed(List recentMessages) { + for (final message in recentMessages) { + final content = _messagePlainText(message); + final normalized = _normalizeRetrievalQuery(content); + if (normalized != null) { + return normalized; + } + } + return null; + } + + String? _normalizeRetrievalQuery(String? query) { + final compact = query?.replaceAll(RegExp(r'\s+'), ' ').trim(); + if (compact == null || compact.isEmpty) { + return null; + } + if (compact.length <= _aiRetrievalQueryMaxLength) { + return compact; + } + return compact.substring(0, _aiRetrievalQueryMaxLength); + } + + String _conversationContextLine({ + required DateTime createdAt, + required String sender, + required String content, + }) => '[${createdAt.toIso8601String()}] $sender: $content'; + + String _messagePlainText(MessageItem message) => _messagePlainTextFromFields( + content: message.content, + mediaName: message.mediaName, + type: message.type, + ); + + String _searchMessagePlainText(SearchMessageDetailItem message) => + _messagePlainTextFromFields( + content: message.content, + mediaName: message.mediaName, + type: message.type, + ); + + String _messagePlainTextFromFields({ + required String? content, + required String? mediaName, + required String type, + }) { + if (content?.trim().isNotEmpty == true) { + return content!.trim(); + } + if (mediaName?.isNotEmpty == true) { + return '[$type] $mediaName'; + } + return '[$type]'; + } +} + +String _previewText( + String? text, { + int maxLength = AiChatPromptBuilder._aiLogPreviewLength, +}) { + final compact = text?.replaceAll(RegExp(r'\s+'), ' ').trim() ?? ''; + if (compact.isEmpty) { + return '""'; + } + if (compact.length <= maxLength) { + return compact; + } + return '${compact.substring(0, maxLength)}...(${compact.length} chars)'; +} + +extension _IterableTakeLastExtension on Iterable { + Iterable takeLast(int count) { + if (count <= 0) return const []; + final list = toList(); + if (list.length <= count) { + return list; + } + return list.sublist(list.length - count); + } +} diff --git a/lib/ai/ai_provider_requester.dart b/lib/ai/ai_provider_requester.dart new file mode 100644 index 0000000000..4d58b493ec --- /dev/null +++ b/lib/ai/ai_provider_requester.dart @@ -0,0 +1,277 @@ +import 'dart:async'; +import 'dart:convert'; + +import 'package:dio/dio.dart'; +import 'package:mixin_logger/mixin_logger.dart'; + +import '../utils/proxy.dart'; +import 'model/ai_prompt_message.dart'; +import 'model/ai_provider_config.dart'; +import 'model/ai_provider_type.dart'; +import 'model/ai_tool.dart'; +import 'provider/ai_provider_strategy.dart'; +import 'provider/anthropic_strategy.dart'; +import 'provider/gemini_strategy.dart'; +import 'provider/openai_compatible_strategy.dart'; +import 'tools/ai_conversation_tool_service.dart'; + +class AiProviderRequester { + const AiProviderRequester(); + + static const _aiToolMaxRounds = 8; + static const _aiStreamFlushChars = 32; + static const _aiLogPreviewLength = 240; + static const _aiLogJsonPreviewLength = 480; + + static const _openAiStrategy = OpenAiCompatibleStrategy(); + static const _anthropicStrategy = AnthropicStrategy(); + static const _geminiStrategy = GeminiStrategy(); + + Future requestText( + AiProviderConfig config, + List messages, { + required ProxyConfig? proxy, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + required bool streamFinalResponse, + required String? conversationId, + Future Function(AiToolCall toolCall)? onToolCall, + }) async { + d( + 'AI request start: provider=${config.type.name} model=${config.model} ' + 'conversationId=$conversationId streamFinal=$streamFinalResponse ' + 'messages=${messages.length} ' + 'tools=${conversationId != null && onToolCall != null}', + ); + final dio = + Dio( + BaseOptions( + baseUrl: config.baseUrl, + connectTimeout: const Duration(seconds: 20), + receiveTimeout: const Duration(minutes: 5), + sendTimeout: const Duration(seconds: 20), + headers: _strategyFor(config.type).headers(config), + ), + ) + ..interceptors.add( + InterceptorsWrapper( + onRequest: (options, handler) { + options.extra['ai_request_started_at'] = DateTime.now(); + d( + 'AI HTTP request: ${options.method} ${options.uri} ' + 'provider=${config.type.name} model=${config.model}', + ); + handler.next(options); + }, + onResponse: (response, handler) { + final startedAt = + response.requestOptions.extra['ai_request_started_at'] + as DateTime?; + d( + 'AI HTTP response: ${response.requestOptions.method} ' + '${response.requestOptions.uri} ' + 'status=${response.statusCode} ' + 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds}', + ); + handler.next(response); + }, + onError: (error, handler) { + final startedAt = + error.requestOptions.extra['ai_request_started_at'] + as DateTime?; + e( + 'AI HTTP error: ${error.requestOptions.method} ' + '${error.requestOptions.uri} ' + 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds} ' + 'error=${error.message}', + error, + error.stackTrace, + ); + handler.next(error); + }, + ), + ) + ..applyProxy(proxy); + + if (conversationId == null || onToolCall == null) { + return _strategyFor(config.type).streamResponse( + dio: dio, + config: config, + messages: messages, + cancelToken: cancelToken, + onContent: onContent, + ); + } + + return _requestWithTools( + dio, + config, + [...messages], + conversationId: conversationId, + cancelToken: cancelToken, + onContent: onContent, + onToolCall: onToolCall, + streamFinalResponse: streamFinalResponse, + ); + } + + Future _requestWithTools( + Dio dio, + AiProviderConfig config, + List messages, { + required String conversationId, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + required Future Function(AiToolCall toolCall) + onToolCall, + required bool streamFinalResponse, + }) async { + for (var round = 0; round < _aiToolMaxRounds; round++) { + d( + 'AI tool round start: conversationId=$conversationId ' + 'round=${round + 1}/$_aiToolMaxRounds messages=${messages.length}', + ); + final response = await _strategyFor(config.type).completeResponse( + dio: dio, + config: config, + messages: messages, + tools: AiConversationToolKit.definitions, + cancelToken: cancelToken, + ); + d( + 'AI tool round response: conversationId=$conversationId ' + 'round=${round + 1} text=${_previewText(response.text)} ' + 'toolCalls=${_previewToolCalls(response.toolCalls)}', + ); + + if (!response.hasToolCalls) { + final text = response.text.trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + if (streamFinalResponse) { + try { + d( + 'AI final stream start: conversationId=$conversationId ' + 'round=${round + 1}', + ); + return await _strategyFor(config.type).streamResponse( + dio: dio, + config: config, + messages: messages, + cancelToken: cancelToken, + onContent: onContent, + ); + } catch (error, stacktrace) { + e('AI final streaming fallback: $error, $stacktrace'); + await _emitBufferedText(text, onContent); + d( + 'AI final stream fallback: conversationId=$conversationId ' + 'round=${round + 1} text=${_previewText(text)}', + ); + return text; + } + } + await onContent(text); + d( + 'AI tool request done without stream: ' + 'conversationId=$conversationId ' + 'round=${round + 1} text=${_previewText(text)}', + ); + return text; + } + + messages.add( + AiPromptMessage( + role: 'assistant', + content: response.text, + toolCalls: response.toolCalls, + ), + ); + for (final toolCall in response.toolCalls) { + final result = await onToolCall(toolCall); + messages.add( + AiPromptMessage( + role: 'tool', + content: result.content, + toolCallId: result.toolCallId, + toolName: result.toolName, + toolPayload: result.payload, + ), + ); + } + } + + e( + 'AI exceeded tool call limit: conversationId=$conversationId ' + 'maxRounds=$_aiToolMaxRounds', + ); + throw Exception('AI exceeded tool call limit'); + } + + Future _emitBufferedText( + String text, + Future Function(String chunk) onContent, + ) async { + final trimmed = text.trim(); + if (trimmed.isEmpty) { + return; + } + if (trimmed.length <= _aiStreamFlushChars) { + await onContent(trimmed); + return; + } + for (var start = 0; start < trimmed.length; start += _aiStreamFlushChars) { + final end = (start + _aiStreamFlushChars).clamp(0, trimmed.length); + await onContent(trimmed.substring(start, end)); + } + } + + AiProviderStrategy _strategyFor(AiProviderType type) => switch (type) { + AiProviderType.openaiCompatible => _openAiStrategy, + AiProviderType.anthropic => _anthropicStrategy, + AiProviderType.gemini => _geminiStrategy, + }; +} + +String _previewText( + String? text, { + int maxLength = AiProviderRequester._aiLogPreviewLength, +}) { + final compact = text?.replaceAll(RegExp(r'\s+'), ' ').trim() ?? ''; + if (compact.isEmpty) { + return '""'; + } + if (compact.length <= maxLength) { + return compact; + } + return '${compact.substring(0, maxLength)}...(${compact.length} chars)'; +} + +String _previewJson( + Object? value, { + int maxLength = AiProviderRequester._aiLogJsonPreviewLength, +}) { + try { + final encoded = jsonEncode(value); + if (encoded.length <= maxLength) { + return encoded; + } + return '${encoded.substring(0, maxLength)}...(${encoded.length} chars)'; + } catch (_) { + return '$value'; + } +} + +String _previewToolCalls(List toolCalls) { + if (toolCalls.isEmpty) { + return '[]'; + } + return toolCalls + .map( + (toolCall) => + '${toolCall.name}#${toolCall.id}(' + '${_previewJson(toolCall.arguments, maxLength: 120)})', + ) + .join(', '); +} diff --git a/lib/ai/provider/ai_provider_strategy.dart b/lib/ai/provider/ai_provider_strategy.dart new file mode 100644 index 0000000000..a0b057d476 --- /dev/null +++ b/lib/ai/provider/ai_provider_strategy.dart @@ -0,0 +1,116 @@ +import 'dart:async'; +import 'dart:convert'; + +import 'package:dio/dio.dart'; + +import '../model/ai_prompt_message.dart'; +import '../model/ai_provider_config.dart'; +import '../model/ai_tool.dart'; + +abstract interface class AiProviderStrategy { + const AiProviderStrategy(); + + Map headers(AiProviderConfig config); + + Future completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }); + + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + }); +} + +class AiCompletionResponse { + const AiCompletionResponse({ + this.text = '', + this.toolCalls = const [], + }); + + final String text; + final List toolCalls; + + bool get hasToolCalls => toolCalls.isNotEmpty; +} + +final class AiProviderStrategySupport { + const AiProviderStrategySupport._(); + + static Map jsonMap(dynamic value) { + if (value is Map) { + return value; + } + if (value is Map) { + return value.map((key, value) => MapEntry('$key', value)); + } + throw Exception('Invalid AI response payload'); + } + + static Map toolArguments(dynamic value) { + if (value == null) { + return const {}; + } + if (value is String) { + final trimmed = value.trim(); + if (trimmed.isEmpty) { + return const {}; + } + final decoded = jsonDecode(trimmed); + return jsonMap(decoded); + } + return jsonMap(value); + } + + static String stringContent(dynamic value) { + if (value is String) { + return value; + } + if (value is List) { + return value + .whereType() + .map((item) => item['text']) + .whereType() + .join('\n'); + } + return ''; + } + + static Stream decodeSse(Stream> stream) async* { + final buffer = StringBuffer(); + await for (final bytes in stream) { + final chunk = utf8.decode(bytes); + buffer.write(chunk.replaceAll('\r\n', '\n').replaceAll('\r', '\n')); + while (true) { + final current = buffer.toString(); + final separatorIndex = current.indexOf('\n\n'); + if (separatorIndex < 0) { + break; + } + + final rawEvent = current.substring(0, separatorIndex); + final remaining = current.substring(separatorIndex + 2); + buffer + ..clear() + ..write(remaining); + + final payload = rawEvent + .split('\n') + .where((line) => line.startsWith('data:')) + .map((line) => line.substring(5).trimLeft()) + .join('\n') + .trim(); + if (payload.isNotEmpty) { + yield payload; + } + } + } + } +} diff --git a/lib/ai/provider/anthropic_strategy.dart b/lib/ai/provider/anthropic_strategy.dart new file mode 100644 index 0000000000..4199965ca3 --- /dev/null +++ b/lib/ai/provider/anthropic_strategy.dart @@ -0,0 +1,208 @@ +import 'dart:convert'; + +import 'package:dio/dio.dart'; + +import '../model/ai_prompt_message.dart'; +import '../model/ai_provider_config.dart'; +import '../model/ai_tool.dart'; +import 'ai_provider_strategy.dart'; + +class AnthropicStrategy implements AiProviderStrategy { + const AnthropicStrategy(); + + @override + Map headers(AiProviderConfig config) => { + 'x-api-key': config.apiKey, + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json', + }; + + @override + Future completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }) async { + final response = await dio.post( + '/messages', + data: { + 'model': config.model, + 'max_tokens': 1024, + 'messages': messages + .where((message) => message.role != 'system') + .map(_anthropicMessagePayload) + .toList(growable: false), + 'system': messages + .where((message) => message.role == 'system') + .map((message) => message.content) + .where((content) => content.isNotEmpty) + .join('\n\n'), + if (tools.isNotEmpty) + 'tools': tools + .map( + (tool) => { + 'name': tool.name, + 'description': tool.description, + 'input_schema': tool.inputSchema, + }, + ) + .toList(growable: false), + }, + cancelToken: cancelToken, + ); + + final body = AiProviderStrategySupport.jsonMap(response.data); + if (body['type'] == 'error') { + final error = AiProviderStrategySupport.jsonMap(body['error']); + throw Exception(error['message'] ?? 'Anthropic request failed'); + } + + final content = body['content'] as List?; + if (content == null || content.isEmpty) { + throw Exception('Empty AI response'); + } + + final textBuffer = StringBuffer(); + final toolCalls = []; + for (final item in content) { + final block = AiProviderStrategySupport.jsonMap(item); + switch (block['type']) { + case 'text': + final text = block['text']; + if (text is String && text.isNotEmpty) { + textBuffer.write(text); + } + case 'tool_use': + final name = block['name'] as String?; + if (name == null || name.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + toolCalls.add( + AiToolCall( + id: block['id'] as String? ?? '${name}_${block.hashCode}', + name: name, + arguments: AiProviderStrategySupport.toolArguments( + block['input'], + ), + ), + ); + } + } + + final text = textBuffer.toString(); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return AiCompletionResponse(text: text, toolCalls: toolCalls); + } + + @override + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + }) async { + final response = await dio.post( + '/messages', + data: { + 'model': config.model, + 'max_tokens': 1024, + 'stream': true, + 'messages': messages + .where((message) => message.role != 'system') + .map(_anthropicMessagePayload) + .toList(growable: false), + 'system': messages + .where((message) => message.role == 'system') + .map((message) => message.content) + .join('\n\n'), + }, + options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final buffer = StringBuffer(); + await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final type = json['type'] as String?; + if (type == 'error') { + final error = json['error']; + if (error is Map) { + throw Exception(error['message'] ?? 'Anthropic request failed'); + } + throw Exception('Anthropic request failed'); + } + + if (type != 'content_block_delta') { + continue; + } + + final delta = json['delta']; + if (delta is! Map) { + continue; + } + + if (delta['type'] != 'text_delta') { + continue; + } + + final text = delta['text']; + if (text is String && text.isNotEmpty) { + buffer.write(text); + await onContent(text); + } + } + + final text = buffer.toString().trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + return text; + } + + Map _anthropicMessagePayload(AiPromptMessage message) => { + 'role': message.isToolResult ? 'user' : message.role, + 'content': _anthropicContentBlocks(message), + }; + + List> _anthropicContentBlocks( + AiPromptMessage message, + ) { + if (message.isToolResult) { + return [ + { + 'type': 'tool_result', + 'tool_use_id': message.toolCallId, + 'content': message.content, + }, + ]; + } + + final blocks = >[]; + if (message.content.isNotEmpty) { + blocks.add({'type': 'text', 'text': message.content}); + } + for (final toolCall in message.toolCalls) { + blocks.add({ + 'type': 'tool_use', + 'id': toolCall.id, + 'name': toolCall.name, + 'input': toolCall.arguments, + }); + } + return blocks; + } +} diff --git a/lib/ai/provider/gemini_strategy.dart b/lib/ai/provider/gemini_strategy.dart new file mode 100644 index 0000000000..b6030ba9b5 --- /dev/null +++ b/lib/ai/provider/gemini_strategy.dart @@ -0,0 +1,267 @@ +import 'dart:convert'; + +import 'package:dio/dio.dart'; + +import '../model/ai_prompt_message.dart'; +import '../model/ai_provider_config.dart'; +import '../model/ai_tool.dart'; +import 'ai_provider_strategy.dart'; + +class GeminiStrategy implements AiProviderStrategy { + const GeminiStrategy(); + + @override + Map headers(AiProviderConfig config) => { + 'x-goog-api-key': config.apiKey, + 'content-type': 'application/json', + }; + + @override + Future completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }) async { + final systemInstruction = messages + .where((message) => message.role == 'system') + .map((message) => message.content.trim()) + .where((content) => content.isNotEmpty) + .join('\n\n'); + final response = await dio.post( + '/models/${Uri.encodeComponent(config.model)}:generateContent', + data: { + 'contents': messages + .where((message) => message.role != 'system') + .map(_geminiMessagePayload) + .toList(growable: false), + if (systemInstruction.isNotEmpty) + 'system_instruction': { + 'parts': [ + {'text': systemInstruction}, + ], + }, + if (tools.isNotEmpty) + 'tools': [ + { + 'functionDeclarations': tools + .map( + (tool) => { + 'name': tool.name, + 'description': tool.description, + 'parameters': tool.inputSchema, + }, + ) + .toList(growable: false), + }, + ], + if (tools.isNotEmpty) + 'toolConfig': { + 'functionCallingConfig': {'mode': 'AUTO'}, + }, + 'generationConfig': { + 'candidateCount': 1, + }, + }, + cancelToken: cancelToken, + ); + + final body = AiProviderStrategySupport.jsonMap(response.data); + final promptFeedback = body['promptFeedback']; + if (promptFeedback is Map) { + final blockReason = promptFeedback['blockReason']; + if (blockReason is String && blockReason.isNotEmpty) { + throw Exception('Gemini request blocked: $blockReason'); + } + } + + final candidates = body['candidates'] as List?; + if (candidates == null || candidates.isEmpty) { + throw Exception('Empty AI response'); + } + final first = AiProviderStrategySupport.jsonMap(candidates.first); + final finishReason = first['finishReason']; + if (finishReason is String && + finishReason.isNotEmpty && + finishReason != 'STOP' && + finishReason != 'FINISH_REASON_UNSPECIFIED') { + throw Exception('Gemini request finished with reason: $finishReason'); + } + + final content = AiProviderStrategySupport.jsonMap(first['content']); + final parts = content['parts'] as List?; + if (parts == null || parts.isEmpty) { + throw Exception('Empty AI response'); + } + + final textBuffer = StringBuffer(); + final toolCalls = []; + for (final item in parts) { + final part = AiProviderStrategySupport.jsonMap(item); + final text = part['text']; + if (text is String && text.isNotEmpty) { + textBuffer.write(text); + } + final functionCall = part['functionCall']; + if (functionCall is Map) { + final name = functionCall['name'] as String?; + if (name == null || name.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + toolCalls.add( + AiToolCall( + id: '${name}_${functionCall.hashCode}', + name: name, + arguments: AiProviderStrategySupport.toolArguments( + functionCall['args'], + ), + ), + ); + } + } + + final text = textBuffer.toString(); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return AiCompletionResponse(text: text, toolCalls: toolCalls); + } + + @override + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + }) async { + final systemInstruction = messages + .where((message) => message.role == 'system') + .map((message) => message.content.trim()) + .where((content) => content.isNotEmpty) + .join('\n\n'); + + final contents = messages + .where((message) => message.role != 'system') + .map(_geminiMessagePayload) + .toList(growable: false); + + final response = await dio.post( + '/models/${Uri.encodeComponent(config.model)}:streamGenerateContent', + queryParameters: const {'alt': 'sse'}, + data: { + 'contents': contents, + if (systemInstruction.isNotEmpty) + 'system_instruction': { + 'parts': [ + {'text': systemInstruction}, + ], + }, + 'generationConfig': { + 'candidateCount': 1, + }, + }, + options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final buffer = StringBuffer(); + await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final promptFeedback = json['promptFeedback']; + if (promptFeedback is Map) { + final blockReason = promptFeedback['blockReason']; + if (blockReason is String && blockReason.isNotEmpty) { + throw Exception('Gemini request blocked: $blockReason'); + } + } + + final candidates = json['candidates'] as List?; + if (candidates == null || candidates.isEmpty) { + continue; + } + + final first = candidates.first; + if (first is! Map) { + continue; + } + + final finishReason = first['finishReason']; + if (finishReason is String && + finishReason.isNotEmpty && + finishReason != 'STOP' && + finishReason != 'FINISH_REASON_UNSPECIFIED') { + throw Exception('Gemini request finished with reason: $finishReason'); + } + + final content = first['content']; + if (content is! Map) { + continue; + } + + final parts = content['parts'] as List?; + if (parts == null || parts.isEmpty) { + continue; + } + + for (final part in parts) { + if (part is! Map) { + continue; + } + final text = part['text']; + if (text is String && text.isNotEmpty) { + buffer.write(text); + await onContent(text); + } + } + } + + final text = buffer.toString().trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + return text; + } + + Map _geminiMessagePayload(AiPromptMessage message) => { + 'role': message.role == 'assistant' ? 'model' : 'user', + 'parts': _geminiMessageParts(message), + }; + + List> _geminiMessageParts(AiPromptMessage message) { + if (message.isToolResult) { + return [ + { + 'functionResponse': { + 'name': message.toolName, + 'response': message.toolPayload ?? {'content': message.content}, + }, + }, + ]; + } + + final parts = >[]; + if (message.content.isNotEmpty) { + parts.add({'text': message.content}); + } + for (final toolCall in message.toolCalls) { + parts.add({ + 'functionCall': { + 'name': toolCall.name, + 'args': toolCall.arguments, + }, + }); + } + return parts; + } +} diff --git a/lib/ai/provider/openai_compatible_strategy.dart b/lib/ai/provider/openai_compatible_strategy.dart new file mode 100644 index 0000000000..dd6add34e3 --- /dev/null +++ b/lib/ai/provider/openai_compatible_strategy.dart @@ -0,0 +1,162 @@ +import 'dart:convert'; + +import 'package:dio/dio.dart'; + +import '../model/ai_prompt_message.dart'; +import '../model/ai_provider_config.dart'; +import '../model/ai_tool.dart'; +import 'ai_provider_strategy.dart'; + +class OpenAiCompatibleStrategy implements AiProviderStrategy { + const OpenAiCompatibleStrategy(); + + @override + Map headers(AiProviderConfig config) => { + 'Authorization': 'Bearer ${config.apiKey}', + 'Content-Type': 'application/json', + }; + + @override + Future completeResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + }) async { + final response = await dio.post( + '/chat/completions', + data: { + 'model': config.model, + 'messages': messages.map(_openAiMessagePayload).toList(growable: false), + if (tools.isNotEmpty) + 'tools': tools + .map( + (tool) => { + 'type': 'function', + 'function': { + 'name': tool.name, + 'description': tool.description, + 'parameters': tool.inputSchema, + }, + }, + ) + .toList(growable: false), + if (tools.isNotEmpty) 'tool_choice': 'auto', + }, + cancelToken: cancelToken, + ); + + final body = AiProviderStrategySupport.jsonMap(response.data); + final choices = body['choices'] as List?; + if (choices == null || choices.isEmpty) { + throw Exception('Empty AI response'); + } + final first = AiProviderStrategySupport.jsonMap(choices.first); + final message = AiProviderStrategySupport.jsonMap(first['message']); + final text = AiProviderStrategySupport.stringContent(message['content']); + final toolCalls = (message['tool_calls'] as List? ?? const []) + .map((item) => _openAiToolCall(AiProviderStrategySupport.jsonMap(item))) + .toList(growable: false); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return AiCompletionResponse(text: text, toolCalls: toolCalls); + } + + @override + Future streamResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + }) async { + final response = await dio.post( + '/chat/completions', + data: { + 'model': config.model, + 'stream': true, + 'messages': messages.map(_openAiMessagePayload).toList(growable: false), + }, + options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final buffer = StringBuffer(); + await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { + if (data == '[DONE]') { + continue; + } + + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final choices = json['choices'] as List?; + if (choices == null || choices.isEmpty) { + continue; + } + + final first = choices.first; + if (first is! Map) { + continue; + } + + final delta = first['delta']; + if (delta is! Map) { + continue; + } + + final content = delta['content']; + if (content is String && content.isNotEmpty) { + buffer.write(content); + await onContent(content); + } + } + + final text = buffer.toString().trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); + } + return text; + } + + Map _openAiMessagePayload(AiPromptMessage message) => { + 'role': message.role, + 'content': message.content, + if (message.hasToolCalls) + 'tool_calls': message.toolCalls + .map( + (toolCall) => { + 'id': toolCall.id, + 'type': 'function', + 'function': { + 'name': toolCall.name, + 'arguments': jsonEncode(toolCall.arguments), + }, + }, + ) + .toList(growable: false), + if (message.isToolResult) 'tool_call_id': message.toolCallId, + }; + + AiToolCall _openAiToolCall(Map value) { + final function = AiProviderStrategySupport.jsonMap(value['function']); + final name = function['name'] as String?; + if (name == null || name.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + return AiToolCall( + id: value['id'] as String? ?? '${name}_${value.hashCode}', + name: name, + arguments: AiProviderStrategySupport.toolArguments(function['arguments']), + ); + } +} diff --git a/pubspec.lock b/pubspec.lock index 24db9422e3..443b221c70 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -1270,10 +1270,11 @@ packages: mixin_markdown_widget: dependency: "direct main" description: - path: "../flutter-plugins/packages/mixin_markdown_widget" - relative: true - source: path - version: "0.1.0" + name: mixin_markdown_widget + sha256: ea1fd34d1eeb837e6be54641458800739080e4ee5a9ecc100536c82f0b69242f + url: "https://pub.dev" + source: hosted + version: "0.2.0" msix: dependency: "direct dev" description: diff --git a/pubspec.yaml b/pubspec.yaml index 2d231a8cd3..8c1dd875dd 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -96,8 +96,7 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - mixin_markdown_widget: - path: ../flutter-plugins/packages/mixin_markdown_widget + mixin_markdown_widget: ^0.2.0 mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 From 2f803e620cf04716c547356c2e81889c63482050 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:43:43 +0800 Subject: [PATCH 20/52] feat: add customizable AI prompt templates with language support --- lib/ai/ai_chat_controller.dart | 68 +--- lib/ai/ai_chat_prompt_builder.dart | 214 ++++++---- lib/ai/model/ai_prompt_template.dart | 354 ++++++++++++++++ lib/ui/home/chat/input_container.dart | 30 +- lib/ui/setting/ai_prompt_settings_page.dart | 421 ++++++++++++++++++++ lib/ui/setting/ai_settings_page.dart | 44 ++ lib/utils/property/setting_property.dart | 40 ++ lib/widgets/message/message_ai_assist.dart | 24 +- test/ai/ai_prompt_template_test.dart | 49 +++ test/db/property_storage_test.dart | 23 ++ 10 files changed, 1096 insertions(+), 171 deletions(-) create mode 100644 lib/ai/model/ai_prompt_template.dart create mode 100644 lib/ui/setting/ai_prompt_settings_page.dart create mode 100644 test/ai/ai_prompt_template_test.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 43e0470855..be8a1d95f8 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -48,6 +48,7 @@ class AiChatController { Future assistText({ required String instruction, + required String language, String? input, String? conversationId, AiProviderConfig? provider, @@ -65,6 +66,7 @@ class AiChatController { final messages = await _promptBuilder.buildAssistPromptMessages( instruction: instruction, + language: language, input: input, conversationId: conversationId, ); @@ -98,6 +100,7 @@ class AiChatController { Future send({ required String conversationId, required String input, + required String language, AiProviderConfig? provider, void Function()? onInputAccepted, }) async { @@ -176,6 +179,7 @@ class AiChatController { final messages = await _promptBuilder.buildPromptMessages( conversationId, input, + language, ); final result = await _requestText( config, @@ -230,70 +234,6 @@ class AiChatController { _activeAiRequests[conversationId]?.cancel('AI generation stopped'); } - Future summarizeConversationRange({ - required String conversationId, - required DateTime startInclusive, - required DateTime endExclusive, - String? languageTag, - AiProviderConfig? provider, - }) async { - final config = provider ?? database.settingProperties.selectedAiProvider; - if (config == null) { - throw Exception('No AI provider configured'); - } - - final stats = await _conversationToolService.getConversationStats( - conversationId: conversationId, - startInclusive: startInclusive, - endExclusive: endExclusive, - ); - if (stats.messageCount <= 0) { - return 'No messages found in the selected time range.'; - } - - final messages = _promptBuilder.buildConversationSummaryPromptMessages( - stats: stats, - languageTag: languageTag, - ); - final cancelToken = CancelToken(); - _activeAiRequests[conversationId] = cancelToken; - try { - return await _requestText( - config, - messages, - cancelToken: cancelToken, - onContent: (_) async {}, - conversationId: conversationId, - streamFinalResponse: false, - ); - } finally { - if (_activeAiRequests[conversationId] == cancelToken) { - _activeAiRequests.remove(conversationId); - } - } - } - - Future summarizeConversationToday({ - required String conversationId, - String? languageTag, - AiProviderConfig? provider, - DateTime? now, - }) { - final localNow = now ?? DateTime.now(); - final startInclusive = DateTime( - localNow.year, - localNow.month, - localNow.day, - ); - return summarizeConversationRange( - conversationId: conversationId, - startInclusive: startInclusive, - endExclusive: startInclusive.add(const Duration(days: 1)), - languageTag: languageTag, - provider: provider, - ); - } - Future _requestText( AiProviderConfig config, List messages, { diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index 54e25d6c58..2c94500c11 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -6,7 +6,7 @@ import '../db/dao/message_dao.dart'; import '../db/database.dart'; import '../db/mixin_database.dart'; import 'model/ai_prompt_message.dart'; -import 'tools/ai_conversation_tool_service.dart'; +import 'model/ai_prompt_template.dart'; class AiChatPromptBuilder { AiChatPromptBuilder(this.database); @@ -24,7 +24,9 @@ class AiChatPromptBuilder { Future> buildPromptMessages( String conversationId, String input, + String language, ) async { + final now = DateTime.now(); final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _aiContextMessageLimit) .get(); @@ -38,21 +40,36 @@ class AiChatPromptBuilder { ); final promptMessages = [ - AiPromptMessage( + ..._promptMessages( role: 'system', - content: - 'You are a local AI assistant inside a chat application. ' - 'Only use the provided current conversation context. ' - 'Help summarize, answer questions about the conversation, ' - 'and draft replies. Be concise and practical.', + content: renderAiPromptTemplate( + database.settingProperties.aiPromptTemplate( + AiPromptTemplateKey.chatSystem, + ), + buildAiPromptTemplateVariables( + conversationId: conversationId, + input: input, + language: language, + now: now, + ), + ), ), ]; - _appendConversationToolInstruction(promptMessages, enabled: true); + _appendConversationToolInstruction( + promptMessages, + enabled: true, + conversationId: conversationId, + language: language, + now: now, + ); _appendConversationContext( promptMessages, + conversationId: conversationId, recentMessages: recentMessages, retrievedMessages: retrievedMessages, + language: language, + now: now, ); final history = aiMessages @@ -64,7 +81,20 @@ class AiChatPromptBuilder { ); } - promptMessages.add(AiPromptMessage(role: _aiRoleUser, content: input)); + promptMessages.addAll( + _promptMessages( + role: _aiRoleUser, + content: renderAiPromptTemplate( + chatUserMessagePromptTemplate, + buildAiPromptTemplateVariables( + conversationId: conversationId, + input: input, + language: language, + now: now, + ), + ), + ), + ); d( 'AI prompt built: conversationId=$conversationId ' 'recent=${recentMessages.length} retrieved=${retrievedMessages.length} ' @@ -77,20 +107,38 @@ class AiChatPromptBuilder { required String instruction, required String? input, required String? conversationId, + required String language, }) async { + final now = DateTime.now(); + final inputText = input?.trim(); + final trimmedInstruction = instruction.trim(); final promptMessages = [ - AiPromptMessage( + ..._promptMessages( role: 'system', - content: - 'You are an invisible writing assistant inside a chat app. ' - 'Return only the requested text. Do not add explanations, ' - 'labels, markdown fences, or greetings unless explicitly ' - 'requested.', + content: renderAiPromptTemplate( + database.settingProperties.aiPromptTemplate( + AiPromptTemplateKey.assistSystem, + ), + buildAiPromptTemplateVariables( + conversationId: conversationId, + input: inputText, + instruction: trimmedInstruction, + inputSection: buildAiPromptInputSection(inputText), + language: language, + now: now, + ), + ), ), ]; if (conversationId != null) { - _appendConversationToolInstruction(promptMessages, enabled: true); + _appendConversationToolInstruction( + promptMessages, + enabled: true, + conversationId: conversationId, + language: language, + now: now, + ); final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _aiContextMessageLimit) .get(); @@ -101,19 +149,28 @@ class AiChatPromptBuilder { ); _appendConversationContext( promptMessages, + conversationId: conversationId, recentMessages: recentMessages, retrievedMessages: retrievedMessages, + language: language, + now: now, ); } - final inputText = input?.trim(); - promptMessages.add( - AiPromptMessage( + promptMessages.addAll( + _promptMessages( role: _aiRoleUser, - content: [ - instruction.trim(), - if (inputText != null && inputText.isNotEmpty) '\nText:\n$inputText', - ].join('\n'), + content: renderAiPromptTemplate( + assistUserMessagePromptTemplate, + buildAiPromptTemplateVariables( + conversationId: conversationId, + input: inputText, + instruction: trimmedInstruction, + inputSection: buildAiPromptInputSection(inputText), + language: language, + now: now, + ), + ), ), ); d( @@ -123,77 +180,38 @@ class AiChatPromptBuilder { return promptMessages; } - List buildConversationSummaryPromptMessages({ - required AiConversationToolStats stats, - required String? languageTag, - }) { - final outputLanguage = languageTag?.trim(); - return [ - AiPromptMessage( - role: 'system', - content: - 'You are a conversation summarizer inside a chat application. ' - 'For this task, you must use the available read-only ' - 'conversation tools to inspect the requested time range ' - 'before writing the final answer. Start by calling ' - 'list_conversation_chunks for the exact range, then ' - 'read_conversation_chunk until you have covered the full ' - 'range. Do not rely only on recent context or search for ' - 'this task.', - ), - AiPromptMessage( - role: 'system', - content: - 'Summaries must cover the requested range completely and ' - 'should include main topics, key decisions, action items, ' - 'unresolved questions, and notable follow-ups. Keep the ' - 'final answer concise but comprehensive.', - ), - AiPromptMessage( - role: 'user', - content: [ - 'Summarize the conversation messages in this time range.', - 'Conversation ID: ${stats.conversationId}', - 'Start time: ${stats.startInclusive?.toIso8601String() ?? 'unspecified'}', - 'End time: ${stats.endExclusive?.toIso8601String() ?? 'unspecified'}', - 'Messages in range: ${stats.messageCount}', - if (stats.firstMessageAt != null) - 'First message at: ${stats.firstMessageAt!.toIso8601String()}', - if (stats.lastMessageAt != null) - 'Last message at: ${stats.lastMessageAt!.toIso8601String()}', - if (outputLanguage != null && outputLanguage.isNotEmpty) - 'Write the final summary in $outputLanguage.', - 'Before finalizing, make sure you have covered every chunk in the range.', - 'Return only the summary text.', - ].join('\n'), - ), - ]; - } - void _appendConversationToolInstruction( List promptMessages, { required bool enabled, + required String? conversationId, + required String language, + required DateTime now, }) { if (!enabled) { return; } - promptMessages.add( - AiPromptMessage( + promptMessages.addAll( + _promptMessages( role: 'system', - content: - 'Read-only conversation tools are available for the current ' - 'conversation. Use them when you need exhaustive coverage, ' - 'date-scoped summaries, statistics, older messages, or more ' - 'context than the provided messages. Do not call tools when ' - 'the provided context is already sufficient.', + content: renderAiPromptTemplate( + conversationToolInstructionPromptTemplate, + buildAiPromptTemplateVariables( + conversationId: conversationId, + language: language, + now: now, + ), + ), ), ); } void _appendConversationContext( List promptMessages, { + required String conversationId, required List recentMessages, required List retrievedMessages, + required String language, + required DateTime now, }) { if (recentMessages.isNotEmpty) { final lines = recentMessages.reversed @@ -205,10 +223,18 @@ class AiChatPromptBuilder { ), ) .join('\n'); - promptMessages.add( - AiPromptMessage( + promptMessages.addAll( + _promptMessages( role: 'system', - content: 'Current conversation recent messages:\n$lines', + content: renderAiPromptTemplate( + recentConversationContextPromptTemplate, + buildAiPromptTemplateVariables( + conversationId: conversationId, + messages: lines, + language: language, + now: now, + ), + ), ), ); } @@ -226,12 +252,18 @@ class AiChatPromptBuilder { ), ) .join('\n'); - promptMessages.add( - AiPromptMessage( + promptMessages.addAll( + _promptMessages( role: 'system', - content: - 'Relevant older conversation messages matched by search ' - '(use only if they help answer the current request):\n$lines', + content: renderAiPromptTemplate( + retrievedConversationContextPromptTemplate, + buildAiPromptTemplateVariables( + conversationId: conversationId, + messages: lines, + language: language, + now: now, + ), + ), ), ); } @@ -342,6 +374,16 @@ class AiChatPromptBuilder { } return '[$type]'; } + + List _promptMessages({ + required String role, + required String content, + }) { + if (content.trim().isEmpty) { + return const []; + } + return [AiPromptMessage(role: role, content: content)]; + } } String _previewText( diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart new file mode 100644 index 0000000000..b63f09ee48 --- /dev/null +++ b/lib/ai/model/ai_prompt_template.dart @@ -0,0 +1,354 @@ +enum AiPromptTemplateGroup { conversation, messageAssist, draftAssist } + +extension AiPromptTemplateGroupExtension on AiPromptTemplateGroup { + String get title => switch (this) { + AiPromptTemplateGroup.conversation => 'Conversation', + AiPromptTemplateGroup.messageAssist => 'Message Assist', + AiPromptTemplateGroup.draftAssist => 'Draft Assist', + }; +} + +enum AiPromptVariable { + conversationId( + 'conversationId', + 'Conversation ID', + 'Current conversation ID.', + ), + currentIsoDateTime( + 'currentIsoDateTime', + 'Current ISO Date Time', + 'Current date and time in ISO 8601 format.', + ), + input( + 'input', + 'Input', + 'Current user input text.', + ), + instruction( + 'instruction', + 'Instruction', + 'Resolved assist instruction text.', + ), + inputSection( + 'inputSection', + 'Input Section', + 'Prebuilt input block, usually empty or starts with a new line.', + ), + language( + 'language', + 'Language', + 'Current locale language tag, for example en-US.', + ), + messages( + 'messages', + 'Messages', + 'Conversation message lines assembled by the app.', + ) + ; + + const AiPromptVariable(this.placeholder, this.title, this.description); + + final String placeholder; + final String title; + final String description; + + String get token => '{{$placeholder}}'; +} + +enum AiPromptTemplateKey { + chatSystem, + assistSystem, + messageTranslate, + messageExplain, + messageSuggestReplies, + draftPolish, + draftShorten, + draftPolite, + draftTranslate, + draftReplyWithContext, +} + +class AiPromptTemplateDefinition { + const AiPromptTemplateDefinition({ + required this.key, + required this.group, + required this.title, + required this.description, + required this.defaultValue, + required this.variables, + }); + + final AiPromptTemplateKey key; + final AiPromptTemplateGroup group; + final String title; + final String description; + final String defaultValue; + final List variables; +} + +const aiPromptTemplateDefinitions = [ + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.chatSystem, + group: AiPromptTemplateGroup.conversation, + title: 'Chat System Prompt', + description: 'Primary system prompt for AI chat mode.', + defaultValue: + 'You are a local AI assistant inside a chat application. ' + 'Only use the provided current conversation context. ' + 'The current time is {{currentIsoDateTime}}. ' + 'Unless the user explicitly asks to preserve the source language, ' + 'quote verbatim, translate, or use another language, respond in ' + '{{language}}. Help summarize, answer questions about the ' + 'conversation, and draft practical replies. Be concise.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.assistSystem, + group: AiPromptTemplateGroup.conversation, + title: 'Assist System Prompt', + description: 'System prompt for invisible writing assist features.', + defaultValue: + 'You are an invisible writing assistant inside a chat app. ' + 'Return only the requested text. Unless the instruction explicitly ' + 'asks to keep the original language, quote verbatim, translate, or ' + 'use another language, return the result in {{language}}. Do not ' + 'add explanations, labels, markdown fences, or greetings unless ' + 'explicitly requested.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.instruction, + AiPromptVariable.inputSection, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.messageTranslate, + group: AiPromptTemplateGroup.messageAssist, + title: 'Message Translate Prompt', + description: 'Instruction for translating one chat message.', + defaultValue: + 'Translate this chat message into {{language}}. ' + 'Return only the translation.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.messageExplain, + group: AiPromptTemplateGroup.messageAssist, + title: 'Message Explain Prompt', + description: 'Instruction for explaining one chat message.', + defaultValue: + 'Explain this chat message clearly and concisely in {{language}}. ' + 'Clarify slang, abbreviations, technical terms, and implied meaning ' + 'when useful. Return only the explanation.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.messageSuggestReplies, + group: AiPromptTemplateGroup.messageAssist, + title: 'Message Suggest Replies Prompt', + description: 'Instruction for generating suggested replies.', + defaultValue: + 'Suggest three concise, natural replies in {{language}} to this ' + 'chat message using the recent conversation context. ' + 'Return one reply per line, without numbering.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.draftPolish, + group: AiPromptTemplateGroup.draftAssist, + title: 'Draft Polish Prompt', + description: 'Instruction for polishing the current draft.', + defaultValue: + 'Polish this draft for a chat message. The preferred output ' + 'language is {{language}}, but for this task keep the original ' + 'language of the input. Keep the original meaning and approximate ' + 'length. Return only the rewritten draft.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.draftShorten, + group: AiPromptTemplateGroup.draftAssist, + title: 'Draft Shorten Prompt', + description: 'Instruction for making the draft shorter.', + defaultValue: + 'Rewrite this chat draft to be shorter and clearer. The preferred ' + 'output language is {{language}}, but for this task keep the ' + 'original language of the input. Keep the original intent. Return ' + 'only the rewritten draft.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.draftPolite, + group: AiPromptTemplateGroup.draftAssist, + title: 'Draft Polite Prompt', + description: 'Instruction for making the draft more polite.', + defaultValue: + 'Rewrite this chat draft to sound polite, natural, and still ' + 'concise. The preferred output language is {{language}}, but for ' + 'this task keep the original language of the input. Return only ' + 'the rewritten draft.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.draftTranslate, + group: AiPromptTemplateGroup.draftAssist, + title: 'Draft Translate Prompt', + description: 'Instruction for translating the current draft.', + defaultValue: + 'Translate this chat draft into {{language}}. ' + 'Return only the translation.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.draftReplyWithContext, + group: AiPromptTemplateGroup.draftAssist, + title: 'Draft Reply With Context Prompt', + description: 'Instruction for replying from recent context.', + defaultValue: + 'Draft a concise, natural reply in {{language}} to the latest ' + 'conversation message using the recent context, unless the user ' + 'explicitly requires another language or preserving the source ' + 'language. Return only the reply text.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.input, + AiPromptVariable.language, + ], + ), +]; + +final _aiPromptTemplateDefinitionMap = { + for (final definition in aiPromptTemplateDefinitions) + definition.key: definition, +}; + +extension AiPromptTemplateKeyExtension on AiPromptTemplateKey { + AiPromptTemplateDefinition get definition => + _aiPromptTemplateDefinitionMap[this]!; + + String get storageKey => name; +} + +const chatUserMessagePromptTemplate = 'User request:\n{{input}}'; + +const assistUserMessagePromptTemplate = + 'Preferred output language: {{language}}\n' + 'Instruction:\n{{instruction}}{{inputSection}}'; + +const conversationToolInstructionPromptTemplate = + 'Read-only conversation tools are available for the current ' + 'conversation. Use them when you need exhaustive coverage, ' + 'date-scoped summaries, statistics, older messages, or more ' + 'context than the provided messages. When answering the user, ' + 'default to {{language}} unless the user explicitly requires ' + 'another language or preserving the source language. Do not call ' + 'tools when the provided context is already sufficient.'; + +const recentConversationContextPromptTemplate = + 'Current conversation recent messages:\n{{messages}}'; + +const retrievedConversationContextPromptTemplate = + 'Relevant older conversation messages matched by search ' + '(use only if they help answer the current request):\n{{messages}}'; + +Map buildAiPromptTemplateVariables({ + String? conversationId, + String? input, + String? instruction, + String? inputSection, + String? language, + String? messages, + DateTime? now, +}) { + final resolvedNow = now ?? DateTime.now(); + return { + AiPromptVariable.conversationId.placeholder: conversationId, + AiPromptVariable.currentIsoDateTime.placeholder: resolvedNow + .toIso8601String(), + AiPromptVariable.input.placeholder: input, + AiPromptVariable.instruction.placeholder: instruction, + AiPromptVariable.inputSection.placeholder: inputSection, + AiPromptVariable.language.placeholder: language, + AiPromptVariable.messages.placeholder: messages, + 'currentDate': _formatDate(resolvedNow), + 'currentTime': _formatTime(resolvedNow), + 'currentDateTime': _formatDateTime(resolvedNow), + }; +} + +String buildAiPromptInputSection(String? input) { + final compact = input?.trim(); + if (compact == null || compact.isEmpty) { + return ''; + } + return '\nText:\n$compact'; +} + +String renderAiPromptTemplate( + String template, + Map variables, +) => template.replaceAllMapped(_promptVariablePattern, (match) { + final key = match.group(1); + if (key == null || !variables.containsKey(key)) { + return match.group(0) ?? ''; + } + return variables[key] ?? ''; +}); + +final _promptVariablePattern = RegExp(r'\{\{\s*([a-zA-Z0-9_]+)\s*\}\}'); + +String _formatDateTime(DateTime value) => + '${_formatDate(value)} ${_formatTime(value)}'; + +String _formatDate(DateTime value) => + '${value.year.toString().padLeft(4, '0')}-' + '${value.month.toString().padLeft(2, '0')}-' + '${value.day.toString().padLeft(2, '0')}'; + +String _formatTime(DateTime value) => + '${value.hour.toString().padLeft(2, '0')}:' + '${value.minute.toString().padLeft(2, '0')}:' + '${value.second.toString().padLeft(2, '0')}'; diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 41d616cc5a..17a3606006 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -20,6 +20,7 @@ import 'package:simple_animations/simple_animations.dart'; import 'package:super_context_menu/super_context_menu.dart'; import '../../../ai/ai_chat_controller.dart'; +import '../../../ai/model/ai_prompt_template.dart'; import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; import '../../../constants/icon_fonts.dart'; @@ -590,18 +591,21 @@ Future _requestAiDraftAction( } final language = _currentLanguageTag(context); - final instruction = switch (action) { - AiDraftAction.polish => - 'Polish this draft for a chat message. Keep the original meaning, language, and approximate length.', - AiDraftAction.shorten => - 'Rewrite this chat draft to be shorter and clearer. Keep the original language and intent.', - AiDraftAction.polite => - 'Rewrite this chat draft to sound polite, natural, and still concise. Keep the original language.', - AiDraftAction.translate => - 'Translate this chat draft into $language. Return only the translation.', - AiDraftAction.replyWithContext => - 'Draft a concise, natural reply to the latest conversation message using the recent context. Return only the reply text.', + final templateKey = switch (action) { + AiDraftAction.polish => AiPromptTemplateKey.draftPolish, + AiDraftAction.shorten => AiPromptTemplateKey.draftShorten, + AiDraftAction.polite => AiPromptTemplateKey.draftPolite, + AiDraftAction.translate => AiPromptTemplateKey.draftTranslate, + AiDraftAction.replyWithContext => AiPromptTemplateKey.draftReplyWithContext, }; + final instruction = renderAiPromptTemplate( + context.database.settingProperties.aiPromptTemplate(templateKey), + buildAiPromptTemplateVariables( + conversationId: conversationId, + input: original, + language: language, + ), + ); final title = switch (action) { AiDraftAction.polish => 'Polish', AiDraftAction.shorten => 'Make shorter', @@ -614,6 +618,7 @@ Future _requestAiDraftAction( final controller = AiChatController(context.database); final result = await controller.assistText( instruction: instruction, + language: language, input: action == AiDraftAction.replyWithContext ? null : original, conversationId: conversationId, ); @@ -682,6 +687,7 @@ Future _sendMessage( final aiModeState = context.providerContainer.read( aiInputModeProvider(conversationId), ); + final language = _currentLanguageTag(context); if (text == '/ai') { final provider = context.database.settingProperties.selectedAiProvider; @@ -708,6 +714,7 @@ Future _sendMessage( await AiChatController(context.database).send( conversationId: conversationId, input: inlineAiInput, + language: language, provider: provider, onInputAccepted: () => textEditingController.text = '', ); @@ -735,6 +742,7 @@ Future _sendMessage( await AiChatController(context.database).send( conversationId: conversationId, input: text, + language: language, provider: provider, onInputAccepted: () => textEditingController.text = '', ); diff --git a/lib/ui/setting/ai_prompt_settings_page.dart b/lib/ui/setting/ai_prompt_settings_page.dart new file mode 100644 index 0000000000..45d9603f6f --- /dev/null +++ b/lib/ui/setting/ai_prompt_settings_page.dart @@ -0,0 +1,421 @@ +import 'package:flutter/material.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../ai/model/ai_prompt_template.dart'; +import '../../utils/extension/extension.dart'; +import '../../widgets/app_bar.dart'; +import '../../widgets/cell.dart'; +import '../../widgets/toast.dart'; +import '../provider/database_provider.dart'; + +class AiPromptSettingsPage extends HookConsumerWidget { + const AiPromptSettingsPage({super.key}); + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + useListenable(database.settingProperties); + final customizedCount = aiPromptTemplateDefinitions + .where( + (definition) => + database.settingProperties.hasAiPromptTemplateOverride( + definition.key, + ), + ) + .length; + + return Scaffold( + backgroundColor: context.theme.background, + appBar: const MixinAppBar(title: Text('AI Prompt Templates')), + body: Align( + alignment: Alignment.topCenter, + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.all(16), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + customizedCount == 0 + ? 'All prompts are using built-in defaults.' + : '$customizedCount prompt templates currently use custom overrides.', + style: TextStyle( + color: context.theme.text, + fontSize: 15, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 8), + Text( + 'Templates support placeholders like {{conversationId}}, {{currentIsoDateTime}}, {{language}}, and {{input}}. Each editor shows the variables available for that prompt.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + height: 1.4, + ), + ), + const SizedBox(height: 8), + Text( + 'Leave a template empty to disable that prompt block. Saving the exact default text removes the custom override.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + height: 1.4, + ), + ), + ], + ), + ), + ), + for (final group in AiPromptTemplateGroup.values) ...[ + _SectionLabel(title: group.title), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + for ( + var i = 0; + i < + aiPromptTemplateDefinitions + .where((item) => item.group == group) + .length; + i++ + ) ...[ + _PromptTemplateCell( + definition: aiPromptTemplateDefinitions + .where((item) => item.group == group) + .elementAt(i), + ), + if (i != + aiPromptTemplateDefinitions + .where((item) => item.group == group) + .length - + 1) + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + ], + ], + ), + ), + ], + ], + ), + ), + ), + ), + ); + } +} + +class _PromptTemplateCell extends HookConsumerWidget { + const _PromptTemplateCell({required this.definition}); + + final AiPromptTemplateDefinition definition; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + final currentValue = database.settingProperties.aiPromptTemplate( + definition.key, + ); + final isCustomized = database.settingProperties.hasAiPromptTemplateOverride( + definition.key, + ); + + return CellItem( + onTap: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => _AiPromptTemplateEditPage(definition: definition), + ), + ), + title: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text(definition.title), + const SizedBox(height: 4), + Text( + definition.description, + maxLines: 2, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + height: 1.3, + ), + ), + ], + ), + description: SizedBox( + width: 120, + child: Text( + _statusText(currentValue, isCustomized), + textAlign: TextAlign.end, + maxLines: 2, + overflow: TextOverflow.ellipsis, + ), + ), + ); + } + + String _statusText(String value, bool customized) { + final compact = value.replaceAll(RegExp(r'\s+'), ' ').trim(); + final preview = compact.isEmpty ? 'Empty' : compact; + final prefix = customized ? 'Custom' : 'Default'; + return '$prefix · $preview'; + } +} + +class _AiPromptTemplateEditPage extends HookConsumerWidget { + const _AiPromptTemplateEditPage({required this.definition}); + + final AiPromptTemplateDefinition definition; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + useListenable(database.settingProperties); + final initialText = database.settingProperties.aiPromptTemplate( + definition.key, + ); + final controller = useTextEditingController(text: initialText); + final theme = context.theme; + final inputBackgroundColor = context.dynamicColor( + Colors.white, + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + final inputBorderColor = context.dynamicColor( + theme.divider, + darkColor: const Color.fromRGBO(255, 255, 255, 0.10), + ); + + void save() { + final value = controller.text; + if (value == definition.defaultValue) { + database.settingProperties.resetAiPromptTemplate(definition.key); + } else { + database.settingProperties.saveAiPromptTemplate(definition.key, value); + } + showToastSuccessful(); + Navigator.of(context).pop(); + } + + return Scaffold( + backgroundColor: theme.background, + appBar: MixinAppBar( + title: Text(definition.title), + actions: [ + TextButton( + onPressed: () => controller.text = definition.defaultValue, + child: Text( + 'Use Default', + style: TextStyle(color: theme.accent, fontSize: 16), + ), + ), + TextButton( + onPressed: save, + child: Text( + 'Save', + style: TextStyle(color: theme.accent, fontSize: 16), + ), + ), + ], + ), + body: Align( + alignment: Alignment.topCenter, + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const _SectionLabel(title: 'Description'), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.all(16), + child: Text( + definition.description, + style: TextStyle( + color: theme.text, + fontSize: 14, + height: 1.45, + ), + ), + ), + ), + const _SectionLabel(title: 'Variables'), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.all(16), + child: _PromptVariableChipWrap( + variables: definition.variables, + onTap: (variable) => + _insertToken(controller, variable.token), + ), + ), + ), + Padding( + padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), + child: Text( + 'Hover to preview the description. Click a chip to insert it at the current cursor position.', + style: TextStyle( + color: theme.secondaryText, + fontSize: 14, + ), + ), + ), + const _SectionLabel(title: 'Template'), + ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 10), + child: Container( + decoration: BoxDecoration( + color: inputBackgroundColor, + borderRadius: BorderRadius.circular(8), + border: Border.all(color: inputBorderColor), + ), + padding: const EdgeInsets.symmetric( + horizontal: 14, + vertical: 12, + ), + child: TextField( + controller: controller, + minLines: 10, + maxLines: null, + style: TextStyle( + color: theme.text, + fontSize: 15, + height: 1.45, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: definition.defaultValue, + hintStyle: TextStyle(color: theme.secondaryText), + ), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.only(left: 20, right: 20, top: 12), + child: Text( + 'Empty text disables this prompt block. Saving the exact default text removes the override and falls back to the built-in template.', + style: TextStyle( + color: theme.secondaryText, + fontSize: 13, + height: 1.4, + ), + ), + ), + ], + ), + ), + ), + ), + ); + } + + void _insertToken(TextEditingController controller, String token) { + final value = controller.value; + final selection = value.selection; + final hasSelection = selection.isValid; + final start = hasSelection ? selection.start : value.text.length; + final end = hasSelection ? selection.end : value.text.length; + final safeStart = start < 0 ? value.text.length : start; + final safeEnd = end < 0 ? value.text.length : end; + final nextText = value.text.replaceRange(safeStart, safeEnd, token); + controller.value = TextEditingValue( + text: nextText, + selection: TextSelection.collapsed(offset: safeStart + token.length), + ); + } +} + +class _PromptVariableChipWrap extends StatelessWidget { + const _PromptVariableChipWrap({required this.variables, required this.onTap}); + + final List variables; + final ValueChanged onTap; + + @override + Widget build(BuildContext context) { + final fillColor = context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.04), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + final outlineColor = context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.08), + darkColor: const Color.fromRGBO(255, 255, 255, 0.12), + ); + + return Wrap( + spacing: 8, + runSpacing: 8, + children: [ + for (final variable in variables) + Tooltip( + message: variable.description, + waitDuration: const Duration(milliseconds: 250), + child: ActionChip( + onPressed: () => onTap(variable), + label: Text( + variable.token, + style: TextStyle( + color: context.theme.text, + fontSize: 13, + fontWeight: FontWeight.w500, + ), + ), + labelPadding: const EdgeInsets.symmetric(horizontal: 2), + padding: const EdgeInsets.symmetric(horizontal: 8, vertical: 8), + side: BorderSide(color: outlineColor), + backgroundColor: fillColor, + surfaceTintColor: Colors.transparent, + shape: RoundedRectangleBorder( + borderRadius: BorderRadius.circular(999), + ), + ), + ), + ], + ); + } +} + +class _SectionLabel extends StatelessWidget { + const _SectionLabel({required this.title}); + + final String title; + + @override + Widget build(BuildContext context) => Padding( + padding: const EdgeInsets.only(left: 20, bottom: 10, top: 12), + child: Text( + title, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + fontWeight: FontWeight.w600, + ), + ), + ); +} diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index 0d4bf59e76..667effb391 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -3,12 +3,14 @@ import 'package:flutter/material.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; +import '../../ai/model/ai_prompt_template.dart'; import '../../ai/model/ai_provider_config.dart'; import '../../utils/extension/extension.dart'; import '../../widgets/app_bar.dart'; import '../../widgets/cell.dart'; import '../../widgets/toast.dart'; import '../provider/database_provider.dart'; +import 'ai_prompt_settings_page.dart'; import 'ai_provider_edit_page.dart'; class AiSettingsPage extends HookConsumerWidget { @@ -21,6 +23,14 @@ class AiSettingsPage extends HookConsumerWidget { final providers = database.settingProperties.aiProviders; final selectedId = database.settingProperties.selectedAiProviderId; final selectedProvider = database.settingProperties.selectedAiProvider; + final customizedPromptCount = aiPromptTemplateDefinitions + .where( + (definition) => + database.settingProperties.hasAiPromptTemplateOverride( + definition.key, + ), + ) + .length; return Scaffold( backgroundColor: context.theme.background, @@ -33,6 +43,40 @@ class AiSettingsPage extends HookConsumerWidget { child: Column( crossAxisAlignment: CrossAxisAlignment.start, children: [ + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: context.theme.settingCellBackgroundColor, + child: CellItem( + title: const Text('Prompt Templates'), + leading: Icon( + Icons.tune_rounded, + color: context.theme.icon, + ), + description: Text( + customizedPromptCount == 0 + ? 'Default' + : '$customizedPromptCount custom', + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + trailing: null, + onTap: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => const AiPromptSettingsPage(), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), + child: Text( + 'Customize chat prompts, assist prompts, and built-in variables like {{conversationId}}, {{currentIsoDateTime}}, and {{language}}.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), CellGroup( padding: const EdgeInsets.only(right: 10, left: 10), cellBackgroundColor: context.theme.settingCellBackgroundColor, diff --git a/lib/utils/property/setting_property.dart b/lib/utils/property/setting_property.dart index fa746a2e50..d752c8b427 100644 --- a/lib/utils/property/setting_property.dart +++ b/lib/utils/property/setting_property.dart @@ -2,6 +2,7 @@ import 'dart:convert'; import 'package:mixin_logger/mixin_logger.dart'; +import '../../ai/model/ai_prompt_template.dart'; import '../../ai/model/ai_provider_config.dart'; import '../../db/dao/property_dao.dart'; import '../../db/util/property_storage.dart'; @@ -14,6 +15,7 @@ const _kSelectedProxyKey = 'selected_proxy'; const _kProxyListKey = 'proxy_list'; const _kAiProviderListKey = 'ai_provider_list'; const _kSelectedAiProviderKey = 'selected_ai_provider'; +const _kAiPromptTemplateOverridesKey = 'ai_prompt_template_overrides'; class SettingPropertyStorage extends PropertyStorage { SettingPropertyStorage(PropertyDao dao) : super(PropertyGroup.setting, dao); @@ -127,4 +129,42 @@ class SettingPropertyStorage extends PropertyStorage { selectedAiProviderId = providers.firstOrNull?.id; } } + + Map get _aiPromptTemplateOverrides { + final json = get(_kAiPromptTemplateOverridesKey); + if (json == null || json.isEmpty) { + return {}; + } + try { + final map = jsonDecode(json) as Map; + return map.map( + (key, value) => MapEntry(key.toString(), value?.toString() ?? ''), + ); + } catch (error, stacktrace) { + e('load aiPromptTemplateOverrides error: $error, $stacktrace'); + return {}; + } + } + + String aiPromptTemplate(AiPromptTemplateKey key) { + final overrides = _aiPromptTemplateOverrides; + if (overrides.containsKey(key.storageKey)) { + return overrides[key.storageKey] ?? ''; + } + return key.definition.defaultValue; + } + + bool hasAiPromptTemplateOverride(AiPromptTemplateKey key) => + _aiPromptTemplateOverrides.containsKey(key.storageKey); + + void saveAiPromptTemplate(AiPromptTemplateKey key, String value) { + final overrides = _aiPromptTemplateOverrides; + overrides[key.storageKey] = value; + set(_kAiPromptTemplateOverridesKey, jsonEncode(overrides)); + } + + void resetAiPromptTemplate(AiPromptTemplateKey key) { + final overrides = _aiPromptTemplateOverrides..remove(key.storageKey); + set(_kAiPromptTemplateOverridesKey, jsonEncode(overrides)); + } } diff --git a/lib/widgets/message/message_ai_assist.dart b/lib/widgets/message/message_ai_assist.dart index e24dbba024..38b6a8a866 100644 --- a/lib/widgets/message/message_ai_assist.dart +++ b/lib/widgets/message/message_ai_assist.dart @@ -2,6 +2,7 @@ import 'package:equatable/equatable.dart'; import 'package:flutter/material.dart'; import '../../ai/ai_chat_controller.dart'; +import '../../ai/model/ai_prompt_template.dart'; import '../../db/mixin_database.dart'; import '../../ui/provider/recall_message_reedit_provider.dart'; import '../../utils/extension/extension.dart'; @@ -104,17 +105,19 @@ Future runMessageAiAction( final language = _currentLanguageTag(context); final provider = context.database.settingProperties.selectedAiProvider; final model = provider?.model; - final instruction = switch (action) { - MessageAiAction.translate => - 'Translate this chat message into $language. Return only the translation.', - MessageAiAction.explain => - 'Explain this chat message clearly and concisely in $language. ' - 'Clarify slang, abbreviations, technical terms, and implied meaning when useful. ' - 'Return only the explanation.', - MessageAiAction.suggestReplies => - 'Suggest three concise, natural replies in $language to this chat message ' - 'using the recent conversation context. Return one reply per line, without numbering.', + final templateKey = switch (action) { + MessageAiAction.translate => AiPromptTemplateKey.messageTranslate, + MessageAiAction.explain => AiPromptTemplateKey.messageExplain, + MessageAiAction.suggestReplies => AiPromptTemplateKey.messageSuggestReplies, }; + final instruction = renderAiPromptTemplate( + context.database.settingProperties.aiPromptTemplate(templateKey), + buildAiPromptTemplateVariables( + conversationId: message.conversationId, + input: input, + language: language, + ), + ); final title = switch (action) { MessageAiAction.translate => 'Translate', MessageAiAction.explain => 'Explain', @@ -128,6 +131,7 @@ Future runMessageAiAction( try { final result = await AiChatController(context.database).assistText( instruction: instruction, + language: language, input: input, conversationId: message.conversationId, provider: provider, diff --git a/test/ai/ai_prompt_template_test.dart b/test/ai/ai_prompt_template_test.dart new file mode 100644 index 0000000000..a07b00280c --- /dev/null +++ b/test/ai/ai_prompt_template_test.dart @@ -0,0 +1,49 @@ +import 'package:flutter_app/ai/model/ai_prompt_template.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('AI prompt template', () { + test('renders known variables', () { + final result = renderAiPromptTemplate( + 'Conversation {{conversationId}} at {{currentIsoDateTime}} in {{language}} -> {{input}}', + buildAiPromptTemplateVariables( + conversationId: 'conversation-1', + input: 'hello', + language: 'zh-CN', + now: DateTime(2026, 4, 28, 9, 30, 15), + ), + ); + + expect( + result, + 'Conversation conversation-1 at 2026-04-28T09:30:15.000 in zh-CN -> hello', + ); + }); + + test('renders legacy date aliases for backwards compatibility', () { + final result = renderAiPromptTemplate( + '{{currentDate}} {{currentTime}} {{currentDateTime}}', + buildAiPromptTemplateVariables( + now: DateTime(2026, 4, 28, 9, 30, 15), + ), + ); + + expect(result, '2026-04-28 09:30:15 2026-04-28 09:30:15'); + }); + + test('keeps unknown variables unchanged', () { + final result = renderAiPromptTemplate( + 'Known={{input}} Unknown={{customValue}}', + buildAiPromptTemplateVariables(input: 'hello'), + ); + + expect(result, 'Known=hello Unknown={{customValue}}'); + }); + + test('builds input section only when input exists', () { + expect(buildAiPromptInputSection(' hello '), '\nText:\nhello'); + expect(buildAiPromptInputSection(' '), isEmpty); + expect(buildAiPromptInputSection(null), isEmpty); + }); + }); +} diff --git a/test/db/property_storage_test.dart b/test/db/property_storage_test.dart index a82b118fd4..d6d1dc5310 100644 --- a/test/db/property_storage_test.dart +++ b/test/db/property_storage_test.dart @@ -2,9 +2,11 @@ library; import 'package:drift/native.dart'; +import 'package:flutter_app/ai/model/ai_prompt_template.dart'; import 'package:flutter_app/db/mixin_database.dart'; import 'package:flutter_app/db/util/property_storage.dart'; import 'package:flutter_app/enum/property_group.dart'; +import 'package:flutter_app/utils/property/setting_property.dart'; import 'package:flutter_test/flutter_test.dart'; void main() { @@ -55,4 +57,25 @@ void main() { expect(storage.getList('test_list_string'), ['1', '2', '3']); expect(storage.getList('test_list_string'), ['1', '2', '3']); }); + + test('AI prompt template settings support override and reset', () async { + final database = MixinDatabase(NativeDatabase.memory()); + final storage = SettingPropertyStorage(database.propertyDao); + const key = AiPromptTemplateKey.chatSystem; + + expect(storage.aiPromptTemplate(key), key.definition.defaultValue); + expect(storage.hasAiPromptTemplateOverride(key), isFalse); + + storage.saveAiPromptTemplate(key, 'Custom prompt {{conversationId}}'); + expect(storage.aiPromptTemplate(key), 'Custom prompt {{conversationId}}'); + expect(storage.hasAiPromptTemplateOverride(key), isTrue); + + storage.saveAiPromptTemplate(key, ''); + expect(storage.aiPromptTemplate(key), isEmpty); + expect(storage.hasAiPromptTemplateOverride(key), isTrue); + + storage.resetAiPromptTemplate(key); + expect(storage.aiPromptTemplate(key), key.definition.defaultValue); + expect(storage.hasAiPromptTemplateOverride(key), isFalse); + }); } From 18da0d28433407fa5572425cbd9ff877131ba673 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 13:09:17 +0800 Subject: [PATCH 21/52] feat: add "Use & Send" action and replace "Append" with "Insert" in AI draft assist panel --- lib/ui/home/chat/ai_draft_assist_panel.dart | 33 +++++++++++++-------- lib/ui/home/chat/input_container.dart | 22 +++++++++++++- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/lib/ui/home/chat/ai_draft_assist_panel.dart b/lib/ui/home/chat/ai_draft_assist_panel.dart index 8bf92b886c..1373e14633 100644 --- a/lib/ui/home/chat/ai_draft_assist_panel.dart +++ b/lib/ui/home/chat/ai_draft_assist_panel.dart @@ -291,16 +291,18 @@ class AiDraftAssistInlineCandidate extends StatelessWidget { required this.viewState, required this.onDismiss, required this.onCopy, - required this.onAppend, + required this.onInsert, required this.onReplace, + required this.onUseAndSend, super.key, }); final AiDraftAssistViewState viewState; final VoidCallback onDismiss; final VoidCallback onCopy; - final VoidCallback onAppend; + final VoidCallback onInsert; final VoidCallback onReplace; + final VoidCallback onUseAndSend; @override Widget build(BuildContext context) { @@ -337,8 +339,9 @@ class AiDraftAssistInlineCandidate extends StatelessWidget { result: viewState.result ?? '', onDismiss: onDismiss, onCopy: onCopy, - onAppend: onAppend, + onInsert: onInsert, onReplace: onReplace, + onUseAndSend: onUseAndSend, ), AiDraftAssistPhase.error => _AiDraftAssistInlineError( error: viewState.error ?? 'Unknown error', @@ -358,16 +361,18 @@ class _AiDraftAssistInlineResult extends StatelessWidget { required this.result, required this.onDismiss, required this.onCopy, - required this.onAppend, + required this.onInsert, required this.onReplace, + required this.onUseAndSend, }); final AiDraftAction? action; final String result; final VoidCallback onDismiss; final VoidCallback onCopy; - final VoidCallback onAppend; + final VoidCallback onInsert; final VoidCallback onReplace; + final VoidCallback onUseAndSend; @override Widget build(BuildContext context) => Column( @@ -385,6 +390,12 @@ class _AiDraftAssistInlineResult extends StatelessWidget { ), ), ), + _AiDraftInlineIconButton( + icon: Icons.copy_all_rounded, + color: context.theme.secondaryText, + onTap: onCopy, + ), + const SizedBox(width: 4), _AiDraftInlineIconButton( icon: Icons.close_rounded, color: context.theme.secondaryText, @@ -415,18 +426,16 @@ class _AiDraftAssistInlineResult extends StatelessWidget { runSpacing: 8, children: [ _AiDraftInlineTextButton( - title: 'Copy', - onTap: onCopy, - secondary: true, - ), - _AiDraftInlineTextButton( - title: 'Append', - onTap: onAppend, + title: 'Insert', + onTap: onInsert, secondary: true, ), + if (action == AiDraftAction.replyWithContext) + _AiDraftInlineTextButton(title: 'Use & Send', onTap: onUseAndSend), _AiDraftInlineTextButton( title: 'Replace Draft', onTap: onReplace, + secondary: action == AiDraftAction.replyWithContext, ), ], ), diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 17a3606006..8e7f4f927c 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -332,7 +332,7 @@ class _InputContainer extends HookConsumerWidget { Clipboard.setData(ClipboardData(text: result)); showToastSuccessful(context: context); }, - onAppend: () { + onInsert: () { final result = aiDraftAssistState.value.result; if (result == null) return; applyAiDraftAssistResult( @@ -342,6 +342,26 @@ class _InputContainer extends HookConsumerWidget { ); dismissAiDraftAssist(); }, + onUseAndSend: () { + final result = aiDraftAssistState.value.result; + if (result == null || conversationId == null) { + return; + } + textEditingController.value = TextEditingValue( + text: result, + selection: TextSelection.collapsed( + offset: result.length, + ), + ); + dismissAiDraftAssist(); + unawaited( + _sendMessage( + context, + textEditingController, + conversationId: conversationId, + ), + ); + }, onReplace: () { final result = aiDraftAssistState.value.result; if (result == null) return; From 99ece417aa29b510cef21ac2aa34a6a6fbb86aa5 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 13:32:24 +0800 Subject: [PATCH 22/52] feat: remove AI message handling and introduce AI Assistant page --- lib/ui/home/bloc/message_bloc.dart | 169 ------- lib/ui/home/chat/chat_bar.dart | 26 + lib/ui/home/chat/chat_page.dart | 201 ++------ lib/ui/home/chat/input_container.dart | 50 +- .../chat_slide_page/ai_assistant_page.dart | 457 ++++++++++++++++++ lib/widgets/ai/ai_message_card.dart | 106 ++-- 6 files changed, 591 insertions(+), 418 deletions(-) create mode 100644 lib/ui/home/chat_slide_page/ai_assistant_page.dart diff --git a/lib/ui/home/bloc/message_bloc.dart b/lib/ui/home/bloc/message_bloc.dart index dacd95dd42..e606123881 100644 --- a/lib/ui/home/bloc/message_bloc.dart +++ b/lib/ui/home/bloc/message_bloc.dart @@ -25,25 +25,6 @@ abstract class _MessageEvent extends Equatable { List get props => []; } -class ChatTimelineItem extends Equatable { - const ChatTimelineItem.message(this.message) : aiMessage = null; - - const ChatTimelineItem.ai(this.aiMessage) : message = null; - - final MessageItem? message; - final AiChatMessage? aiMessage; - - bool get isMessage => message != null; - bool get isAiMessage => aiMessage != null; - - String get id => message?.messageId ?? aiMessage!.id; - - DateTime get createdAt => message?.createdAt ?? aiMessage!.createdAt; - - @override - List get props => [message, aiMessage]; -} - class _MessageJumpCurrentEvent extends _MessageEvent {} class _MessageInitEvent extends _MessageEvent { @@ -92,21 +73,11 @@ class _MessageDeleteEvent extends _MessageEvent { List get props => [messageId]; } -class _AiMessagesChangedEvent extends _MessageEvent { - _AiMessagesChangedEvent(this.data); - - final List data; - - @override - List get props => [data]; -} - class MessageState extends Equatable { MessageState({ this.top = const [], this.center, this.bottom = const [], - this.aiMessages = const [], this.conversationId, this.isLatest = false, this.isOldest = false, @@ -130,7 +101,6 @@ class MessageState extends Equatable { final List top; final MessageItem? center; final List bottom; - final List aiMessages; final bool isLatest; final bool isOldest; final String? lastReadMessageId; @@ -142,7 +112,6 @@ class MessageState extends Equatable { top, center, bottom, - aiMessages, isLatest, isOldest, lastReadMessageId, @@ -163,86 +132,11 @@ class MessageState extends Equatable { ...bottom, ]; - List get visibleAiMessages { - if (aiMessages.isEmpty) return const []; - - final messages = list; - if (messages.isEmpty) { - return aiMessages.toList()..sort(_compareAiMessages); - } - - final messageIds = messages.map((message) => message.messageId).toSet(); - final start = topMessage?.createdAt; - final end = bottomMessage?.createdAt; - - bool inLoadedRange(DateTime? value) { - if (value == null || start == null || end == null) return false; - return !value.isBefore(start) && !value.isAfter(end); - } - - final visible = aiMessages.where((message) { - final anchorMessageId = message.anchorMessageId; - if (anchorMessageId != null && messageIds.contains(anchorMessageId)) { - return true; - } - return inLoadedRange(message.anchorCreatedAt) || - inLoadedRange(message.createdAt); - }).toList()..sort(_compareAiMessages); - - return visible; - } - - List get timeline { - final messages = list; - final visibleAi = visibleAiMessages; - - if (messages.isEmpty) { - return visibleAi.map(ChatTimelineItem.ai).toList(); - } - - final anchored = >{}; - final floating = []; - - for (final aiMessage in visibleAi) { - final anchorMessageId = aiMessage.anchorMessageId; - if (anchorMessageId != null) { - anchored.putIfAbsent(anchorMessageId, () => []).add(aiMessage); - } else { - floating.add(aiMessage); - } - } - - final result = []; - var floatingIndex = 0; - - for (final message in messages) { - while (floatingIndex < floating.length && - !floating[floatingIndex].createdAt.isAfter(message.createdAt)) { - result.add(ChatTimelineItem.ai(floating[floatingIndex])); - floatingIndex++; - } - - result.add(ChatTimelineItem.message(message)); - final anchoredMessages = anchored[message.messageId]; - if (anchoredMessages != null) { - result.addAll(anchoredMessages.map(ChatTimelineItem.ai)); - } - } - - while (floatingIndex < floating.length) { - result.add(ChatTimelineItem.ai(floating[floatingIndex])); - floatingIndex++; - } - - return result; - } - MessageState copyWith({ String? conversationId, List? top, MessageItem? center, List? bottom, - List? aiMessages, bool? isLatest, bool? isOldest, String? lastReadMessageId, @@ -252,7 +146,6 @@ class MessageState extends Equatable { top: top ?? this.top, center: center ?? this.center, bottom: bottom ?? this.bottom, - aiMessages: aiMessages ?? this.aiMessages, isLatest: isLatest ?? this.isLatest, isOldest: isOldest ?? this.isOldest, lastReadMessageId: lastReadMessageId ?? this.lastReadMessageId, @@ -261,7 +154,6 @@ class MessageState extends Equatable { MessageState _copyWithJumpCurrentState() => MessageState( top: list.toList(), - aiMessages: aiMessages, refreshKey: Object(), conversationId: conversationId, isLatest: isLatest, @@ -275,7 +167,6 @@ class MessageState extends Equatable { conversationId: conversationId, top: top, bottom: bottom, - aiMessages: aiMessages, isLatest: isLatest, isOldest: isOldest, lastReadMessageId: lastReadMessageId, @@ -298,38 +189,6 @@ class MessageState extends Equatable { } } -int _compareAiMessages(AiChatMessage a, AiChatMessage b) { - final anchorCompare = _compareNullableDateTime( - a.anchorCreatedAt, - b.anchorCreatedAt, - ); - if (anchorCompare != 0) return anchorCompare; - - final createdAtCompare = a.createdAt.compareTo(b.createdAt); - if (createdAtCompare != 0) return createdAtCompare; - - final roleCompare = _compareAiRoles(a.role, b.role); - if (roleCompare != 0) return roleCompare; - - return a.id.compareTo(b.id); -} - -int _compareNullableDateTime(DateTime? a, DateTime? b) { - if (a == null && b == null) return 0; - if (a == null) return 1; - if (b == null) return -1; - return a.compareTo(b); -} - -int _compareAiRoles(String a, String b) { - const order = { - 'user': 0, - 'assistant': 1, - }; - - return (order[a] ?? order.length).compareTo(order[b] ?? order.length); -} - class MessageBloc extends Bloc<_MessageEvent, MessageState> with SubscribeMixin { MessageBloc({ @@ -359,10 +218,6 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> _onEvent, transformer: sequential(), ); - on<_AiMessagesChangedEvent>( - _onEvent, - transformer: restartable(), - ); on<_MessageScrollEvent>( _onEvent, transformer: restartable(), @@ -413,22 +268,6 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> .listen((state) => add(_MessageInsertOrReplaceEvent(state))), ); - addSubscription( - conversationNotifier.stream - .startWith(conversationNotifier.state) - .map((event) => event?.conversationId) - .distinct() - .switchMap((conversationId) { - if (conversationId == null) { - return Stream.value(const []); - } - return database.aiChatMessageDao.watchConversationMessages( - conversationId, - ); - }) - .listen((state) => add(_AiMessagesChangedEvent(state))), - ); - addSubscription( DataBaseEventBus.instance.deleteMessageIdStream.listen((messageIds) { messageIds.forEach((messageId) { @@ -486,8 +325,6 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> } else if (event is _MessageDeleteEvent) { final messageState = state.removeMessage(event.messageId); emit(_pretreatment(messageState)); - } else if (event is _AiMessagesChangedEvent) { - emit(_pretreatment(state.copyWith(aiMessages: event.data))); } else { if (event is _MessageLoadMoreEvent) { if (event is _MessageLoadAfterEvent) { @@ -571,16 +408,12 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> limit, centerMessageId: _centerMessageId, ); - final aiMessages = await database.aiChatMessageDao.conversationMessages( - conversationId, - ); return state.copyWith( conversationId: conversationId, center: state.center, bottom: state.bottom, top: state.top, - aiMessages: aiMessages, ); } @@ -596,7 +429,6 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> return MessageState( top: list.reversed.toList(), - aiMessages: state.aiMessages, isLatest: true, isOldest: list.length < limit, ); @@ -636,7 +468,6 @@ class MessageBloc extends Bloc<_MessageEvent, MessageState> top: topList, center: center, bottom: bottomList, - aiMessages: state.aiMessages, isLatest: isLatest, isOldest: isOldest, ); diff --git a/lib/ui/home/chat/chat_bar.dart b/lib/ui/home/chat/chat_bar.dart index 6e516f8960..f70c75445c 100644 --- a/lib/ui/home/chat/chat_bar.dart +++ b/lib/ui/home/chat/chat_bar.dart @@ -26,6 +26,8 @@ class ChatBar extends HookConsumerWidget { @override Widget build(BuildContext context, WidgetRef ref) { + useListenable(context.database.settingProperties); + final actionColor = context.theme.icon; final chatSideCubit = context.read(); @@ -40,6 +42,11 @@ class ChatBar extends HookConsumerWidget { final conversation = ref.watch(conversationProvider); final inMultiSelectMode = ref.watch(hasSelectedMessageProvider); + final hasAvailableAiModel = + context.database.settingProperties.selectedAiProvider?.model + .trim() + .isNotEmpty == + true; MoveWindowBarrier toggleInfoPageWrapper({ required Widget child, @@ -137,6 +144,25 @@ class ChatBar extends HookConsumerWidget { ), ) else ...[ + if (hasAvailableAiModel) + MoveWindowBarrier( + child: ActionButton( + color: actionColor, + onTap: () { + final cubit = context.read(); + if (cubit.state.pages.lastOrNull?.name == + ChatSideCubit.aiAssistantPage) { + return cubit.pop(); + } + cubit.replace(ChatSideCubit.aiAssistantPage); + }, + child: Icon( + Icons.auto_awesome_rounded, + size: 20, + color: actionColor, + ), + ), + ), MoveWindowBarrier( child: ActionButton( name: Resources.assetsImagesIcSearchSvg, diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index e2f1e7fcc0..8697933d2b 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -1,5 +1,4 @@ import 'dart:io'; -import 'dart:math' as math; import 'package:desktop_drop/desktop_drop.dart'; import 'package:flutter/material.dart'; @@ -22,7 +21,6 @@ import '../../../utils/extension/extension.dart'; import '../../../utils/hook.dart'; import '../../../widgets/action_button.dart'; import '../../../widgets/actions/actions.dart'; -import '../../../widgets/ai/ai_message_card.dart'; import '../../../widgets/animated_visibility.dart'; import '../../../widgets/clamping_custom_scroll_view/clamping_custom_scroll_view.dart'; import '../../../widgets/conversation/mute_dialog.dart'; @@ -30,7 +28,6 @@ import '../../../widgets/dash_path_border.dart'; import '../../../widgets/dialog.dart'; import '../../../widgets/high_light_text.dart'; import '../../../widgets/interactive_decorated_box.dart'; -import '../../../widgets/markdown.dart'; import '../../../widgets/menu.dart'; import '../../../widgets/message/message.dart'; import '../../../widgets/message/message_bubble.dart'; @@ -45,6 +42,7 @@ import '../../provider/message_selection_provider.dart'; import '../../provider/pending_jump_message_provider.dart'; import '../bloc/blink_cubit.dart'; import '../bloc/message_bloc.dart'; +import '../chat_slide_page/ai_assistant_page.dart'; import '../chat_slide_page/chat_info_page.dart'; import '../chat_slide_page/circle_manager_page.dart'; import '../chat_slide_page/disappear_message_page.dart'; @@ -74,6 +72,7 @@ class ChatSideCubit extends AbstractResponsiveNavigatorCubit { static const sharedApps = 'sharedApps'; static const groupsInCommon = 'groupsInCommon'; static const disappearMessages = 'disappearMessages'; + static const aiAssistantPage = 'aiAssistantPage'; @override MaterialPage route(String name, Object? arguments) { @@ -132,6 +131,12 @@ class ChatSideCubit extends AbstractResponsiveNavigatorCubit { name: disappearMessages, child: _ChatSidePageBuilder(DisappearMessagePage.new), ); + case aiAssistantPage: + return const MaterialPage( + key: ValueKey(aiAssistantPage), + name: aiAssistantPage, + child: _ChatSidePageBuilder(AiAssistantPage.new), + ); default: throw ArgumentError('Invalid route'); } @@ -573,24 +578,11 @@ class _List extends HookConsumerWidget { final state = useBlocState( when: (state) => state.conversationId != null, ); + final key = ValueKey((state.conversationId, state.refreshKey)); + final top = state.top; final center = state.center; - final timeline = state.timeline; - - final centerTimelineIndex = center == null - ? null - : timeline.indexWhere( - (item) => item.message?.messageId == center.messageId, - ); - final topTimeline = centerTimelineIndex == null - ? timeline - : timeline.take(centerTimelineIndex).toList(); - final centerTimeline = centerTimelineIndex == null - ? null - : timeline[centerTimelineIndex]; - final bottomTimeline = centerTimelineIndex == null - ? const [] - : timeline.skip(centerTimelineIndex + 1).toList(); + final bottom = state.bottom; final keyRef = useRef>({}); @@ -612,128 +604,6 @@ class _List extends HookConsumerWidget { context, ).scrollController; - MessageItem? prevMessageOf( - ChatTimelineItem item, - List items, - ) { - final index = items.indexOf(item); - if (index <= 0) return null; - for (var i = index - 1; i >= 0; i--) { - final message = items[i].message; - if (message != null) return message; - } - return null; - } - - MessageItem? nextMessageOf( - ChatTimelineItem item, - List items, - ) { - final index = items.indexOf(item); - if (index == -1 || index >= items.length - 1) return null; - for (var i = index + 1; i < items.length; i++) { - final message = items[i].message; - if (message != null) return message; - } - return null; - } - - AiChatMessage? prevAiOf( - ChatTimelineItem item, - List items, - ) { - final index = items.indexOf(item); - if (index <= 0) return null; - for (var i = index - 1; i >= 0; i--) { - final aiMessage = items[i].aiMessage; - if (aiMessage != null) return aiMessage; - } - return null; - } - - AiChatMessage? nextAiOf( - ChatTimelineItem item, - List items, - ) { - final index = items.indexOf(item); - if (index == -1 || index >= items.length - 1) return null; - for (var i = index + 1; i < items.length; i++) { - final aiMessage = items[i].aiMessage; - if (aiMessage != null) return aiMessage; - } - return null; - } - - ({String key, String data})? markdownWarmupEntryOf(ChatTimelineItem item) { - final aiMessage = item.aiMessage; - if (aiMessage != null) { - if (aiMessage.role == 'user' || aiMessage.status == 'error') { - return null; - } - final content = aiMessage.content.trim(); - if (content.isEmpty) return null; - return ( - key: buildMarkdownCacheKey( - namespace: 'ai', - id: aiMessage.id, - ), - data: content, - ); - } - - final message = item.message; - if (message == null || !message.type.isPost) return null; - final content = (message.content ?? '').postOptimize(); - if (content.isEmpty) return null; - return ( - key: buildMarkdownCacheKey( - namespace: 'post', - id: message.messageId, - ), - data: content, - ); - } - - void warmupMarkdownAround(int index) { - final start = math.max(0, index - 6); - final end = math.min(timeline.length, index + 7); - final entries = <({String key, String data})>[]; - for (var i = start; i < end; i++) { - final entry = markdownWarmupEntryOf(timeline[i]); - if (entry != null) { - entries.add(entry); - } - } - markdownControllerCache.warmupAll(entries); - } - - Widget buildTimelineChild(ChatTimelineItem item, int index) { - warmupMarkdownAround(index); - final prevDateTime = index > 0 ? timeline[index - 1].createdAt : null; - if (item.isAiMessage) { - return MessageDayTimeItem( - key: ValueKey('ai-daytime-${item.id}'), - dateTime: item.createdAt, - prevDateTime: prevDateTime, - child: AiMessageCard( - key: ValueKey('ai-${item.id}'), - message: item.aiMessage!, - prev: prevAiOf(item, timeline), - next: nextAiOf(item, timeline), - ), - ); - } - final message = item.message!; - return MessageItemWidget( - key: keyRef.value[message.messageId], - prev: prevMessageOf(item, timeline), - prevDateTime: prevDateTime, - message: message, - next: nextMessageOf(item, timeline), - lastReadMessageId: state.lastReadMessageId, - ); - } - return MessageDayTimeViewportWidget.chatPage( key: key, bottomKey: bottomKey, @@ -756,31 +626,50 @@ class _List extends HookConsumerWidget { context, index, ) { - final actualIndex = topTimeline.length - index - 1; - return buildTimelineChild(topTimeline[actualIndex], actualIndex); - }, childCount: topTimeline.length), + final actualIndex = top.length - index - 1; + final messageItem = top[actualIndex]; + return MessageItemWidget( + key: keyRef.value[messageItem.messageId], + prev: top.getOrNull(actualIndex - 1), + message: messageItem, + next: + top.getOrNull(actualIndex + 1) ?? + center ?? + bottom.lastOrNull, + lastReadMessageId: state.lastReadMessageId, + ); + }, childCount: top.length), ), SliverToBoxAdapter( key: key, child: Builder( builder: (context) { - if (centerTimeline == null) return const SizedBox(); - return buildTimelineChild(centerTimeline, centerTimelineIndex!); + if (center == null) return const SizedBox(); + return MessageItemWidget( + key: keyRef.value[center.messageId], + prev: top.lastOrNull, + message: center, + next: bottom.firstOrNull, + lastReadMessageId: state.lastReadMessageId, + ); }, ), ), SliverList( key: bottomKey, - delegate: SliverChildBuilderDelegate( - ( - context, - index, - ) => buildTimelineChild( - bottomTimeline[index], - (centerTimelineIndex ?? -1) + index + 1, - ), - childCount: bottomTimeline.length, - ), + delegate: SliverChildBuilderDelegate(( + context, + index, + ) { + final messageItem = bottom[index]; + return MessageItemWidget( + key: keyRef.value[messageItem.messageId], + prev: bottom.getOrNull(index - 1) ?? center ?? top.lastOrNull, + message: messageItem, + next: bottom.getOrNull(index + 1), + lastReadMessageId: state.lastReadMessageId, + ); + }, childCount: bottom.length), ), const SliverToBoxAdapter(child: SizedBox(height: 10)), ], diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 8e7f4f927c..07449513fc 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -701,21 +701,15 @@ Future _sendMessage( return; } - final aiModeController = context.providerContainer.read( - aiInputModeProvider(conversationId).notifier, - ); - final aiModeState = context.providerContainer.read( - aiInputModeProvider(conversationId), - ); - final language = _currentLanguageTag(context); - if (text == '/ai') { final provider = context.database.settingProperties.selectedAiProvider; - if (provider == null) { + if (provider == null || provider.model.trim().isEmpty) { showToastFailed(ToastError('Please add an AI provider first')); return; } - aiModeController.enter(providerId: provider.id, model: provider.model); + unawaited( + context.read().replace(ChatSideCubit.aiAssistantPage), + ); textEditingController.text = ''; return; } @@ -725,44 +719,18 @@ Future _sendMessage( : null; if (inlineAiInput != null && inlineAiInput.isNotEmpty) { final provider = context.database.settingProperties.selectedAiProvider; - if (provider == null) { + if (provider == null || provider.model.trim().isEmpty) { showToastFailed(ToastError('Please add an AI provider first')); return; } - aiModeController.enter(providerId: provider.id, model: provider.model); - try { - await AiChatController(context.database).send( - conversationId: conversationId, - input: inlineAiInput, - language: language, - provider: provider, - onInputAccepted: () => textEditingController.text = '', - ); - } catch (error, _) { - showToastFailed(error); - } - return; - } - - if (aiModeState.enabled) { - final provider = _resolveAiModeProvider( - selectedAiProvider: context.database.settingProperties.selectedAiProvider, - enabledAiProviders: context.database.settingProperties.aiProviders - .whereType() - .where((element) => element.enabled) - .toList(), - providerId: aiModeState.providerId, - selectedModel: aiModeState.model, + unawaited( + context.read().replace(ChatSideCubit.aiAssistantPage), ); - if (provider == null) { - showToastFailed(ToastError('Please add an AI provider first')); - return; - } try { await AiChatController(context.database).send( conversationId: conversationId, - input: text, - language: language, + input: inlineAiInput, + language: _currentLanguageTag(context), provider: provider, onInputAccepted: () => textEditingController.text = '', ); diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart new file mode 100644 index 0000000000..11c8815078 --- /dev/null +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -0,0 +1,457 @@ +import 'dart:async'; + +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../../ai/ai_chat_controller.dart'; +import '../../../ai/model/ai_provider_config.dart'; +import '../../../constants/constants.dart'; +import '../../../db/mixin_database.dart'; +import '../../../utils/extension/extension.dart'; +import '../../../utils/hook.dart'; +import '../../../widgets/action_button.dart'; +import '../../../widgets/ai/ai_message_card.dart'; +import '../../../widgets/app_bar.dart'; +import '../../../widgets/empty.dart'; +import '../../../widgets/menu.dart'; +import '../../../widgets/message/message_day_time.dart'; +import '../../../widgets/toast.dart'; +import '../../provider/ai_input_mode_provider.dart'; +import '../../provider/conversation_provider.dart'; + +const _aiAssistantTitle = 'AI Assistant'; +const _aiAssistantEmpty = 'Ask AI about this conversation'; +const _aiAssistantInputHint = 'Ask about this conversation'; +const _aiAssistantUnavailable = 'Add a usable AI model in Settings first'; + +class AiAssistantPage extends HookConsumerWidget { + const AiAssistantPage(this.conversationState, {super.key}); + + final ConversationState conversationState; + + @override + Widget build(BuildContext context, WidgetRef ref) { + useListenable(context.database.settingProperties); + + final conversationId = conversationState.conversationId; + final aiModeState = ref.watch(aiInputModeProvider(conversationId)); + final aiModeNotifier = ref.read( + aiInputModeProvider(conversationId).notifier, + ); + final enabledAiProviders = context.database.settingProperties.aiProviders + .whereType() + .where((item) => item.enabled && item.model.trim().isNotEmpty) + .toList(growable: false); + final aiProvider = _resolveAiAssistantProvider( + selectedAiProvider: context.database.settingProperties.selectedAiProvider, + enabledAiProviders: enabledAiProviders, + providerId: aiModeState.providerId, + selectedModel: aiModeState.model, + ); + final messages = + useMemoizedStream( + () => context.database.aiChatMessageDao.watchConversationMessages( + conversationId, + ), + keys: [conversationId], + initialData: const [], + ).data ?? + const []; + final requestInFlight = messages.any(isActivePendingAiMessage); + final textEditingController = useMemoized( + TextEditingController.new, + [conversationId], + ); + final scrollController = useScrollController(); + final focusNode = useFocusNode(); + final lastMessage = messages.lastOrNull; + + useEffect(() { + if (!context.textFieldAutoGainFocus) { + focusNode.unfocus(); + return null; + } + focusNode.requestFocus(); + return null; + }, [conversationId]); + + useEffect(() { + WidgetsBinding.instance.addPostFrameCallback((_) { + if (!scrollController.hasClients) return; + final position = scrollController.position; + final shouldStickToBottom = + messages.length <= 2 || + !position.hasContentDimensions || + position.maxScrollExtent - position.pixels < 96 || + lastMessage?.role == 'user'; + if (!shouldStickToBottom) return; + unawaited( + scrollController.animateTo( + position.maxScrollExtent, + duration: const Duration(milliseconds: 180), + curve: Curves.easeOutCubic, + ), + ); + }); + return null; + }, [messages.length, lastMessage?.updatedAt, lastMessage?.content]); + + Future send() async { + final text = textEditingController.text.trim(); + if (text.isEmpty) return; + if (text.length > kMaxTextLength) { + showToastFailed(ToastError(context.l10n.contentTooLong)); + return; + } + if (aiProvider == null) { + showToastFailed(ToastError(_aiAssistantUnavailable)); + return; + } + + try { + await AiChatController(context.database).send( + conversationId: conversationId, + input: text, + language: _currentLanguageTag(context), + provider: aiProvider, + onInputAccepted: textEditingController.clear, + ); + } catch (error, _) { + showToastFailed(error); + } + } + + return Scaffold( + backgroundColor: context.theme.primary, + appBar: const MixinAppBar(title: Text(_aiAssistantTitle)), + body: Column( + children: [ + if (aiProvider != null) + _AiAssistantModeBar( + provider: aiProvider, + enabledAiProviders: enabledAiProviders, + onProviderSelected: (value) => aiModeNotifier.updateProvider( + providerId: value.id, + model: value.model, + ), + onModelSelected: aiModeNotifier.updateModel, + ), + Expanded( + child: messages.isEmpty + ? const Empty(text: _aiAssistantEmpty) + : ListView.builder( + controller: scrollController, + padding: const EdgeInsets.fromLTRB(16, 12, 16, 20), + itemCount: messages.length, + itemBuilder: (context, index) { + final message = messages[index]; + return MessageDayTimeItem( + key: ValueKey('assistant-${message.id}'), + dateTime: message.createdAt, + prevDateTime: index > 0 + ? messages[index - 1].createdAt + : null, + child: AiMessageCard( + message: message, + prev: index > 0 ? messages[index - 1] : null, + next: index < messages.length - 1 + ? messages[index + 1] + : null, + ), + ); + }, + ), + ), + _AiAssistantComposer( + focusNode: focusNode, + textEditingController: textEditingController, + enabled: aiProvider != null, + requestInFlight: requestInFlight, + onSend: send, + onStop: () => + AiChatController(context.database).stop(conversationId), + ), + ], + ), + ); + } +} + +class _AiAssistantModeBar extends StatelessWidget { + const _AiAssistantModeBar({ + required this.provider, + required this.enabledAiProviders, + required this.onProviderSelected, + required this.onModelSelected, + }); + + final AiProviderConfig provider; + final List enabledAiProviders; + final ValueChanged onProviderSelected; + final ValueChanged onModelSelected; + + @override + Widget build(BuildContext context) { + final providerOptions = enabledAiProviders + .map( + (item) => CustomPopupMenuItem( + title: item.name, + value: item, + ), + ) + .toList(growable: false); + final modelOptions = provider.models + .where((item) => item.trim().isNotEmpty) + .map( + (item) => CustomPopupMenuItem( + title: item.trim(), + value: item.trim(), + ), + ) + .toList(growable: false); + + return Container( + width: double.infinity, + padding: const EdgeInsets.fromLTRB(16, 12, 16, 0), + child: DecoratedBox( + decoration: BoxDecoration( + color: context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.06), + ), + borderRadius: const BorderRadius.all(Radius.circular(12)), + border: Border.all( + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.05), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ), + ), + ), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 10), + child: Row( + children: [ + Icon( + Icons.auto_awesome_rounded, + size: 14, + color: context.theme.ai.accent, + ), + const SizedBox(width: 10), + Expanded( + child: Row( + children: [ + Flexible( + child: _AiModeChip( + icon: Icons.hub_rounded, + label: provider.name, + items: providerOptions, + enabled: providerOptions.length > 1, + onSelected: onProviderSelected, + ), + ), + const SizedBox(width: 8), + Flexible( + child: _AiModeChip( + icon: Icons.tune_rounded, + label: provider.model, + items: modelOptions, + enabled: modelOptions.length > 1, + onSelected: onModelSelected, + ), + ), + ], + ), + ), + ], + ), + ), + ), + ); + } +} + +class _AiModeChip extends StatelessWidget { + const _AiModeChip({ + required this.icon, + required this.label, + required this.items, + required this.onSelected, + required this.enabled, + }); + + final IconData icon; + final String label; + final List> items; + final ValueChanged onSelected; + final bool enabled; + + @override + Widget build(BuildContext context) { + final child = Row( + children: [ + Icon(icon, size: 13, color: context.theme.secondaryText), + const SizedBox(width: 6), + Expanded( + child: Text( + label, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + ), + if (enabled) ...[ + const SizedBox(width: 2), + Icon( + Icons.keyboard_arrow_down_rounded, + size: 14, + color: context.theme.secondaryText, + ), + ], + ], + ); + + if (!enabled || items.isEmpty) return child; + + return CustomPopupMenuButton( + itemBuilder: (_) => items, + onSelected: onSelected, + color: Colors.transparent, + useActionButton: false, + child: child, + ); + } +} + +class _AiAssistantComposer extends StatelessWidget { + const _AiAssistantComposer({ + required this.focusNode, + required this.textEditingController, + required this.enabled, + required this.requestInFlight, + required this.onSend, + required this.onStop, + }); + + final FocusNode focusNode; + final TextEditingController textEditingController; + final bool enabled; + final bool requestInFlight; + final VoidCallback onSend; + final VoidCallback onStop; + + @override + Widget build(BuildContext context) { + final buttonColor = !enabled + ? context.theme.secondaryText + : requestInFlight + ? context.theme.red + : context.theme.accent; + + return Container( + padding: const EdgeInsets.fromLTRB(16, 12, 16, 16), + decoration: BoxDecoration( + color: context.theme.primary, + border: Border(top: BorderSide(color: context.theme.divider)), + ), + child: DecoratedBox( + decoration: BoxDecoration( + borderRadius: const BorderRadius.all(Radius.circular(14)), + color: context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ), + ), + child: Padding( + padding: const EdgeInsets.fromLTRB(12, 10, 8, 10), + child: Row( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Expanded( + child: TextField( + focusNode: focusNode, + controller: textEditingController, + enabled: enabled, + minLines: 1, + maxLines: 6, + inputFormatters: [ + LengthLimitingTextInputFormatter(kMaxTextLength), + ], + style: TextStyle( + color: context.theme.text, + fontSize: 14, + height: 1.4, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: enabled + ? _aiAssistantInputHint + : _aiAssistantUnavailable, + hintStyle: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), + ), + const SizedBox(width: 8), + ActionButton( + padding: const EdgeInsets.all(6), + size: 20, + interactive: enabled, + onTap: requestInFlight ? onStop : onSend, + child: Icon( + requestInFlight + ? Icons.stop_rounded + : Icons.arrow_upward_rounded, + size: 18, + color: buttonColor, + ), + ), + ], + ), + ), + ), + ); + } +} + +AiProviderConfig? _resolveAiAssistantProvider({ + required AiProviderConfig? selectedAiProvider, + required List enabledAiProviders, + required String? providerId, + required String? selectedModel, +}) { + var provider = selectedAiProvider; + if (providerId != null) { + for (final item in enabledAiProviders) { + if (item.id == providerId) { + provider = item; + break; + } + } + } + if (provider == null || provider.model.trim().isEmpty) { + provider = enabledAiProviders.firstOrNull; + } + if (provider == null) return null; + + final trimmedModel = selectedModel?.trim(); + if (trimmedModel == null || trimmedModel.isEmpty) return provider; + if (!provider.models.contains(trimmedModel)) return provider; + if (provider.model == trimmedModel) return provider; + return provider.copyWith(defaultModel: trimmedModel, model: trimmedModel); +} + +String _currentLanguageTag(BuildContext context) { + final locale = Localizations.localeOf(context); + final countryCode = locale.countryCode; + if (countryCode == null || countryCode.isEmpty) return locale.languageCode; + return '${locale.languageCode}-$countryCode'; +} diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index f55061ec90..12113e2cf3 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -20,6 +20,9 @@ import '../message/message_layout.dart'; import '../message/message_style.dart'; import '../qr_code.dart'; +const _aiAssistantTitle = 'AI Assistant'; +const _copyAiMessageTitle = 'Copy AI Message'; + class AiMessageCard extends StatelessWidget { const AiMessageCard({ required this.message, @@ -41,84 +44,81 @@ class AiMessageCard extends StatelessWidget { final sameRoleNext = next?.role == message.role; final mergedWithPrev = sameDayPrev && sameRolePrev; final mergedWithNext = sameDayNext && sameRoleNext; - final showAssistantMeta = !isUser && !mergedWithPrev; - final bubbleColor = _bubbleColor( - context, - isUser: isUser, - status: message.status, - ); - final body = _AiBubble( - isCurrentUser: isUser, - showNip: !mergedWithNext && !showAssistantMeta, - color: bubbleColor, - child: ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 420), - child: _AiMessageBody(message: message), - ), - ); - final content = Column( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: isUser - ? CrossAxisAlignment.end - : CrossAxisAlignment.start, - children: [ - if (showAssistantMeta) - Padding( - padding: const EdgeInsets.only(left: 10, bottom: 2), - child: Text( - 'AI Assistant', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: context.messageStyle.statusFontSize, - fontWeight: FontWeight.w500, - ), - ), - ), - body, - ], + final body = ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 420), + child: _AiMessageBody(message: message), ); if (isUser) { return Padding( padding: EdgeInsets.only( - left: 65, - right: 16, - top: mergedWithPrev ? 0 : 8, - bottom: 2, + left: 72, + right: 8, + top: mergedWithPrev ? 4 : 14, + bottom: 4, ), child: Align( alignment: Alignment.centerRight, child: _AiMessageMenu( message: message, - child: content, + child: _AiBubble( + isCurrentUser: true, + showNip: !mergedWithNext, + color: _bubbleColor( + context, + isUser: true, + status: message.status, + ), + child: body, + ), ), ), ); } return Padding( - padding: EdgeInsets.only(top: mergedWithPrev ? 0 : 8, bottom: 2), + padding: EdgeInsets.only( + left: 8, + right: 44, + top: mergedWithPrev ? 6 : 18, + bottom: 6, + ), child: Row( - mainAxisSize: MainAxisSize.min, crossAxisAlignment: CrossAxisAlignment.start, children: [ - const SizedBox(width: 8), SizedBox( width: 32, - child: showAssistantMeta + child: !mergedWithPrev ? _AiAvatar(thinking: message.status == 'pending') : null, ), - Flexible( + const SizedBox(width: 12), + Expanded( child: Padding( - padding: const EdgeInsets.symmetric(vertical: 2), - child: _AiMessageMenu( - message: message, - child: content, + padding: const EdgeInsets.only(top: 1), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (!mergedWithPrev) + Padding( + padding: const EdgeInsets.only(bottom: 6), + child: Text( + _aiAssistantTitle, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: context.messageStyle.statusFontSize, + fontWeight: FontWeight.w600, + ), + ), + ), + _AiMessageMenu( + message: message, + child: body, + ), + ], ), ), ), - const SizedBox(width: 65), ], ), ); @@ -142,7 +142,9 @@ class _AiMessageBody extends StatelessWidget { Widget body; final textStyle = TextStyle( - color: context.theme.text, + color: message.status == 'error' + ? context.theme.ai.error + : context.theme.text, fontSize: context.messageStyle.primaryFontSize, height: 1.45, ); @@ -364,7 +366,7 @@ class _AiMessageMenu extends StatelessWidget { [ MenuAction( image: MenuImage.icon(Icons.data_object), - title: 'Copy AI message', + title: _copyAiMessageTitle, callback: () { Clipboard.setData(ClipboardData(text: message.toString())); }, From 815252fd3078064ed8e43c4759e73e5181b00624 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 14:36:52 +0800 Subject: [PATCH 23/52] feat(ai-assistant): add new components for AI assistant UI and message handling --- lib/db/dao/ai_chat_message_dao.dart | 42 ++ .../ai_assistant/composer.dart | 257 ++++++++++++ .../ai_assistant/constants.dart | 6 + .../chat_slide_page/ai_assistant/helpers.dart | 37 ++ .../ai_assistant/message_list.dart | 264 ++++++++++++ .../chat_slide_page/ai_assistant_page.dart | 382 ++---------------- lib/widgets/ai/ai_message_card.dart | 297 ++------------ 7 files changed, 665 insertions(+), 620 deletions(-) create mode 100644 lib/ui/home/chat_slide_page/ai_assistant/composer.dart create mode 100644 lib/ui/home/chat_slide_page/ai_assistant/constants.dart create mode 100644 lib/ui/home/chat_slide_page/ai_assistant/helpers.dart create mode 100644 lib/ui/home/chat_slide_page/ai_assistant/message_list.dart diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index 12ddfe414a..a43ad38348 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -26,6 +26,22 @@ class AiChatMessageDao extends DatabaseAccessor ])) .watch(); + Stream> watchLatestConversationMessages( + String conversationId, + int limit, + ) => + (select( + db.aiChatMessages, + ) + ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.createdAt), + (tbl) => OrderingTerm.desc(tbl.id), + ]) + ..limit(limit)) + .watch() + .map((items) => items.reversed.toList(growable: false)); + Future> conversationMessages(String conversationId) => (select( db.aiChatMessages, @@ -37,6 +53,32 @@ class AiChatMessageDao extends DatabaseAccessor ])) .get(); + Future> beforeConversationMessages({ + required String conversationId, + required AiChatMessage before, + required int limit, + }) async { + final beforeCreatedAt = before.createdAt.millisecondsSinceEpoch; + final list = + await (select( + db.aiChatMessages, + ) + ..where( + (tbl) => + tbl.conversationId.equals(conversationId) & + (tbl.createdAt.isSmallerThanValue(beforeCreatedAt) | + (tbl.createdAt.equals(beforeCreatedAt) & + tbl.id.isSmallerThanValue(before.id))), + ) + ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.createdAt), + (tbl) => OrderingTerm.desc(tbl.id), + ]) + ..limit(limit)) + .get(); + return list.reversed.toList(growable: false); + } + Future insertMessage(AiChatMessagesCompanion row) => into(db.aiChatMessages).insertOnConflictUpdate(row); diff --git a/lib/ui/home/chat_slide_page/ai_assistant/composer.dart b/lib/ui/home/chat_slide_page/ai_assistant/composer.dart new file mode 100644 index 0000000000..9838c48e54 --- /dev/null +++ b/lib/ui/home/chat_slide_page/ai_assistant/composer.dart @@ -0,0 +1,257 @@ +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; + +import '../../../../ai/model/ai_provider_config.dart'; +import '../../../../constants/constants.dart'; +import '../../../../utils/extension/extension.dart'; +import '../../../../widgets/action_button.dart'; +import '../../../../widgets/menu.dart'; +import 'constants.dart'; + +class AiAssistantComposer extends StatelessWidget { + const AiAssistantComposer({ + required this.focusNode, + required this.textEditingController, + required this.enabled, + required this.enabledAiProviders, + required this.requestInFlight, + required this.onSend, + required this.onStop, + required this.onProviderSelected, + required this.onModelSelected, + this.provider, + super.key, + }); + + final FocusNode focusNode; + final TextEditingController textEditingController; + final bool enabled; + final AiProviderConfig? provider; + final List enabledAiProviders; + final bool requestInFlight; + final VoidCallback onSend; + final VoidCallback onStop; + final ValueChanged onProviderSelected; + final ValueChanged onModelSelected; + + @override + Widget build(BuildContext context) { + final buttonColor = !enabled + ? context.theme.secondaryText + : requestInFlight + ? context.theme.red + : context.theme.accent; + + return Container( + padding: const EdgeInsets.fromLTRB(16, 12, 16, 16), + decoration: BoxDecoration( + color: context.theme.primary, + border: Border(top: BorderSide(color: context.theme.divider)), + ), + child: DecoratedBox( + decoration: BoxDecoration( + borderRadius: const BorderRadius.all(Radius.circular(14)), + color: context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ), + ), + child: Padding( + padding: const EdgeInsets.fromLTRB(12, 10, 8, 10), + child: Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + if (provider != null) ...[ + _AiAssistantModeBar( + provider: provider!, + enabledAiProviders: enabledAiProviders, + onProviderSelected: onProviderSelected, + onModelSelected: onModelSelected, + ), + const SizedBox(height: 2), + ], + Row( + crossAxisAlignment: CrossAxisAlignment.end, + children: [ + Expanded( + child: TextField( + focusNode: focusNode, + controller: textEditingController, + enabled: enabled, + minLines: 1, + maxLines: 6, + inputFormatters: [ + LengthLimitingTextInputFormatter(kMaxTextLength), + ], + style: TextStyle( + color: context.theme.text, + fontSize: 14, + height: 1.4, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: enabled + ? aiAssistantInputHint + : aiAssistantUnavailable, + hintStyle: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), + ), + const SizedBox(width: 8), + ActionButton( + padding: const EdgeInsets.all(6), + size: 20, + interactive: enabled, + onTap: requestInFlight ? onStop : onSend, + child: Icon( + requestInFlight + ? Icons.stop_rounded + : Icons.arrow_upward_rounded, + size: 18, + color: buttonColor, + ), + ), + ], + ), + ], + ), + ), + ), + ); + } +} + +class _AiAssistantModeBar extends StatelessWidget { + const _AiAssistantModeBar({ + required this.provider, + required this.enabledAiProviders, + required this.onProviderSelected, + required this.onModelSelected, + }); + + final AiProviderConfig provider; + final List enabledAiProviders; + final ValueChanged onProviderSelected; + final ValueChanged onModelSelected; + + @override + Widget build(BuildContext context) { + final providerOptions = enabledAiProviders + .map( + (item) => CustomPopupMenuItem( + title: item.name, + value: item, + ), + ) + .toList(growable: false); + final modelOptions = provider.models + .where((item) => item.trim().isNotEmpty) + .map( + (item) => CustomPopupMenuItem( + title: item.trim(), + value: item.trim(), + ), + ) + .toList(growable: false); + + return DecoratedBox( + decoration: BoxDecoration( + border: Border( + bottom: BorderSide( + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.05), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ), + ), + ), + ), + child: Padding( + padding: const EdgeInsets.only(bottom: 8), + child: Row( + children: [ + Flexible( + child: _AiModeChip( + icon: Icons.hub_rounded, + label: provider.name, + items: providerOptions, + enabled: providerOptions.length > 1, + onSelected: onProviderSelected, + ), + ), + const SizedBox(width: 8), + Flexible( + child: _AiModeChip( + icon: Icons.tune_rounded, + label: provider.model, + items: modelOptions, + enabled: modelOptions.length > 1, + onSelected: onModelSelected, + ), + ), + ], + ), + ), + ); + } +} + +class _AiModeChip extends StatelessWidget { + const _AiModeChip({ + required this.icon, + required this.label, + required this.items, + required this.onSelected, + required this.enabled, + }); + + final IconData icon; + final String label; + final List> items; + final ValueChanged onSelected; + final bool enabled; + + @override + Widget build(BuildContext context) { + final child = Row( + children: [ + Icon(icon, size: 13, color: context.theme.secondaryText), + const SizedBox(width: 6), + Expanded( + child: Text( + label, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + ), + if (enabled) ...[ + const SizedBox(width: 2), + Icon( + Icons.keyboard_arrow_down_rounded, + size: 14, + color: context.theme.secondaryText, + ), + ], + ], + ); + + if (!enabled || items.isEmpty) return child; + + return CustomPopupMenuButton( + itemBuilder: (_) => items, + onSelected: onSelected, + color: Colors.transparent, + useActionButton: false, + child: child, + ); + } +} diff --git a/lib/ui/home/chat_slide_page/ai_assistant/constants.dart b/lib/ui/home/chat_slide_page/ai_assistant/constants.dart new file mode 100644 index 0000000000..e35e29800b --- /dev/null +++ b/lib/ui/home/chat_slide_page/ai_assistant/constants.dart @@ -0,0 +1,6 @@ +const aiAssistantTitle = 'AI Assistant'; +const aiAssistantEmpty = 'Ask AI about this conversation'; +const aiAssistantInputHint = 'Ask about this conversation'; +const aiAssistantUnavailable = 'Add a usable AI model in Settings first'; +const aiAssistantStickToBottomDistance = 96.0; +const aiAssistantMessagePageLimit = 80; diff --git a/lib/ui/home/chat_slide_page/ai_assistant/helpers.dart b/lib/ui/home/chat_slide_page/ai_assistant/helpers.dart new file mode 100644 index 0000000000..4690ddb015 --- /dev/null +++ b/lib/ui/home/chat_slide_page/ai_assistant/helpers.dart @@ -0,0 +1,37 @@ +import 'package:flutter/widgets.dart'; + +import '../../../../ai/model/ai_provider_config.dart'; + +AiProviderConfig? resolveAiAssistantProvider({ + required AiProviderConfig? selectedAiProvider, + required List enabledAiProviders, + required String? providerId, + required String? selectedModel, +}) { + var provider = selectedAiProvider; + if (providerId != null) { + for (final item in enabledAiProviders) { + if (item.id == providerId) { + provider = item; + break; + } + } + } + if (provider == null || provider.model.trim().isEmpty) { + provider = enabledAiProviders.firstOrNull; + } + if (provider == null) return null; + + final trimmedModel = selectedModel?.trim(); + if (trimmedModel == null || trimmedModel.isEmpty) return provider; + if (!provider.models.contains(trimmedModel)) return provider; + if (provider.model == trimmedModel) return provider; + return provider.copyWith(defaultModel: trimmedModel, model: trimmedModel); +} + +String currentLanguageTag(BuildContext context) { + final locale = Localizations.localeOf(context); + final countryCode = locale.countryCode; + if (countryCode == null || countryCode.isEmpty) return locale.languageCode; + return '${locale.languageCode}-$countryCode'; +} diff --git a/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart b/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart new file mode 100644 index 0000000000..1c1fbcdec2 --- /dev/null +++ b/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart @@ -0,0 +1,264 @@ +import 'dart:async'; + +import 'package:flutter/material.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; + +import '../../../../db/mixin_database.dart'; +import '../../../../utils/extension/extension.dart'; +import '../../../../widgets/ai/ai_message_card.dart'; +import '../../../../widgets/clamping_custom_scroll_view/clamping_custom_scroll_view.dart'; +import '../../../../widgets/empty.dart'; +import '../../../../widgets/message/message_day_time.dart'; +import '../../../../widgets/toast.dart'; +import 'constants.dart'; + +class AiAssistantMessageList extends HookWidget { + const AiAssistantMessageList({ + required this.conversationId, + required this.latestMessages, + super.key, + }); + + final String conversationId; + final List latestMessages; + + @override + Widget build(BuildContext context) { + final olderMessages = useState(const []); + final isLoadingOlder = useState(false); + final isOldest = useState(false); + final messages = useMemoized( + () => _mergeAiMessages([...olderMessages.value, ...latestMessages]), + [olderMessages.value, latestMessages], + ); + final centerKey = useMemoized( + () => ValueKey('ai-list-center-$conversationId'), + [conversationId], + ); + final topKey = useMemoized( + () => GlobalKey(debugLabel: 'ai list top'), + ); + final bottomKey = useMemoized( + () => GlobalKey(debugLabel: 'ai list bottom'), + ); + final scrollController = useScrollController(); + final lastMessage = messages.lastOrNull; + final shouldStickToBottomRef = useRef(true); + final initialMessagesDisplayedRef = useRef(false); + final lastUserMessageIdRef = useRef(null); + final previousLatestMessagesRef = useRef(const []); + + void scrollToCurrent({required bool animated}) { + WidgetsBinding.instance.addPostFrameCallback((_) { + if (!scrollController.hasClients) return; + final position = scrollController.position; + if (!position.hasContentDimensions) return; + if (animated) { + unawaited( + scrollController.animateTo( + position.maxScrollExtent, + duration: const Duration(milliseconds: 160), + curve: Curves.easeOutCubic, + ), + ); + } else { + scrollController.jumpTo(position.maxScrollExtent); + } + }); + } + + Future loadOlderMessages() async { + if (isLoadingOlder.value || isOldest.value || messages.isEmpty) { + return; + } + + final before = messages.first; + isLoadingOlder.value = true; + + try { + final list = await context.database.aiChatMessageDao + .beforeConversationMessages( + conversationId: conversationId, + before: before, + limit: aiAssistantMessagePageLimit, + ); + olderMessages.value = _mergeAiMessages([ + ...list, + ...olderMessages.value, + ]); + isOldest.value = list.length < aiAssistantMessagePageLimit; + } catch (error, _) { + showToastFailed(error); + } finally { + isLoadingOlder.value = false; + } + } + + useEffect(() { + olderMessages.value = const []; + isLoadingOlder.value = false; + isOldest.value = false; + shouldStickToBottomRef.value = true; + initialMessagesDisplayedRef.value = false; + lastUserMessageIdRef.value = null; + previousLatestMessagesRef.value = const []; + return null; + }, [conversationId]); + + useEffect(() { + final previousLatestMessages = previousLatestMessagesRef.value; + previousLatestMessagesRef.value = latestMessages; + + if (olderMessages.value.isEmpty || + previousLatestMessages.isEmpty || + latestMessages.isEmpty) { + return null; + } + + final latestIds = latestMessages.map((item) => item.id).toSet(); + final firstLatestMessage = latestMessages.first; + final droppedMessages = previousLatestMessages + .where( + (item) => + !latestIds.contains(item.id) && + _compareAiMessages(item, firstLatestMessage) < 0, + ) + .toList(growable: false); + if (droppedMessages.isNotEmpty) { + olderMessages.value = _mergeAiMessages([ + ...olderMessages.value, + ...droppedMessages, + ]); + } + + return null; + }, [latestMessages]); + + useEffect(() { + if (olderMessages.value.isEmpty) { + isOldest.value = latestMessages.length < aiAssistantMessagePageLimit; + } + return null; + }, [latestMessages, olderMessages.value]); + + useEffect(() { + void updateStickToBottom() { + if (!scrollController.hasClients) return; + final position = scrollController.position; + if (!position.hasContentDimensions) return; + shouldStickToBottomRef.value = + position.maxScrollExtent - position.pixels < + aiAssistantStickToBottomDistance; + } + + scrollController.addListener(updateStickToBottom); + return () => scrollController.removeListener(updateStickToBottom); + }, [scrollController]); + + useEffect(() { + if (messages.isEmpty) return null; + if (!initialMessagesDisplayedRef.value) { + initialMessagesDisplayedRef.value = true; + return null; + } + + final lastMessageIsUser = lastMessage?.role == 'user'; + final hasNewUserMessage = + lastMessageIsUser && lastMessage?.id != lastUserMessageIdRef.value; + if (hasNewUserMessage) { + lastUserMessageIdRef.value = lastMessage?.id; + shouldStickToBottomRef.value = true; + } + + if (!shouldStickToBottomRef.value) return null; + scrollToCurrent(animated: hasNewUserMessage); + return null; + }, [messages.length, lastMessage?.updatedAt, lastMessage?.content]); + + if (messages.isEmpty) { + return const Empty(text: aiAssistantEmpty); + } + + return NotificationListener( + onNotification: (notification) { + shouldStickToBottomRef.value = + notification.metrics.maxScrollExtent - notification.metrics.pixels < + aiAssistantStickToBottomDistance; + if (notification is ScrollUpdateNotification && + (notification.scrollDelta ?? 0) < 0) { + final dimension = notification.metrics.viewportDimension / 2; + if ((notification.metrics.minScrollExtent - + notification.metrics.pixels) + .abs() < + dimension) { + unawaited(loadOlderMessages()); + } + } + return false; + }, + child: MessageDayTimeViewportWidget.chatPage( + key: ValueKey(conversationId), + bottomKey: bottomKey, + center: null, + topKey: topKey, + scrollController: scrollController, + centerKey: null, + child: ClampingCustomScrollView( + key: centerKey, + center: centerKey, + controller: scrollController, + anchor: 0.3, + physics: const ClampingScrollPhysics(), + slivers: [ + SliverPadding( + padding: const EdgeInsets.fromLTRB(16, 12, 16, 0), + sliver: SliverList( + key: topKey, + delegate: SliverChildBuilderDelegate(( + context, + index, + ) { + final actualIndex = messages.length - index - 1; + final message = messages[actualIndex]; + return MessageDayTimeItem( + key: ValueKey('assistant-${message.id}'), + dateTime: message.createdAt, + prevDateTime: actualIndex > 0 + ? messages[actualIndex - 1].createdAt + : null, + child: AiMessageCard( + message: message, + prev: actualIndex > 0 ? messages[actualIndex - 1] : null, + next: actualIndex < messages.length - 1 + ? messages[actualIndex + 1] + : null, + ), + ); + }, childCount: messages.length), + ), + ), + SliverToBoxAdapter(key: centerKey), + SliverPadding( + key: bottomKey, + padding: const EdgeInsets.only(bottom: 20), + ), + ], + ), + ), + ); + } +} + +List _mergeAiMessages(Iterable messages) { + final map = {}; + for (final message in messages) { + map[message.id] = message; + } + return map.values.toList(growable: false)..sort(_compareAiMessages); +} + +int _compareAiMessages(AiChatMessage a, AiChatMessage b) { + final createdAtResult = a.createdAt.compareTo(b.createdAt); + if (createdAtResult != 0) return createdAtResult; + return a.id.compareTo(b.id); +} diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart index 11c8815078..7133a8dd2b 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant_page.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -1,7 +1,4 @@ -import 'dart:async'; - import 'package:flutter/material.dart'; -import 'package:flutter/services.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; @@ -11,20 +8,14 @@ import '../../../constants/constants.dart'; import '../../../db/mixin_database.dart'; import '../../../utils/extension/extension.dart'; import '../../../utils/hook.dart'; -import '../../../widgets/action_button.dart'; -import '../../../widgets/ai/ai_message_card.dart'; import '../../../widgets/app_bar.dart'; -import '../../../widgets/empty.dart'; -import '../../../widgets/menu.dart'; -import '../../../widgets/message/message_day_time.dart'; import '../../../widgets/toast.dart'; import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; - -const _aiAssistantTitle = 'AI Assistant'; -const _aiAssistantEmpty = 'Ask AI about this conversation'; -const _aiAssistantInputHint = 'Ask about this conversation'; -const _aiAssistantUnavailable = 'Add a usable AI model in Settings first'; +import 'ai_assistant/composer.dart'; +import 'ai_assistant/constants.dart'; +import 'ai_assistant/helpers.dart'; +import 'ai_assistant/message_list.dart'; class AiAssistantPage extends HookConsumerWidget { const AiAssistantPage(this.conversationState, {super.key}); @@ -44,29 +35,29 @@ class AiAssistantPage extends HookConsumerWidget { .whereType() .where((item) => item.enabled && item.model.trim().isNotEmpty) .toList(growable: false); - final aiProvider = _resolveAiAssistantProvider( + final aiProvider = resolveAiAssistantProvider( selectedAiProvider: context.database.settingProperties.selectedAiProvider, enabledAiProviders: enabledAiProviders, providerId: aiModeState.providerId, selectedModel: aiModeState.model, ); - final messages = + final latestMessages = useMemoizedStream( - () => context.database.aiChatMessageDao.watchConversationMessages( - conversationId, - ), + () => + context.database.aiChatMessageDao.watchLatestConversationMessages( + conversationId, + aiAssistantMessagePageLimit, + ), keys: [conversationId], initialData: const [], ).data ?? const []; - final requestInFlight = messages.any(isActivePendingAiMessage); + final requestInFlight = latestMessages.any(isActivePendingAiMessage); final textEditingController = useMemoized( TextEditingController.new, [conversationId], ); - final scrollController = useScrollController(); final focusNode = useFocusNode(); - final lastMessage = messages.lastOrNull; useEffect(() { if (!context.textFieldAutoGainFocus) { @@ -77,27 +68,6 @@ class AiAssistantPage extends HookConsumerWidget { return null; }, [conversationId]); - useEffect(() { - WidgetsBinding.instance.addPostFrameCallback((_) { - if (!scrollController.hasClients) return; - final position = scrollController.position; - final shouldStickToBottom = - messages.length <= 2 || - !position.hasContentDimensions || - position.maxScrollExtent - position.pixels < 96 || - lastMessage?.role == 'user'; - if (!shouldStickToBottom) return; - unawaited( - scrollController.animateTo( - position.maxScrollExtent, - duration: const Duration(milliseconds: 180), - curve: Curves.easeOutCubic, - ), - ); - }); - return null; - }, [messages.length, lastMessage?.updatedAt, lastMessage?.content]); - Future send() async { final text = textEditingController.text.trim(); if (text.isEmpty) return; @@ -106,7 +76,7 @@ class AiAssistantPage extends HookConsumerWidget { return; } if (aiProvider == null) { - showToastFailed(ToastError(_aiAssistantUnavailable)); + showToastFailed(ToastError(aiAssistantUnavailable)); return; } @@ -114,7 +84,7 @@ class AiAssistantPage extends HookConsumerWidget { await AiChatController(context.database).send( conversationId: conversationId, input: text, - language: _currentLanguageTag(context), + language: currentLanguageTag(context), provider: aiProvider, onInputAccepted: textEditingController.clear, ); @@ -125,333 +95,33 @@ class AiAssistantPage extends HookConsumerWidget { return Scaffold( backgroundColor: context.theme.primary, - appBar: const MixinAppBar(title: Text(_aiAssistantTitle)), + appBar: const MixinAppBar(title: Text(aiAssistantTitle)), body: Column( children: [ - if (aiProvider != null) - _AiAssistantModeBar( - provider: aiProvider, - enabledAiProviders: enabledAiProviders, - onProviderSelected: (value) => aiModeNotifier.updateProvider( - providerId: value.id, - model: value.model, - ), - onModelSelected: aiModeNotifier.updateModel, - ), Expanded( - child: messages.isEmpty - ? const Empty(text: _aiAssistantEmpty) - : ListView.builder( - controller: scrollController, - padding: const EdgeInsets.fromLTRB(16, 12, 16, 20), - itemCount: messages.length, - itemBuilder: (context, index) { - final message = messages[index]; - return MessageDayTimeItem( - key: ValueKey('assistant-${message.id}'), - dateTime: message.createdAt, - prevDateTime: index > 0 - ? messages[index - 1].createdAt - : null, - child: AiMessageCard( - message: message, - prev: index > 0 ? messages[index - 1] : null, - next: index < messages.length - 1 - ? messages[index + 1] - : null, - ), - ); - }, - ), + child: AiAssistantMessageList( + conversationId: conversationId, + latestMessages: latestMessages, + ), ), - _AiAssistantComposer( + AiAssistantComposer( focusNode: focusNode, textEditingController: textEditingController, enabled: aiProvider != null, + provider: aiProvider, + enabledAiProviders: enabledAiProviders, requestInFlight: requestInFlight, onSend: send, onStop: () => AiChatController(context.database).stop(conversationId), - ), - ], - ), - ); - } -} - -class _AiAssistantModeBar extends StatelessWidget { - const _AiAssistantModeBar({ - required this.provider, - required this.enabledAiProviders, - required this.onProviderSelected, - required this.onModelSelected, - }); - - final AiProviderConfig provider; - final List enabledAiProviders; - final ValueChanged onProviderSelected; - final ValueChanged onModelSelected; - - @override - Widget build(BuildContext context) { - final providerOptions = enabledAiProviders - .map( - (item) => CustomPopupMenuItem( - title: item.name, - value: item, - ), - ) - .toList(growable: false); - final modelOptions = provider.models - .where((item) => item.trim().isNotEmpty) - .map( - (item) => CustomPopupMenuItem( - title: item.trim(), - value: item.trim(), - ), - ) - .toList(growable: false); - - return Container( - width: double.infinity, - padding: const EdgeInsets.fromLTRB(16, 12, 16, 0), - child: DecoratedBox( - decoration: BoxDecoration( - color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), - darkColor: const Color.fromRGBO(255, 255, 255, 0.06), - ), - borderRadius: const BorderRadius.all(Radius.circular(12)), - border: Border.all( - color: context.dynamicColor( - const Color.fromRGBO(0, 0, 0, 0.05), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + onProviderSelected: (value) => aiModeNotifier.updateProvider( + providerId: value.id, + model: value.model, ), - ), - ), - child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 12, vertical: 10), - child: Row( - children: [ - Icon( - Icons.auto_awesome_rounded, - size: 14, - color: context.theme.ai.accent, - ), - const SizedBox(width: 10), - Expanded( - child: Row( - children: [ - Flexible( - child: _AiModeChip( - icon: Icons.hub_rounded, - label: provider.name, - items: providerOptions, - enabled: providerOptions.length > 1, - onSelected: onProviderSelected, - ), - ), - const SizedBox(width: 8), - Flexible( - child: _AiModeChip( - icon: Icons.tune_rounded, - label: provider.model, - items: modelOptions, - enabled: modelOptions.length > 1, - onSelected: onModelSelected, - ), - ), - ], - ), - ), - ], - ), - ), - ), - ); - } -} - -class _AiModeChip extends StatelessWidget { - const _AiModeChip({ - required this.icon, - required this.label, - required this.items, - required this.onSelected, - required this.enabled, - }); - - final IconData icon; - final String label; - final List> items; - final ValueChanged onSelected; - final bool enabled; - - @override - Widget build(BuildContext context) { - final child = Row( - children: [ - Icon(icon, size: 13, color: context.theme.secondaryText), - const SizedBox(width: 6), - Expanded( - child: Text( - label, - maxLines: 1, - overflow: TextOverflow.ellipsis, - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 12, - fontWeight: FontWeight.w500, - ), - ), - ), - if (enabled) ...[ - const SizedBox(width: 2), - Icon( - Icons.keyboard_arrow_down_rounded, - size: 14, - color: context.theme.secondaryText, + onModelSelected: aiModeNotifier.updateModel, ), ], - ], - ); - - if (!enabled || items.isEmpty) return child; - - return CustomPopupMenuButton( - itemBuilder: (_) => items, - onSelected: onSelected, - color: Colors.transparent, - useActionButton: false, - child: child, - ); - } -} - -class _AiAssistantComposer extends StatelessWidget { - const _AiAssistantComposer({ - required this.focusNode, - required this.textEditingController, - required this.enabled, - required this.requestInFlight, - required this.onSend, - required this.onStop, - }); - - final FocusNode focusNode; - final TextEditingController textEditingController; - final bool enabled; - final bool requestInFlight; - final VoidCallback onSend; - final VoidCallback onStop; - - @override - Widget build(BuildContext context) { - final buttonColor = !enabled - ? context.theme.secondaryText - : requestInFlight - ? context.theme.red - : context.theme.accent; - - return Container( - padding: const EdgeInsets.fromLTRB(16, 12, 16, 16), - decoration: BoxDecoration( - color: context.theme.primary, - border: Border(top: BorderSide(color: context.theme.divider)), - ), - child: DecoratedBox( - decoration: BoxDecoration( - borderRadius: const BorderRadius.all(Radius.circular(14)), - color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), - ), - ), - child: Padding( - padding: const EdgeInsets.fromLTRB(12, 10, 8, 10), - child: Row( - crossAxisAlignment: CrossAxisAlignment.end, - children: [ - Expanded( - child: TextField( - focusNode: focusNode, - controller: textEditingController, - enabled: enabled, - minLines: 1, - maxLines: 6, - inputFormatters: [ - LengthLimitingTextInputFormatter(kMaxTextLength), - ], - style: TextStyle( - color: context.theme.text, - fontSize: 14, - height: 1.4, - ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: enabled - ? _aiAssistantInputHint - : _aiAssistantUnavailable, - hintStyle: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, - ), - ), - ), - ), - const SizedBox(width: 8), - ActionButton( - padding: const EdgeInsets.all(6), - size: 20, - interactive: enabled, - onTap: requestInFlight ? onStop : onSend, - child: Icon( - requestInFlight - ? Icons.stop_rounded - : Icons.arrow_upward_rounded, - size: 18, - color: buttonColor, - ), - ), - ], - ), - ), ), ); } } - -AiProviderConfig? _resolveAiAssistantProvider({ - required AiProviderConfig? selectedAiProvider, - required List enabledAiProviders, - required String? providerId, - required String? selectedModel, -}) { - var provider = selectedAiProvider; - if (providerId != null) { - for (final item in enabledAiProviders) { - if (item.id == providerId) { - provider = item; - break; - } - } - } - if (provider == null || provider.model.trim().isEmpty) { - provider = enabledAiProviders.firstOrNull; - } - if (provider == null) return null; - - final trimmedModel = selectedModel?.trim(); - if (trimmedModel == null || trimmedModel.isEmpty) return provider; - if (!provider.models.contains(trimmedModel)) return provider; - if (provider.model == trimmedModel) return provider; - return provider.copyWith(defaultModel: trimmedModel, model: trimmedModel); -} - -String _currentLanguageTag(BuildContext context) { - final locale = Localizations.localeOf(context); - final countryCode = locale.countryCode; - if (countryCode == null || countryCode.isEmpty) return locale.languageCode; - return '${locale.languageCode}-$countryCode'; -} diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 12113e2cf3..cfd030d667 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -2,11 +2,9 @@ import 'package:flutter/material.dart' hide SelectableRegion, SelectableRegionState; import 'package:flutter/rendering.dart' show SelectedContent, SelectionStatus; import 'package:flutter/services.dart'; -import 'package:flutter_hooks/flutter_hooks.dart'; -import 'package:flutter_svg/svg.dart'; +import 'package:intl/intl.dart'; import 'package:super_context_menu/super_context_menu.dart'; -import '../../constants/resources.dart'; import '../../db/mixin_database.dart' hide Offset; import '../../utils/datetime_format_utils.dart'; import '../../utils/extension/extension.dart'; @@ -20,7 +18,6 @@ import '../message/message_layout.dart'; import '../message/message_style.dart'; import '../qr_code.dart'; -const _aiAssistantTitle = 'AI Assistant'; const _copyAiMessageTitle = 'Copy AI Message'; class AiMessageCard extends StatelessWidget { @@ -44,10 +41,12 @@ class AiMessageCard extends StatelessWidget { final sameRoleNext = next?.role == message.role; final mergedWithPrev = sameDayPrev && sameRolePrev; final mergedWithNext = sameDayNext && sameRoleNext; - final body = ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 420), - child: _AiMessageBody(message: message), - ); + final body = isUser + ? ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 420), + child: _AiMessageBody(message: message), + ) + : _AiMessageBody(message: message); if (isUser) { return Padding( @@ -78,48 +77,12 @@ class AiMessageCard extends StatelessWidget { return Padding( padding: EdgeInsets.only( - left: 8, - right: 44, top: mergedWithPrev ? 6 : 18, bottom: 6, ), - child: Row( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - SizedBox( - width: 32, - child: !mergedWithPrev - ? _AiAvatar(thinking: message.status == 'pending') - : null, - ), - const SizedBox(width: 12), - Expanded( - child: Padding( - padding: const EdgeInsets.only(top: 1), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - if (!mergedWithPrev) - Padding( - padding: const EdgeInsets.only(bottom: 6), - child: Text( - _aiAssistantTitle, - style: TextStyle( - color: context.theme.secondaryText, - fontSize: context.messageStyle.statusFontSize, - fontWeight: FontWeight.w600, - ), - ), - ), - _AiMessageMenu( - message: message, - child: body, - ), - ], - ), - ), - ), - ], + child: _AiMessageMenu( + message: message, + child: body, ), ); } @@ -134,11 +97,6 @@ class _AiMessageBody extends StatelessWidget { Widget build(BuildContext context) { final isUser = message.role == 'user'; final text = _displayText(message); - final statusColor = _statusColor( - context, - isUser: isUser, - status: message.status, - ); Widget body; final textStyle = TextStyle( @@ -173,8 +131,6 @@ class _AiMessageBody extends StatelessWidget { dateAndStatus: _AiFooter( isUser: isUser, model: message.model, - status: message.status, - color: statusColor, dateTime: message.createdAt, ), ); @@ -253,70 +209,6 @@ class _AiBubble extends StatelessWidget { } } -class _AiAvatar extends HookWidget { - const _AiAvatar({required this.thinking}); - - final bool thinking; - - @override - Widget build(BuildContext context) { - final aiColors = context.theme.ai; - final background = aiColors.avatarBackground; - final foreground = aiColors.accent; - final disableAnimations = - MediaQuery.maybeOf(context)?.disableAnimations ?? false; - final controller = useAnimationController( - duration: const Duration(milliseconds: 1800), - ); - useEffect(() { - if (!thinking || disableAnimations) { - controller - ..stop() - ..value = 0; - return null; - } - controller.repeat(); - return null; - }, [thinking, disableAnimations, controller]); - - final progress = useAnimation( - CurvedAnimation(parent: controller, curve: Curves.easeInOut), - ); - final scale = !thinking || disableAnimations - ? 1.0 - : 1 + (0.03 * (0.5 - (progress - 0.5).abs()) * 2); - final glowAlpha = !thinking || disableAnimations ? 0.0 : 0.16 * progress; - - return Transform.scale( - scale: scale, - child: Container( - width: 32, - height: 32, - decoration: BoxDecoration( - color: background, - shape: BoxShape.circle, - boxShadow: glowAlpha == 0 - ? null - : [ - BoxShadow( - color: foreground.withValues(alpha: glowAlpha), - blurRadius: 10, - spreadRadius: 0.5, - ), - ], - ), - alignment: Alignment.center, - child: SvgPicture.asset( - Resources.assetsImagesBotFillSvg, - width: 18, - height: 18, - colorFilter: ColorFilter.mode(foreground, BlendMode.srcIn), - ), - ), - ); - } -} - class _AiMessageMenu extends StatelessWidget { const _AiMessageMenu({ required this.message, @@ -410,158 +302,51 @@ SelectedContent? _findSelectedContent(BuildContext context) { return null; } -class _AiStatusBadge extends HookWidget { - const _AiStatusBadge({ - required this.isUser, - required this.model, - required this.status, - required this.color, - }); - - final bool isUser; - final String? model; - final String status; - final Color color; - - @override - Widget build(BuildContext context) { - final trimmedModel = isUser ? null : model?.trim(); - final icon = status == 'pending' - ? _AiThinkingIndicator(color: color) - : Icon( - _statusIcon(messageRoleIsUser: isUser, status: status), - size: 12, - color: color, - ); - - if (trimmedModel == null || trimmedModel.isEmpty) { - return icon; - } - - return Row( - mainAxisSize: MainAxisSize.min, - children: [ - icon, - const SizedBox(width: 4), - Text( - trimmedModel, - style: TextStyle( - fontSize: context.messageStyle.statusFontSize, - color: color, - ), - ), - ], - ); - } -} - class _AiFooter extends StatelessWidget { const _AiFooter({ required this.isUser, required this.model, - required this.status, - required this.color, required this.dateTime, }); final bool isUser; final String? model; - final String status; - final Color color; final DateTime dateTime; - @override - Widget build(BuildContext context) => MessageMetaRow( - dateTime: dateTime, - trailingSpacing: 4, - trailing: _AiStatusBadge( - isUser: isUser, - model: model, - status: status, - color: color, - ), - ); -} - -class _AiThinkingIndicator extends HookWidget { - const _AiThinkingIndicator({required this.color}); - - final Color color; - @override Widget build(BuildContext context) { - final disableAnimations = - MediaQuery.maybeOf(context)?.disableAnimations ?? false; - - if (disableAnimations) { - return Icon(Icons.more_horiz_rounded, size: 12, color: color); + if (isUser) { + return MessageMetaRow(dateTime: dateTime); } - final controller = useAnimationController( - duration: const Duration(milliseconds: 1200), + final metaColor = context.dynamicColor( + const Color.fromRGBO(131, 145, 158, 1), + darkColor: const Color.fromRGBO(128, 131, 134, 1), ); - useEffect(() { - controller.repeat(); - return null; - }, [controller]); + final textStyle = TextStyle( + fontSize: context.messageStyle.statusFontSize, + color: metaColor, + ); + final dateTimeText = DateFormat.Hm().format(dateTime.toLocal()); + final trimmedModel = isUser ? null : model?.trim(); - return RotationTransition( - turns: controller, - child: CustomPaint( - size: const Size.square(12), - painter: _AiThinkingIndicatorPainter(color: color), + return SelectionContainer.disabled( + child: SizedBox( + width: double.infinity, + child: Row( + children: [ + Text(dateTimeText, style: textStyle), + if (trimmedModel != null && trimmedModel.isNotEmpty) ...[ + const Spacer(), + Text(trimmedModel, style: textStyle), + ], + ], + ), ), ); } } -class _AiThinkingIndicatorPainter extends CustomPainter { - const _AiThinkingIndicatorPainter({required this.color}); - - final Color color; - - @override - void paint(Canvas canvas, Size size) { - final center = size.center(Offset.zero); - final radius = (size.width / 2) - 1; - - final track = Paint() - ..color = color.withValues(alpha: 0.22) - ..style = PaintingStyle.stroke - ..strokeWidth = 1.2 - ..strokeCap = StrokeCap.round; - - final arc = Paint() - ..color = color - ..style = PaintingStyle.stroke - ..strokeWidth = 1.4 - ..strokeCap = StrokeCap.round; - - canvas - ..drawCircle(center, radius, track) - ..drawArc( - Rect.fromCircle(center: center, radius: radius), - -1.2, - 1.95, - false, - arc, - ); - } - - @override - bool shouldRepaint(covariant _AiThinkingIndicatorPainter oldDelegate) => - oldDelegate.color != color; -} - -IconData _statusIcon({ - required bool messageRoleIsUser, - required String status, -}) { - if (status == 'error') return Icons.error_outline_rounded; - if (messageRoleIsUser) return Icons.auto_awesome_rounded; - return Icons.smart_toy_rounded; -} - Color _bubbleColor( BuildContext context, { required bool isUser, @@ -578,22 +363,6 @@ Color _bubbleColor( return context.theme.ai.assistantBubble; } -Color _statusColor( - BuildContext context, { - required bool isUser, - required String status, -}) { - if (status == 'error') { - return context.theme.ai.error; - } - - if (isUser) { - return context.theme.green; - } - - return context.theme.ai.accent; -} - String _menuCopyText(AiChatMessage message) => _displayText(message); String _displayText(AiChatMessage message) { From e9858cc1f70a955282da04cf7a9f6d53ab17acac Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:26:19 +0800 Subject: [PATCH 24/52] feat(ai): enhance metadata handling and tool activity tracking --- lib/ai/ai_chat_controller.dart | 65 ++- lib/ai/ai_provider_requester.dart | 57 +-- lib/ai/model/ai_chat_metadata.dart | 84 ++++ lib/ai/provider/ai_provider_strategy.dart | 2 + .../provider/openai_compatible_strategy.dart | 146 ++++++ lib/db/dao/ai_chat_message_dao.dart | 25 + lib/db/mixin_database.dart | 5 +- lib/db/mixin_database.g.dart | 68 +++ lib/db/moor/mixin.drift | 1 + .../ai_assistant/composer.dart | 321 ++++++++----- lib/ui/setting/ai_prompt_settings_page.dart | 317 ++++++------- lib/ui/setting/ai_provider_edit_page.dart | 427 +++++++++--------- lib/ui/setting/ai_settings_page.dart | 185 ++++---- lib/ui/setting/setting_page.dart | 2 +- lib/widgets/ai/ai_message_card.dart | 96 +++- 15 files changed, 1197 insertions(+), 604 deletions(-) create mode 100644 lib/ai/model/ai_chat_metadata.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index be8a1d95f8..fe70f57552 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -11,6 +11,7 @@ import '../db/database.dart'; import '../db/mixin_database.dart'; import 'ai_chat_prompt_builder.dart'; import 'ai_provider_requester.dart'; +import 'model/ai_chat_metadata.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_provider_config.dart'; import 'model/ai_tool.dart'; @@ -82,7 +83,6 @@ class AiChatController { cancelToken: cancelToken, onContent: (_) async {}, conversationId: conversationId, - streamFinalResponse: false, ); d( 'AI assist done: provider=${config.type.name} model=${config.model} ' @@ -129,6 +129,7 @@ class AiChatController { ); final now = DateTime.now(); + final assistantCreatedAt = now.add(const Duration(milliseconds: 1)); final userMessageId = _uuid.v4(); final assistantMessageId = _uuid.v4(); final cancelToken = CancelToken(); @@ -163,8 +164,9 @@ class AiChatController { content: '', status: _kAiStatusPending, model: Value(config.model), - createdAt: now, - updatedAt: now, + metadata: Value(createAiMessageMetadata(config)), + createdAt: assistantCreatedAt, + updatedAt: assistantCreatedAt, ), ); @@ -187,7 +189,7 @@ class AiChatController { cancelToken: cancelToken, onContent: updater.append, conversationId: conversationId, - streamFinalResponse: true, + assistantMessageId: assistantMessageId, ); await updater.flush(contentOverride: result, force: true); await database.aiChatMessageDao.updateMessageStatus( @@ -239,33 +241,38 @@ class AiChatController { List messages, { required CancelToken cancelToken, required Future Function(String chunk) onContent, - required bool streamFinalResponse, String? conversationId, + String? assistantMessageId, }) => _providerRequester.requestText( config, messages, proxy: database.settingProperties.activatedProxy, cancelToken: cancelToken, onContent: onContent, - streamFinalResponse: streamFinalResponse, conversationId: conversationId, - onToolCall: _toolExecutorFor(conversationId), + onToolCall: _toolExecutorFor( + conversationId, + assistantMessageId: assistantMessageId, + ), ); Future Function(AiToolCall toolCall)? _toolExecutorFor( - String? conversationId, - ) { + String? conversationId, { + String? assistantMessageId, + }) { if (conversationId == null) { return null; } return (toolCall) => _executeConversationTool( conversationId: conversationId, + assistantMessageId: assistantMessageId, toolCall: toolCall, ); } Future _executeConversationTool({ required String conversationId, + required String? assistantMessageId, required AiToolCall toolCall, }) async { final stopwatch = Stopwatch()..start(); @@ -274,6 +281,10 @@ class AiChatController { 'tool=${toolCall.name} id=${toolCall.id} ' 'arguments=${_previewJson(toolCall.arguments)}', ); + await _appendAssistantToolEvent( + assistantMessageId, + createAiToolCallEvent(toolCall), + ); try { final result = await _conversationTools.execute( conversationId: conversationId, @@ -285,9 +296,27 @@ class AiChatController { 'elapsedMs=${stopwatch.elapsedMilliseconds} ' 'result=${_previewJson(result.payload)}', ); + await _appendAssistantToolEvent( + assistantMessageId, + createAiToolResultEvent( + toolCall: toolCall, + status: 'done', + elapsedMs: stopwatch.elapsedMilliseconds, + resultPreview: _previewJson(result.payload), + ), + ); return result; } catch (error, stacktrace) { e('AI tool execution error: $error, $stacktrace'); + await _appendAssistantToolEvent( + assistantMessageId, + createAiToolResultEvent( + toolCall: toolCall, + status: 'error', + elapsedMs: stopwatch.elapsedMilliseconds, + errorText: error.toString(), + ), + ); return AiToolExecutionResult( toolCallId: toolCall.id, toolName: toolCall.name, @@ -295,6 +324,24 @@ class AiChatController { ); } } + + Future _appendAssistantToolEvent( + String? assistantMessageId, + Map event, + ) async { + if (assistantMessageId == null) { + return; + } + try { + await database.aiChatMessageDao.appendMessageMetadataToolEvent( + assistantMessageId, + event, + updatedAt: DateTime.now(), + ); + } catch (error, stacktrace) { + e('AI tool metadata update error: $error, $stacktrace'); + } + } } String _previewText(String? text, {int maxLength = _kAiLogPreviewLength}) { diff --git a/lib/ai/ai_provider_requester.dart b/lib/ai/ai_provider_requester.dart index 4d58b493ec..55480f3bb9 100644 --- a/lib/ai/ai_provider_requester.dart +++ b/lib/ai/ai_provider_requester.dart @@ -33,14 +33,12 @@ class AiProviderRequester { required ProxyConfig? proxy, required CancelToken cancelToken, required Future Function(String chunk) onContent, - required bool streamFinalResponse, required String? conversationId, Future Function(AiToolCall toolCall)? onToolCall, }) async { d( 'AI request start: provider=${config.type.name} model=${config.model} ' - 'conversationId=$conversationId streamFinal=$streamFinalResponse ' - 'messages=${messages.length} ' + 'conversationId=$conversationId messages=${messages.length} ' 'tools=${conversationId != null && onToolCall != null}', ); final dio = @@ -111,7 +109,6 @@ class AiProviderRequester { cancelToken: cancelToken, onContent: onContent, onToolCall: onToolCall, - streamFinalResponse: streamFinalResponse, ); } @@ -124,20 +121,29 @@ class AiProviderRequester { required Future Function(String chunk) onContent, required Future Function(AiToolCall toolCall) onToolCall, - required bool streamFinalResponse, }) async { for (var round = 0; round < _aiToolMaxRounds; round++) { d( 'AI tool round start: conversationId=$conversationId ' 'round=${round + 1}/$_aiToolMaxRounds messages=${messages.length}', ); - final response = await _strategyFor(config.type).completeResponse( - dio: dio, - config: config, - messages: messages, - tools: AiConversationToolKit.definitions, - cancelToken: cancelToken, - ); + final strategy = _strategyFor(config.type); + final response = strategy is OpenAiCompatibleStrategy + ? await strategy.streamCompleteResponse( + dio: dio, + config: config, + messages: messages, + tools: AiConversationToolKit.definitions, + cancelToken: cancelToken, + onContent: onContent, + ) + : await strategy.completeResponse( + dio: dio, + config: config, + messages: messages, + tools: AiConversationToolKit.definitions, + cancelToken: cancelToken, + ); d( 'AI tool round response: conversationId=$conversationId ' 'round=${round + 1} text=${_previewText(response.text)} ' @@ -149,32 +155,11 @@ class AiProviderRequester { if (text.isEmpty) { throw Exception('Empty AI response'); } - if (streamFinalResponse) { - try { - d( - 'AI final stream start: conversationId=$conversationId ' - 'round=${round + 1}', - ); - return await _strategyFor(config.type).streamResponse( - dio: dio, - config: config, - messages: messages, - cancelToken: cancelToken, - onContent: onContent, - ); - } catch (error, stacktrace) { - e('AI final streaming fallback: $error, $stacktrace'); - await _emitBufferedText(text, onContent); - d( - 'AI final stream fallback: conversationId=$conversationId ' - 'round=${round + 1} text=${_previewText(text)}', - ); - return text; - } + if (!response.contentEmitted) { + await _emitBufferedText(text, onContent); } - await onContent(text); d( - 'AI tool request done without stream: ' + 'AI tool request done: ' 'conversationId=$conversationId ' 'round=${round + 1} text=${_previewText(text)}', ); diff --git a/lib/ai/model/ai_chat_metadata.dart b/lib/ai/model/ai_chat_metadata.dart new file mode 100644 index 0000000000..90a86ee9f7 --- /dev/null +++ b/lib/ai/model/ai_chat_metadata.dart @@ -0,0 +1,84 @@ +import 'dart:convert'; + +import 'ai_provider_config.dart'; +import 'ai_tool.dart'; + +const aiMetadataToolEventsKey = 'toolEvents'; +const aiToolEventTypeCall = 'tool_call'; +const aiToolEventTypeResult = 'tool_result'; + +String createAiMessageMetadata(AiProviderConfig provider) => jsonEncode({ + 'provider': { + 'id': provider.id, + 'type': provider.type.name, + 'model': provider.model, + }, + aiMetadataToolEventsKey: const >[], +}); + +Map decodeAiMessageMetadata(String? metadata) { + if (metadata == null || metadata.trim().isEmpty) { + return {}; + } + try { + final decoded = jsonDecode(metadata); + if (decoded is Map) { + return decoded; + } + if (decoded is Map) { + return decoded.map((key, value) => MapEntry('$key', value)); + } + } catch (_) { + return {}; + } + return {}; +} + +String appendAiToolEventToMetadata( + String? metadata, + Map event, +) { + final root = decodeAiMessageMetadata(metadata); + final currentEvents = root[aiMetadataToolEventsKey]; + final events = currentEvents is List + ? currentEvents.toList(growable: true) + : []; + root[aiMetadataToolEventsKey] = events..add(event); + return jsonEncode(root); +} + +Map createAiToolCallEvent(AiToolCall toolCall) => { + 'type': aiToolEventTypeCall, + 'id': toolCall.id, + 'name': toolCall.name, + 'arguments': toolCall.arguments, + 'createdAt': DateTime.now().toUtc().toIso8601String(), +}; + +Map createAiToolResultEvent({ + required AiToolCall toolCall, + required String status, + required int elapsedMs, + String? resultPreview, + String? errorText, +}) => { + 'type': aiToolEventTypeResult, + 'id': toolCall.id, + 'name': toolCall.name, + 'status': status, + 'elapsedMs': elapsedMs, + 'resultPreview': resultPreview, + 'errorText': errorText, + 'createdAt': DateTime.now().toUtc().toIso8601String(), +}..removeWhere((_, value) => value == null); + +List> aiMetadataToolEvents(String? metadata) { + final events = decodeAiMessageMetadata(metadata)[aiMetadataToolEventsKey]; + if (events is! List) { + return const >[]; + } + return events + .whereType() + .map((event) => event.map((key, value) => MapEntry('$key', value))) + .toList(growable: false); +} diff --git a/lib/ai/provider/ai_provider_strategy.dart b/lib/ai/provider/ai_provider_strategy.dart index a0b057d476..10df379512 100644 --- a/lib/ai/provider/ai_provider_strategy.dart +++ b/lib/ai/provider/ai_provider_strategy.dart @@ -33,10 +33,12 @@ class AiCompletionResponse { const AiCompletionResponse({ this.text = '', this.toolCalls = const [], + this.contentEmitted = false, }); final String text; final List toolCalls; + final bool contentEmitted; bool get hasToolCalls => toolCalls.isNotEmpty; } diff --git a/lib/ai/provider/openai_compatible_strategy.dart b/lib/ai/provider/openai_compatible_strategy.dart index dd6add34e3..32330e14fe 100644 --- a/lib/ai/provider/openai_compatible_strategy.dart +++ b/lib/ai/provider/openai_compatible_strategy.dart @@ -128,6 +128,105 @@ class OpenAiCompatibleStrategy implements AiProviderStrategy { return text; } + Future streamCompleteResponse({ + required Dio dio, + required AiProviderConfig config, + required List messages, + required List tools, + required CancelToken cancelToken, + required Future Function(String chunk) onContent, + }) async { + final response = await dio.post( + '/chat/completions', + data: { + 'model': config.model, + 'stream': true, + 'messages': messages.map(_openAiMessagePayload).toList(growable: false), + if (tools.isNotEmpty) + 'tools': tools + .map( + (tool) => { + 'type': 'function', + 'function': { + 'name': tool.name, + 'description': tool.description, + 'parameters': tool.inputSchema, + }, + }, + ) + .toList(growable: false), + if (tools.isNotEmpty) 'tool_choice': 'auto', + }, + options: Options(responseType: ResponseType.stream), + cancelToken: cancelToken, + ); + + final body = response.data; + if (body == null) { + throw Exception('Empty AI response'); + } + + final textBuffer = StringBuffer(); + var contentEmitted = false; + final toolCallBuilders = {}; + await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { + if (data == '[DONE]') { + continue; + } + + final json = jsonDecode(data); + if (json is! Map) { + continue; + } + + final choices = json['choices'] as List?; + if (choices == null || choices.isEmpty) { + continue; + } + + final first = choices.first; + if (first is! Map) { + continue; + } + + final delta = first['delta']; + if (delta is! Map) { + continue; + } + + final toolCalls = delta['tool_calls']; + if (toolCalls is List) { + for (final item in toolCalls) { + if (item is Map) { + _appendOpenAiToolCallDelta(toolCallBuilders, item); + } + } + } + + final content = delta['content']; + if (content is String && content.isNotEmpty) { + textBuffer.write(content); + if (toolCallBuilders.isEmpty) { + contentEmitted = true; + await onContent(content); + } + } + } + + final text = textBuffer.toString(); + final toolCalls = toolCallBuilders.values + .map((builder) => builder.build()) + .toList(growable: false); + if (text.trim().isEmpty && toolCalls.isEmpty) { + throw Exception('Empty AI response'); + } + return AiCompletionResponse( + text: text, + toolCalls: toolCalls, + contentEmitted: contentEmitted, + ); + } + Map _openAiMessagePayload(AiPromptMessage message) => { 'role': message.role, 'content': message.content, @@ -159,4 +258,51 @@ class OpenAiCompatibleStrategy implements AiProviderStrategy { arguments: AiProviderStrategySupport.toolArguments(function['arguments']), ); } + + void _appendOpenAiToolCallDelta( + Map builders, + Map value, + ) { + final index = value['index']; + final toolCallIndex = index is int ? index : builders.length; + final builder = builders.putIfAbsent( + toolCallIndex, + _OpenAiToolCallBuilder.new, + ); + + final id = value['id']; + if (id is String && id.isNotEmpty) { + builder.id = id; + } + + final function = value['function']; + if (function is Map) { + final name = function['name']; + if (name is String && name.isNotEmpty) { + builder.name = name; + } + final arguments = function['arguments']; + if (arguments is String && arguments.isNotEmpty) { + builder.arguments.write(arguments); + } + } + } +} + +final class _OpenAiToolCallBuilder { + String? id; + String? name; + final StringBuffer arguments = StringBuffer(); + + AiToolCall build() { + final toolName = name; + if (toolName == null || toolName.isEmpty) { + throw Exception('Invalid AI tool call name'); + } + return AiToolCall( + id: id ?? '${toolName}_$hashCode', + name: toolName, + arguments: AiProviderStrategySupport.toolArguments(arguments.toString()), + ); + } } diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index a43ad38348..4a1929f83f 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -1,5 +1,6 @@ import 'package:drift/drift.dart'; +import '../../ai/model/ai_chat_metadata.dart'; import '../mixin_database.dart'; part 'ai_chat_message_dao.g.dart'; @@ -106,6 +107,30 @@ class AiChatMessageDao extends DatabaseAccessor ), ); + Future appendMessageMetadataToolEvent( + String id, + Map event, { + required DateTime updatedAt, + }) async { + await transaction(() async { + final message = await (select( + db.aiChatMessages, + )..where((tbl) => tbl.id.equals(id))).getSingleOrNull(); + if (message == null) { + return; + } + final metadata = appendAiToolEventToMetadata(message.metadata, event); + await (update( + db.aiChatMessages, + )..where((tbl) => tbl.id.equals(id))).write( + AiChatMessagesCompanion( + metadata: Value(metadata), + updatedAt: Value(updatedAt), + ), + ); + }); + } + Future deleteConversationMessages(String conversationId) => (delete( db.aiChatMessages, )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); diff --git a/lib/db/mixin_database.dart b/lib/db/mixin_database.dart index d4f4bd9568..ab55e76c35 100644 --- a/lib/db/mixin_database.dart +++ b/lib/db/mixin_database.dart @@ -101,7 +101,7 @@ class MixinDatabase extends _$MixinDatabase { MixinDatabase(super.e); @override - int get schemaVersion => 30; + int get schemaVersion => 31; final eventBus = DataBaseEventBus.instance; @@ -296,6 +296,9 @@ class MixinDatabase extends _$MixinDatabase { aiChatMessages.anchorCreatedAt, ); } + if (from <= 30) { + await _addColumnIfNotExists(m, aiChatMessages, aiChatMessages.metadata); + } }, beforeOpen: (details) async { if (details.hadUpgrade && details.versionBefore! <= 20) { diff --git a/lib/db/mixin_database.g.dart b/lib/db/mixin_database.g.dart index 0aeb8bfca1..94de2d2c84 100644 --- a/lib/db/mixin_database.g.dart +++ b/lib/db/mixin_database.g.dart @@ -17771,6 +17771,17 @@ class AiChatMessages extends Table requiredDuringInsert: false, $customConstraints: '', ); + static const VerificationMeta _metadataMeta = const VerificationMeta( + 'metadata', + ); + late final GeneratedColumn metadata = GeneratedColumn( + 'metadata', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); late final GeneratedColumnWithTypeConverter createdAt = GeneratedColumn( 'created_at', @@ -17801,6 +17812,7 @@ class AiChatMessages extends Table status, model, errorText, + metadata, createdAt, updatedAt, ]; @@ -17885,6 +17897,12 @@ class AiChatMessages extends Table errorText.isAcceptableOrUnknown(data['error_text']!, _errorTextMeta), ); } + if (data.containsKey('metadata')) { + context.handle( + _metadataMeta, + metadata.isAcceptableOrUnknown(data['metadata']!, _metadataMeta), + ); + } return context; } @@ -17936,6 +17954,10 @@ class AiChatMessages extends Table DriftSqlType.string, data['${effectivePrefix}error_text'], ), + metadata: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}metadata'], + ), createdAt: AiChatMessages.$convertercreatedAt.fromSql( attachedDatabase.typeMapping.read( DriftSqlType.int, @@ -17981,6 +18003,7 @@ class AiChatMessage extends DataClass implements Insertable { final String status; final String? model; final String? errorText; + final String? metadata; final DateTime createdAt; final DateTime updatedAt; const AiChatMessage({ @@ -17994,6 +18017,7 @@ class AiChatMessage extends DataClass implements Insertable { required this.status, this.model, this.errorText, + this.metadata, required this.createdAt, required this.updatedAt, }); @@ -18020,6 +18044,9 @@ class AiChatMessage extends DataClass implements Insertable { if (!nullToAbsent || errorText != null) { map['error_text'] = Variable(errorText); } + if (!nullToAbsent || metadata != null) { + map['metadata'] = Variable(metadata); + } { map['created_at'] = Variable( AiChatMessages.$convertercreatedAt.toSql(createdAt), @@ -18053,6 +18080,9 @@ class AiChatMessage extends DataClass implements Insertable { errorText: errorText == null && nullToAbsent ? const Value.absent() : Value(errorText), + metadata: metadata == null && nullToAbsent + ? const Value.absent() + : Value(metadata), createdAt: Value(createdAt), updatedAt: Value(updatedAt), ); @@ -18076,6 +18106,7 @@ class AiChatMessage extends DataClass implements Insertable { status: serializer.fromJson(json['status']), model: serializer.fromJson(json['model']), errorText: serializer.fromJson(json['error_text']), + metadata: serializer.fromJson(json['metadata']), createdAt: serializer.fromJson(json['created_at']), updatedAt: serializer.fromJson(json['updated_at']), ); @@ -18094,6 +18125,7 @@ class AiChatMessage extends DataClass implements Insertable { 'status': serializer.toJson(status), 'model': serializer.toJson(model), 'error_text': serializer.toJson(errorText), + 'metadata': serializer.toJson(metadata), 'created_at': serializer.toJson(createdAt), 'updated_at': serializer.toJson(updatedAt), }; @@ -18110,6 +18142,7 @@ class AiChatMessage extends DataClass implements Insertable { String? status, Value model = const Value.absent(), Value errorText = const Value.absent(), + Value metadata = const Value.absent(), DateTime? createdAt, DateTime? updatedAt, }) => AiChatMessage( @@ -18127,6 +18160,7 @@ class AiChatMessage extends DataClass implements Insertable { status: status ?? this.status, model: model.present ? model.value : this.model, errorText: errorText.present ? errorText.value : this.errorText, + metadata: metadata.present ? metadata.value : this.metadata, createdAt: createdAt ?? this.createdAt, updatedAt: updatedAt ?? this.updatedAt, ); @@ -18150,6 +18184,7 @@ class AiChatMessage extends DataClass implements Insertable { status: data.status.present ? data.status.value : this.status, model: data.model.present ? data.model.value : this.model, errorText: data.errorText.present ? data.errorText.value : this.errorText, + metadata: data.metadata.present ? data.metadata.value : this.metadata, createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, ); @@ -18168,6 +18203,7 @@ class AiChatMessage extends DataClass implements Insertable { ..write('status: $status, ') ..write('model: $model, ') ..write('errorText: $errorText, ') + ..write('metadata: $metadata, ') ..write('createdAt: $createdAt, ') ..write('updatedAt: $updatedAt') ..write(')')) @@ -18186,6 +18222,7 @@ class AiChatMessage extends DataClass implements Insertable { status, model, errorText, + metadata, createdAt, updatedAt, ); @@ -18203,6 +18240,7 @@ class AiChatMessage extends DataClass implements Insertable { other.status == this.status && other.model == this.model && other.errorText == this.errorText && + other.metadata == this.metadata && other.createdAt == this.createdAt && other.updatedAt == this.updatedAt); } @@ -18218,6 +18256,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { final Value status; final Value model; final Value errorText; + final Value metadata; final Value createdAt; final Value updatedAt; final Value rowid; @@ -18232,6 +18271,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { this.status = const Value.absent(), this.model = const Value.absent(), this.errorText = const Value.absent(), + this.metadata = const Value.absent(), this.createdAt = const Value.absent(), this.updatedAt = const Value.absent(), this.rowid = const Value.absent(), @@ -18247,6 +18287,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { required String status, this.model = const Value.absent(), this.errorText = const Value.absent(), + this.metadata = const Value.absent(), required DateTime createdAt, required DateTime updatedAt, this.rowid = const Value.absent(), @@ -18269,6 +18310,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { Expression? status, Expression? model, Expression? errorText, + Expression? metadata, Expression? createdAt, Expression? updatedAt, Expression? rowid, @@ -18284,6 +18326,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { if (status != null) 'status': status, if (model != null) 'model': model, if (errorText != null) 'error_text': errorText, + if (metadata != null) 'metadata': metadata, if (createdAt != null) 'created_at': createdAt, if (updatedAt != null) 'updated_at': updatedAt, if (rowid != null) 'rowid': rowid, @@ -18301,6 +18344,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { Value? status, Value? model, Value? errorText, + Value? metadata, Value? createdAt, Value? updatedAt, Value? rowid, @@ -18316,6 +18360,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { status: status ?? this.status, model: model ?? this.model, errorText: errorText ?? this.errorText, + metadata: metadata ?? this.metadata, createdAt: createdAt ?? this.createdAt, updatedAt: updatedAt ?? this.updatedAt, rowid: rowid ?? this.rowid, @@ -18357,6 +18402,9 @@ class AiChatMessagesCompanion extends UpdateCompanion { if (errorText.present) { map['error_text'] = Variable(errorText.value); } + if (metadata.present) { + map['metadata'] = Variable(metadata.value); + } if (createdAt.present) { map['created_at'] = Variable( AiChatMessages.$convertercreatedAt.toSql(createdAt.value), @@ -18386,6 +18434,7 @@ class AiChatMessagesCompanion extends UpdateCompanion { ..write('status: $status, ') ..write('model: $model, ') ..write('errorText: $errorText, ') + ..write('metadata: $metadata, ') ..write('createdAt: $createdAt, ') ..write('updatedAt: $updatedAt, ') ..write('rowid: $rowid') @@ -28764,6 +28813,7 @@ typedef $AiChatMessagesCreateCompanionBuilder = required String status, Value model, Value errorText, + Value metadata, required DateTime createdAt, required DateTime updatedAt, Value rowid, @@ -28780,6 +28830,7 @@ typedef $AiChatMessagesUpdateCompanionBuilder = Value status, Value model, Value errorText, + Value metadata, Value createdAt, Value updatedAt, Value rowid, @@ -28845,6 +28896,11 @@ class $AiChatMessagesFilterComposer builder: (column) => ColumnFilters(column), ); + ColumnFilters get metadata => $composableBuilder( + column: $table.metadata, + builder: (column) => ColumnFilters(column), + ); + ColumnWithTypeConverterFilters get createdAt => $composableBuilder( column: $table.createdAt, @@ -28917,6 +28973,11 @@ class $AiChatMessagesOrderingComposer builder: (column) => ColumnOrderings(column), ); + ColumnOrderings get metadata => $composableBuilder( + column: $table.metadata, + builder: (column) => ColumnOrderings(column), + ); + ColumnOrderings get createdAt => $composableBuilder( column: $table.createdAt, builder: (column) => ColumnOrderings(column), @@ -28976,6 +29037,9 @@ class $AiChatMessagesAnnotationComposer GeneratedColumn get errorText => $composableBuilder(column: $table.errorText, builder: (column) => column); + GeneratedColumn get metadata => + $composableBuilder(column: $table.metadata, builder: (column) => column); + GeneratedColumnWithTypeConverter get createdAt => $composableBuilder(column: $table.createdAt, builder: (column) => column); @@ -29024,6 +29088,7 @@ class $AiChatMessagesTableManager Value status = const Value.absent(), Value model = const Value.absent(), Value errorText = const Value.absent(), + Value metadata = const Value.absent(), Value createdAt = const Value.absent(), Value updatedAt = const Value.absent(), Value rowid = const Value.absent(), @@ -29038,6 +29103,7 @@ class $AiChatMessagesTableManager status: status, model: model, errorText: errorText, + metadata: metadata, createdAt: createdAt, updatedAt: updatedAt, rowid: rowid, @@ -29054,6 +29120,7 @@ class $AiChatMessagesTableManager required String status, Value model = const Value.absent(), Value errorText = const Value.absent(), + Value metadata = const Value.absent(), required DateTime createdAt, required DateTime updatedAt, Value rowid = const Value.absent(), @@ -29068,6 +29135,7 @@ class $AiChatMessagesTableManager status: status, model: model, errorText: errorText, + metadata: metadata, createdAt: createdAt, updatedAt: updatedAt, rowid: rowid, diff --git a/lib/db/moor/mixin.drift b/lib/db/moor/mixin.drift index 805839e0b3..dbafc21d8b 100644 --- a/lib/db/moor/mixin.drift +++ b/lib/db/moor/mixin.drift @@ -84,6 +84,7 @@ CREATE TABLE ai_chat_messages ( status TEXT NOT NULL, model TEXT, error_text TEXT, + metadata TEXT, created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, PRIMARY KEY(id) diff --git a/lib/ui/home/chat_slide_page/ai_assistant/composer.dart b/lib/ui/home/chat_slide_page/ai_assistant/composer.dart index 9838c48e54..9abacb5bd5 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant/composer.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant/composer.dart @@ -1,10 +1,15 @@ +import 'dart:ui' as ui show BoxHeightStyle; + import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import '../../../../ai/model/ai_provider_config.dart'; import '../../../../constants/constants.dart'; +import '../../../../constants/resources.dart'; import '../../../../utils/extension/extension.dart'; import '../../../../widgets/action_button.dart'; +import '../../../../widgets/actions/actions.dart'; +import '../../../../widgets/high_light_text.dart'; import '../../../../widgets/menu.dart'; import 'constants.dart'; @@ -36,96 +41,174 @@ class AiAssistantComposer extends StatelessWidget { @override Widget build(BuildContext context) { - final buttonColor = !enabled - ? context.theme.secondaryText - : requestInFlight - ? context.theme.red - : context.theme.accent; + final fieldColor = context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); return Container( - padding: const EdgeInsets.fromLTRB(16, 12, 16, 16), - decoration: BoxDecoration( - color: context.theme.primary, - border: Border(top: BorderSide(color: context.theme.divider)), - ), - child: DecoratedBox( - decoration: BoxDecoration( - borderRadius: const BorderRadius.all(Radius.circular(14)), - color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), - ), - ), - child: Padding( - padding: const EdgeInsets.fromLTRB(12, 10, 8, 10), - child: Column( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: CrossAxisAlignment.start, + padding: const EdgeInsets.fromLTRB(16, 8, 16, 8), + color: context.theme.primary, + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + if (provider != null) ...[ + _AiAssistantModeBar( + provider: provider!, + enabledAiProviders: enabledAiProviders, + onProviderSelected: onProviderSelected, + onModelSelected: onModelSelected, + ), + const SizedBox(height: 8), + ], + Row( + crossAxisAlignment: CrossAxisAlignment.end, children: [ - if (provider != null) ...[ - _AiAssistantModeBar( - provider: provider!, - enabledAiProviders: enabledAiProviders, - onProviderSelected: onProviderSelected, - onModelSelected: onModelSelected, - ), - const SizedBox(height: 2), - ], - Row( - crossAxisAlignment: CrossAxisAlignment.end, - children: [ - Expanded( - child: TextField( - focusNode: focusNode, - controller: textEditingController, - enabled: enabled, - minLines: 1, - maxLines: 6, - inputFormatters: [ - LengthLimitingTextInputFormatter(kMaxTextLength), - ], - style: TextStyle( - color: context.theme.text, - fontSize: 14, - height: 1.4, - ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: enabled - ? aiAssistantInputHint - : aiAssistantUnavailable, - hintStyle: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, + Expanded( + child: Container( + constraints: const BoxConstraints(minHeight: 40), + decoration: BoxDecoration( + borderRadius: const BorderRadius.all(Radius.circular(10)), + color: fieldColor, + ), + alignment: Alignment.center, + child: ValueListenableBuilder( + valueListenable: textEditingController, + builder: (context, value, child) { + final hasInputText = value.text.trim().isNotEmpty; + final canSend = + enabled && + !requestInFlight && + hasInputText && + value.composing.composed; + + return FocusableActionDetector( + autofocus: true, + shortcuts: { + if (canSend) + const SingleActivator(LogicalKeyboardKey.enter): + const _SendMessageIntent(), + const SingleActivator(LogicalKeyboardKey.escape): + const EscapeIntent(), + }, + actions: { + _SendMessageIntent: CallbackAction( + onInvoke: (_) { + onSend(); + return null; + }, + ), + EscapeIntent: CallbackAction( + onInvoke: (_) { + focusNode.unfocus(); + return null; + }, + ), + }, + child: Stack( + children: [ + TextField( + focusNode: focusNode, + controller: textEditingController, + enabled: enabled, + minLines: 1, + maxLines: 7, + inputFormatters: [ + LengthLimitingTextInputFormatter( + kMaxTextLength, + ), + ], + textAlignVertical: TextAlignVertical.center, + style: TextStyle( + color: context.theme.text, + fontSize: 14, + ), + decoration: const InputDecoration( + isDense: true, + border: InputBorder.none, + enabledBorder: InputBorder.none, + focusedBorder: InputBorder.none, + contentPadding: EdgeInsets.only( + left: 10, + right: 10, + top: 8, + bottom: 8, + ), + ), + selectionHeightStyle: + ui.BoxHeightStyle.includeLineSpacingMiddle, + contextMenuBuilder: (context, state) => + MixinAdaptiveSelectionToolbar( + editableTextState: state, + ), + ), + if (!hasInputText) + Positioned.fill( + left: 8, + child: Align( + alignment: Alignment.centerLeft, + child: IgnorePointer( + child: Text( + enabled + ? aiAssistantInputHint + : aiAssistantUnavailable, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + ), + ), + ], ), - ), - ), + ); + }, ), - const SizedBox(width: 8), - ActionButton( - padding: const EdgeInsets.all(6), - size: 20, - interactive: enabled, - onTap: requestInFlight ? onStop : onSend, - child: Icon( - requestInFlight - ? Icons.stop_rounded - : Icons.arrow_upward_rounded, - size: 18, + ), + ), + const SizedBox(width: 10), + ValueListenableBuilder( + valueListenable: textEditingController, + builder: (context, value, child) { + final hasInputText = value.text.trim().isNotEmpty; + final interactive = + enabled && (requestInFlight || hasInputText); + final buttonColor = + !enabled || (!requestInFlight && !hasInputText) + ? context.theme.secondaryText + : requestInFlight + ? context.theme.red + : context.theme.accent; + + return AnimatedOpacity( + duration: const Duration(milliseconds: 180), + opacity: interactive ? 1 : 0.45, + child: ActionButton( + name: requestInFlight + ? Resources.assetsImagesRecordStopSvg + : Resources.assetsImagesIcSendSvg, color: buttonColor, + interactive: interactive, + onTap: requestInFlight ? onStop : onSend, ), - ), - ], + ); + }, ), ], ), - ), + ], ), ); } } +class _SendMessageIntent extends Intent { + const _SendMessageIntent(); +} + class _AiAssistantModeBar extends StatelessWidget { const _AiAssistantModeBar({ required this.provider, @@ -159,47 +242,62 @@ class _AiAssistantModeBar extends StatelessWidget { ) .toList(growable: false); - return DecoratedBox( - decoration: BoxDecoration( - border: Border( - bottom: BorderSide( - color: context.dynamicColor( - const Color.fromRGBO(0, 0, 0, 0.05), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), - ), - ), - ), - ), - child: Padding( - padding: const EdgeInsets.only(bottom: 8), - child: Row( - children: [ - Flexible( - child: _AiModeChip( - icon: Icons.hub_rounded, - label: provider.name, - items: providerOptions, - enabled: providerOptions.length > 1, - onSelected: onProviderSelected, + return SizedBox( + width: double.infinity, + height: 30, + child: LayoutBuilder( + builder: (context, constraints) { + const spacing = 10.0; + const dividerSpace = 21.0; + final availableWidth = constraints.maxWidth - spacing - dividerSpace; + + return Row( + children: [ + ConstrainedBox( + constraints: BoxConstraints( + maxWidth: availableWidth > 0 ? availableWidth / 2 : 0, + ), + child: _AiModeChip( + icon: Icons.hub_rounded, + label: provider.name, + items: providerOptions, + enabled: providerOptions.length > 1, + onSelected: onProviderSelected, + ), ), - ), - const SizedBox(width: 8), - Flexible( - child: _AiModeChip( - icon: Icons.tune_rounded, - label: provider.model, - items: modelOptions, - enabled: modelOptions.length > 1, - onSelected: onModelSelected, + const SizedBox(width: 10), + _AiModeDivider(), + const SizedBox(width: 10), + Expanded( + child: _AiModeChip( + icon: Icons.tune_rounded, + label: provider.model, + items: modelOptions, + enabled: modelOptions.length > 1, + fill: true, + onSelected: onModelSelected, + ), ), - ), - ], - ), + ], + ); + }, ), ); } } +class _AiModeDivider extends StatelessWidget { + @override + Widget build(BuildContext context) => Container( + width: 1, + height: 14, + color: context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.08), + darkColor: const Color.fromRGBO(255, 255, 255, 0.1), + ), + ); +} + class _AiModeChip extends StatelessWidget { const _AiModeChip({ required this.icon, @@ -207,6 +305,7 @@ class _AiModeChip extends StatelessWidget { required this.items, required this.onSelected, required this.enabled, + this.fill = false, }); final IconData icon; @@ -214,14 +313,16 @@ class _AiModeChip extends StatelessWidget { final List> items; final ValueChanged onSelected; final bool enabled; + final bool fill; @override Widget build(BuildContext context) { final child = Row( + mainAxisSize: fill ? MainAxisSize.max : MainAxisSize.min, children: [ Icon(icon, size: 13, color: context.theme.secondaryText), const SizedBox(width: 6), - Expanded( + Flexible( child: Text( label, maxLines: 1, diff --git a/lib/ui/setting/ai_prompt_settings_page.dart b/lib/ui/setting/ai_prompt_settings_page.dart index 45d9603f6f..41a737548f 100644 --- a/lib/ui/setting/ai_prompt_settings_page.dart +++ b/lib/ui/setting/ai_prompt_settings_page.dart @@ -30,90 +30,94 @@ class AiPromptSettingsPage extends HookConsumerWidget { appBar: const MixinAppBar(title: Text('AI Prompt Templates')), body: Align( alignment: Alignment.topCenter, - child: SingleChildScrollView( - child: Padding( - padding: const EdgeInsets.only(top: 20, bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Padding( - padding: const EdgeInsets.all(16), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Text( - customizedCount == 0 - ? 'All prompts are using built-in defaults.' - : '$customizedCount prompt templates currently use custom overrides.', - style: TextStyle( - color: context.theme.text, - fontSize: 15, - fontWeight: FontWeight.w600, - ), - ), - const SizedBox(height: 8), - Text( - 'Templates support placeholders like {{conversationId}}, {{currentIsoDateTime}}, {{language}}, and {{input}}. Each editor shows the variables available for that prompt.', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, - height: 1.4, - ), - ), - const SizedBox(height: 8), - Text( - 'Leave a template empty to disable that prompt block. Saving the exact default text removes the custom override.', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 13, - height: 1.4, - ), - ), - ], - ), - ), - ), - for (final group in AiPromptTemplateGroup.values) ...[ - _SectionLabel(title: group.title), + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ CellGroup( padding: const EdgeInsets.only(right: 10, left: 10), cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Column( - children: [ - for ( - var i = 0; - i < - aiPromptTemplateDefinitions - .where((item) => item.group == group) - .length; - i++ - ) ...[ - _PromptTemplateCell( - definition: aiPromptTemplateDefinitions - .where((item) => item.group == group) - .elementAt(i), + child: Padding( + padding: const EdgeInsets.all(16), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + customizedCount == 0 + ? 'All prompts are using built-in defaults.' + : '$customizedCount prompt templates currently use custom overrides.', + style: TextStyle( + color: context.theme.text, + fontSize: 15, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 8), + Text( + 'Templates support placeholders like {{conversationId}}, {{currentIsoDateTime}}, {{language}}, and {{input}}. Each editor shows the variables available for that prompt.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + height: 1.4, + ), ), - if (i != - aiPromptTemplateDefinitions - .where((item) => item.group == group) - .length - - 1) - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, + const SizedBox(height: 8), + Text( + 'Leave a template empty to disable that prompt block. Saving the exact default text removes the custom override.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + height: 1.4, ), + ), ], - ], + ), ), ), + for (final group in AiPromptTemplateGroup.values) ...[ + _SectionLabel(title: group.title), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + for ( + var i = 0; + i < + aiPromptTemplateDefinitions + .where((item) => item.group == group) + .length; + i++ + ) ...[ + _PromptTemplateCell( + definition: aiPromptTemplateDefinitions + .where((item) => item.group == group) + .elementAt(i), + ), + if (i != + aiPromptTemplateDefinitions + .where((item) => item.group == group) + .length - + 1) + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + ], + ], + ), + ), + ], ], - ], + ), ), ), ), @@ -237,97 +241,108 @@ class _AiPromptTemplateEditPage extends HookConsumerWidget { ), body: Align( alignment: Alignment.topCenter, - child: SingleChildScrollView( - child: Padding( - padding: const EdgeInsets.only(top: 20, bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const _SectionLabel(title: 'Description'), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: theme.settingCellBackgroundColor, - child: Padding( - padding: const EdgeInsets.all(16), - child: Text( - definition.description, - style: TextStyle( - color: theme.text, - fontSize: 14, - height: 1.45, + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const _SectionLabel(title: 'Description'), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.all(16), + child: Text( + definition.description, + style: TextStyle( + color: theme.text, + fontSize: 14, + height: 1.45, + ), ), ), ), - ), - const _SectionLabel(title: 'Variables'), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: theme.settingCellBackgroundColor, - child: Padding( - padding: const EdgeInsets.all(16), - child: _PromptVariableChipWrap( - variables: definition.variables, - onTap: (variable) => - _insertToken(controller, variable.token), + const _SectionLabel(title: 'Variables'), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.all(16), + child: _PromptVariableChipWrap( + variables: definition.variables, + onTap: (variable) => + _insertToken(controller, variable.token), + ), ), ), - ), - Padding( - padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), - child: Text( - 'Hover to preview the description. Click a chip to insert it at the current cursor position.', - style: TextStyle( - color: theme.secondaryText, - fontSize: 14, + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, ), - ), - ), - const _SectionLabel(title: 'Template'), - ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 600), - child: Padding( - padding: const EdgeInsets.symmetric(horizontal: 10), - child: Container( - decoration: BoxDecoration( - color: inputBackgroundColor, - borderRadius: BorderRadius.circular(8), - border: Border.all(color: inputBorderColor), - ), - padding: const EdgeInsets.symmetric( - horizontal: 14, - vertical: 12, + child: Text( + 'Hover to preview the description. Click a chip to insert it at the current cursor position.', + style: TextStyle( + color: theme.secondaryText, + fontSize: 14, ), - child: TextField( - controller: controller, - minLines: 10, - maxLines: null, - style: TextStyle( - color: theme.text, - fontSize: 15, - height: 1.45, + ), + ), + const _SectionLabel(title: 'Template'), + ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 10), + child: Container( + decoration: BoxDecoration( + color: inputBackgroundColor, + borderRadius: BorderRadius.circular(8), + border: Border.all(color: inputBorderColor), ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: definition.defaultValue, - hintStyle: TextStyle(color: theme.secondaryText), + padding: const EdgeInsets.symmetric( + horizontal: 14, + vertical: 12, + ), + child: TextField( + controller: controller, + minLines: 10, + maxLines: null, + style: TextStyle( + color: theme.text, + fontSize: 15, + height: 1.45, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: definition.defaultValue, + hintStyle: TextStyle(color: theme.secondaryText), + ), ), ), ), ), - ), - Padding( - padding: const EdgeInsets.only(left: 20, right: 20, top: 12), - child: Text( - 'Empty text disables this prompt block. Saving the exact default text removes the override and falls back to the built-in template.', - style: TextStyle( - color: theme.secondaryText, - fontSize: 13, - height: 1.4, + Padding( + padding: const EdgeInsets.only( + left: 20, + right: 20, + top: 12, + ), + child: Text( + 'Empty text disables this prompt block. Saving the exact default text removes the override and falls back to the built-in template.', + style: TextStyle( + color: theme.secondaryText, + fontSize: 13, + height: 1.4, + ), ), ), - ), - ], + ], + ), ), ), ), diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart index 574301d7a3..d5dd230978 100644 --- a/lib/ui/setting/ai_provider_edit_page.dart +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -166,242 +166,253 @@ class AiProviderEditPage extends HookConsumerWidget { ), body: Align( alignment: Alignment.topCenter, - child: SingleChildScrollView( - child: Padding( - padding: const EdgeInsets.only(top: 20, bottom: 20), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const _SectionLabel( - title: 'Provider', - ), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Column( - children: [ - _FormFieldCell( - label: 'Display Name', - backgroundColor: inputBackgroundColor, - borderColor: inputBorderColor, - child: TextField( - controller: nameController, - style: TextStyle( - color: theme.text, - fontSize: 16, - ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: - 'OpenAI / Anthropic / Gemini / Self-hosted', - hintStyle: TextStyle(color: theme.secondaryText), - ), - ), - ), - _CellDivider(color: theme.divider), - _FormFieldCell( - label: 'Provider Type', - backgroundColor: inputBackgroundColor, - borderColor: inputBorderColor, - child: DropdownButtonHideUnderline( - child: DropdownButton( - value: providerType.value, - isExpanded: true, - dropdownColor: theme.popUp, + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const _SectionLabel( + title: 'Provider', + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + _FormFieldCell( + label: 'Display Name', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: TextField( + controller: nameController, style: TextStyle( color: theme.text, fontSize: 16, ), - iconEnabledColor: inputIconColor, - onChanged: (value) { - if (value == null || - value == providerType.value) { - return; - } - final previousType = providerType.value; - providerType.value = value; - if (initial == null) { - final suggestion = _defaultBaseUrlFor(value); - final current = baseUrlController.text.trim(); - final replaceCurrent = - current.isEmpty || - current == _defaultBaseUrlFor(previousType); - if (replaceCurrent && suggestion.isNotEmpty) { - baseUrlController.text = suggestion; + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: + 'OpenAI / Anthropic / Gemini / Self-hosted', + hintStyle: TextStyle(color: theme.secondaryText), + ), + ), + ), + _CellDivider(color: theme.divider), + _FormFieldCell( + label: 'Provider Type', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: DropdownButtonHideUnderline( + child: DropdownButton( + value: providerType.value, + isExpanded: true, + dropdownColor: theme.popUp, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + iconEnabledColor: inputIconColor, + onChanged: (value) { + if (value == null || + value == providerType.value) { + return; + } + final previousType = providerType.value; + providerType.value = value; + if (initial == null) { + final suggestion = _defaultBaseUrlFor(value); + final current = baseUrlController.text.trim(); + final replaceCurrent = + current.isEmpty || + current == + _defaultBaseUrlFor(previousType); + if (replaceCurrent && suggestion.isNotEmpty) { + baseUrlController.text = suggestion; + } } - } - }, - items: AiProviderType.values - .map( - (type) => DropdownMenuItem( - value: type, - child: Text( - switch (type) { - AiProviderType.anthropic => 'Anthropic', - AiProviderType.gemini => 'Gemini', - AiProviderType.openaiCompatible => - 'OpenAI Compatible', - }, + }, + items: AiProviderType.values + .map( + (type) => DropdownMenuItem( + value: type, + child: Text( + switch (type) { + AiProviderType.anthropic => + 'Anthropic', + AiProviderType.gemini => 'Gemini', + AiProviderType.openaiCompatible => + 'OpenAI Compatible', + }, + ), ), - ), - ) - .toList(), + ) + .toList(), + ), ), ), + ], + ), + ), + const _SectionLabel( + title: 'Endpoint', + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: _FormFieldCell( + label: 'Base URL', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: TextField( + controller: baseUrlController, + keyboardType: TextInputType.url, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: _baseUrlHintFor(providerType.value), + hintStyle: TextStyle(color: theme.secondaryText), + ), ), - ], + ), ), - ), - const _SectionLabel( - title: 'Endpoint', - ), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: theme.settingCellBackgroundColor, - child: _FormFieldCell( - label: 'Base URL', - backgroundColor: inputBackgroundColor, - borderColor: inputBorderColor, - child: TextField( - controller: baseUrlController, - keyboardType: TextInputType.url, + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + _baseUrlHelperTextFor(providerType.value), style: TextStyle( - color: theme.text, - fontSize: 16, - ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: _baseUrlHintFor(providerType.value), - hintStyle: TextStyle(color: theme.secondaryText), + color: context.theme.secondaryText, + fontSize: 14, ), ), ), - ), - Padding( - padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), - child: Text( - _baseUrlHelperTextFor(providerType.value), - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, - ), + const _SectionLabel( + title: 'Authorization', ), - ), - const _SectionLabel( - title: 'Authorization', - ), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Column( - children: [ - _FormFieldCell( - label: 'API Key', - backgroundColor: inputBackgroundColor, - borderColor: inputBorderColor, - trailing: IconButton( - onPressed: () => - obscureApiKey.value = !obscureApiKey.value, - icon: Icon( - obscureApiKey.value - ? Icons.visibility_outlined - : Icons.visibility_off_outlined, - size: 20, - color: inputIconColor, - ), - ), - child: TextField( - controller: apiKeyController, - obscureText: obscureApiKey.value, - style: TextStyle( - color: theme.text, - fontSize: 16, + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + _FormFieldCell( + label: 'API Key', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + trailing: IconButton( + onPressed: () => + obscureApiKey.value = !obscureApiKey.value, + icon: Icon( + obscureApiKey.value + ? Icons.visibility_outlined + : Icons.visibility_off_outlined, + size: 20, + color: inputIconColor, + ), ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: _apiKeyHintFor(providerType.value), - hintStyle: TextStyle(color: theme.secondaryText), + child: TextField( + controller: apiKeyController, + obscureText: obscureApiKey.value, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: _apiKeyHintFor(providerType.value), + hintStyle: TextStyle(color: theme.secondaryText), + ), ), ), - ), - ], + ], + ), ), - ), - const _SectionLabel( - title: 'Models', - ), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: theme.settingCellBackgroundColor, - child: Column( - children: [ - CellItem( - title: const Text('Default Model'), - description: Text( - defaultModel.value.isEmpty - ? 'No default model yet' - : defaultModel.value, - maxLines: 1, - overflow: TextOverflow.ellipsis, + const _SectionLabel( + title: 'Models', + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: Column( + children: [ + CellItem( + title: const Text('Default Model'), + description: Text( + defaultModel.value.isEmpty + ? 'No default model yet' + : defaultModel.value, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + trailing: null, ), - trailing: null, - ), - _CellDivider(color: context.theme.divider), - CellItem( - title: const Text('Add Model'), - leading: Icon(Icons.add, color: context.theme.icon), - trailing: null, - onTap: showModelDialog, - ), - if (models.value.isEmpty) ...[ _CellDivider(color: context.theme.divider), - Padding( - padding: const EdgeInsets.symmetric( - horizontal: 16, - vertical: 20, - ), - child: Row( - children: [ - Icon( - Icons.view_list_outlined, - size: 18, - color: theme.secondaryText, - ), - const SizedBox(width: 10), - Expanded( - child: Text( - 'No models yet. Add at least one model before saving.', - style: TextStyle( - color: theme.secondaryText, - fontSize: 14, - ), - ), - ), - ], - ), + CellItem( + title: const Text('Add Model'), + leading: Icon(Icons.add, color: context.theme.icon), + trailing: null, + onTap: showModelDialog, ), - ] else ...[ - for (var i = 0; i < models.value.length; i++) ...[ + if (models.value.isEmpty) ...[ _CellDivider(color: context.theme.divider), - _ModelItem( - model: models.value[i], - selected: models.value[i] == defaultModel.value, - onTap: () => defaultModel.value = models.value[i], - onEdit: () => showModelDialog( - initialValue: models.value[i], - index: i, + Padding( + padding: const EdgeInsets.symmetric( + horizontal: 16, + vertical: 20, + ), + child: Row( + children: [ + Icon( + Icons.view_list_outlined, + size: 18, + color: theme.secondaryText, + ), + const SizedBox(width: 10), + Expanded( + child: Text( + 'No models yet. Add at least one model before saving.', + style: TextStyle( + color: theme.secondaryText, + fontSize: 14, + ), + ), + ), + ], ), - onDelete: () => removeModelAt(i), ), + ] else ...[ + for (var i = 0; i < models.value.length; i++) ...[ + _CellDivider(color: context.theme.divider), + _ModelItem( + model: models.value[i], + selected: models.value[i] == defaultModel.value, + onTap: () => defaultModel.value = models.value[i], + onEdit: () => showModelDialog( + initialValue: models.value[i], + index: i, + ), + onDelete: () => removeModelAt(i), + ), + ], ], ], - ], + ), ), - ), - ], + ], + ), ), ), ), diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index 667effb391..19e370db1d 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -37,85 +37,37 @@ class AiSettingsPage extends HookConsumerWidget { appBar: const MixinAppBar(title: Text('AI Settings')), body: Align( alignment: Alignment.topCenter, - child: SingleChildScrollView( - child: Padding( - padding: const EdgeInsets.only(top: 40), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: CellItem( - title: const Text('Prompt Templates'), - leading: Icon( - Icons.tune_rounded, - color: context.theme.icon, - ), - description: Text( - customizedPromptCount == 0 - ? 'Default' - : '$customizedPromptCount custom', - maxLines: 1, - overflow: TextOverflow.ellipsis, - ), - trailing: null, - onTap: () => Navigator.of(context).push( - MaterialPageRoute( - builder: (_) => const AiPromptSettingsPage(), - ), - ), - ), - ), - Padding( - padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), - child: Text( - 'Customize chat prompts, assist prompts, and built-in variables like {{conversationId}}, {{currentIsoDateTime}}, and {{language}}.', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, - ), - ), - ), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: CellItem( - title: const Text('Add Provider'), - leading: Icon(Icons.add, color: context.theme.icon), - trailing: null, - onTap: () => Navigator.of(context).push( - MaterialPageRoute( - builder: (_) => const AiProviderEditPage(), - ), - ), - ), - ), - Padding( - padding: const EdgeInsets.only(left: 20, bottom: 14, top: 10), - child: Text( - providers.isEmpty - ? 'Add an AI provider to enable AI mode in chat.' - : 'The selected provider is used by default in AI mode.', - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, - ), - ), - ), - if (providers.isNotEmpty) ...[ + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 40), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ CellGroup( padding: const EdgeInsets.only(right: 10, left: 10), cellBackgroundColor: context.theme.settingCellBackgroundColor, child: CellItem( - title: const Text('Default Provider'), + title: const Text('Prompt Templates'), + leading: Icon( + Icons.tune_rounded, + color: context.theme.icon, + ), description: Text( - _providerSummary(selectedProvider), + customizedPromptCount == 0 + ? 'Default' + : '$customizedPromptCount custom', maxLines: 1, overflow: TextOverflow.ellipsis, ), trailing: null, + onTap: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => const AiPromptSettingsPage(), + ), + ), ), ), Padding( @@ -125,7 +77,7 @@ class AiSettingsPage extends HookConsumerWidget { top: 10, ), child: Text( - 'Each API endpoint can contain multiple models. One default model is used for new AI requests.', + 'Customize chat prompts, assist prompts, and built-in variables like {{conversationId}}, {{currentIsoDateTime}}, and {{language}}.', style: TextStyle( color: context.theme.secondaryText, fontSize: 14, @@ -136,26 +88,87 @@ class AiSettingsPage extends HookConsumerWidget { padding: const EdgeInsets.only(right: 10, left: 10), cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Column( - children: [ - for (var i = 0; i < providers.length; i++) ...[ - _ProviderCell( - provider: providers[i], - selected: selectedId == providers[i].id, - ), - if (i != providers.length - 1) - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, + child: CellItem( + title: const Text('Add Provider'), + leading: Icon(Icons.add, color: context.theme.icon), + trailing: null, + onTap: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => const AiProviderEditPage(), + ), + ), + ), + ), + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + providers.isEmpty + ? 'Add an AI provider to enable AI mode in chat.' + : 'The selected provider is used by default in AI mode.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), + if (providers.isNotEmpty) ...[ + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: CellItem( + title: const Text('Default Provider'), + description: Text( + _providerSummary(selectedProvider), + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + trailing: null, + ), + ), + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + 'Each API endpoint can contain multiple models. One default model is used for new AI requests.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + for (var i = 0; i < providers.length; i++) ...[ + _ProviderCell( + provider: providers[i], + selected: selectedId == providers[i].id, ), + if (i != providers.length - 1) + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + ], ], - ], + ), ), - ), + ], ], - ], + ), ), ), ), diff --git a/lib/ui/setting/setting_page.dart b/lib/ui/setting/setting_page.dart index 9481c60ee3..1d6d80dbe5 100644 --- a/lib/ui/setting/setting_page.dart +++ b/lib/ui/setting/setting_page.dart @@ -138,7 +138,7 @@ class SettingPage extends HookConsumerWidget { ResponsiveNavigatorStateNotifier.appearancePage, title: context.l10n.appearance, ), - _Item( + const _Item( leadingAssetName: Resources.assetsImagesIcAppearanceSvg, pageName: diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index cfd030d667..5ea64b3020 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -5,6 +5,7 @@ import 'package:flutter/services.dart'; import 'package:intl/intl.dart'; import 'package:super_context_menu/super_context_menu.dart'; +import '../../ai/model/ai_chat_metadata.dart'; import '../../db/mixin_database.dart' hide Offset; import '../../utils/datetime_format_utils.dart'; import '../../utils/extension/extension.dart'; @@ -97,6 +98,10 @@ class _AiMessageBody extends StatelessWidget { Widget build(BuildContext context) { final isUser = message.role == 'user'; final text = _displayText(message); + final isPendingAssistant = + !isUser && + message.status == 'pending' && + message.content.trim().isEmpty; Widget body; final textStyle = TextStyle( @@ -107,7 +112,9 @@ class _AiMessageBody extends StatelessWidget { height: 1.45, ); - if (isUser || message.status == 'error') { + if (isPendingAssistant) { + body = _AiPendingAssistantActivity(message: message, style: textStyle); + } else if (isUser || message.status == 'error') { body = _AiSelectableText(text: text, style: textStyle); } else { final cacheKey = buildMarkdownCacheKey( @@ -137,6 +144,47 @@ class _AiMessageBody extends StatelessWidget { } } +class _AiPendingAssistantActivity extends StatelessWidget { + const _AiPendingAssistantActivity({ + required this.message, + required this.style, + }); + + final AiChatMessage message; + final TextStyle style; + + @override + Widget build(BuildContext context) { + final text = _pendingAssistantText(message); + final color = context.dynamicColor( + const Color.fromRGBO(131, 145, 158, 1), + darkColor: const Color.fromRGBO(128, 131, 134, 1), + ); + + return SelectionContainer.disabled( + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox.square( + dimension: 14, + child: CircularProgressIndicator( + strokeWidth: 2, + valueColor: AlwaysStoppedAnimation(color), + ), + ), + const SizedBox(width: 8), + Flexible( + child: Text( + text, + style: style.copyWith(color: color), + ), + ), + ], + ), + ); + } +} + class _AiSelectableText extends StatefulWidget { const _AiSelectableText({ required this.text, @@ -371,6 +419,50 @@ String _displayText(AiChatMessage message) { if (message.status == 'error') { return message.errorText ?? 'Request failed'; } - if (message.status == 'pending') return 'Thinking...'; + if (message.status == 'pending') return _pendingAssistantText(message); return message.errorText ?? 'No response'; } + +String _pendingAssistantText(AiChatMessage message) { + final activeToolName = _activeToolName(message.metadata); + if (activeToolName != null) { + return _toolActivityText(activeToolName); + } + return 'Thinking...'; +} + +String? _activeToolName(String? metadata) { + final events = aiMetadataToolEvents(metadata); + if (events.isEmpty) { + return null; + } + + final finishedToolCallIds = events + .where((event) => event['type'] == aiToolEventTypeResult) + .map((event) => event['id']) + .whereType() + .toSet(); + + for (final event in events.reversed) { + if (event['type'] != aiToolEventTypeCall) { + continue; + } + final id = event['id']; + if (id is String && finishedToolCallIds.contains(id)) { + continue; + } + final name = event['name']; + if (name is String && name.isNotEmpty) { + return name; + } + } + return null; +} + +String _toolActivityText(String toolName) => switch (toolName) { + 'get_conversation_stats' => 'Reading conversation stats...', + 'list_conversation_chunks' => 'Planning conversation read...', + 'read_conversation_chunk' => 'Reading conversation...', + 'search_conversation_messages' => 'Searching conversation...', + _ => 'Using tool...', +}; From 337bce5a2e1171e2774984da647b0530665203a6 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 16:17:38 +0800 Subject: [PATCH 25/52] refactor: replace custom AI tool handling with Genkit integration --- lib/ai/ai_chat_controller.dart | 81 +++-- lib/ai/ai_chat_prompt_builder.dart | 144 +------- lib/ai/ai_provider_requester.dart | 329 +++++++----------- lib/ai/model/ai_chat_metadata.dart | 20 +- lib/ai/model/ai_prompt_message.dart | 19 +- lib/ai/model/ai_prompt_template.dart | 4 - lib/ai/model/ai_tool.dart | 28 -- lib/ai/provider/ai_provider_strategy.dart | 118 ------- lib/ai/provider/anthropic_strategy.dart | 208 ----------- lib/ai/provider/gemini_strategy.dart | 267 -------------- .../provider/openai_compatible_strategy.dart | 308 ---------------- .../tools/ai_conversation_tool_service.dart | 34 +- pubspec.lock | 105 +++++- pubspec.yaml | 10 +- 14 files changed, 321 insertions(+), 1354 deletions(-) delete mode 100644 lib/ai/provider/ai_provider_strategy.dart delete mode 100644 lib/ai/provider/anthropic_strategy.dart delete mode 100644 lib/ai/provider/gemini_strategy.dart delete mode 100644 lib/ai/provider/openai_compatible_strategy.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index fe70f57552..ccb32bde96 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -3,7 +3,9 @@ import 'dart:convert'; import 'package:dio/dio.dart'; import 'package:drift/drift.dart'; +import 'package:genkit/genkit.dart' as genkit; import 'package:mixin_logger/mixin_logger.dart'; +import 'package:schemantic/schemantic.dart'; import 'package:uuid/uuid.dart'; import '../db/dao/ai_chat_message_dao.dart'; @@ -250,59 +252,85 @@ class AiChatController { cancelToken: cancelToken, onContent: onContent, conversationId: conversationId, - onToolCall: _toolExecutorFor( + tools: _toolsFor( conversationId, assistantMessageId: assistantMessageId, ), ); - Future Function(AiToolCall toolCall)? _toolExecutorFor( + List, Map>>? _toolsFor( String? conversationId, { String? assistantMessageId, }) { if (conversationId == null) { return null; } - return (toolCall) => _executeConversationTool( - conversationId: conversationId, - assistantMessageId: assistantMessageId, - toolCall: toolCall, - ); + return AiConversationToolKit.definitions + .map( + (definition) => + genkit.Tool, Map>( + name: definition.name, + description: definition.description, + inputSchema: _schemaFor(definition), + fn: (input, context) async { + final request = context.toolRequest?.toolRequest; + return _executeConversationTool( + conversationId: conversationId, + assistantMessageId: assistantMessageId, + id: request?.ref ?? '${definition.name}_${input.hashCode}', + name: request?.name ?? definition.name, + arguments: input, + ); + }, + ), + ) + .toList(growable: false); } - Future _executeConversationTool({ + SchemanticType> _schemaFor( + AiToolDefinition definition, + ) => SchemanticType.from>( + jsonSchema: definition.inputSchema.map(MapEntry.new), + parse: _jsonMap, + ); + + Future> _executeConversationTool({ required String conversationId, required String? assistantMessageId, - required AiToolCall toolCall, + required String id, + required String name, + required Map arguments, }) async { final stopwatch = Stopwatch()..start(); d( 'AI tool execute start: conversationId=$conversationId ' - 'tool=${toolCall.name} id=${toolCall.id} ' - 'arguments=${_previewJson(toolCall.arguments)}', + 'tool=$name id=$id ' + 'arguments=${_previewJson(arguments)}', ); await _appendAssistantToolEvent( assistantMessageId, - createAiToolCallEvent(toolCall), + createAiToolCallEvent(id: id, name: name, arguments: arguments), ); try { final result = await _conversationTools.execute( conversationId: conversationId, - call: toolCall, + name: name, + arguments: arguments, ); d( 'AI tool execute done: conversationId=$conversationId ' - 'tool=${toolCall.name} id=${toolCall.id} ' + 'tool=$name id=$id ' 'elapsedMs=${stopwatch.elapsedMilliseconds} ' - 'result=${_previewJson(result.payload)}', + 'result=${_previewJson(result)}', ); await _appendAssistantToolEvent( assistantMessageId, createAiToolResultEvent( - toolCall: toolCall, + id: id, + name: name, status: 'done', elapsedMs: stopwatch.elapsedMilliseconds, - resultPreview: _previewJson(result.payload), + resultPreview: _previewJson(result), ), ); return result; @@ -311,17 +339,14 @@ class AiChatController { await _appendAssistantToolEvent( assistantMessageId, createAiToolResultEvent( - toolCall: toolCall, + id: id, + name: name, status: 'error', elapsedMs: stopwatch.elapsedMilliseconds, errorText: error.toString(), ), ); - return AiToolExecutionResult( - toolCallId: toolCall.id, - toolName: toolCall.name, - payload: {'error': '$error'}, - ); + return {'error': '$error'}; } } @@ -367,6 +392,16 @@ String _previewJson(Object? value, {int maxLength = _kAiLogJsonPreviewLength}) { } } +Map _jsonMap(dynamic value) { + if (value is Map) { + return value; + } + if (value is Map) { + return value.map((key, value) => MapEntry('$key', value)); + } + throw Exception('Invalid AI tool arguments'); +} + class _StreamingMessageUpdater { _StreamingMessageUpdater({required this.dao, required this.messageId}); diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index 2c94500c11..ca01d40072 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -1,8 +1,5 @@ -import 'dart:async'; - import 'package:mixin_logger/mixin_logger.dart'; -import '../db/dao/message_dao.dart'; import '../db/database.dart'; import '../db/mixin_database.dart'; import 'model/ai_prompt_message.dart'; @@ -14,10 +11,7 @@ class AiChatPromptBuilder { static const _aiRoleUser = 'user'; static const _aiStatusPending = 'pending'; static const _aiContextMessageLimit = 30; - static const _aiRetrievedMessageLimit = 6; static const _aiHistoryLimit = 12; - static const _aiRetrievalQueryMaxLength = 120; - static const _aiLogPreviewLength = 240; final Database database; @@ -30,11 +24,6 @@ class AiChatPromptBuilder { final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _aiContextMessageLimit) .get(); - final retrievedMessages = await _retrieveConversationMessages( - conversationId: conversationId, - recentMessages: recentMessages, - query: input, - ); final aiMessages = await database.aiChatMessageDao.conversationMessages( conversationId, ); @@ -67,7 +56,6 @@ class AiChatPromptBuilder { promptMessages, conversationId: conversationId, recentMessages: recentMessages, - retrievedMessages: retrievedMessages, language: language, now: now, ); @@ -97,7 +85,7 @@ class AiChatPromptBuilder { ); d( 'AI prompt built: conversationId=$conversationId ' - 'recent=${recentMessages.length} retrieved=${retrievedMessages.length} ' + 'recent=${recentMessages.length} ' 'history=${history.length} promptMessages=${promptMessages.length}', ); return promptMessages; @@ -142,16 +130,10 @@ class AiChatPromptBuilder { final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _aiContextMessageLimit) .get(); - final retrievedMessages = await _retrieveConversationMessages( - conversationId: conversationId, - recentMessages: recentMessages, - query: input ?? _latestRetrievalSeed(recentMessages), - ); _appendConversationContext( promptMessages, conversationId: conversationId, recentMessages: recentMessages, - retrievedMessages: retrievedMessages, language: language, now: now, ); @@ -209,7 +191,6 @@ class AiChatPromptBuilder { List promptMessages, { required String conversationId, required List recentMessages, - required List retrievedMessages, required String language, required DateTime now, }) { @@ -238,108 +219,6 @@ class AiChatPromptBuilder { ), ); } - - if (retrievedMessages.isEmpty) { - return; - } - - final lines = retrievedMessages - .map( - (message) => _conversationContextLine( - createdAt: message.createdAt, - sender: message.senderFullName ?? message.senderId, - content: _searchMessagePlainText(message), - ), - ) - .join('\n'); - promptMessages.addAll( - _promptMessages( - role: 'system', - content: renderAiPromptTemplate( - retrievedConversationContextPromptTemplate, - buildAiPromptTemplateVariables( - conversationId: conversationId, - messages: lines, - language: language, - now: now, - ), - ), - ), - ); - } - - Future> _retrieveConversationMessages({ - required String conversationId, - required List recentMessages, - required String? query, - }) async { - final normalizedQuery = _normalizeRetrievalQuery(query); - if (normalizedQuery == null) { - d('AI retrieval skipped: conversationId=$conversationId empty query'); - return const []; - } - - final recentIds = recentMessages - .map((message) => message.messageId) - .toSet(); - final matchedIds = await database.ftsDatabase.fuzzySearchMessage( - query: normalizedQuery, - limit: _aiRetrievedMessageLimit + recentIds.length, - conversationIds: [conversationId], - ); - final candidateIds = matchedIds - .where((messageId) => !recentIds.contains(messageId)) - .take(_aiRetrievedMessageLimit) - .toList(growable: false); - if (candidateIds.isEmpty) { - d( - 'AI retrieval no match: conversationId=$conversationId ' - 'query=${_previewText(normalizedQuery)}', - ); - return const []; - } - - final matchedMessages = await database.messageDao - .searchMessageByIds(candidateIds) - .get(); - final messagesById = { - for (final message in matchedMessages) message.messageId: message, - }; - final ordered = []; - for (final messageId in candidateIds) { - final message = messagesById[messageId]; - if (message != null) { - ordered.add(message); - } - } - ordered.sort((left, right) => left.createdAt.compareTo(right.createdAt)); - d( - 'AI retrieval matched: conversationId=$conversationId ' - 'query=${_previewText(normalizedQuery)} matches=${ordered.length}', - ); - return ordered; - } - - String? _latestRetrievalSeed(List recentMessages) { - for (final message in recentMessages) { - final content = _messagePlainText(message); - final normalized = _normalizeRetrievalQuery(content); - if (normalized != null) { - return normalized; - } - } - return null; - } - - String? _normalizeRetrievalQuery(String? query) { - final compact = query?.replaceAll(RegExp(r'\s+'), ' ').trim(); - if (compact == null || compact.isEmpty) { - return null; - } - if (compact.length <= _aiRetrievalQueryMaxLength) { - return compact; - } - return compact.substring(0, _aiRetrievalQueryMaxLength); } String _conversationContextLine({ @@ -354,13 +233,6 @@ class AiChatPromptBuilder { type: message.type, ); - String _searchMessagePlainText(SearchMessageDetailItem message) => - _messagePlainTextFromFields( - content: message.content, - mediaName: message.mediaName, - type: message.type, - ); - String _messagePlainTextFromFields({ required String? content, required String? mediaName, @@ -386,20 +258,6 @@ class AiChatPromptBuilder { } } -String _previewText( - String? text, { - int maxLength = AiChatPromptBuilder._aiLogPreviewLength, -}) { - final compact = text?.replaceAll(RegExp(r'\s+'), ' ').trim() ?? ''; - if (compact.isEmpty) { - return '""'; - } - if (compact.length <= maxLength) { - return compact; - } - return '${compact.substring(0, maxLength)}...(${compact.length} chars)'; -} - extension _IterableTakeLastExtension on Iterable { Iterable takeLast(int count) { if (count <= 0) return const []; diff --git a/lib/ai/ai_provider_requester.dart b/lib/ai/ai_provider_requester.dart index 55480f3bb9..87f19b9423 100644 --- a/lib/ai/ai_provider_requester.dart +++ b/lib/ai/ai_provider_requester.dart @@ -1,31 +1,24 @@ import 'dart:async'; -import 'dart:convert'; +import 'dart:io'; import 'package:dio/dio.dart'; +import 'package:genkit/genkit.dart' as genkit; +import 'package:genkit/plugin.dart' as genkit_plugin; +import 'package:genkit_anthropic/genkit_anthropic.dart'; +import 'package:genkit_google_genai/genkit_google_genai.dart'; +import 'package:genkit_openai/genkit_openai.dart'; import 'package:mixin_logger/mixin_logger.dart'; import '../utils/proxy.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_provider_config.dart'; import 'model/ai_provider_type.dart'; -import 'model/ai_tool.dart'; -import 'provider/ai_provider_strategy.dart'; -import 'provider/anthropic_strategy.dart'; -import 'provider/gemini_strategy.dart'; -import 'provider/openai_compatible_strategy.dart'; -import 'tools/ai_conversation_tool_service.dart'; class AiProviderRequester { const AiProviderRequester(); static const _aiToolMaxRounds = 8; - static const _aiStreamFlushChars = 32; static const _aiLogPreviewLength = 240; - static const _aiLogJsonPreviewLength = 480; - - static const _openAiStrategy = OpenAiCompatibleStrategy(); - static const _anthropicStrategy = AnthropicStrategy(); - static const _geminiStrategy = GeminiStrategy(); Future requestText( AiProviderConfig config, @@ -34,189 +27,163 @@ class AiProviderRequester { required CancelToken cancelToken, required Future Function(String chunk) onContent, required String? conversationId, - Future Function(AiToolCall toolCall)? onToolCall, + List? tools, }) async { d( 'AI request start: provider=${config.type.name} model=${config.model} ' 'conversationId=$conversationId messages=${messages.length} ' - 'tools=${conversationId != null && onToolCall != null}', + 'tools=${tools?.length ?? 0}', ); - final dio = - Dio( - BaseOptions( - baseUrl: config.baseUrl, - connectTimeout: const Duration(seconds: 20), - receiveTimeout: const Duration(minutes: 5), - sendTimeout: const Duration(seconds: 20), - headers: _strategyFor(config.type).headers(config), - ), - ) - ..interceptors.add( - InterceptorsWrapper( - onRequest: (options, handler) { - options.extra['ai_request_started_at'] = DateTime.now(); - d( - 'AI HTTP request: ${options.method} ${options.uri} ' - 'provider=${config.type.name} model=${config.model}', - ); - handler.next(options); - }, - onResponse: (response, handler) { - final startedAt = - response.requestOptions.extra['ai_request_started_at'] - as DateTime?; - d( - 'AI HTTP response: ${response.requestOptions.method} ' - '${response.requestOptions.uri} ' - 'status=${response.statusCode} ' - 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds}', - ); - handler.next(response); - }, - onError: (error, handler) { - final startedAt = - error.requestOptions.extra['ai_request_started_at'] - as DateTime?; - e( - 'AI HTTP error: ${error.requestOptions.method} ' - '${error.requestOptions.uri} ' - 'elapsedMs=${startedAt == null ? -1 : DateTime.now().difference(startedAt).inMilliseconds} ' - 'error=${error.message}', - error, - error.stackTrace, - ); - handler.next(error); - }, - ), - ) - ..applyProxy(proxy); + if (cancelToken.isCancelled) { + throw Exception('AI generation stopped'); + } - if (conversationId == null || onToolCall == null) { - return _strategyFor(config.type).streamResponse( - dio: dio, - config: config, - messages: messages, + return _runWithProxy( + proxy, + () => _requestTextWithGenkit( + config, + messages, cancelToken: cancelToken, onContent: onContent, - ); - } - - return _requestWithTools( - dio, - config, - [...messages], - conversationId: conversationId, - cancelToken: cancelToken, - onContent: onContent, - onToolCall: onToolCall, + conversationId: conversationId, + tools: tools, + ), ); } - Future _requestWithTools( - Dio dio, + Future _requestTextWithGenkit( AiProviderConfig config, List messages, { - required String conversationId, required CancelToken cancelToken, required Future Function(String chunk) onContent, - required Future Function(AiToolCall toolCall) - onToolCall, + required String? conversationId, + required List? tools, }) async { - for (var round = 0; round < _aiToolMaxRounds; round++) { - d( - 'AI tool round start: conversationId=$conversationId ' - 'round=${round + 1}/$_aiToolMaxRounds messages=${messages.length}', + final ai = _createGenkit(config); + try { + final cancelFuture = cancelToken.whenCancel.then((_) {}); + final stream = ai.generateStream( + messages: messages.map(_genkitMessage).toList(growable: false), + model: _modelFor(config), + tools: tools, + toolChoice: tools == null ? null : 'auto', + maxTurns: _aiToolMaxRounds, ); - final strategy = _strategyFor(config.type); - final response = strategy is OpenAiCompatibleStrategy - ? await strategy.streamCompleteResponse( - dio: dio, - config: config, - messages: messages, - tools: AiConversationToolKit.definitions, - cancelToken: cancelToken, - onContent: onContent, - ) - : await strategy.completeResponse( - dio: dio, - config: config, - messages: messages, - tools: AiConversationToolKit.definitions, - cancelToken: cancelToken, - ); - d( - 'AI tool round response: conversationId=$conversationId ' - 'round=${round + 1} text=${_previewText(response.text)} ' - 'toolCalls=${_previewToolCalls(response.toolCalls)}', + + final subscriptionCompleter = Completer(); + late final StreamSubscription> + subscription; + subscription = stream.listen( + (chunk) { + final text = chunk.text; + if (text.isEmpty) { + return; + } + subscription.pause(); + unawaited( + Future.sync(() => onContent(text)) + .catchError((Object error, StackTrace stackTrace) { + if (!subscriptionCompleter.isCompleted) { + subscriptionCompleter.completeError(error, stackTrace); + } + }) + .whenComplete(subscription.resume), + ); + }, + onError: subscriptionCompleter.completeError, + onDone: subscriptionCompleter.complete, + cancelOnError: true, ); - if (!response.hasToolCalls) { - final text = response.text.trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - if (!response.contentEmitted) { - await _emitBufferedText(text, onContent); - } - d( - 'AI tool request done: ' - 'conversationId=$conversationId ' - 'round=${round + 1} text=${_previewText(text)}', - ); - return text; + await Future.any([ + subscriptionCompleter.future, + cancelFuture.then((_) async { + await subscription.cancel(); + throw Exception('AI generation stopped'); + }), + ]); + + final response = await Future.any([ + stream.onResult, + cancelFuture.then>((_) { + throw Exception('AI generation stopped'); + }), + ]); + final text = response.text.trim(); + if (text.isEmpty) { + throw Exception('Empty AI response'); } - - messages.add( - AiPromptMessage( - role: 'assistant', - content: response.text, - toolCalls: response.toolCalls, - ), + d( + 'AI request done: provider=${config.type.name} model=${config.model} ' + 'conversationId=$conversationId text=${_previewText(text)}', ); - for (final toolCall in response.toolCalls) { - final result = await onToolCall(toolCall); - messages.add( - AiPromptMessage( - role: 'tool', - content: result.content, - toolCallId: result.toolCallId, - toolName: result.toolName, - toolPayload: result.payload, - ), - ); - } + return text; + } finally { + await ai.shutdown(); } - - e( - 'AI exceeded tool call limit: conversationId=$conversationId ' - 'maxRounds=$_aiToolMaxRounds', - ); - throw Exception('AI exceeded tool call limit'); } - Future _emitBufferedText( - String text, - Future Function(String chunk) onContent, - ) async { - final trimmed = text.trim(); - if (trimmed.isEmpty) { - return; + Future _runWithProxy( + ProxyConfig? proxy, + Future Function() fn, + ) { + if (proxy == null) { + return fn(); } - if (trimmed.length <= _aiStreamFlushChars) { - await onContent(trimmed); - return; - } - for (var start = 0; start < trimmed.length; start += _aiStreamFlushChars) { - final end = (start + _aiStreamFlushChars).clamp(0, trimmed.length); - await onContent(trimmed.substring(start, end)); + if (proxy.type == ProxyType.socks5) { + d('AI Genkit request does not support SOCKS5 proxy: ${proxy.toUri()}'); + return fn(); } + return HttpOverrides.runZoned( + fn, + createHttpClient: (context) => + HttpClient(context: context)..setProxy(proxy), + ); } - AiProviderStrategy _strategyFor(AiProviderType type) => switch (type) { - AiProviderType.openaiCompatible => _openAiStrategy, - AiProviderType.anthropic => _anthropicStrategy, - AiProviderType.gemini => _geminiStrategy, + genkit.Genkit _createGenkit(AiProviderConfig config) => genkit.Genkit( + plugins: [_pluginFor(config)], + model: _modelFor(config), + isDevEnv: false, + ); + + genkit_plugin.GenkitPlugin _pluginFor(AiProviderConfig config) => + switch (config.type) { + AiProviderType.openaiCompatible => openAI( + apiKey: config.apiKey, + baseUrl: _emptyToNull(config.baseUrl), + models: [CustomModelDefinition(name: config.model)], + ), + AiProviderType.anthropic => anthropic( + apiKey: config.apiKey, + baseUrl: _emptyToNull(config.baseUrl), + ), + AiProviderType.gemini => googleAI(apiKey: config.apiKey), + }; + + genkit.ModelRef _modelFor(AiProviderConfig config) => + switch (config.type) { + AiProviderType.openaiCompatible => openAI.model(config.model), + AiProviderType.anthropic => anthropic.model(config.model), + AiProviderType.gemini => googleAI.gemini(config.model), + }; + + genkit.Message _genkitMessage(AiPromptMessage message) => genkit.Message( + role: _roleFor(message.role), + content: [genkit.TextPart(text: message.content)], + ); + + genkit.Role _roleFor(String role) => switch (role) { + 'system' => genkit.Role.system, + 'assistant' => genkit.Role.model, + 'tool' => genkit.Role.tool, + _ => genkit.Role.user, }; + + String? _emptyToNull(String value) { + final trimmed = value.trim(); + return trimmed.isEmpty ? null : trimmed; + } } String _previewText( @@ -232,31 +199,3 @@ String _previewText( } return '${compact.substring(0, maxLength)}...(${compact.length} chars)'; } - -String _previewJson( - Object? value, { - int maxLength = AiProviderRequester._aiLogJsonPreviewLength, -}) { - try { - final encoded = jsonEncode(value); - if (encoded.length <= maxLength) { - return encoded; - } - return '${encoded.substring(0, maxLength)}...(${encoded.length} chars)'; - } catch (_) { - return '$value'; - } -} - -String _previewToolCalls(List toolCalls) { - if (toolCalls.isEmpty) { - return '[]'; - } - return toolCalls - .map( - (toolCall) => - '${toolCall.name}#${toolCall.id}(' - '${_previewJson(toolCall.arguments, maxLength: 120)})', - ) - .join(', '); -} diff --git a/lib/ai/model/ai_chat_metadata.dart b/lib/ai/model/ai_chat_metadata.dart index 90a86ee9f7..a8a325d7ae 100644 --- a/lib/ai/model/ai_chat_metadata.dart +++ b/lib/ai/model/ai_chat_metadata.dart @@ -1,7 +1,6 @@ import 'dart:convert'; import 'ai_provider_config.dart'; -import 'ai_tool.dart'; const aiMetadataToolEventsKey = 'toolEvents'; const aiToolEventTypeCall = 'tool_call'; @@ -47,24 +46,29 @@ String appendAiToolEventToMetadata( return jsonEncode(root); } -Map createAiToolCallEvent(AiToolCall toolCall) => { +Map createAiToolCallEvent({ + required String id, + required String name, + required Map arguments, +}) => { 'type': aiToolEventTypeCall, - 'id': toolCall.id, - 'name': toolCall.name, - 'arguments': toolCall.arguments, + 'id': id, + 'name': name, + 'arguments': arguments, 'createdAt': DateTime.now().toUtc().toIso8601String(), }; Map createAiToolResultEvent({ - required AiToolCall toolCall, + required String id, + required String name, required String status, required int elapsedMs, String? resultPreview, String? errorText, }) => { 'type': aiToolEventTypeResult, - 'id': toolCall.id, - 'name': toolCall.name, + 'id': id, + 'name': name, 'status': status, 'elapsedMs': elapsedMs, 'resultPreview': resultPreview, diff --git a/lib/ai/model/ai_prompt_message.dart b/lib/ai/model/ai_prompt_message.dart index 0574dccdeb..c7ab2f94a4 100644 --- a/lib/ai/model/ai_prompt_message.dart +++ b/lib/ai/model/ai_prompt_message.dart @@ -1,23 +1,6 @@ -import 'ai_tool.dart'; - class AiPromptMessage { - AiPromptMessage({ - required this.role, - required this.content, - List? toolCalls, - this.toolCallId, - this.toolName, - this.toolPayload, - }) : toolCalls = toolCalls ?? const []; + AiPromptMessage({required this.role, required this.content}); final String role; final String content; - final List toolCalls; - final String? toolCallId; - final String? toolName; - final Map? toolPayload; - - bool get hasToolCalls => toolCalls.isNotEmpty; - - bool get isToolResult => role == 'tool'; } diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart index b63f09ee48..09b4b452ac 100644 --- a/lib/ai/model/ai_prompt_template.dart +++ b/lib/ai/model/ai_prompt_template.dart @@ -290,10 +290,6 @@ const conversationToolInstructionPromptTemplate = const recentConversationContextPromptTemplate = 'Current conversation recent messages:\n{{messages}}'; -const retrievedConversationContextPromptTemplate = - 'Relevant older conversation messages matched by search ' - '(use only if they help answer the current request):\n{{messages}}'; - Map buildAiPromptTemplateVariables({ String? conversationId, String? input, diff --git a/lib/ai/model/ai_tool.dart b/lib/ai/model/ai_tool.dart index 76e2fbff5b..690f729a8d 100644 --- a/lib/ai/model/ai_tool.dart +++ b/lib/ai/model/ai_tool.dart @@ -1,5 +1,3 @@ -import 'dart:convert'; - class AiToolDefinition { const AiToolDefinition({ required this.name, @@ -11,29 +9,3 @@ class AiToolDefinition { final String description; final Map inputSchema; } - -class AiToolCall { - const AiToolCall({ - required this.id, - required this.name, - required this.arguments, - }); - - final String id; - final String name; - final Map arguments; -} - -class AiToolExecutionResult { - const AiToolExecutionResult({ - required this.toolCallId, - required this.toolName, - required this.payload, - }); - - final String toolCallId; - final String toolName; - final Map payload; - - String get content => jsonEncode(payload); -} diff --git a/lib/ai/provider/ai_provider_strategy.dart b/lib/ai/provider/ai_provider_strategy.dart deleted file mode 100644 index 10df379512..0000000000 --- a/lib/ai/provider/ai_provider_strategy.dart +++ /dev/null @@ -1,118 +0,0 @@ -import 'dart:async'; -import 'dart:convert'; - -import 'package:dio/dio.dart'; - -import '../model/ai_prompt_message.dart'; -import '../model/ai_provider_config.dart'; -import '../model/ai_tool.dart'; - -abstract interface class AiProviderStrategy { - const AiProviderStrategy(); - - Map headers(AiProviderConfig config); - - Future completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }); - - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }); -} - -class AiCompletionResponse { - const AiCompletionResponse({ - this.text = '', - this.toolCalls = const [], - this.contentEmitted = false, - }); - - final String text; - final List toolCalls; - final bool contentEmitted; - - bool get hasToolCalls => toolCalls.isNotEmpty; -} - -final class AiProviderStrategySupport { - const AiProviderStrategySupport._(); - - static Map jsonMap(dynamic value) { - if (value is Map) { - return value; - } - if (value is Map) { - return value.map((key, value) => MapEntry('$key', value)); - } - throw Exception('Invalid AI response payload'); - } - - static Map toolArguments(dynamic value) { - if (value == null) { - return const {}; - } - if (value is String) { - final trimmed = value.trim(); - if (trimmed.isEmpty) { - return const {}; - } - final decoded = jsonDecode(trimmed); - return jsonMap(decoded); - } - return jsonMap(value); - } - - static String stringContent(dynamic value) { - if (value is String) { - return value; - } - if (value is List) { - return value - .whereType() - .map((item) => item['text']) - .whereType() - .join('\n'); - } - return ''; - } - - static Stream decodeSse(Stream> stream) async* { - final buffer = StringBuffer(); - await for (final bytes in stream) { - final chunk = utf8.decode(bytes); - buffer.write(chunk.replaceAll('\r\n', '\n').replaceAll('\r', '\n')); - while (true) { - final current = buffer.toString(); - final separatorIndex = current.indexOf('\n\n'); - if (separatorIndex < 0) { - break; - } - - final rawEvent = current.substring(0, separatorIndex); - final remaining = current.substring(separatorIndex + 2); - buffer - ..clear() - ..write(remaining); - - final payload = rawEvent - .split('\n') - .where((line) => line.startsWith('data:')) - .map((line) => line.substring(5).trimLeft()) - .join('\n') - .trim(); - if (payload.isNotEmpty) { - yield payload; - } - } - } - } -} diff --git a/lib/ai/provider/anthropic_strategy.dart b/lib/ai/provider/anthropic_strategy.dart deleted file mode 100644 index 4199965ca3..0000000000 --- a/lib/ai/provider/anthropic_strategy.dart +++ /dev/null @@ -1,208 +0,0 @@ -import 'dart:convert'; - -import 'package:dio/dio.dart'; - -import '../model/ai_prompt_message.dart'; -import '../model/ai_provider_config.dart'; -import '../model/ai_tool.dart'; -import 'ai_provider_strategy.dart'; - -class AnthropicStrategy implements AiProviderStrategy { - const AnthropicStrategy(); - - @override - Map headers(AiProviderConfig config) => { - 'x-api-key': config.apiKey, - 'anthropic-version': '2023-06-01', - 'content-type': 'application/json', - }; - - @override - Future completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }) async { - final response = await dio.post( - '/messages', - data: { - 'model': config.model, - 'max_tokens': 1024, - 'messages': messages - .where((message) => message.role != 'system') - .map(_anthropicMessagePayload) - .toList(growable: false), - 'system': messages - .where((message) => message.role == 'system') - .map((message) => message.content) - .where((content) => content.isNotEmpty) - .join('\n\n'), - if (tools.isNotEmpty) - 'tools': tools - .map( - (tool) => { - 'name': tool.name, - 'description': tool.description, - 'input_schema': tool.inputSchema, - }, - ) - .toList(growable: false), - }, - cancelToken: cancelToken, - ); - - final body = AiProviderStrategySupport.jsonMap(response.data); - if (body['type'] == 'error') { - final error = AiProviderStrategySupport.jsonMap(body['error']); - throw Exception(error['message'] ?? 'Anthropic request failed'); - } - - final content = body['content'] as List?; - if (content == null || content.isEmpty) { - throw Exception('Empty AI response'); - } - - final textBuffer = StringBuffer(); - final toolCalls = []; - for (final item in content) { - final block = AiProviderStrategySupport.jsonMap(item); - switch (block['type']) { - case 'text': - final text = block['text']; - if (text is String && text.isNotEmpty) { - textBuffer.write(text); - } - case 'tool_use': - final name = block['name'] as String?; - if (name == null || name.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - toolCalls.add( - AiToolCall( - id: block['id'] as String? ?? '${name}_${block.hashCode}', - name: name, - arguments: AiProviderStrategySupport.toolArguments( - block['input'], - ), - ), - ); - } - } - - final text = textBuffer.toString(); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return AiCompletionResponse(text: text, toolCalls: toolCalls); - } - - @override - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final response = await dio.post( - '/messages', - data: { - 'model': config.model, - 'max_tokens': 1024, - 'stream': true, - 'messages': messages - .where((message) => message.role != 'system') - .map(_anthropicMessagePayload) - .toList(growable: false), - 'system': messages - .where((message) => message.role == 'system') - .map((message) => message.content) - .join('\n\n'), - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final buffer = StringBuffer(); - await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final type = json['type'] as String?; - if (type == 'error') { - final error = json['error']; - if (error is Map) { - throw Exception(error['message'] ?? 'Anthropic request failed'); - } - throw Exception('Anthropic request failed'); - } - - if (type != 'content_block_delta') { - continue; - } - - final delta = json['delta']; - if (delta is! Map) { - continue; - } - - if (delta['type'] != 'text_delta') { - continue; - } - - final text = delta['text']; - if (text is String && text.isNotEmpty) { - buffer.write(text); - await onContent(text); - } - } - - final text = buffer.toString().trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - return text; - } - - Map _anthropicMessagePayload(AiPromptMessage message) => { - 'role': message.isToolResult ? 'user' : message.role, - 'content': _anthropicContentBlocks(message), - }; - - List> _anthropicContentBlocks( - AiPromptMessage message, - ) { - if (message.isToolResult) { - return [ - { - 'type': 'tool_result', - 'tool_use_id': message.toolCallId, - 'content': message.content, - }, - ]; - } - - final blocks = >[]; - if (message.content.isNotEmpty) { - blocks.add({'type': 'text', 'text': message.content}); - } - for (final toolCall in message.toolCalls) { - blocks.add({ - 'type': 'tool_use', - 'id': toolCall.id, - 'name': toolCall.name, - 'input': toolCall.arguments, - }); - } - return blocks; - } -} diff --git a/lib/ai/provider/gemini_strategy.dart b/lib/ai/provider/gemini_strategy.dart deleted file mode 100644 index b6030ba9b5..0000000000 --- a/lib/ai/provider/gemini_strategy.dart +++ /dev/null @@ -1,267 +0,0 @@ -import 'dart:convert'; - -import 'package:dio/dio.dart'; - -import '../model/ai_prompt_message.dart'; -import '../model/ai_provider_config.dart'; -import '../model/ai_tool.dart'; -import 'ai_provider_strategy.dart'; - -class GeminiStrategy implements AiProviderStrategy { - const GeminiStrategy(); - - @override - Map headers(AiProviderConfig config) => { - 'x-goog-api-key': config.apiKey, - 'content-type': 'application/json', - }; - - @override - Future completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }) async { - final systemInstruction = messages - .where((message) => message.role == 'system') - .map((message) => message.content.trim()) - .where((content) => content.isNotEmpty) - .join('\n\n'); - final response = await dio.post( - '/models/${Uri.encodeComponent(config.model)}:generateContent', - data: { - 'contents': messages - .where((message) => message.role != 'system') - .map(_geminiMessagePayload) - .toList(growable: false), - if (systemInstruction.isNotEmpty) - 'system_instruction': { - 'parts': [ - {'text': systemInstruction}, - ], - }, - if (tools.isNotEmpty) - 'tools': [ - { - 'functionDeclarations': tools - .map( - (tool) => { - 'name': tool.name, - 'description': tool.description, - 'parameters': tool.inputSchema, - }, - ) - .toList(growable: false), - }, - ], - if (tools.isNotEmpty) - 'toolConfig': { - 'functionCallingConfig': {'mode': 'AUTO'}, - }, - 'generationConfig': { - 'candidateCount': 1, - }, - }, - cancelToken: cancelToken, - ); - - final body = AiProviderStrategySupport.jsonMap(response.data); - final promptFeedback = body['promptFeedback']; - if (promptFeedback is Map) { - final blockReason = promptFeedback['blockReason']; - if (blockReason is String && blockReason.isNotEmpty) { - throw Exception('Gemini request blocked: $blockReason'); - } - } - - final candidates = body['candidates'] as List?; - if (candidates == null || candidates.isEmpty) { - throw Exception('Empty AI response'); - } - final first = AiProviderStrategySupport.jsonMap(candidates.first); - final finishReason = first['finishReason']; - if (finishReason is String && - finishReason.isNotEmpty && - finishReason != 'STOP' && - finishReason != 'FINISH_REASON_UNSPECIFIED') { - throw Exception('Gemini request finished with reason: $finishReason'); - } - - final content = AiProviderStrategySupport.jsonMap(first['content']); - final parts = content['parts'] as List?; - if (parts == null || parts.isEmpty) { - throw Exception('Empty AI response'); - } - - final textBuffer = StringBuffer(); - final toolCalls = []; - for (final item in parts) { - final part = AiProviderStrategySupport.jsonMap(item); - final text = part['text']; - if (text is String && text.isNotEmpty) { - textBuffer.write(text); - } - final functionCall = part['functionCall']; - if (functionCall is Map) { - final name = functionCall['name'] as String?; - if (name == null || name.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - toolCalls.add( - AiToolCall( - id: '${name}_${functionCall.hashCode}', - name: name, - arguments: AiProviderStrategySupport.toolArguments( - functionCall['args'], - ), - ), - ); - } - } - - final text = textBuffer.toString(); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return AiCompletionResponse(text: text, toolCalls: toolCalls); - } - - @override - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final systemInstruction = messages - .where((message) => message.role == 'system') - .map((message) => message.content.trim()) - .where((content) => content.isNotEmpty) - .join('\n\n'); - - final contents = messages - .where((message) => message.role != 'system') - .map(_geminiMessagePayload) - .toList(growable: false); - - final response = await dio.post( - '/models/${Uri.encodeComponent(config.model)}:streamGenerateContent', - queryParameters: const {'alt': 'sse'}, - data: { - 'contents': contents, - if (systemInstruction.isNotEmpty) - 'system_instruction': { - 'parts': [ - {'text': systemInstruction}, - ], - }, - 'generationConfig': { - 'candidateCount': 1, - }, - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final buffer = StringBuffer(); - await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final promptFeedback = json['promptFeedback']; - if (promptFeedback is Map) { - final blockReason = promptFeedback['blockReason']; - if (blockReason is String && blockReason.isNotEmpty) { - throw Exception('Gemini request blocked: $blockReason'); - } - } - - final candidates = json['candidates'] as List?; - if (candidates == null || candidates.isEmpty) { - continue; - } - - final first = candidates.first; - if (first is! Map) { - continue; - } - - final finishReason = first['finishReason']; - if (finishReason is String && - finishReason.isNotEmpty && - finishReason != 'STOP' && - finishReason != 'FINISH_REASON_UNSPECIFIED') { - throw Exception('Gemini request finished with reason: $finishReason'); - } - - final content = first['content']; - if (content is! Map) { - continue; - } - - final parts = content['parts'] as List?; - if (parts == null || parts.isEmpty) { - continue; - } - - for (final part in parts) { - if (part is! Map) { - continue; - } - final text = part['text']; - if (text is String && text.isNotEmpty) { - buffer.write(text); - await onContent(text); - } - } - } - - final text = buffer.toString().trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - return text; - } - - Map _geminiMessagePayload(AiPromptMessage message) => { - 'role': message.role == 'assistant' ? 'model' : 'user', - 'parts': _geminiMessageParts(message), - }; - - List> _geminiMessageParts(AiPromptMessage message) { - if (message.isToolResult) { - return [ - { - 'functionResponse': { - 'name': message.toolName, - 'response': message.toolPayload ?? {'content': message.content}, - }, - }, - ]; - } - - final parts = >[]; - if (message.content.isNotEmpty) { - parts.add({'text': message.content}); - } - for (final toolCall in message.toolCalls) { - parts.add({ - 'functionCall': { - 'name': toolCall.name, - 'args': toolCall.arguments, - }, - }); - } - return parts; - } -} diff --git a/lib/ai/provider/openai_compatible_strategy.dart b/lib/ai/provider/openai_compatible_strategy.dart deleted file mode 100644 index 32330e14fe..0000000000 --- a/lib/ai/provider/openai_compatible_strategy.dart +++ /dev/null @@ -1,308 +0,0 @@ -import 'dart:convert'; - -import 'package:dio/dio.dart'; - -import '../model/ai_prompt_message.dart'; -import '../model/ai_provider_config.dart'; -import '../model/ai_tool.dart'; -import 'ai_provider_strategy.dart'; - -class OpenAiCompatibleStrategy implements AiProviderStrategy { - const OpenAiCompatibleStrategy(); - - @override - Map headers(AiProviderConfig config) => { - 'Authorization': 'Bearer ${config.apiKey}', - 'Content-Type': 'application/json', - }; - - @override - Future completeResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - }) async { - final response = await dio.post( - '/chat/completions', - data: { - 'model': config.model, - 'messages': messages.map(_openAiMessagePayload).toList(growable: false), - if (tools.isNotEmpty) - 'tools': tools - .map( - (tool) => { - 'type': 'function', - 'function': { - 'name': tool.name, - 'description': tool.description, - 'parameters': tool.inputSchema, - }, - }, - ) - .toList(growable: false), - if (tools.isNotEmpty) 'tool_choice': 'auto', - }, - cancelToken: cancelToken, - ); - - final body = AiProviderStrategySupport.jsonMap(response.data); - final choices = body['choices'] as List?; - if (choices == null || choices.isEmpty) { - throw Exception('Empty AI response'); - } - final first = AiProviderStrategySupport.jsonMap(choices.first); - final message = AiProviderStrategySupport.jsonMap(first['message']); - final text = AiProviderStrategySupport.stringContent(message['content']); - final toolCalls = (message['tool_calls'] as List? ?? const []) - .map((item) => _openAiToolCall(AiProviderStrategySupport.jsonMap(item))) - .toList(growable: false); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return AiCompletionResponse(text: text, toolCalls: toolCalls); - } - - @override - Future streamResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final response = await dio.post( - '/chat/completions', - data: { - 'model': config.model, - 'stream': true, - 'messages': messages.map(_openAiMessagePayload).toList(growable: false), - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final buffer = StringBuffer(); - await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { - if (data == '[DONE]') { - continue; - } - - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final choices = json['choices'] as List?; - if (choices == null || choices.isEmpty) { - continue; - } - - final first = choices.first; - if (first is! Map) { - continue; - } - - final delta = first['delta']; - if (delta is! Map) { - continue; - } - - final content = delta['content']; - if (content is String && content.isNotEmpty) { - buffer.write(content); - await onContent(content); - } - } - - final text = buffer.toString().trim(); - if (text.isEmpty) { - throw Exception('Empty AI response'); - } - return text; - } - - Future streamCompleteResponse({ - required Dio dio, - required AiProviderConfig config, - required List messages, - required List tools, - required CancelToken cancelToken, - required Future Function(String chunk) onContent, - }) async { - final response = await dio.post( - '/chat/completions', - data: { - 'model': config.model, - 'stream': true, - 'messages': messages.map(_openAiMessagePayload).toList(growable: false), - if (tools.isNotEmpty) - 'tools': tools - .map( - (tool) => { - 'type': 'function', - 'function': { - 'name': tool.name, - 'description': tool.description, - 'parameters': tool.inputSchema, - }, - }, - ) - .toList(growable: false), - if (tools.isNotEmpty) 'tool_choice': 'auto', - }, - options: Options(responseType: ResponseType.stream), - cancelToken: cancelToken, - ); - - final body = response.data; - if (body == null) { - throw Exception('Empty AI response'); - } - - final textBuffer = StringBuffer(); - var contentEmitted = false; - final toolCallBuilders = {}; - await for (final data in AiProviderStrategySupport.decodeSse(body.stream)) { - if (data == '[DONE]') { - continue; - } - - final json = jsonDecode(data); - if (json is! Map) { - continue; - } - - final choices = json['choices'] as List?; - if (choices == null || choices.isEmpty) { - continue; - } - - final first = choices.first; - if (first is! Map) { - continue; - } - - final delta = first['delta']; - if (delta is! Map) { - continue; - } - - final toolCalls = delta['tool_calls']; - if (toolCalls is List) { - for (final item in toolCalls) { - if (item is Map) { - _appendOpenAiToolCallDelta(toolCallBuilders, item); - } - } - } - - final content = delta['content']; - if (content is String && content.isNotEmpty) { - textBuffer.write(content); - if (toolCallBuilders.isEmpty) { - contentEmitted = true; - await onContent(content); - } - } - } - - final text = textBuffer.toString(); - final toolCalls = toolCallBuilders.values - .map((builder) => builder.build()) - .toList(growable: false); - if (text.trim().isEmpty && toolCalls.isEmpty) { - throw Exception('Empty AI response'); - } - return AiCompletionResponse( - text: text, - toolCalls: toolCalls, - contentEmitted: contentEmitted, - ); - } - - Map _openAiMessagePayload(AiPromptMessage message) => { - 'role': message.role, - 'content': message.content, - if (message.hasToolCalls) - 'tool_calls': message.toolCalls - .map( - (toolCall) => { - 'id': toolCall.id, - 'type': 'function', - 'function': { - 'name': toolCall.name, - 'arguments': jsonEncode(toolCall.arguments), - }, - }, - ) - .toList(growable: false), - if (message.isToolResult) 'tool_call_id': message.toolCallId, - }; - - AiToolCall _openAiToolCall(Map value) { - final function = AiProviderStrategySupport.jsonMap(value['function']); - final name = function['name'] as String?; - if (name == null || name.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - return AiToolCall( - id: value['id'] as String? ?? '${name}_${value.hashCode}', - name: name, - arguments: AiProviderStrategySupport.toolArguments(function['arguments']), - ); - } - - void _appendOpenAiToolCallDelta( - Map builders, - Map value, - ) { - final index = value['index']; - final toolCallIndex = index is int ? index : builders.length; - final builder = builders.putIfAbsent( - toolCallIndex, - _OpenAiToolCallBuilder.new, - ); - - final id = value['id']; - if (id is String && id.isNotEmpty) { - builder.id = id; - } - - final function = value['function']; - if (function is Map) { - final name = function['name']; - if (name is String && name.isNotEmpty) { - builder.name = name; - } - final arguments = function['arguments']; - if (arguments is String && arguments.isNotEmpty) { - builder.arguments.write(arguments); - } - } - } -} - -final class _OpenAiToolCallBuilder { - String? id; - String? name; - final StringBuffer arguments = StringBuffer(); - - AiToolCall build() { - final toolName = name; - if (toolName == null || toolName.isEmpty) { - throw Exception('Invalid AI tool call name'); - } - return AiToolCall( - id: id ?? '${toolName}_$hashCode', - name: toolName, - arguments: AiProviderStrategySupport.toolArguments(arguments.toString()), - ); - } -} diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index 8ba90468ea..cd97868cef 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -495,12 +495,12 @@ class AiConversationToolKit { ), ]; - Future execute({ + Future> execute({ required String conversationId, - required AiToolCall call, + required String name, + required Map arguments, }) async { - final arguments = call.arguments; - switch (call.name) { + switch (name) { case 'get_conversation_stats': final (startInclusive, endExclusive) = _parseRange(arguments); final stats = await service.getConversationStats( @@ -508,11 +508,7 @@ class AiConversationToolKit { startInclusive: startInclusive, endExclusive: endExclusive, ); - return AiToolExecutionResult( - toolCallId: call.id, - toolName: call.name, - payload: stats.toJson(), - ); + return stats.toJson(); case 'list_conversation_chunks': final (startInclusive, endExclusive) = _parseRange(arguments); final chunkSize = _parseInt( @@ -528,11 +524,7 @@ class AiConversationToolKit { startInclusive: startInclusive, endExclusive: endExclusive, ); - return AiToolExecutionResult( - toolCallId: call.id, - toolName: call.name, - payload: chunks.toJson(), - ); + return chunks.toJson(); case 'read_conversation_chunk': final (startInclusive, endExclusive) = _parseRange(arguments); final offset = _parseInt( @@ -556,11 +548,7 @@ class AiConversationToolKit { startInclusive: startInclusive, endExclusive: endExclusive, ); - return AiToolExecutionResult( - toolCallId: call.id, - toolName: call.name, - payload: page.toJson(), - ); + return page.toJson(); case 'search_conversation_messages': final query = _parseRequiredString(arguments, 'query'); final limit = _parseInt( @@ -575,13 +563,9 @@ class AiConversationToolKit { query: query, limit: limit, ); - return AiToolExecutionResult( - toolCallId: call.id, - toolName: call.name, - payload: result.toJson(), - ); + return result.toJson(); default: - throw UnsupportedError('Unknown conversation tool: ${call.name}'); + throw UnsupportedError('Unknown conversation tool: $name'); } } diff --git a/pubspec.lock b/pubspec.lock index 443b221c70..7868e5cc35 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -41,6 +41,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.0.3" + anthropic_sdk_dart: + dependency: transitive + description: + name: anthropic_sdk_dart + sha256: b0e91039942930341b24e3871d5b9481b6d31528b711eb752f413f4d1f5980eb + url: "https://pub.dev" + source: hosted + version: "1.5.0" archive: dependency: "direct main" description: @@ -442,10 +450,10 @@ packages: dependency: transitive description: name: dlibphonenumber - sha256: "95d8e08c6f750e81c5303efd16085db4e7c696b4c306c99617548f81b1854f0c" + sha256: df96f4bdb14b0a47664de8ec7cd1de92e4e9b9bc4e34330d2da9088b0ba71e59 url: "https://pub.dev" source: hosted - version: "1.1.47" + version: "1.1.62" drift: dependency: "direct main" description: @@ -486,6 +494,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.0.5" + email_validator: + dependency: transitive + description: + name: email_validator + sha256: b19aa5d92fdd76fbc65112060c94d45ba855105a28bb6e462de7ff03b12fa1fb + url: "https://pub.dev" + source: hosted + version: "3.0.0" emojis: dependency: "direct main" description: @@ -803,6 +819,38 @@ packages: url: "https://github.com/boyan01/gal.git" source: git version: "2.1.3" + genkit: + dependency: "direct main" + description: + name: genkit + sha256: ccc84935593e6f12447c8d4b27034fceb4573cbc3d51b76d32aa9bbbf8d3badd + url: "https://pub.dev" + source: hosted + version: "0.12.1" + genkit_anthropic: + dependency: "direct main" + description: + name: genkit_anthropic + sha256: "3a902993fdd4cefca4d10ee76df3cf064622247d74c5d0f2b619e5c2b6022727" + url: "https://pub.dev" + source: hosted + version: "0.2.4" + genkit_google_genai: + dependency: "direct main" + description: + name: genkit_google_genai + sha256: "95f798f10776e9078251f7e6ae91e4e5c9cb0b960fc22d50a6baf59fdbcecfbc" + url: "https://pub.dev" + source: hosted + version: "0.2.4" + genkit_openai: + dependency: "direct main" + description: + name: genkit_openai + sha256: "398b3d7a5bc08fb6764ea0a5244151cca31263e67ba9598981e81238109874d1" + url: "https://pub.dev" + source: hosted + version: "0.2.4" get_it: dependency: transitive description: @@ -1091,6 +1139,14 @@ packages: url: "https://pub.dev" source: hosted version: "4.11.0" + json_schema_builder: + dependency: transitive + description: + name: json_schema_builder + sha256: "65035d48d028401ad0ffc8c2f173209c7b1441e465a942a0f909070fae33170c" + url: "https://pub.dev" + source: hosted + version: "0.1.3" json_serializable: dependency: "direct dev" description: @@ -1134,10 +1190,11 @@ packages: libsignal_protocol_dart: dependency: "direct main" description: - name: libsignal_protocol_dart - sha256: "2b18de43016474ab85d21553a88f59d6f4fea8c2eddf35be7e24ab5f8969a81d" - url: "https://pub.dev" - source: hosted + path: "." + ref: "9f7dcbd61850eb5a056d28de3d5758bc08153a0d" + resolved-ref: "9f7dcbd61850eb5a056d28de3d5758bc08153a0d" + url: "https://github.com/MixinNetwork/libsignal_protocol_dart.git" + source: git version: "0.7.4" lints: dependency: transitive @@ -1412,6 +1469,22 @@ packages: url: "https://pub.dev" source: hosted version: "0.0.3" + openai_dart: + dependency: transitive + description: + name: openai_dart + sha256: "13763068d8bf87f7e0ebdb8bf365bada7a7538696380e68ec5d37644d97ff519" + url: "https://pub.dev" + source: hosted + version: "2.0.0" + opentelemetry: + dependency: transitive + description: + name: opentelemetry + sha256: "92d63a2e0731d34a7548add82420b8f3819ccda569f9bdfdcc4b25e00fe88da4" + url: "https://pub.dev" + source: hosted + version: "0.18.11" optional: dependency: transitive description: @@ -1608,10 +1681,10 @@ packages: dependency: transitive description: name: protobuf - sha256: "579fe5557eae58e3adca2e999e38f02441d8aa908703854a9e0a0f47fa857731" + sha256: "75ec242d22e950bdcc79ee38dd520ce4ee0bc491d7fadc4ea47694604d22bf06" url: "https://pub.dev" source: hosted - version: "4.1.0" + version: "6.0.0" protocol_handler: dependency: "direct main" description: @@ -1692,6 +1765,14 @@ packages: url: "https://pub.dev" source: hosted version: "3.0.2" + quiver: + dependency: transitive + description: + name: quiver + sha256: ea0b925899e64ecdfbf9c7becb60d5b50e706ade44a85b2363be2a22d88117d2 + url: "https://pub.dev" + source: hosted + version: "3.2.2" rational: dependency: transitive description: @@ -1740,6 +1821,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.28.0" + schemantic: + dependency: "direct main" + description: + name: schemantic + sha256: "8c143bf964c18a0f2c0c6053d71599ab4985567d86a289d373c171c9190ccb9d" + url: "https://pub.dev" + source: hosted + version: "0.1.1" screen_retriever: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 8c1dd875dd..86b27c0bdd 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -92,7 +92,10 @@ dependencies: intl_phone_number_input: ^0.7.5 isolate: ^2.1.0 json_annotation: ^4.11.0 - libsignal_protocol_dart: ^0.7.4 + libsignal_protocol_dart: + git: + url: https://github.com/MixinNetwork/libsignal_protocol_dart.git + ref: 9f7dcbd61850eb5a056d28de3d5758bc08153a0d local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 @@ -166,6 +169,11 @@ dependencies: ref: 08c1ce40eb6abfad6049fb6aad8bd30312ec5319 path: packages/data_detector envied: ^1.3.4 + genkit: ^0.12.1 + genkit_openai: ^0.2.4 + genkit_anthropic: ^0.2.4 + genkit_google_genai: ^0.2.4 + schemantic: ^0.1.1 dev_dependencies: build_runner: ^2.13.1 From ac1cc1a9bb3ee912efb6e1badd79b64e6e27f9bc Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 16:45:30 +0800 Subject: [PATCH 26/52] refactor: simplify AI tool handling and improve role mapping --- lib/ai/ai_chat_controller.dart | 130 +---- lib/ai/ai_chat_prompt_builder.dart | 17 +- lib/ai/ai_provider_requester.dart | 16 +- lib/ai/model/ai_prompt_message.dart | 23 +- lib/ai/model/ai_tool.dart | 11 - .../tools/ai_conversation_tool_service.dart | 537 ++++++++++++------ lib/ui/setting/ai_provider_edit_page.dart | 81 +-- test/ai/ai_provider_requester_test.dart | 70 +++ 8 files changed, 523 insertions(+), 362 deletions(-) delete mode 100644 lib/ai/model/ai_tool.dart create mode 100644 test/ai/ai_provider_requester_test.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index ccb32bde96..9fecd4ec8e 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -1,11 +1,8 @@ import 'dart:async'; -import 'dart:convert'; import 'package:dio/dio.dart'; import 'package:drift/drift.dart'; -import 'package:genkit/genkit.dart' as genkit; import 'package:mixin_logger/mixin_logger.dart'; -import 'package:schemantic/schemantic.dart'; import 'package:uuid/uuid.dart'; import '../db/dao/ai_chat_message_dao.dart'; @@ -16,7 +13,6 @@ import 'ai_provider_requester.dart'; import 'model/ai_chat_metadata.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_provider_config.dart'; -import 'model/ai_tool.dart'; import 'tools/ai_conversation_tool_service.dart'; const _kAiRoleUser = 'user'; @@ -27,7 +23,6 @@ const _kAiStatusError = 'error'; const _kAiStreamFlushChars = 32; const _kAiStreamFlushInterval = Duration(milliseconds: 80); const _kAiLogPreviewLength = 240; -const _kAiLogJsonPreviewLength = 480; final kAiRuntimeStartedAt = DateTime.now(); final _activeAiRequests = {}; @@ -252,104 +247,15 @@ class AiChatController { cancelToken: cancelToken, onContent: onContent, conversationId: conversationId, - tools: _toolsFor( - conversationId, - assistantMessageId: assistantMessageId, - ), + tools: conversationId == null + ? null + : _conversationTools.genkitTools( + conversationId: conversationId, + onEvent: (event) => + _appendAssistantToolEvent(assistantMessageId, event), + ), ); - List, Map>>? _toolsFor( - String? conversationId, { - String? assistantMessageId, - }) { - if (conversationId == null) { - return null; - } - return AiConversationToolKit.definitions - .map( - (definition) => - genkit.Tool, Map>( - name: definition.name, - description: definition.description, - inputSchema: _schemaFor(definition), - fn: (input, context) async { - final request = context.toolRequest?.toolRequest; - return _executeConversationTool( - conversationId: conversationId, - assistantMessageId: assistantMessageId, - id: request?.ref ?? '${definition.name}_${input.hashCode}', - name: request?.name ?? definition.name, - arguments: input, - ); - }, - ), - ) - .toList(growable: false); - } - - SchemanticType> _schemaFor( - AiToolDefinition definition, - ) => SchemanticType.from>( - jsonSchema: definition.inputSchema.map(MapEntry.new), - parse: _jsonMap, - ); - - Future> _executeConversationTool({ - required String conversationId, - required String? assistantMessageId, - required String id, - required String name, - required Map arguments, - }) async { - final stopwatch = Stopwatch()..start(); - d( - 'AI tool execute start: conversationId=$conversationId ' - 'tool=$name id=$id ' - 'arguments=${_previewJson(arguments)}', - ); - await _appendAssistantToolEvent( - assistantMessageId, - createAiToolCallEvent(id: id, name: name, arguments: arguments), - ); - try { - final result = await _conversationTools.execute( - conversationId: conversationId, - name: name, - arguments: arguments, - ); - d( - 'AI tool execute done: conversationId=$conversationId ' - 'tool=$name id=$id ' - 'elapsedMs=${stopwatch.elapsedMilliseconds} ' - 'result=${_previewJson(result)}', - ); - await _appendAssistantToolEvent( - assistantMessageId, - createAiToolResultEvent( - id: id, - name: name, - status: 'done', - elapsedMs: stopwatch.elapsedMilliseconds, - resultPreview: _previewJson(result), - ), - ); - return result; - } catch (error, stacktrace) { - e('AI tool execution error: $error, $stacktrace'); - await _appendAssistantToolEvent( - assistantMessageId, - createAiToolResultEvent( - id: id, - name: name, - status: 'error', - elapsedMs: stopwatch.elapsedMilliseconds, - errorText: error.toString(), - ), - ); - return {'error': '$error'}; - } - } - Future _appendAssistantToolEvent( String? assistantMessageId, Map event, @@ -380,28 +286,6 @@ String _previewText(String? text, {int maxLength = _kAiLogPreviewLength}) { return '${compact.substring(0, maxLength)}...(${compact.length} chars)'; } -String _previewJson(Object? value, {int maxLength = _kAiLogJsonPreviewLength}) { - try { - final encoded = jsonEncode(value); - if (encoded.length <= maxLength) { - return encoded; - } - return '${encoded.substring(0, maxLength)}...(${encoded.length} chars)'; - } catch (_) { - return '$value'; - } -} - -Map _jsonMap(dynamic value) { - if (value is Map) { - return value; - } - if (value is Map) { - return value.map((key, value) => MapEntry('$key', value)); - } - throw Exception('Invalid AI tool arguments'); -} - class _StreamingMessageUpdater { _StreamingMessageUpdater({required this.dao, required this.messageId}); diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index ca01d40072..202ad5ad50 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -8,7 +8,6 @@ import 'model/ai_prompt_template.dart'; class AiChatPromptBuilder { AiChatPromptBuilder(this.database); - static const _aiRoleUser = 'user'; static const _aiStatusPending = 'pending'; static const _aiContextMessageLimit = 30; static const _aiHistoryLimit = 12; @@ -30,7 +29,7 @@ class AiChatPromptBuilder { final promptMessages = [ ..._promptMessages( - role: 'system', + role: AiPromptRole.system, content: renderAiPromptTemplate( database.settingProperties.aiPromptTemplate( AiPromptTemplateKey.chatSystem, @@ -65,13 +64,13 @@ class AiChatPromptBuilder { .takeLast(_aiHistoryLimit); for (final item in history) { promptMessages.add( - AiPromptMessage(role: item.role, content: item.content), + AiPromptMessage(role: AiPromptRole(item.role), content: item.content), ); } promptMessages.addAll( _promptMessages( - role: _aiRoleUser, + role: AiPromptRole.user, content: renderAiPromptTemplate( chatUserMessagePromptTemplate, buildAiPromptTemplateVariables( @@ -102,7 +101,7 @@ class AiChatPromptBuilder { final trimmedInstruction = instruction.trim(); final promptMessages = [ ..._promptMessages( - role: 'system', + role: AiPromptRole.system, content: renderAiPromptTemplate( database.settingProperties.aiPromptTemplate( AiPromptTemplateKey.assistSystem, @@ -141,7 +140,7 @@ class AiChatPromptBuilder { promptMessages.addAll( _promptMessages( - role: _aiRoleUser, + role: AiPromptRole.user, content: renderAiPromptTemplate( assistUserMessagePromptTemplate, buildAiPromptTemplateVariables( @@ -174,7 +173,7 @@ class AiChatPromptBuilder { } promptMessages.addAll( _promptMessages( - role: 'system', + role: AiPromptRole.system, content: renderAiPromptTemplate( conversationToolInstructionPromptTemplate, buildAiPromptTemplateVariables( @@ -206,7 +205,7 @@ class AiChatPromptBuilder { .join('\n'); promptMessages.addAll( _promptMessages( - role: 'system', + role: AiPromptRole.system, content: renderAiPromptTemplate( recentConversationContextPromptTemplate, buildAiPromptTemplateVariables( @@ -248,7 +247,7 @@ class AiChatPromptBuilder { } List _promptMessages({ - required String role, + required AiPromptRole role, required String content, }) { if (content.trim().isEmpty) { diff --git a/lib/ai/ai_provider_requester.dart b/lib/ai/ai_provider_requester.dart index 87f19b9423..b49daf40f5 100644 --- a/lib/ai/ai_provider_requester.dart +++ b/lib/ai/ai_provider_requester.dart @@ -63,7 +63,9 @@ class AiProviderRequester { try { final cancelFuture = cancelToken.whenCancel.then((_) {}); final stream = ai.generateStream( - messages: messages.map(_genkitMessage).toList(growable: false), + messages: messages + .map((message) => message.toGenkitMessage()) + .toList(growable: false), model: _modelFor(config), tools: tools, toolChoice: tools == null ? null : 'auto', @@ -168,18 +170,6 @@ class AiProviderRequester { AiProviderType.gemini => googleAI.gemini(config.model), }; - genkit.Message _genkitMessage(AiPromptMessage message) => genkit.Message( - role: _roleFor(message.role), - content: [genkit.TextPart(text: message.content)], - ); - - genkit.Role _roleFor(String role) => switch (role) { - 'system' => genkit.Role.system, - 'assistant' => genkit.Role.model, - 'tool' => genkit.Role.tool, - _ => genkit.Role.user, - }; - String? _emptyToNull(String value) { final trimmed = value.trim(); return trimmed.isEmpty ? null : trimmed; diff --git a/lib/ai/model/ai_prompt_message.dart b/lib/ai/model/ai_prompt_message.dart index c7ab2f94a4..cb67ebd8df 100644 --- a/lib/ai/model/ai_prompt_message.dart +++ b/lib/ai/model/ai_prompt_message.dart @@ -1,6 +1,27 @@ +import 'package:genkit/genkit.dart' as genkit; + +extension type AiPromptRole(String value) { + static AiPromptRole get system => AiPromptRole('system'); + static AiPromptRole get user => AiPromptRole('user'); + static AiPromptRole get assistant => AiPromptRole('assistant'); + static AiPromptRole get tool => AiPromptRole('tool'); + + genkit.Role toGenkitRole() => switch (value) { + 'system' => genkit.Role.system, + 'assistant' => genkit.Role.model, + 'tool' => genkit.Role.tool, + _ => genkit.Role.user, + }; +} + class AiPromptMessage { AiPromptMessage({required this.role, required this.content}); - final String role; + final AiPromptRole role; final String content; + + genkit.Message toGenkitMessage() => genkit.Message( + role: role.toGenkitRole(), + content: [genkit.TextPart(text: content)], + ); } diff --git a/lib/ai/model/ai_tool.dart b/lib/ai/model/ai_tool.dart deleted file mode 100644 index 690f729a8d..0000000000 --- a/lib/ai/model/ai_tool.dart +++ /dev/null @@ -1,11 +0,0 @@ -class AiToolDefinition { - const AiToolDefinition({ - required this.name, - required this.description, - required this.inputSchema, - }); - - final String name; - final String description; - final Map inputSchema; -} diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index cd97868cef..a10002686f 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -1,14 +1,23 @@ +import 'dart:convert'; import 'dart:math' as math; +import 'package:genkit/genkit.dart' as genkit; +import 'package:mixin_logger/mixin_logger.dart'; +import 'package:schemantic/schemantic.dart'; + import '../../db/dao/message_dao.dart'; import '../../db/database.dart'; import '../../db/mixin_database.dart'; -import '../model/ai_tool.dart'; +import '../model/ai_chat_metadata.dart'; const _kDefaultConversationChunkSize = 100; const _kMaxConversationChunkSize = 200; const _kDefaultConversationSearchLimit = 8; const _kMaxConversationSearchLimit = 20; +const _kAiToolLogPreviewLength = 480; + +typedef AiConversationToolEventSink = + Future Function(Map event); class AiConversationToolMessage { const AiConversationToolMessage({ @@ -401,226 +410,420 @@ class AiConversationToolKit { final AiConversationToolService service; - static const definitions = [ - AiToolDefinition( + List genkitTools({ + required String conversationId, + AiConversationToolEventSink? onEvent, + }) => [ + genkit.Tool>( name: 'get_conversation_stats', description: 'Get message counts and boundary timestamps for the current conversation or a specific time range.', - inputSchema: { - 'type': 'object', - 'properties': { - 'start_time': { - 'type': 'string', - 'description': 'Optional inclusive ISO-8601 start time.', - }, - 'end_time': { - 'type': 'string', - 'description': 'Optional exclusive ISO-8601 end time.', - }, + inputSchema: GetConversationStatsInput.schema, + fn: (input, context) => _executeTool( + conversationId: conversationId, + name: 'get_conversation_stats', + arguments: input.toArguments(), + context: context, + onEvent: onEvent, + fn: () async { + final stats = await service.getConversationStats( + conversationId: conversationId, + startInclusive: input.startInclusive, + endExclusive: input.endExclusive, + ); + return stats.toJson(); }, - 'additionalProperties': false, - }, + ), ), - AiToolDefinition( + genkit.Tool>( name: 'list_conversation_chunks', description: 'List chunk offsets that can be used to read the current conversation in fixed-size batches, optionally scoped to a time range.', - inputSchema: { - 'type': 'object', - 'properties': { - 'chunk_size': { - 'type': 'integer', - 'description': 'Optional chunk size between 1 and 200.', - }, - 'start_time': { - 'type': 'string', - 'description': 'Optional inclusive ISO-8601 start time.', - }, - 'end_time': { - 'type': 'string', - 'description': 'Optional exclusive ISO-8601 end time.', - }, + inputSchema: ListConversationChunksInput.schema, + fn: (input, context) => _executeTool( + conversationId: conversationId, + name: 'list_conversation_chunks', + arguments: input.toArguments(), + context: context, + onEvent: onEvent, + fn: () async { + final chunks = await service.listConversationChunks( + conversationId: conversationId, + chunkSize: input.chunkSize, + startInclusive: input.startInclusive, + endExclusive: input.endExclusive, + ); + return chunks.toJson(); }, - 'additionalProperties': false, - }, + ), ), - AiToolDefinition( + genkit.Tool>( name: 'read_conversation_chunk', description: 'Read a batch of messages from the current conversation by offset and limit, optionally scoped to a time range.', - inputSchema: { - 'type': 'object', - 'properties': { - 'offset': { - 'type': 'integer', - 'description': 'Zero-based offset into the matching message list.', - }, - 'limit': { - 'type': 'integer', - 'description': 'Number of messages to read, between 1 and 200.', - }, - 'start_time': { - 'type': 'string', - 'description': 'Optional inclusive ISO-8601 start time.', - }, - 'end_time': { - 'type': 'string', - 'description': 'Optional exclusive ISO-8601 end time.', - }, + inputSchema: ReadConversationChunkInput.schema, + fn: (input, context) => _executeTool( + conversationId: conversationId, + name: 'read_conversation_chunk', + arguments: input.toArguments(), + context: context, + onEvent: onEvent, + fn: () async { + final page = await service.readConversationChunk( + conversationId: conversationId, + offset: input.offset, + limit: input.limit, + startInclusive: input.startInclusive, + endExclusive: input.endExclusive, + ); + return page.toJson(); }, - 'required': ['offset'], - 'additionalProperties': false, - }, + ), ), - AiToolDefinition( + genkit.Tool>( name: 'search_conversation_messages', description: 'Search the current conversation for messages relevant to a query string.', - inputSchema: { - 'type': 'object', - 'properties': { - 'query': { - 'type': 'string', - 'description': 'Search query text.', - }, - 'limit': { - 'type': 'integer', - 'description': - 'Maximum number of matches to return, between 1 and 20.', - }, + inputSchema: SearchConversationMessagesInput.schema, + fn: (input, context) => _executeTool( + conversationId: conversationId, + name: 'search_conversation_messages', + arguments: input.toArguments(), + context: context, + onEvent: onEvent, + fn: () async { + final result = await service.searchConversationMessages( + conversationId: conversationId, + query: input.query, + limit: input.limit, + ); + return result.toJson(); }, - 'required': ['query'], - 'additionalProperties': false, - }, + ), ), ]; - Future> execute({ + Future> _executeTool({ required String conversationId, required String name, required Map arguments, + required genkit.ToolFnArgs context, + required Future> Function() fn, + required AiConversationToolEventSink? onEvent, }) async { - switch (name) { - case 'get_conversation_stats': - final (startInclusive, endExclusive) = _parseRange(arguments); - final stats = await service.getConversationStats( - conversationId: conversationId, - startInclusive: startInclusive, - endExclusive: endExclusive, - ); - return stats.toJson(); - case 'list_conversation_chunks': - final (startInclusive, endExclusive) = _parseRange(arguments); - final chunkSize = _parseInt( + final request = context.toolRequest?.toolRequest; + final id = request?.ref ?? '${name}_${arguments.hashCode}'; + final stopwatch = Stopwatch()..start(); + d( + 'AI tool execute start: conversationId=$conversationId ' + 'tool=$name id=$id arguments=${_previewJson(arguments)}', + ); + await onEvent?.call( + createAiToolCallEvent(id: id, name: name, arguments: arguments), + ); + try { + final result = await fn(); + d( + 'AI tool execute done: conversationId=$conversationId ' + 'tool=$name id=$id elapsedMs=${stopwatch.elapsedMilliseconds} ' + 'result=${_previewJson(result)}', + ); + await onEvent?.call( + createAiToolResultEvent( + id: id, + name: name, + status: 'done', + elapsedMs: stopwatch.elapsedMilliseconds, + resultPreview: _previewJson(result), + ), + ); + return result; + } catch (error, stacktrace) { + e('AI tool execution error: $error, $stacktrace'); + await onEvent?.call( + createAiToolResultEvent( + id: id, + name: name, + status: 'error', + elapsedMs: stopwatch.elapsedMilliseconds, + errorText: error.toString(), + ), + ); + return {'error': '$error'}; + } + } +} + +class GetConversationStatsInput { + const GetConversationStatsInput({ + this.startInclusive, + this.endExclusive, + }); + + final DateTime? startInclusive; + final DateTime? endExclusive; + + static final schema = SchemanticType.from( + jsonSchema: _rangeSchema(), + parse: (value) { + final arguments = _jsonMap(value); + final (startInclusive, endExclusive) = _parseRange(arguments); + return GetConversationStatsInput( + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + }, + ); + + Map toArguments() => { + 'start_time': startInclusive?.toIso8601String(), + 'end_time': endExclusive?.toIso8601String(), + }..removeWhere((_, value) => value == null); +} + +class ListConversationChunksInput { + const ListConversationChunksInput({ + required this.chunkSize, + this.startInclusive, + this.endExclusive, + }); + + final int chunkSize; + final DateTime? startInclusive; + final DateTime? endExclusive; + + static final schema = SchemanticType.from( + jsonSchema: _rangeSchema( + properties: { + 'chunk_size': { + 'type': 'integer', + 'description': 'Optional chunk size between 1 and 200.', + }, + }, + ), + parse: (value) { + final arguments = _jsonMap(value); + final (startInclusive, endExclusive) = _parseRange(arguments); + return ListConversationChunksInput( + chunkSize: _parseInt( arguments, 'chunk_size', defaultValue: _kDefaultConversationChunkSize, min: 1, max: _kMaxConversationChunkSize, - ); - final chunks = await service.listConversationChunks( - conversationId: conversationId, - chunkSize: chunkSize, - startInclusive: startInclusive, - endExclusive: endExclusive, - ); - return chunks.toJson(); - case 'read_conversation_chunk': - final (startInclusive, endExclusive) = _parseRange(arguments); - final offset = _parseInt( + ), + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + }, + ); + + Map toArguments() => { + 'chunk_size': chunkSize, + 'start_time': startInclusive?.toIso8601String(), + 'end_time': endExclusive?.toIso8601String(), + }..removeWhere((_, value) => value == null); +} + +class ReadConversationChunkInput { + const ReadConversationChunkInput({ + required this.offset, + required this.limit, + this.startInclusive, + this.endExclusive, + }); + + final int offset; + final int limit; + final DateTime? startInclusive; + final DateTime? endExclusive; + + static final schema = SchemanticType.from( + jsonSchema: _rangeSchema( + properties: { + 'offset': { + 'type': 'integer', + 'description': 'Zero-based offset into the matching message list.', + }, + 'limit': { + 'type': 'integer', + 'description': 'Number of messages to read, between 1 and 200.', + }, + }, + required: ['offset'], + ), + parse: (value) { + final arguments = _jsonMap(value); + final (startInclusive, endExclusive) = _parseRange(arguments); + return ReadConversationChunkInput( + offset: _parseInt( arguments, 'offset', defaultValue: 0, min: 0, max: 1 << 20, - ); - final limit = _parseInt( + ), + limit: _parseInt( arguments, 'limit', defaultValue: _kDefaultConversationChunkSize, min: 1, max: _kMaxConversationChunkSize, - ); - final page = await service.readConversationChunk( - conversationId: conversationId, - offset: offset, - limit: limit, - startInclusive: startInclusive, - endExclusive: endExclusive, - ); - return page.toJson(); - case 'search_conversation_messages': - final query = _parseRequiredString(arguments, 'query'); - final limit = _parseInt( + ), + startInclusive: startInclusive, + endExclusive: endExclusive, + ); + }, + ); + + Map toArguments() => { + 'offset': offset, + 'limit': limit, + 'start_time': startInclusive?.toIso8601String(), + 'end_time': endExclusive?.toIso8601String(), + }..removeWhere((_, value) => value == null); +} + +class SearchConversationMessagesInput { + const SearchConversationMessagesInput({ + required this.query, + required this.limit, + }); + + final String query; + final int limit; + + static final schema = SchemanticType.from( + jsonSchema: { + 'type': 'object', + 'properties': { + 'query': { + 'type': 'string', + 'description': 'Search query text.', + }, + 'limit': { + 'type': 'integer', + 'description': + 'Maximum number of matches to return, between 1 and 20.', + }, + }, + 'required': ['query'], + 'additionalProperties': false, + }, + parse: (value) { + final arguments = _jsonMap(value); + return SearchConversationMessagesInput( + query: _parseRequiredString(arguments, 'query'), + limit: _parseInt( arguments, 'limit', defaultValue: _kDefaultConversationSearchLimit, min: 1, max: _kMaxConversationSearchLimit, - ); - final result = await service.searchConversationMessages( - conversationId: conversationId, - query: query, - limit: limit, - ); - return result.toJson(); - default: - throw UnsupportedError('Unknown conversation tool: $name'); - } + ), + ); + }, + ); + + Map toArguments() => { + 'query': query, + 'limit': limit, + }; +} + +Map _rangeSchema({ + Map properties = const {}, + List required = const [], +}) => { + 'type': 'object', + 'properties': { + 'start_time': { + 'type': 'string', + 'description': 'Optional inclusive ISO-8601 start time.', + }, + 'end_time': { + 'type': 'string', + 'description': 'Optional exclusive ISO-8601 end time.', + }, + ...properties, + }, + if (required.isNotEmpty) 'required': required, + 'additionalProperties': false, +}; + +(DateTime?, DateTime?) _parseRange(Map arguments) { + final startInclusive = _parseDateTime(arguments, 'start_time'); + final endExclusive = _parseDateTime(arguments, 'end_time'); + if (startInclusive != null && + endExclusive != null && + !endExclusive.isAfter(startInclusive)) { + throw const FormatException('end_time must be later than start_time'); } + return (startInclusive, endExclusive); +} - (DateTime?, DateTime?) _parseRange(Map arguments) { - final startInclusive = _parseDateTime(arguments, 'start_time'); - final endExclusive = _parseDateTime(arguments, 'end_time'); - if (startInclusive != null && - endExclusive != null && - !endExclusive.isAfter(startInclusive)) { - throw const FormatException('end_time must be later than start_time'); - } - return (startInclusive, endExclusive); +DateTime? _parseDateTime(Map arguments, String key) { + final raw = arguments[key]; + if (raw == null) { + return null; } + if (raw is! String || raw.trim().isEmpty) { + throw FormatException('$key must be an ISO-8601 string'); + } + final value = DateTime.tryParse(raw.trim()); + if (value == null) { + throw FormatException('$key must be a valid ISO-8601 string'); + } + return value; +} - DateTime? _parseDateTime(Map arguments, String key) { - final raw = arguments[key]; - if (raw == null) { - return null; - } - if (raw is! String || raw.trim().isEmpty) { - throw FormatException('$key must be an ISO-8601 string'); - } - final value = DateTime.tryParse(raw.trim()); - if (value == null) { - throw FormatException('$key must be a valid ISO-8601 string'); - } - return value; +int _parseInt( + Map arguments, + String key, { + required int defaultValue, + required int min, + required int max, +}) { + final raw = arguments[key]; + if (raw == null) { + return defaultValue; } + final value = switch (raw) { + final int value => value, + final String value => + int.tryParse(value.trim()) ?? + (throw FormatException('$key must be an integer')), + _ => throw FormatException('$key must be an integer'), + }; + return value.clamp(min, max); +} - int _parseInt( - Map arguments, - String key, { - required int defaultValue, - required int min, - required int max, - }) { - final raw = arguments[key]; - if (raw == null) { - return defaultValue; - } - final value = switch (raw) { - final int value => value, - final String value => - int.tryParse(value.trim()) ?? - (throw FormatException('$key must be an integer')), - _ => throw FormatException('$key must be an integer'), - }; - return value.clamp(min, max); +String _parseRequiredString(Map arguments, String key) { + final raw = arguments[key]; + if (raw is! String || raw.trim().isEmpty) { + throw FormatException('$key must be a non-empty string'); } + return raw.trim(); +} + +Map _jsonMap(dynamic value) { + if (value is Map) { + return value; + } + if (value is Map) { + return value.map((key, value) => MapEntry('$key', value)); + } + throw Exception('Invalid AI tool arguments'); +} - String _parseRequiredString(Map arguments, String key) { - final raw = arguments[key]; - if (raw is! String || raw.trim().isEmpty) { - throw FormatException('$key must be a non-empty string'); +String _previewJson(Object? value) { + try { + final encoded = jsonEncode(value); + if (encoded.length <= _kAiToolLogPreviewLength) { + return encoded; } - return raw.trim(); + return '${encoded.substring(0, _kAiToolLogPreviewLength)}...(${encoded.length} chars)'; + } catch (_) { + return '$value'; } } diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart index d5dd230978..e0dbc79e21 100644 --- a/lib/ui/setting/ai_provider_edit_page.dart +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -117,7 +117,9 @@ class AiProviderEditPage extends HookConsumerWidget { TextButton( onPressed: () { final name = nameController.text.trim(); - final baseUrl = baseUrlController.text.trim(); + final baseUrl = providerType.value == AiProviderType.gemini + ? '' + : baseUrlController.text.trim(); final apiKey = apiKeyController.text.trim(); final normalizedModels = _normalizeModels(models.value); final resolvedDefaultModel = _resolveDefaultModel( @@ -125,7 +127,8 @@ class AiProviderEditPage extends HookConsumerWidget { defaultModel.value, ); if (name.isEmpty || - baseUrl.isEmpty || + (providerType.value != AiProviderType.gemini && + baseUrl.isEmpty) || apiKey.isEmpty || normalizedModels.isEmpty || resolvedDefaultModel.isEmpty) { @@ -258,46 +261,48 @@ class AiProviderEditPage extends HookConsumerWidget { ], ), ), - const _SectionLabel( - title: 'Endpoint', - ), - CellGroup( - padding: const EdgeInsets.only(right: 10, left: 10), - cellBackgroundColor: theme.settingCellBackgroundColor, - child: _FormFieldCell( - label: 'Base URL', - backgroundColor: inputBackgroundColor, - borderColor: inputBorderColor, - child: TextField( - controller: baseUrlController, - keyboardType: TextInputType.url, - style: TextStyle( - color: theme.text, - fontSize: 16, - ), - decoration: InputDecoration( - isDense: true, - border: InputBorder.none, - hintText: _baseUrlHintFor(providerType.value), - hintStyle: TextStyle(color: theme.secondaryText), + if (providerType.value != AiProviderType.gemini) ...[ + const _SectionLabel( + title: 'Endpoint', + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: theme.settingCellBackgroundColor, + child: _FormFieldCell( + label: 'Base URL', + backgroundColor: inputBackgroundColor, + borderColor: inputBorderColor, + child: TextField( + controller: baseUrlController, + keyboardType: TextInputType.url, + style: TextStyle( + color: theme.text, + fontSize: 16, + ), + decoration: InputDecoration( + isDense: true, + border: InputBorder.none, + hintText: _baseUrlHintFor(providerType.value), + hintStyle: TextStyle(color: theme.secondaryText), + ), ), ), ), - ), - Padding( - padding: const EdgeInsets.only( - left: 20, - bottom: 14, - top: 10, - ), - child: Text( - _baseUrlHelperTextFor(providerType.value), - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 14, + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + _baseUrlHelperTextFor(providerType.value), + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), ), ), - ), + ], const _SectionLabel( title: 'Authorization', ), @@ -440,7 +445,7 @@ class AiProviderEditPage extends HookConsumerWidget { static String _defaultBaseUrlFor(AiProviderType type) => switch (type) { AiProviderType.openaiCompatible => '', AiProviderType.anthropic => 'https://api.anthropic.com/v1', - AiProviderType.gemini => 'https://generativelanguage.googleapis.com/v1beta', + AiProviderType.gemini => '', }; static String _baseUrlHintFor(AiProviderType type) => switch (type) { diff --git a/test/ai/ai_provider_requester_test.dart b/test/ai/ai_provider_requester_test.dart new file mode 100644 index 0000000000..bb55fd83fe --- /dev/null +++ b/test/ai/ai_provider_requester_test.dart @@ -0,0 +1,70 @@ +import 'package:dio/dio.dart'; +import 'package:flutter_app/ai/ai_provider_requester.dart'; +import 'package:flutter_app/ai/model/ai_prompt_message.dart'; +import 'package:flutter_app/ai/model/ai_provider_config.dart'; +import 'package:flutter_app/ai/model/ai_provider_type.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:genkit/genkit.dart' as genkit; + +void main() { + group('AI provider requester', () { + test('maps prompt messages to Genkit messages', () { + final userMessage = AiPromptMessage( + role: AiPromptRole.user, + content: 'hello', + ).toGenkitMessage(); + final assistantMessage = AiPromptMessage( + role: AiPromptRole.assistant, + content: 'hi', + ).toGenkitMessage(); + final systemMessage = AiPromptMessage( + role: AiPromptRole.system, + content: 'rules', + ).toGenkitMessage(); + final unknownMessage = AiPromptMessage( + role: AiPromptRole('unknown'), + content: 'fallback', + ).toGenkitMessage(); + + expect(userMessage.role, genkit.Role.user); + expect(userMessage.text, 'hello'); + expect(assistantMessage.role, genkit.Role.model); + expect(assistantMessage.text, 'hi'); + expect(systemMessage.role, genkit.Role.system); + expect(systemMessage.text, 'rules'); + expect(unknownMessage.role, genkit.Role.user); + expect(unknownMessage.text, 'fallback'); + }); + + test('throws before creating a request when cancelled', () async { + final cancelToken = CancelToken()..cancel('stopped'); + + await expectLater( + const AiProviderRequester().requestText( + AiProviderConfig( + id: 'provider-id', + name: 'Provider', + type: AiProviderType.openaiCompatible, + baseUrl: 'https://api.example.com/v1', + apiKey: 'key', + model: 'test-model', + ), + [ + AiPromptMessage(role: AiPromptRole.user, content: 'hello'), + ], + proxy: null, + cancelToken: cancelToken, + onContent: (_) async {}, + conversationId: null, + ), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('AI generation stopped'), + ), + ), + ); + }); + }); +} From a1d62b15b31547bc564b0d53dc99270e85ea6432 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 17:07:25 +0800 Subject: [PATCH 27/52] refactor: extract user and response message cards into separate components --- lib/widgets/ai/ai_message_card.dart | 230 ++++++++++++---------- lib/widgets/ai/ai_text_result_dialog.dart | 124 ------------ 2 files changed, 126 insertions(+), 228 deletions(-) delete mode 100644 lib/widgets/ai/ai_text_result_dialog.dart diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 5ea64b3020..75feab79ed 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -42,103 +42,140 @@ class AiMessageCard extends StatelessWidget { final sameRoleNext = next?.role == message.role; final mergedWithPrev = sameDayPrev && sameRolePrev; final mergedWithNext = sameDayNext && sameRoleNext; - final body = isUser - ? ConstrainedBox( - constraints: const BoxConstraints(maxWidth: 420), - child: _AiMessageBody(message: message), - ) - : _AiMessageBody(message: message); if (isUser) { - return Padding( - padding: EdgeInsets.only( - left: 72, - right: 8, - top: mergedWithPrev ? 4 : 14, - bottom: 4, - ), - child: Align( - alignment: Alignment.centerRight, - child: _AiMessageMenu( - message: message, - child: _AiBubble( - isCurrentUser: true, - showNip: !mergedWithNext, - color: _bubbleColor( - context, - isUser: true, - status: message.status, - ), - child: body, - ), - ), - ), + return _AiUserMessageCard( + message: message, + mergedWithPrev: mergedWithPrev, + mergedWithNext: mergedWithNext, ); } - return Padding( - padding: EdgeInsets.only( - top: mergedWithPrev ? 6 : 18, - bottom: 6, - ), + return _AiResponseMessageCard( + message: message, + mergedWithPrev: mergedWithPrev, + ); + } +} + +class _AiUserMessageCard extends StatelessWidget { + const _AiUserMessageCard({ + required this.message, + required this.mergedWithPrev, + required this.mergedWithNext, + }); + + final AiChatMessage message; + final bool mergedWithPrev; + final bool mergedWithNext; + + @override + Widget build(BuildContext context) => Padding( + padding: EdgeInsets.only( + left: 36, + right: 8, + top: mergedWithPrev ? 4 : 14, + bottom: 4, + ), + child: Align( + alignment: Alignment.centerRight, child: _AiMessageMenu( message: message, - child: body, + child: _AiBubble( + isCurrentUser: true, + showNip: !mergedWithNext, + color: context.theme.ai.userBubble, + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 420), + child: MessageLayout( + spacing: 6, + content: _AiUserMessageBody(message: message), + dateAndStatus: MessageMetaRow(dateTime: message.createdAt), + ), + ), + ), ), - ); - } + ), + ); +} + +class _AiResponseMessageCard extends StatelessWidget { + const _AiResponseMessageCard({ + required this.message, + required this.mergedWithPrev, + }); + + final AiChatMessage message; + final bool mergedWithPrev; + + @override + Widget build(BuildContext context) => Padding( + padding: EdgeInsets.only( + top: mergedWithPrev ? 6 : 18, + bottom: 6, + ), + child: _AiMessageMenu( + message: message, + child: Column( + spacing: 6, + children: [ + _AiResponseMessageBody(message: message), + const SizedBox(height: 4), + _AiResponseFooter( + model: message.model, + dateTime: message.createdAt, + ), + ], + ), + ), + ); } -class _AiMessageBody extends StatelessWidget { - const _AiMessageBody({required this.message}); +class _AiUserMessageBody extends StatelessWidget { + const _AiUserMessageBody({required this.message}); + + final AiChatMessage message; + + @override + Widget build(BuildContext context) => _AiSelectableText( + text: _displayText(message), + style: _aiMessageTextStyle(context, message), + ); +} + +class _AiResponseMessageBody extends StatelessWidget { + const _AiResponseMessageBody({required this.message}); final AiChatMessage message; @override Widget build(BuildContext context) { - final isUser = message.role == 'user'; - final text = _displayText(message); final isPendingAssistant = - !isUser && - message.status == 'pending' && - message.content.trim().isEmpty; - - Widget body; - final textStyle = TextStyle( - color: message.status == 'error' - ? context.theme.ai.error - : context.theme.text, - fontSize: context.messageStyle.primaryFontSize, - height: 1.45, - ); + message.status == 'pending' && message.content.trim().isEmpty; + final textStyle = _aiMessageTextStyle(context, message); if (isPendingAssistant) { - body = _AiPendingAssistantActivity(message: message, style: textStyle); - } else if (isUser || message.status == 'error') { - body = _AiSelectableText(text: text, style: textStyle); - } else { - final cacheKey = buildMarkdownCacheKey( - namespace: 'ai', - id: message.id, - ); - body = DefaultTextStyle.merge( + return _AiPendingAssistantActivity(message: message, style: textStyle); + } + + if (message.status == 'error') { + return _AiSelectableText( + text: _displayText(message), style: textStyle, - child: MarkdownColumn( - data: text, - selectable: true, - cacheKey: cacheKey, - streaming: message.status == 'pending', - ), ); } - return MessageLayout( - spacing: 6, - content: body, - dateAndStatus: _AiFooter( - isUser: isUser, - model: message.model, - dateTime: message.createdAt, + final cacheKey = buildMarkdownCacheKey( + namespace: 'ai', + id: message.id, + ); + return DefaultTextStyle.merge( + style: textStyle, + child: MarkdownColumn( + data: _displayText(message), + selectable: true, + cacheKey: cacheKey, + streaming: message.status == 'pending', ), ); } @@ -224,6 +261,14 @@ class _AiSelectableTextState extends State<_AiSelectableText> { } } +TextStyle _aiMessageTextStyle(BuildContext context, AiChatMessage message) => + TextStyle( + color: message.status == 'error' + ? context.theme.ai.error + : context.theme.text, + fontSize: context.messageStyle.primaryFontSize, + ); + class _AiBubble extends StatelessWidget { const _AiBubble({ required this.child, @@ -268,7 +313,7 @@ class _AiMessageMenu extends StatelessWidget { @override Widget build(BuildContext context) { - final content = _menuCopyText(message); + final content = _displayText(message); return Builder( builder: (childContext) => CustomContextMenuWidget( @@ -350,23 +395,17 @@ SelectedContent? _findSelectedContent(BuildContext context) { return null; } -class _AiFooter extends StatelessWidget { - const _AiFooter({ - required this.isUser, +class _AiResponseFooter extends StatelessWidget { + const _AiResponseFooter({ required this.model, required this.dateTime, }); - final bool isUser; final String? model; final DateTime dateTime; @override Widget build(BuildContext context) { - if (isUser) { - return MessageMetaRow(dateTime: dateTime); - } - final metaColor = context.dynamicColor( const Color.fromRGBO(131, 145, 158, 1), darkColor: const Color.fromRGBO(128, 131, 134, 1), @@ -376,16 +415,17 @@ class _AiFooter extends StatelessWidget { color: metaColor, ); final dateTimeText = DateFormat.Hm().format(dateTime.toLocal()); - final trimmedModel = isUser ? null : model?.trim(); + final trimmedModel = model?.trim(); return SelectionContainer.disabled( child: SizedBox( width: double.infinity, child: Row( children: [ + const SizedBox(width: 4), Text(dateTimeText, style: textStyle), if (trimmedModel != null && trimmedModel.isNotEmpty) ...[ - const Spacer(), + const SizedBox(width: 12), Text(trimmedModel, style: textStyle), ], ], @@ -395,24 +435,6 @@ class _AiFooter extends StatelessWidget { } } -Color _bubbleColor( - BuildContext context, { - required bool isUser, - required String status, -}) { - if (status == 'error') { - return context.theme.ai.errorBubble; - } - - if (isUser) { - return context.theme.ai.userBubble; - } - - return context.theme.ai.assistantBubble; -} - -String _menuCopyText(AiChatMessage message) => _displayText(message); - String _displayText(AiChatMessage message) { final content = message.content.trim(); if (content.isNotEmpty) return content; diff --git a/lib/widgets/ai/ai_text_result_dialog.dart b/lib/widgets/ai/ai_text_result_dialog.dart deleted file mode 100644 index d4dab9589b..0000000000 --- a/lib/widgets/ai/ai_text_result_dialog.dart +++ /dev/null @@ -1,124 +0,0 @@ -import 'package:flutter/material.dart'; -import 'package:flutter/services.dart'; - -import '../../utils/extension/extension.dart'; -import '../dialog.dart'; -import '../toast.dart'; - -enum AiTextResultAction { replace, insert } - -Future showAiTextResultDialog({ - required BuildContext context, - required String title, - required String result, - String? original, - bool allowReplace = true, -}) => showMixinDialog( - context: context, - constraints: const BoxConstraints(maxWidth: 560), - child: AlertDialogLayout( - minWidth: 420, - minHeight: 0, - titleMarginBottom: 20, - title: Text(title), - content: _AiTextResultContent(original: original, result: result), - actions: [ - MixinButton( - backgroundTransparent: true, - child: const Text('Copy'), - onTap: () { - Clipboard.setData(ClipboardData(text: result)); - showToastSuccessful(context: context); - Navigator.pop(context); - }, - ), - const MixinButton( - backgroundTransparent: true, - value: AiTextResultAction.insert, - child: Text('Insert'), - ), - if (allowReplace) - const MixinButton( - value: AiTextResultAction.replace, - child: Text('Replace'), - ), - ], - ), -); - -class _AiTextResultContent extends StatelessWidget { - const _AiTextResultContent({required this.result, this.original}); - - final String? original; - final String result; - - @override - Widget build(BuildContext context) { - final original = this.original?.trim(); - return ConstrainedBox( - constraints: const BoxConstraints(maxHeight: 420), - child: SingleChildScrollView( - child: DefaultTextStyle.merge( - style: TextStyle( - color: context.theme.text, - fontSize: 14, - fontWeight: FontWeight.normal, - height: 1.45, - ), - child: Column( - crossAxisAlignment: CrossAxisAlignment.start, - mainAxisSize: MainAxisSize.min, - children: [ - if (original != null && original.isNotEmpty) ...[ - const _SectionLabel('Original'), - _TextBlock(original), - const SizedBox(height: 16), - ], - const _SectionLabel('AI'), - _TextBlock(result), - ], - ), - ), - ), - ); - } -} - -class _SectionLabel extends StatelessWidget { - const _SectionLabel(this.text); - - final String text; - - @override - Widget build(BuildContext context) => Padding( - padding: const EdgeInsets.only(bottom: 6), - child: Text( - text, - style: TextStyle( - color: context.theme.secondaryText, - fontSize: 12, - fontWeight: FontWeight.w500, - ), - ), - ); -} - -class _TextBlock extends StatelessWidget { - const _TextBlock(this.text); - - final String text; - - @override - Widget build(BuildContext context) => Container( - width: double.infinity, - padding: const EdgeInsets.all(12), - decoration: BoxDecoration( - color: context.dynamicColor( - const Color.fromRGBO(245, 247, 250, 1), - darkColor: const Color.fromRGBO(255, 255, 255, 0.08), - ), - borderRadius: const BorderRadius.all(Radius.circular(6)), - ), - child: SelectableText(text), - ); -} From 91b66be71e672fbba263087f69da689a276fbafc Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 17:12:51 +0800 Subject: [PATCH 28/52] feat: add model testing functionality in AI provider edit page --- lib/ui/setting/ai_provider_edit_page.dart | 83 +++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart index e0dbc79e21..66429690aa 100644 --- a/lib/ui/setting/ai_provider_edit_page.dart +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -1,8 +1,11 @@ +import 'package:dio/dio.dart'; import 'package:flutter/material.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:uuid/uuid.dart'; +import '../../ai/ai_provider_requester.dart'; +import '../../ai/model/ai_prompt_message.dart'; import '../../ai/model/ai_provider_config.dart'; import '../../ai/model/ai_provider_type.dart'; import '../../utils/extension/extension.dart'; @@ -53,6 +56,7 @@ class AiProviderEditPage extends HookConsumerWidget { ), ); final obscureApiKey = useState(true); + final testingModel = useState(null); useEffect(() { if (initial != null) return null; @@ -109,6 +113,64 @@ class AiProviderEditPage extends HookConsumerWidget { ); } + Future testModel(String model) async { + final name = nameController.text.trim().isEmpty + ? 'Test Provider' + : nameController.text.trim(); + final baseUrl = providerType.value == AiProviderType.gemini + ? '' + : baseUrlController.text.trim(); + final apiKey = apiKeyController.text.trim(); + if ((providerType.value != AiProviderType.gemini && baseUrl.isEmpty) || + apiKey.isEmpty || + model.trim().isEmpty) { + showToastFailed(ToastError('Please complete provider settings first')); + return; + } + + testingModel.value = model; + showToastLoading(context: context); + final stopwatch = Stopwatch()..start(); + try { + await const AiProviderRequester().requestText( + AiProviderConfig( + id: initial?.id ?? const Uuid().v4(), + name: name, + type: providerType.value, + baseUrl: baseUrl, + apiKey: apiKey, + model: model, + models: [model], + defaultModel: model, + ), + [ + AiPromptMessage( + role: AiPromptRole.user, + content: 'Reply with exactly: OK', + ), + ], + proxy: database.settingProperties.activatedProxy, + cancelToken: CancelToken(), + onContent: (_) async {}, + conversationId: null, + ); + stopwatch.stop(); + if (!context.mounted) return; + showToast( + 'Model works · ${stopwatch.elapsedMilliseconds} ms', + context: context, + ); + } catch (error) { + stopwatch.stop(); + if (!context.mounted) return; + showToastFailed(error, context: context); + } finally { + if (context.mounted && testingModel.value == model) { + testingModel.value = null; + } + } + } + return Scaffold( backgroundColor: theme.background, appBar: MixinAppBar( @@ -404,7 +466,9 @@ class AiProviderEditPage extends HookConsumerWidget { _ModelItem( model: models.value[i], selected: models.value[i] == defaultModel.value, + testing: testingModel.value == models.value[i], onTap: () => defaultModel.value = models.value[i], + onTest: () => testModel(models.value[i]), onEdit: () => showModelDialog( initialValue: models.value[i], index: i, @@ -567,14 +631,18 @@ class _ModelItem extends StatelessWidget { const _ModelItem({ required this.model, required this.selected, + required this.testing, required this.onTap, + required this.onTest, required this.onEdit, required this.onDelete, }); final String model; final bool selected; + final bool testing; final VoidCallback onTap; + final VoidCallback onTest; final VoidCallback onEdit; final VoidCallback onDelete; @@ -618,6 +686,21 @@ class _ModelItem extends StatelessWidget { trailing: Row( mainAxisSize: MainAxisSize.min, children: [ + IconButton( + tooltip: 'Test model', + onPressed: testing ? null : onTest, + icon: testing + ? SizedBox.square( + dimension: 18, + child: CircularProgressIndicator( + strokeWidth: 2, + valueColor: AlwaysStoppedAnimation( + context.theme.accent, + ), + ), + ) + : Icon(Icons.speed_rounded, color: context.theme.icon), + ), IconButton( onPressed: onEdit, icon: Icon(Icons.edit_outlined, color: context.theme.icon), From 2c5f62a14216af7ab9cbe51b79cc651da0b9c7a7 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 17:25:19 +0800 Subject: [PATCH 29/52] feat: integrate TOON format for tool results and update tool result handling --- lib/ai/model/ai_prompt_template.dart | 10 +++-- .../tools/ai_conversation_tool_service.dart | 44 +++++++++++++++---- pubspec.lock | 9 ++++ pubspec.yaml | 4 ++ 4 files changed, 54 insertions(+), 13 deletions(-) diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart index 09b4b452ac..89968b5e3e 100644 --- a/lib/ai/model/ai_prompt_template.dart +++ b/lib/ai/model/ai_prompt_template.dart @@ -282,10 +282,12 @@ const conversationToolInstructionPromptTemplate = 'Read-only conversation tools are available for the current ' 'conversation. Use them when you need exhaustive coverage, ' 'date-scoped summaries, statistics, older messages, or more ' - 'context than the provided messages. When answering the user, ' - 'default to {{language}} unless the user explicitly requires ' - 'another language or preserving the source language. Do not call ' - 'tools when the provided context is already sufficient.'; + 'context than the provided messages. Tool results are returned in ' + 'TOON format, a compact tabular notation for structured data. ' + 'When answering the user, default to {{language}} unless the user ' + 'explicitly requires another language or preserving the source ' + 'language. Do not call tools when the provided context is already ' + 'sufficient.'; const recentConversationContextPromptTemplate = 'Current conversation recent messages:\n{{messages}}'; diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index a10002686f..bd0ea8f3fa 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -4,6 +4,7 @@ import 'dart:math' as math; import 'package:genkit/genkit.dart' as genkit; import 'package:mixin_logger/mixin_logger.dart'; import 'package:schemantic/schemantic.dart'; +import 'package:toon_format/toon_format.dart'; import '../../db/dao/message_dao.dart'; import '../../db/database.dart'; @@ -414,7 +415,7 @@ class AiConversationToolKit { required String conversationId, AiConversationToolEventSink? onEvent, }) => [ - genkit.Tool>( + genkit.Tool( name: 'get_conversation_stats', description: 'Get message counts and boundary timestamps for the current conversation or a specific time range.', @@ -435,7 +436,7 @@ class AiConversationToolKit { }, ), ), - genkit.Tool>( + genkit.Tool( name: 'list_conversation_chunks', description: 'List chunk offsets that can be used to read the current conversation in fixed-size batches, optionally scoped to a time range.', @@ -457,7 +458,7 @@ class AiConversationToolKit { }, ), ), - genkit.Tool>( + genkit.Tool( name: 'read_conversation_chunk', description: 'Read a batch of messages from the current conversation by offset and limit, optionally scoped to a time range.', @@ -480,7 +481,7 @@ class AiConversationToolKit { }, ), ), - genkit.Tool>( + genkit.Tool( name: 'search_conversation_messages', description: 'Search the current conversation for messages relevant to a query string.', @@ -503,7 +504,7 @@ class AiConversationToolKit { ), ]; - Future> _executeTool({ + Future _executeTool({ required String conversationId, required String name, required Map arguments, @@ -523,10 +524,11 @@ class AiConversationToolKit { ); try { final result = await fn(); + final encodedResult = _encodeToolResult(result); d( 'AI tool execute done: conversationId=$conversationId ' 'tool=$name id=$id elapsedMs=${stopwatch.elapsedMilliseconds} ' - 'result=${_previewJson(result)}', + 'result=${_previewText(encodedResult)}', ); await onEvent?.call( createAiToolResultEvent( @@ -534,10 +536,10 @@ class AiConversationToolKit { name: name, status: 'done', elapsedMs: stopwatch.elapsedMilliseconds, - resultPreview: _previewJson(result), + resultPreview: _previewText(encodedResult), ), ); - return result; + return encodedResult; } catch (error, stacktrace) { e('AI tool execution error: $error, $stacktrace'); await onEvent?.call( @@ -549,7 +551,7 @@ class AiConversationToolKit { errorText: error.toString(), ), ); - return {'error': '$error'}; + return _encodeToolResult({'error': '$error'}); } } } @@ -827,3 +829,27 @@ String _previewJson(Object? value) { return '$value'; } } + +String _encodeToolResult(Map result) => + encode(_stripNullValues(result)); + +Object? _stripNullValues(Object? value) { + if (value is Map) { + return { + for (final entry in value.entries) + if (entry.value != null) entry.key: _stripNullValues(entry.value), + }; + } + if (value is List) { + return value.map(_stripNullValues).toList(growable: false); + } + return value; +} + +String _previewText(String value) { + final compact = value.replaceAll(RegExp(r'\s+'), ' ').trim(); + if (compact.length <= _kAiToolLogPreviewLength) { + return compact; + } + return '${compact.substring(0, _kAiToolLogPreviewLength)}...(${compact.length} chars)'; +} diff --git a/pubspec.lock b/pubspec.lock index 7868e5cc35..bcec75369f 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -2098,6 +2098,15 @@ packages: url: "https://pub.dev" source: hosted version: "0.11.0" + toon_format: + dependency: "direct main" + description: + path: "." + ref: "51fa0e9311837b84c24e30827b53891041378448" + resolved-ref: "51fa0e9311837b84c24e30827b53891041378448" + url: "https://github.com/toon-format/toon-dart.git" + source: git + version: "0.1.0" tuple: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 86b27c0bdd..ebbeef908e 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -174,6 +174,10 @@ dependencies: genkit_anthropic: ^0.2.4 genkit_google_genai: ^0.2.4 schemantic: ^0.1.1 + toon_format: + git: + url: https://github.com/toon-format/toon-dart.git + ref: 51fa0e9311837b84c24e30827b53891041378448 dev_dependencies: build_runner: ^2.13.1 From 1b2d28880b92e764fb58e2242d0fc9a32e9c0920 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 17:35:07 +0800 Subject: [PATCH 30/52] refactor: simplify data models and enhance search functionality --- .../tools/ai_conversation_tool_service.dart | 213 ++++++++++-------- 1 file changed, 115 insertions(+), 98 deletions(-) diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index bd0ea8f3fa..e92aa4785b 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -16,6 +16,8 @@ const _kMaxConversationChunkSize = 200; const _kDefaultConversationSearchLimit = 8; const _kMaxConversationSearchLimit = 20; const _kAiToolLogPreviewLength = 480; +const _kMaxConversationMessageTextLength = 1000; +const _kSearchMessageSnippetRadius = 240; typedef AiConversationToolEventSink = Future Function(Map event); @@ -24,7 +26,6 @@ class AiConversationToolMessage { const AiConversationToolMessage({ required this.messageId, required this.createdAt, - required this.senderId, required this.senderName, required this.type, required this.text, @@ -32,15 +33,13 @@ class AiConversationToolMessage { final String messageId; final DateTime createdAt; - final String senderId; final String senderName; final String type; final String text; Map toJson() => { 'message_id': messageId, - 'created_at': createdAt.toIso8601String(), - 'sender_id': senderId, + 'created_at': _formatToolDateTime(createdAt), 'sender_name': senderName, 'type': type, 'text': text, @@ -49,28 +48,23 @@ class AiConversationToolMessage { class AiConversationToolStats { const AiConversationToolStats({ - required this.conversationId, required this.messageCount, - required this.startInclusive, - required this.endExclusive, this.firstMessageAt, this.lastMessageAt, }); - final String conversationId; final int messageCount; - final DateTime? startInclusive; - final DateTime? endExclusive; final DateTime? firstMessageAt; final DateTime? lastMessageAt; Map toJson() => { - 'conversation_id': conversationId, 'message_count': messageCount, - 'start_time': startInclusive?.toIso8601String(), - 'end_time': endExclusive?.toIso8601String(), - 'first_message_at': firstMessageAt?.toIso8601String(), - 'last_message_at': lastMessageAt?.toIso8601String(), + 'first_message_at': firstMessageAt == null + ? null + : _formatToolDateTime(firstMessageAt!), + 'last_message_at': lastMessageAt == null + ? null + : _formatToolDateTime(lastMessageAt!), }; } @@ -94,62 +88,38 @@ class AiConversationToolChunk { class AiConversationToolChunkList { const AiConversationToolChunkList({ - required this.conversationId, - required this.chunkSize, required this.totalMessages, - required this.startInclusive, - required this.endExclusive, required this.chunks, }); - final String conversationId; - final int chunkSize; final int totalMessages; - final DateTime? startInclusive; - final DateTime? endExclusive; final List chunks; Map toJson() => { - 'conversation_id': conversationId, - 'chunk_size': chunkSize, 'total_messages': totalMessages, 'total_chunks': chunks.length, - 'start_time': startInclusive?.toIso8601String(), - 'end_time': endExclusive?.toIso8601String(), 'chunks': chunks.map((chunk) => chunk.toJson()).toList(growable: false), }; } class AiConversationToolChunkPage { const AiConversationToolChunkPage({ - required this.conversationId, required this.offset, - required this.limit, required this.totalMessages, - required this.startInclusive, - required this.endExclusive, required this.messages, required this.nextOffset, }); - final String conversationId; final int offset; - final int limit; final int totalMessages; - final DateTime? startInclusive; - final DateTime? endExclusive; final List messages; final int? nextOffset; Map toJson() => { - 'conversation_id': conversationId, 'offset': offset, - 'limit': limit, 'total_messages': totalMessages, 'returned_count': messages.length, 'next_offset': nextOffset, - 'start_time': startInclusive?.toIso8601String(), - 'end_time': endExclusive?.toIso8601String(), 'messages': messages .map((message) => message.toJson()) .toList(growable: false), @@ -158,22 +128,16 @@ class AiConversationToolChunkPage { class AiConversationToolSearchResult { const AiConversationToolSearchResult({ - required this.conversationId, - required this.query, - required this.limit, required this.messages, + required this.nextAnchorId, }); - final String conversationId; - final String query; - final int limit; final List messages; + final String? nextAnchorId; Map toJson() => { - 'conversation_id': conversationId, - 'query': query, - 'limit': limit, 'returned_count': messages.length, + 'next_anchor_id': nextAnchorId, 'messages': messages .map((message) => message.toJson()) .toList(growable: false), @@ -206,6 +170,7 @@ abstract interface class AiConversationToolService { required String conversationId, required String query, required int limit, + String? anchorMessageId, }); } @@ -253,10 +218,7 @@ class DatabaseAiConversationToolService implements AiConversationToolService { } return AiConversationToolStats( - conversationId: conversationId, messageCount: messageCount, - startInclusive: startInclusive, - endExclusive: endExclusive, firstMessageAt: firstMessageAt, lastMessageAt: lastMessageAt, ); @@ -289,11 +251,7 @@ class DatabaseAiConversationToolService implements AiConversationToolService { ); } return AiConversationToolChunkList( - conversationId: conversationId, - chunkSize: chunkSize, totalMessages: totalMessages, - startInclusive: startInclusive, - endExclusive: endExclusive, chunks: chunks, ); } @@ -330,12 +288,8 @@ class DatabaseAiConversationToolService implements AiConversationToolService { : null; return AiConversationToolChunkPage( - conversationId: conversationId, offset: safeOffset, - limit: limit, totalMessages: totalMessages, - startInclusive: startInclusive, - endExclusive: endExclusive, messages: messages.map(_messageItemToToolMessage).toList(growable: false), nextOffset: nextOffset, ); @@ -346,19 +300,19 @@ class DatabaseAiConversationToolService implements AiConversationToolService { required String conversationId, required String query, required int limit, + String? anchorMessageId, }) async { final messages = await database.fuzzySearchMessage( query: query, limit: limit, conversationIds: [conversationId], + anchorMessageId: anchorMessageId, ); return AiConversationToolSearchResult( - conversationId: conversationId, - query: query, - limit: limit, messages: messages - .map(_searchMessageToToolMessage) + .map((message) => _searchMessageToToolMessage(message, query: query)) .toList(growable: false), + nextAnchorId: messages.length < limit ? null : messages.last.messageId, ); } @@ -366,28 +320,30 @@ class DatabaseAiConversationToolService implements AiConversationToolService { AiConversationToolMessage( messageId: message.messageId, createdAt: message.createdAt, - senderId: message.userId, senderName: message.userFullName ?? message.userId, type: message.type, text: _messageText( content: message.content, mediaName: message.mediaName, type: message.type, + maxLength: _kMaxConversationMessageTextLength, ), ); AiConversationToolMessage _searchMessageToToolMessage( - SearchMessageDetailItem message, - ) => AiConversationToolMessage( + SearchMessageDetailItem message, { + required String query, + }) => AiConversationToolMessage( messageId: message.messageId, createdAt: message.createdAt, - senderId: message.senderId, senderName: message.senderFullName ?? message.senderId, type: message.type, text: _messageText( content: message.content, mediaName: message.mediaName, type: message.type, + query: query, + maxLength: _kSearchMessageSnippetRadius * 2, ), ); @@ -395,9 +351,13 @@ class DatabaseAiConversationToolService implements AiConversationToolService { required String? content, required String? mediaName, required String type, + String? query, + int? maxLength, }) { if (content?.trim().isNotEmpty == true) { - return content!.trim(); + final text = content!.trim(); + final snippet = query == null ? text : _searchSnippet(text, query); + return _truncateText(snippet, maxLength); } if (mediaName?.isNotEmpty == true) { return '[$type] $mediaName'; @@ -418,7 +378,7 @@ class AiConversationToolKit { genkit.Tool( name: 'get_conversation_stats', description: - 'Get message counts and boundary timestamps for the current conversation or a specific time range.', + 'Get message count and first/last timestamps for the conversation.', inputSchema: GetConversationStatsInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -438,8 +398,7 @@ class AiConversationToolKit { ), genkit.Tool( name: 'list_conversation_chunks', - description: - 'List chunk offsets that can be used to read the current conversation in fixed-size batches, optionally scoped to a time range.', + description: 'List offsets for reading conversation messages in batches.', inputSchema: ListConversationChunksInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -460,8 +419,7 @@ class AiConversationToolKit { ), genkit.Tool( name: 'read_conversation_chunk', - description: - 'Read a batch of messages from the current conversation by offset and limit, optionally scoped to a time range.', + description: 'Read conversation messages by offset and limit.', inputSchema: ReadConversationChunkInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -483,8 +441,7 @@ class AiConversationToolKit { ), genkit.Tool( name: 'search_conversation_messages', - description: - 'Search the current conversation for messages relevant to a query string.', + description: 'Search messages in the current conversation.', inputSchema: SearchConversationMessagesInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -497,6 +454,7 @@ class AiConversationToolKit { conversationId: conversationId, query: input.query, limit: input.limit, + anchorMessageId: input.anchorMessageId, ); return result.toJson(); }, @@ -578,8 +536,8 @@ class GetConversationStatsInput { ); Map toArguments() => { - 'start_time': startInclusive?.toIso8601String(), - 'end_time': endExclusive?.toIso8601String(), + 'start': startInclusive?.toIso8601String(), + 'end': endExclusive?.toIso8601String(), }..removeWhere((_, value) => value == null); } @@ -597,9 +555,9 @@ class ListConversationChunksInput { static final schema = SchemanticType.from( jsonSchema: _rangeSchema( properties: { - 'chunk_size': { + 'size': { 'type': 'integer', - 'description': 'Optional chunk size between 1 and 200.', + 'description': 'Batch size, 1-200.', }, }, ), @@ -609,7 +567,7 @@ class ListConversationChunksInput { return ListConversationChunksInput( chunkSize: _parseInt( arguments, - 'chunk_size', + 'size', defaultValue: _kDefaultConversationChunkSize, min: 1, max: _kMaxConversationChunkSize, @@ -621,9 +579,9 @@ class ListConversationChunksInput { ); Map toArguments() => { - 'chunk_size': chunkSize, - 'start_time': startInclusive?.toIso8601String(), - 'end_time': endExclusive?.toIso8601String(), + 'size': chunkSize, + 'start': startInclusive?.toIso8601String(), + 'end': endExclusive?.toIso8601String(), }..removeWhere((_, value) => value == null); } @@ -645,11 +603,11 @@ class ReadConversationChunkInput { properties: { 'offset': { 'type': 'integer', - 'description': 'Zero-based offset into the matching message list.', + 'description': 'Zero-based message offset.', }, 'limit': { 'type': 'integer', - 'description': 'Number of messages to read, between 1 and 200.', + 'description': 'Message count, 1-200.', }, }, required: ['offset'], @@ -681,8 +639,8 @@ class ReadConversationChunkInput { Map toArguments() => { 'offset': offset, 'limit': limit, - 'start_time': startInclusive?.toIso8601String(), - 'end_time': endExclusive?.toIso8601String(), + 'start': startInclusive?.toIso8601String(), + 'end': endExclusive?.toIso8601String(), }..removeWhere((_, value) => value == null); } @@ -690,10 +648,12 @@ class SearchConversationMessagesInput { const SearchConversationMessagesInput({ required this.query, required this.limit, + this.anchorMessageId, }); final String query; final int limit; + final String? anchorMessageId; static final schema = SchemanticType.from( jsonSchema: { @@ -701,12 +661,15 @@ class SearchConversationMessagesInput { 'properties': { 'query': { 'type': 'string', - 'description': 'Search query text.', + 'description': 'Search text.', }, 'limit': { 'type': 'integer', - 'description': - 'Maximum number of matches to return, between 1 and 20.', + 'description': 'Max matches, 1-20.', + }, + 'anchor_id': { + 'type': 'string', + 'description': 'Use next_anchor_id from the previous page.', }, }, 'required': ['query'], @@ -723,6 +686,7 @@ class SearchConversationMessagesInput { min: 1, max: _kMaxConversationSearchLimit, ), + anchorMessageId: _parseOptionalString(arguments, 'anchor_id'), ); }, ); @@ -730,7 +694,8 @@ class SearchConversationMessagesInput { Map toArguments() => { 'query': query, 'limit': limit, - }; + 'anchor_id': anchorMessageId, + }..removeWhere((_, value) => value == null); } Map _rangeSchema({ @@ -739,13 +704,13 @@ Map _rangeSchema({ }) => { 'type': 'object', 'properties': { - 'start_time': { + 'start': { 'type': 'string', - 'description': 'Optional inclusive ISO-8601 start time.', + 'description': 'Inclusive ISO-8601 start.', }, - 'end_time': { + 'end': { 'type': 'string', - 'description': 'Optional exclusive ISO-8601 end time.', + 'description': 'Exclusive ISO-8601 end.', }, ...properties, }, @@ -754,12 +719,12 @@ Map _rangeSchema({ }; (DateTime?, DateTime?) _parseRange(Map arguments) { - final startInclusive = _parseDateTime(arguments, 'start_time'); - final endExclusive = _parseDateTime(arguments, 'end_time'); + final startInclusive = _parseDateTime(arguments, 'start'); + final endExclusive = _parseDateTime(arguments, 'end'); if (startInclusive != null && endExclusive != null && !endExclusive.isAfter(startInclusive)) { - throw const FormatException('end_time must be later than start_time'); + throw const FormatException('end must be later than start'); } return (startInclusive, endExclusive); } @@ -808,6 +773,18 @@ String _parseRequiredString(Map arguments, String key) { return raw.trim(); } +String? _parseOptionalString(Map arguments, String key) { + final raw = arguments[key]; + if (raw == null) { + return null; + } + if (raw is! String) { + throw FormatException('$key must be a string'); + } + final value = raw.trim(); + return value.isEmpty ? null : value; +} + Map _jsonMap(dynamic value) { if (value is Map) { return value; @@ -830,6 +807,46 @@ String _previewJson(Object? value) { } } +String _formatToolDateTime(DateTime value) => + '${value.year.toString().padLeft(4, '0')}-' + '${value.month.toString().padLeft(2, '0')}-' + '${value.day.toString().padLeft(2, '0')}T' + '${value.hour.toString().padLeft(2, '0')}:' + '${value.minute.toString().padLeft(2, '0')}' + '${value.isUtc ? 'Z' : ''}'; + +String _searchSnippet(String text, String query) { + final trimmedQuery = query.trim(); + if (trimmedQuery.isEmpty || text.length <= _kSearchMessageSnippetRadius * 2) { + return text; + } + + final lowerText = text.toLowerCase(); + final lowerQuery = trimmedQuery.toLowerCase(); + final index = lowerText.indexOf(lowerQuery); + if (index < 0) { + return _truncateText(text, _kSearchMessageSnippetRadius * 2); + } + + final start = math.max(0, index - _kSearchMessageSnippetRadius); + final end = math.min( + text.length, + index + trimmedQuery.length + _kSearchMessageSnippetRadius, + ); + final prefix = start == 0 ? '' : '...'; + final suffix = end == text.length ? '' : '...'; + return '$prefix${text.substring(start, end)}$suffix'; +} + +String _truncateText(String text, int? maxLength) { + if (maxLength == null || text.length <= maxLength) { + return text; + } + const suffix = '... [truncated]'; + final end = math.max(0, maxLength - suffix.length); + return '${text.substring(0, end)}$suffix'; +} + String _encodeToolResult(Map result) => encode(_stripNullValues(result)); From c15a6cc77f45f1b2e525fb63c2973b70d43e7773 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 20:31:19 +0800 Subject: [PATCH 31/52] feat(ai): add thread-based message handling and metadata enhancements --- lib/ai/ai_chat_controller.dart | 105 +- lib/ai/ai_chat_prompt_builder.dart | 17 +- lib/ai/ai_provider_requester.dart | 22 + lib/ai/model/ai_chat_metadata.dart | 36 + lib/db/ai_database.dart | 31 + lib/db/ai_database.g.dart | 1669 +++++++++++++++++ lib/db/dao/ai_chat_message_dao.dart | 133 +- lib/db/dao/ai_chat_message_dao.g.dart | 2 +- lib/db/dao/asset_dao.g.dart | 3 - lib/db/dao/chain_dao.g.dart | 3 - lib/db/dao/circle_conversation_dao.g.dart | 3 - lib/db/dao/circle_dao.g.dart | 3 - lib/db/dao/conversation_dao.g.dart | 3 - lib/db/dao/expired_message_dao.g.dart | 3 - lib/db/dao/favorite_app_dao.g.dart | 3 - lib/db/dao/flood_message_dao.g.dart | 3 - lib/db/dao/inscription_collection_dao.g.dart | 3 - lib/db/dao/inscription_item_dao.g.dart | 3 - lib/db/dao/message_dao.g.dart | 3 - lib/db/dao/participant_dao.g.dart | 3 - lib/db/dao/participant_session_dao.g.dart | 3 - lib/db/dao/pin_message_dao.g.dart | 3 - lib/db/dao/property_dao.g.dart | 3 - lib/db/dao/safe_snapshot_dao.g.dart | 3 - lib/db/dao/snapshot_dao.g.dart | 3 - lib/db/dao/sticker_album_dao.g.dart | 3 - lib/db/dao/sticker_dao.g.dart | 3 - lib/db/dao/sticker_relationship_dao.g.dart | 3 - lib/db/dao/token_dao.g.dart | 3 - lib/db/dao/transcript_message_dao.g.dart | 3 - lib/db/dao/user_dao.g.dart | 3 - lib/db/database.dart | 8 +- lib/db/mixin_database.dart | 23 +- lib/db/mixin_database.g.dart | 1154 ------------ lib/db/moor/ai.drift | 30 + lib/db/moor/mixin.drift | 18 - lib/ui/home/chat/input_container.dart | 27 +- .../ai_assistant/message_list.dart | 18 +- .../chat_slide_page/ai_assistant_page.dart | 28 +- lib/ui/provider/database_provider.dart | 2 + lib/widgets/ai/ai_message_card.dart | 47 +- lib/workers/device_transfer.dart | 2 + lib/workers/message_worker_isolate.dart | 2 + test/ai/ai_chat_metadata_test.dart | 62 + test/ai/ai_chat_thread_test.dart | 143 ++ test/utils/device_transfer_test.dart | 3 + 46 files changed, 2309 insertions(+), 1342 deletions(-) create mode 100644 lib/db/ai_database.dart create mode 100644 lib/db/ai_database.g.dart create mode 100644 lib/db/moor/ai.drift create mode 100644 test/ai/ai_chat_metadata_test.dart create mode 100644 test/ai/ai_chat_thread_test.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 9fecd4ec8e..77134fccd2 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -5,9 +5,9 @@ import 'package:drift/drift.dart'; import 'package:mixin_logger/mixin_logger.dart'; import 'package:uuid/uuid.dart'; +import '../db/ai_database.dart'; import '../db/dao/ai_chat_message_dao.dart'; import '../db/database.dart'; -import '../db/mixin_database.dart'; import 'ai_chat_prompt_builder.dart'; import 'ai_provider_requester.dart'; import 'model/ai_chat_metadata.dart'; @@ -98,16 +98,22 @@ class AiChatController { required String conversationId, required String input, required String language, + String? threadId, AiProviderConfig? provider, void Function()? onInputAccepted, }) async { + final thread = await database.aiChatMessageDao.ensureThread( + conversationId: conversationId, + threadId: threadId, + ); await database.aiChatMessageDao.resolveStalePendingAssistantMessages( updatedBefore: kAiRuntimeStartedAt, conversationId: conversationId, + threadId: thread.id, ); final hasPendingAssistant = await database.aiChatMessageDao .hasPendingAssistantMessage( - conversationId, + thread.id, updatedAfter: kAiRuntimeStartedAt, ); if (hasPendingAssistant) { @@ -121,6 +127,7 @@ class AiChatController { d( 'AI send start: conversationId=$conversationId ' + 'threadId=${thread.id} ' 'provider=${config.type.name} model=${config.model} ' 'input=${_previewText(input)}', ); @@ -130,18 +137,13 @@ class AiChatController { final userMessageId = _uuid.v4(); final assistantMessageId = _uuid.v4(); final cancelToken = CancelToken(); - final anchorMessage = await database.messageDao - .messagesByConversationId(conversationId, 1) - .getSingleOrNull(); - await database.aiChatMessageDao.insertMessage( AiChatMessagesCompanion.insert( id: userMessageId, + threadId: Value(thread.id), conversationId: conversationId, role: _kAiRoleUser, providerId: config.id, - anchorMessageId: Value(anchorMessage?.messageId), - anchorCreatedAt: Value(anchorMessage?.createdAt), content: input, status: _kAiStatusDone, model: Value(config.model), @@ -153,11 +155,10 @@ class AiChatController { await database.aiChatMessageDao.insertMessage( AiChatMessagesCompanion.insert( id: assistantMessageId, + threadId: Value(thread.id), conversationId: conversationId, role: _kAiRoleAssistant, providerId: config.id, - anchorMessageId: Value(anchorMessage?.messageId), - anchorCreatedAt: Value(anchorMessage?.createdAt), content: '', status: _kAiStatusPending, model: Value(config.model), @@ -173,13 +174,21 @@ class AiChatController { dao: database.aiChatMessageDao, messageId: assistantMessageId, ); + final requestKeys = { + conversationId, + thread.id, + }; _activeAiRequests[conversationId] = cancelToken; + _activeAiRequests[thread.id] = cancelToken; try { final messages = await _promptBuilder.buildPromptMessages( conversationId, + thread.id, input, language, + currentMessageId: userMessageId, ); + Map? responseMetadata; final result = await _requestText( config, messages, @@ -187,8 +196,34 @@ class AiChatController { onContent: updater.append, conversationId: conversationId, assistantMessageId: assistantMessageId, + onResponseMetadata: (metadata) { + responseMetadata = createAiResponseMetadata( + elapsedMs: (metadata['elapsedMs'] as num?)?.round() ?? 0, + promptMessageCount: + (metadata['promptMessageCount'] as num?)?.round() ?? + messages.length, + toolCount: (metadata['toolCount'] as num?)?.round() ?? 0, + outputCharacters: 0, + response: metadata, + ); + }, ); await updater.flush(contentOverride: result, force: true); + final completedResponseMetadata = + responseMetadata ?? + createAiResponseMetadata( + elapsedMs: 0, + promptMessageCount: messages.length, + toolCount: 0, + outputCharacters: result.length, + response: const {}, + ); + completedResponseMetadata['outputCharacters'] = result.length; + await database.aiChatMessageDao.setMessageMetadataResponse( + assistantMessageId, + completedResponseMetadata, + updatedAt: DateTime.now(), + ); await database.aiChatMessageDao.updateMessageStatus( assistantMessageId, _kAiStatusDone, @@ -196,12 +231,14 @@ class AiChatController { ); d( 'AI send done: conversationId=$conversationId ' + 'threadId=${thread.id} ' 'assistantMessageId=$assistantMessageId output=${_previewText(result)}', ); } catch (error, stacktrace) { if (cancelToken.isCancelled) { d( 'AI send cancelled: conversationId=$conversationId ' + 'threadId=${thread.id} ' 'assistantMessageId=$assistantMessageId', ); await updater.flush(force: true); @@ -222,15 +259,22 @@ class AiChatController { ); rethrow; } finally { - if (_activeAiRequests[conversationId] == cancelToken) { - _activeAiRequests.remove(conversationId); + for (final requestKey in requestKeys) { + if (_activeAiRequests[requestKey] == cancelToken) { + _activeAiRequests.remove(requestKey); + } } } } - void stop(String conversationId) { - d('AI stop requested: conversationId=$conversationId'); - _activeAiRequests[conversationId]?.cancel('AI generation stopped'); + void stop(String conversationId, {String? threadId}) { + d('AI stop requested: conversationId=$conversationId threadId=$threadId'); + final cancelToken = threadId == null + ? _activeAiRequests[conversationId] + : null; + (cancelToken ?? _activeAiRequests[threadId])?.cancel( + 'AI generation stopped', + ); } Future _requestText( @@ -240,21 +284,32 @@ class AiChatController { required Future Function(String chunk) onContent, String? conversationId, String? assistantMessageId, - }) => _providerRequester.requestText( - config, - messages, - proxy: database.settingProperties.activatedProxy, - cancelToken: cancelToken, - onContent: onContent, - conversationId: conversationId, - tools: conversationId == null + void Function(Map metadata)? onResponseMetadata, + }) { + final tools = conversationId == null ? null : _conversationTools.genkitTools( conversationId: conversationId, onEvent: (event) => _appendAssistantToolEvent(assistantMessageId, event), - ), - ); + ); + return _providerRequester.requestText( + config, + messages, + proxy: database.settingProperties.activatedProxy, + cancelToken: cancelToken, + onContent: onContent, + conversationId: conversationId, + onResponseMetadata: onResponseMetadata == null + ? null + : (metadata) => onResponseMetadata({ + ...metadata, + 'promptMessageCount': messages.length, + 'toolCount': tools?.length ?? 0, + }), + tools: tools, + ); + } Future _appendAssistantToolEvent( String? assistantMessageId, diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index 202ad5ad50..0e13e070e9 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -16,16 +16,16 @@ class AiChatPromptBuilder { Future> buildPromptMessages( String conversationId, + String threadId, String input, - String language, - ) async { + String language, { + String? currentMessageId, + }) async { final now = DateTime.now(); final recentMessages = await database.messageDao .messagesByConversationId(conversationId, _aiContextMessageLimit) .get(); - final aiMessages = await database.aiChatMessageDao.conversationMessages( - conversationId, - ); + final aiMessages = await database.aiChatMessageDao.threadMessages(threadId); final promptMessages = [ ..._promptMessages( @@ -60,7 +60,11 @@ class AiChatPromptBuilder { ); final history = aiMessages - .where((element) => element.status != _aiStatusPending) + .where( + (element) => + element.status != _aiStatusPending && + element.id != currentMessageId, + ) .takeLast(_aiHistoryLimit); for (final item in history) { promptMessages.add( @@ -84,6 +88,7 @@ class AiChatPromptBuilder { ); d( 'AI prompt built: conversationId=$conversationId ' + 'threadId=$threadId ' 'recent=${recentMessages.length} ' 'history=${history.length} promptMessages=${promptMessages.length}', ); diff --git a/lib/ai/ai_provider_requester.dart b/lib/ai/ai_provider_requester.dart index b49daf40f5..59d847bb47 100644 --- a/lib/ai/ai_provider_requester.dart +++ b/lib/ai/ai_provider_requester.dart @@ -28,6 +28,7 @@ class AiProviderRequester { required Future Function(String chunk) onContent, required String? conversationId, List? tools, + void Function(Map metadata)? onResponseMetadata, }) async { d( 'AI request start: provider=${config.type.name} model=${config.model} ' @@ -46,6 +47,7 @@ class AiProviderRequester { cancelToken: cancelToken, onContent: onContent, conversationId: conversationId, + onResponseMetadata: onResponseMetadata, tools: tools, ), ); @@ -57,9 +59,11 @@ class AiProviderRequester { required CancelToken cancelToken, required Future Function(String chunk) onContent, required String? conversationId, + required void Function(Map metadata)? onResponseMetadata, required List? tools, }) async { final ai = _createGenkit(config); + final stopwatch = Stopwatch()..start(); try { final cancelFuture = cancelToken.whenCancel.then((_) {}); final stream = ai.generateStream( @@ -115,6 +119,13 @@ class AiProviderRequester { if (text.isEmpty) { throw Exception('Empty AI response'); } + stopwatch.stop(); + onResponseMetadata?.call( + _genkitResponseMetadata( + response, + elapsedMs: stopwatch.elapsedMilliseconds, + ), + ); d( 'AI request done: provider=${config.type.name} model=${config.model} ' 'conversationId=$conversationId text=${_previewText(text)}', @@ -176,6 +187,17 @@ class AiProviderRequester { } } +Map _genkitResponseMetadata( + genkit.GenerateResponseHelper response, { + required int elapsedMs, +}) => { + 'elapsedMs': elapsedMs, + 'latencyMs': response.latencyMs, + 'finishReason': response.finishReason?.value, + 'finishMessage': response.finishMessage, + 'usage': response.usage?.toJson(), +}..removeWhere((_, value) => value == null); + String _previewText( String? text, { int maxLength = AiProviderRequester._aiLogPreviewLength, diff --git a/lib/ai/model/ai_chat_metadata.dart b/lib/ai/model/ai_chat_metadata.dart index a8a325d7ae..632b8c7637 100644 --- a/lib/ai/model/ai_chat_metadata.dart +++ b/lib/ai/model/ai_chat_metadata.dart @@ -3,6 +3,7 @@ import 'dart:convert'; import 'ai_provider_config.dart'; const aiMetadataToolEventsKey = 'toolEvents'; +const aiMetadataResponseKey = 'response'; const aiToolEventTypeCall = 'tool_call'; const aiToolEventTypeResult = 'tool_result'; @@ -46,6 +47,41 @@ String appendAiToolEventToMetadata( return jsonEncode(root); } +String setAiResponseMetadata( + String? metadata, + Map responseMetadata, +) { + final root = decodeAiMessageMetadata(metadata); + root[aiMetadataResponseKey] = responseMetadata; + return jsonEncode(root); +} + +Map createAiResponseMetadata({ + required int elapsedMs, + required int promptMessageCount, + required int toolCount, + required int outputCharacters, + required Map response, +}) => { + 'elapsedMs': elapsedMs, + 'promptMessageCount': promptMessageCount, + 'toolCount': toolCount, + 'outputCharacters': outputCharacters, + 'completedAt': DateTime.now().toUtc().toIso8601String(), + ...response, +}..removeWhere((_, value) => value == null); + +Map aiMetadataResponse(String? metadata) { + final response = decodeAiMessageMetadata(metadata)[aiMetadataResponseKey]; + if (response is Map) { + return response; + } + if (response is Map) { + return response.map((key, value) => MapEntry('$key', value)); + } + return const {}; +} + Map createAiToolCallEvent({ required String id, required String name, diff --git a/lib/db/ai_database.dart b/lib/db/ai_database.dart new file mode 100644 index 0000000000..d08b0dc93a --- /dev/null +++ b/lib/db/ai_database.dart @@ -0,0 +1,31 @@ +import 'package:drift/drift.dart'; + +import 'converter/millis_date_converter.dart'; +import 'dao/ai_chat_message_dao.dart'; +import 'util/open_database.dart'; + +part 'ai_database.g.dart'; + +@DriftDatabase( + include: {'moor/ai.drift'}, + daos: [AiChatMessageDao], +) +class AiDatabase extends _$AiDatabase { + AiDatabase(super.e); + + static Future connect( + String identityNumber, { + bool fromMainIsolate = false, + }) async { + final queryExecutor = await openQueryExecutor( + identityNumber: identityNumber, + dbName: 'ai', + readCount: 4, + fromMainIsolate: fromMainIsolate, + ); + return AiDatabase(queryExecutor); + } + + @override + int get schemaVersion => 1; +} diff --git a/lib/db/ai_database.g.dart b/lib/db/ai_database.g.dart new file mode 100644 index 0000000000..b1f0ec4a53 --- /dev/null +++ b/lib/db/ai_database.g.dart @@ -0,0 +1,1669 @@ +// GENERATED CODE - DO NOT MODIFY BY HAND + +part of 'ai_database.dart'; + +// ignore_for_file: type=lint +class AiChatMessages extends Table + with TableInfo { + @override + final GeneratedDatabase attachedDatabase; + final String? _alias; + AiChatMessages(this.attachedDatabase, [this._alias]); + static const VerificationMeta _idMeta = const VerificationMeta('id'); + late final GeneratedColumn id = GeneratedColumn( + 'id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _threadIdMeta = const VerificationMeta( + 'threadId', + ); + late final GeneratedColumn threadId = GeneratedColumn( + 'thread_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: 'NOT NULL DEFAULT \'\'', + defaultValue: const CustomExpression('\'\''), + ); + static const VerificationMeta _conversationIdMeta = const VerificationMeta( + 'conversationId', + ); + late final GeneratedColumn conversationId = GeneratedColumn( + 'conversation_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _roleMeta = const VerificationMeta('role'); + late final GeneratedColumn role = GeneratedColumn( + 'role', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _providerIdMeta = const VerificationMeta( + 'providerId', + ); + late final GeneratedColumn providerId = GeneratedColumn( + 'provider_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _contentMeta = const VerificationMeta( + 'content', + ); + late final GeneratedColumn content = GeneratedColumn( + 'content', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _statusMeta = const VerificationMeta('status'); + late final GeneratedColumn status = GeneratedColumn( + 'status', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _modelMeta = const VerificationMeta('model'); + late final GeneratedColumn model = GeneratedColumn( + 'model', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + static const VerificationMeta _errorTextMeta = const VerificationMeta( + 'errorText', + ); + late final GeneratedColumn errorText = GeneratedColumn( + 'error_text', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + static const VerificationMeta _metadataMeta = const VerificationMeta( + 'metadata', + ); + late final GeneratedColumn metadata = GeneratedColumn( + 'metadata', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + late final GeneratedColumnWithTypeConverter createdAt = + GeneratedColumn( + 'created_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(AiChatMessages.$convertercreatedAt); + late final GeneratedColumnWithTypeConverter updatedAt = + GeneratedColumn( + 'updated_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(AiChatMessages.$converterupdatedAt); + @override + List get $columns => [ + id, + threadId, + conversationId, + role, + providerId, + content, + status, + model, + errorText, + metadata, + createdAt, + updatedAt, + ]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'ai_chat_messages'; + @override + VerificationContext validateIntegrity( + Insertable instance, { + bool isInserting = false, + }) { + final context = VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } else if (isInserting) { + context.missing(_idMeta); + } + if (data.containsKey('thread_id')) { + context.handle( + _threadIdMeta, + threadId.isAcceptableOrUnknown(data['thread_id']!, _threadIdMeta), + ); + } + if (data.containsKey('conversation_id')) { + context.handle( + _conversationIdMeta, + conversationId.isAcceptableOrUnknown( + data['conversation_id']!, + _conversationIdMeta, + ), + ); + } else if (isInserting) { + context.missing(_conversationIdMeta); + } + if (data.containsKey('role')) { + context.handle( + _roleMeta, + role.isAcceptableOrUnknown(data['role']!, _roleMeta), + ); + } else if (isInserting) { + context.missing(_roleMeta); + } + if (data.containsKey('provider_id')) { + context.handle( + _providerIdMeta, + providerId.isAcceptableOrUnknown(data['provider_id']!, _providerIdMeta), + ); + } else if (isInserting) { + context.missing(_providerIdMeta); + } + if (data.containsKey('content')) { + context.handle( + _contentMeta, + content.isAcceptableOrUnknown(data['content']!, _contentMeta), + ); + } else if (isInserting) { + context.missing(_contentMeta); + } + if (data.containsKey('status')) { + context.handle( + _statusMeta, + status.isAcceptableOrUnknown(data['status']!, _statusMeta), + ); + } else if (isInserting) { + context.missing(_statusMeta); + } + if (data.containsKey('model')) { + context.handle( + _modelMeta, + model.isAcceptableOrUnknown(data['model']!, _modelMeta), + ); + } + if (data.containsKey('error_text')) { + context.handle( + _errorTextMeta, + errorText.isAcceptableOrUnknown(data['error_text']!, _errorTextMeta), + ); + } + if (data.containsKey('metadata')) { + context.handle( + _metadataMeta, + metadata.isAcceptableOrUnknown(data['metadata']!, _metadataMeta), + ); + } + return context; + } + + @override + Set get $primaryKey => {id}; + @override + AiChatMessage map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return AiChatMessage( + id: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}id'], + )!, + threadId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}thread_id'], + )!, + conversationId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}conversation_id'], + )!, + role: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}role'], + )!, + providerId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}provider_id'], + )!, + content: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}content'], + )!, + status: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}status'], + )!, + model: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}model'], + ), + errorText: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}error_text'], + ), + metadata: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}metadata'], + ), + createdAt: AiChatMessages.$convertercreatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}created_at'], + )!, + ), + updatedAt: AiChatMessages.$converterupdatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}updated_at'], + )!, + ), + ); + } + + @override + AiChatMessages createAlias(String alias) { + return AiChatMessages(attachedDatabase, alias); + } + + static TypeConverter $convertercreatedAt = + const MillisDateConverter(); + static TypeConverter $converterupdatedAt = + const MillisDateConverter(); + @override + List get customConstraints => const ['PRIMARY KEY(id)']; + @override + bool get dontWriteConstraints => true; +} + +class AiChatMessage extends DataClass implements Insertable { + final String id; + final String threadId; + final String conversationId; + final String role; + final String providerId; + final String content; + final String status; + final String? model; + final String? errorText; + final String? metadata; + final DateTime createdAt; + final DateTime updatedAt; + const AiChatMessage({ + required this.id, + required this.threadId, + required this.conversationId, + required this.role, + required this.providerId, + required this.content, + required this.status, + this.model, + this.errorText, + this.metadata, + required this.createdAt, + required this.updatedAt, + }); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = Variable(id); + map['thread_id'] = Variable(threadId); + map['conversation_id'] = Variable(conversationId); + map['role'] = Variable(role); + map['provider_id'] = Variable(providerId); + map['content'] = Variable(content); + map['status'] = Variable(status); + if (!nullToAbsent || model != null) { + map['model'] = Variable(model); + } + if (!nullToAbsent || errorText != null) { + map['error_text'] = Variable(errorText); + } + if (!nullToAbsent || metadata != null) { + map['metadata'] = Variable(metadata); + } + { + map['created_at'] = Variable( + AiChatMessages.$convertercreatedAt.toSql(createdAt), + ); + } + { + map['updated_at'] = Variable( + AiChatMessages.$converterupdatedAt.toSql(updatedAt), + ); + } + return map; + } + + AiChatMessagesCompanion toCompanion(bool nullToAbsent) { + return AiChatMessagesCompanion( + id: Value(id), + threadId: Value(threadId), + conversationId: Value(conversationId), + role: Value(role), + providerId: Value(providerId), + content: Value(content), + status: Value(status), + model: model == null && nullToAbsent + ? const Value.absent() + : Value(model), + errorText: errorText == null && nullToAbsent + ? const Value.absent() + : Value(errorText), + metadata: metadata == null && nullToAbsent + ? const Value.absent() + : Value(metadata), + createdAt: Value(createdAt), + updatedAt: Value(updatedAt), + ); + } + + factory AiChatMessage.fromJson( + Map json, { + ValueSerializer? serializer, + }) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return AiChatMessage( + id: serializer.fromJson(json['id']), + threadId: serializer.fromJson(json['thread_id']), + conversationId: serializer.fromJson(json['conversation_id']), + role: serializer.fromJson(json['role']), + providerId: serializer.fromJson(json['provider_id']), + content: serializer.fromJson(json['content']), + status: serializer.fromJson(json['status']), + model: serializer.fromJson(json['model']), + errorText: serializer.fromJson(json['error_text']), + metadata: serializer.fromJson(json['metadata']), + createdAt: serializer.fromJson(json['created_at']), + updatedAt: serializer.fromJson(json['updated_at']), + ); + } + @override + Map toJson({ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + 'thread_id': serializer.toJson(threadId), + 'conversation_id': serializer.toJson(conversationId), + 'role': serializer.toJson(role), + 'provider_id': serializer.toJson(providerId), + 'content': serializer.toJson(content), + 'status': serializer.toJson(status), + 'model': serializer.toJson(model), + 'error_text': serializer.toJson(errorText), + 'metadata': serializer.toJson(metadata), + 'created_at': serializer.toJson(createdAt), + 'updated_at': serializer.toJson(updatedAt), + }; + } + + AiChatMessage copyWith({ + String? id, + String? threadId, + String? conversationId, + String? role, + String? providerId, + String? content, + String? status, + Value model = const Value.absent(), + Value errorText = const Value.absent(), + Value metadata = const Value.absent(), + DateTime? createdAt, + DateTime? updatedAt, + }) => AiChatMessage( + id: id ?? this.id, + threadId: threadId ?? this.threadId, + conversationId: conversationId ?? this.conversationId, + role: role ?? this.role, + providerId: providerId ?? this.providerId, + content: content ?? this.content, + status: status ?? this.status, + model: model.present ? model.value : this.model, + errorText: errorText.present ? errorText.value : this.errorText, + metadata: metadata.present ? metadata.value : this.metadata, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + ); + AiChatMessage copyWithCompanion(AiChatMessagesCompanion data) { + return AiChatMessage( + id: data.id.present ? data.id.value : this.id, + threadId: data.threadId.present ? data.threadId.value : this.threadId, + conversationId: data.conversationId.present + ? data.conversationId.value + : this.conversationId, + role: data.role.present ? data.role.value : this.role, + providerId: data.providerId.present + ? data.providerId.value + : this.providerId, + content: data.content.present ? data.content.value : this.content, + status: data.status.present ? data.status.value : this.status, + model: data.model.present ? data.model.value : this.model, + errorText: data.errorText.present ? data.errorText.value : this.errorText, + metadata: data.metadata.present ? data.metadata.value : this.metadata, + createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, + updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, + ); + } + + @override + String toString() { + return (StringBuffer('AiChatMessage(') + ..write('id: $id, ') + ..write('threadId: $threadId, ') + ..write('conversationId: $conversationId, ') + ..write('role: $role, ') + ..write('providerId: $providerId, ') + ..write('content: $content, ') + ..write('status: $status, ') + ..write('model: $model, ') + ..write('errorText: $errorText, ') + ..write('metadata: $metadata, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt') + ..write(')')) + .toString(); + } + + @override + int get hashCode => Object.hash( + id, + threadId, + conversationId, + role, + providerId, + content, + status, + model, + errorText, + metadata, + createdAt, + updatedAt, + ); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is AiChatMessage && + other.id == this.id && + other.threadId == this.threadId && + other.conversationId == this.conversationId && + other.role == this.role && + other.providerId == this.providerId && + other.content == this.content && + other.status == this.status && + other.model == this.model && + other.errorText == this.errorText && + other.metadata == this.metadata && + other.createdAt == this.createdAt && + other.updatedAt == this.updatedAt); +} + +class AiChatMessagesCompanion extends UpdateCompanion { + final Value id; + final Value threadId; + final Value conversationId; + final Value role; + final Value providerId; + final Value content; + final Value status; + final Value model; + final Value errorText; + final Value metadata; + final Value createdAt; + final Value updatedAt; + final Value rowid; + const AiChatMessagesCompanion({ + this.id = const Value.absent(), + this.threadId = const Value.absent(), + this.conversationId = const Value.absent(), + this.role = const Value.absent(), + this.providerId = const Value.absent(), + this.content = const Value.absent(), + this.status = const Value.absent(), + this.model = const Value.absent(), + this.errorText = const Value.absent(), + this.metadata = const Value.absent(), + this.createdAt = const Value.absent(), + this.updatedAt = const Value.absent(), + this.rowid = const Value.absent(), + }); + AiChatMessagesCompanion.insert({ + required String id, + this.threadId = const Value.absent(), + required String conversationId, + required String role, + required String providerId, + required String content, + required String status, + this.model = const Value.absent(), + this.errorText = const Value.absent(), + this.metadata = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + this.rowid = const Value.absent(), + }) : id = Value(id), + conversationId = Value(conversationId), + role = Value(role), + providerId = Value(providerId), + content = Value(content), + status = Value(status), + createdAt = Value(createdAt), + updatedAt = Value(updatedAt); + static Insertable custom({ + Expression? id, + Expression? threadId, + Expression? conversationId, + Expression? role, + Expression? providerId, + Expression? content, + Expression? status, + Expression? model, + Expression? errorText, + Expression? metadata, + Expression? createdAt, + Expression? updatedAt, + Expression? rowid, + }) { + return RawValuesInsertable({ + if (id != null) 'id': id, + if (threadId != null) 'thread_id': threadId, + if (conversationId != null) 'conversation_id': conversationId, + if (role != null) 'role': role, + if (providerId != null) 'provider_id': providerId, + if (content != null) 'content': content, + if (status != null) 'status': status, + if (model != null) 'model': model, + if (errorText != null) 'error_text': errorText, + if (metadata != null) 'metadata': metadata, + if (createdAt != null) 'created_at': createdAt, + if (updatedAt != null) 'updated_at': updatedAt, + if (rowid != null) 'rowid': rowid, + }); + } + + AiChatMessagesCompanion copyWith({ + Value? id, + Value? threadId, + Value? conversationId, + Value? role, + Value? providerId, + Value? content, + Value? status, + Value? model, + Value? errorText, + Value? metadata, + Value? createdAt, + Value? updatedAt, + Value? rowid, + }) { + return AiChatMessagesCompanion( + id: id ?? this.id, + threadId: threadId ?? this.threadId, + conversationId: conversationId ?? this.conversationId, + role: role ?? this.role, + providerId: providerId ?? this.providerId, + content: content ?? this.content, + status: status ?? this.status, + model: model ?? this.model, + errorText: errorText ?? this.errorText, + metadata: metadata ?? this.metadata, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + rowid: rowid ?? this.rowid, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = Variable(id.value); + } + if (threadId.present) { + map['thread_id'] = Variable(threadId.value); + } + if (conversationId.present) { + map['conversation_id'] = Variable(conversationId.value); + } + if (role.present) { + map['role'] = Variable(role.value); + } + if (providerId.present) { + map['provider_id'] = Variable(providerId.value); + } + if (content.present) { + map['content'] = Variable(content.value); + } + if (status.present) { + map['status'] = Variable(status.value); + } + if (model.present) { + map['model'] = Variable(model.value); + } + if (errorText.present) { + map['error_text'] = Variable(errorText.value); + } + if (metadata.present) { + map['metadata'] = Variable(metadata.value); + } + if (createdAt.present) { + map['created_at'] = Variable( + AiChatMessages.$convertercreatedAt.toSql(createdAt.value), + ); + } + if (updatedAt.present) { + map['updated_at'] = Variable( + AiChatMessages.$converterupdatedAt.toSql(updatedAt.value), + ); + } + if (rowid.present) { + map['rowid'] = Variable(rowid.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('AiChatMessagesCompanion(') + ..write('id: $id, ') + ..write('threadId: $threadId, ') + ..write('conversationId: $conversationId, ') + ..write('role: $role, ') + ..write('providerId: $providerId, ') + ..write('content: $content, ') + ..write('status: $status, ') + ..write('model: $model, ') + ..write('errorText: $errorText, ') + ..write('metadata: $metadata, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt, ') + ..write('rowid: $rowid') + ..write(')')) + .toString(); + } +} + +class AiChatThreads extends Table with TableInfo { + @override + final GeneratedDatabase attachedDatabase; + final String? _alias; + AiChatThreads(this.attachedDatabase, [this._alias]); + static const VerificationMeta _idMeta = const VerificationMeta('id'); + late final GeneratedColumn id = GeneratedColumn( + 'id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _conversationIdMeta = const VerificationMeta( + 'conversationId', + ); + late final GeneratedColumn conversationId = GeneratedColumn( + 'conversation_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _titleMeta = const VerificationMeta('title'); + late final GeneratedColumn title = GeneratedColumn( + 'title', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + late final GeneratedColumnWithTypeConverter createdAt = + GeneratedColumn( + 'created_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(AiChatThreads.$convertercreatedAt); + late final GeneratedColumnWithTypeConverter updatedAt = + GeneratedColumn( + 'updated_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(AiChatThreads.$converterupdatedAt); + @override + List get $columns => [ + id, + conversationId, + title, + createdAt, + updatedAt, + ]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'ai_chat_threads'; + @override + VerificationContext validateIntegrity( + Insertable instance, { + bool isInserting = false, + }) { + final context = VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('id')) { + context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); + } else if (isInserting) { + context.missing(_idMeta); + } + if (data.containsKey('conversation_id')) { + context.handle( + _conversationIdMeta, + conversationId.isAcceptableOrUnknown( + data['conversation_id']!, + _conversationIdMeta, + ), + ); + } else if (isInserting) { + context.missing(_conversationIdMeta); + } + if (data.containsKey('title')) { + context.handle( + _titleMeta, + title.isAcceptableOrUnknown(data['title']!, _titleMeta), + ); + } + return context; + } + + @override + Set get $primaryKey => {id}; + @override + AiChatThread map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return AiChatThread( + id: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}id'], + )!, + conversationId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}conversation_id'], + )!, + title: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}title'], + ), + createdAt: AiChatThreads.$convertercreatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}created_at'], + )!, + ), + updatedAt: AiChatThreads.$converterupdatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}updated_at'], + )!, + ), + ); + } + + @override + AiChatThreads createAlias(String alias) { + return AiChatThreads(attachedDatabase, alias); + } + + static TypeConverter $convertercreatedAt = + const MillisDateConverter(); + static TypeConverter $converterupdatedAt = + const MillisDateConverter(); + @override + List get customConstraints => const ['PRIMARY KEY(id)']; + @override + bool get dontWriteConstraints => true; +} + +class AiChatThread extends DataClass implements Insertable { + final String id; + final String conversationId; + final String? title; + final DateTime createdAt; + final DateTime updatedAt; + const AiChatThread({ + required this.id, + required this.conversationId, + this.title, + required this.createdAt, + required this.updatedAt, + }); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['id'] = Variable(id); + map['conversation_id'] = Variable(conversationId); + if (!nullToAbsent || title != null) { + map['title'] = Variable(title); + } + { + map['created_at'] = Variable( + AiChatThreads.$convertercreatedAt.toSql(createdAt), + ); + } + { + map['updated_at'] = Variable( + AiChatThreads.$converterupdatedAt.toSql(updatedAt), + ); + } + return map; + } + + AiChatThreadsCompanion toCompanion(bool nullToAbsent) { + return AiChatThreadsCompanion( + id: Value(id), + conversationId: Value(conversationId), + title: title == null && nullToAbsent + ? const Value.absent() + : Value(title), + createdAt: Value(createdAt), + updatedAt: Value(updatedAt), + ); + } + + factory AiChatThread.fromJson( + Map json, { + ValueSerializer? serializer, + }) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return AiChatThread( + id: serializer.fromJson(json['id']), + conversationId: serializer.fromJson(json['conversation_id']), + title: serializer.fromJson(json['title']), + createdAt: serializer.fromJson(json['created_at']), + updatedAt: serializer.fromJson(json['updated_at']), + ); + } + @override + Map toJson({ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return { + 'id': serializer.toJson(id), + 'conversation_id': serializer.toJson(conversationId), + 'title': serializer.toJson(title), + 'created_at': serializer.toJson(createdAt), + 'updated_at': serializer.toJson(updatedAt), + }; + } + + AiChatThread copyWith({ + String? id, + String? conversationId, + Value title = const Value.absent(), + DateTime? createdAt, + DateTime? updatedAt, + }) => AiChatThread( + id: id ?? this.id, + conversationId: conversationId ?? this.conversationId, + title: title.present ? title.value : this.title, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + ); + AiChatThread copyWithCompanion(AiChatThreadsCompanion data) { + return AiChatThread( + id: data.id.present ? data.id.value : this.id, + conversationId: data.conversationId.present + ? data.conversationId.value + : this.conversationId, + title: data.title.present ? data.title.value : this.title, + createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, + updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, + ); + } + + @override + String toString() { + return (StringBuffer('AiChatThread(') + ..write('id: $id, ') + ..write('conversationId: $conversationId, ') + ..write('title: $title, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt') + ..write(')')) + .toString(); + } + + @override + int get hashCode => + Object.hash(id, conversationId, title, createdAt, updatedAt); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is AiChatThread && + other.id == this.id && + other.conversationId == this.conversationId && + other.title == this.title && + other.createdAt == this.createdAt && + other.updatedAt == this.updatedAt); +} + +class AiChatThreadsCompanion extends UpdateCompanion { + final Value id; + final Value conversationId; + final Value title; + final Value createdAt; + final Value updatedAt; + final Value rowid; + const AiChatThreadsCompanion({ + this.id = const Value.absent(), + this.conversationId = const Value.absent(), + this.title = const Value.absent(), + this.createdAt = const Value.absent(), + this.updatedAt = const Value.absent(), + this.rowid = const Value.absent(), + }); + AiChatThreadsCompanion.insert({ + required String id, + required String conversationId, + this.title = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + this.rowid = const Value.absent(), + }) : id = Value(id), + conversationId = Value(conversationId), + createdAt = Value(createdAt), + updatedAt = Value(updatedAt); + static Insertable custom({ + Expression? id, + Expression? conversationId, + Expression? title, + Expression? createdAt, + Expression? updatedAt, + Expression? rowid, + }) { + return RawValuesInsertable({ + if (id != null) 'id': id, + if (conversationId != null) 'conversation_id': conversationId, + if (title != null) 'title': title, + if (createdAt != null) 'created_at': createdAt, + if (updatedAt != null) 'updated_at': updatedAt, + if (rowid != null) 'rowid': rowid, + }); + } + + AiChatThreadsCompanion copyWith({ + Value? id, + Value? conversationId, + Value? title, + Value? createdAt, + Value? updatedAt, + Value? rowid, + }) { + return AiChatThreadsCompanion( + id: id ?? this.id, + conversationId: conversationId ?? this.conversationId, + title: title ?? this.title, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + rowid: rowid ?? this.rowid, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (id.present) { + map['id'] = Variable(id.value); + } + if (conversationId.present) { + map['conversation_id'] = Variable(conversationId.value); + } + if (title.present) { + map['title'] = Variable(title.value); + } + if (createdAt.present) { + map['created_at'] = Variable( + AiChatThreads.$convertercreatedAt.toSql(createdAt.value), + ); + } + if (updatedAt.present) { + map['updated_at'] = Variable( + AiChatThreads.$converterupdatedAt.toSql(updatedAt.value), + ); + } + if (rowid.present) { + map['rowid'] = Variable(rowid.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('AiChatThreadsCompanion(') + ..write('id: $id, ') + ..write('conversationId: $conversationId, ') + ..write('title: $title, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt, ') + ..write('rowid: $rowid') + ..write(')')) + .toString(); + } +} + +abstract class _$AiDatabase extends GeneratedDatabase { + _$AiDatabase(QueryExecutor e) : super(e); + $AiDatabaseManager get managers => $AiDatabaseManager(this); + late final AiChatMessages aiChatMessages = AiChatMessages(this); + late final AiChatThreads aiChatThreads = AiChatThreads(this); + late final Index indexAiChatMessagesConversationIdCreatedAt = Index( + 'index_ai_chat_messages_conversation_id_created_at', + 'CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages (conversation_id, created_at DESC)', + ); + late final Index indexAiChatMessagesThreadIdCreatedAt = Index( + 'index_ai_chat_messages_thread_id_created_at', + 'CREATE INDEX IF NOT EXISTS index_ai_chat_messages_thread_id_created_at ON ai_chat_messages (thread_id, created_at DESC)', + ); + late final Index indexAiChatThreadsConversationIdUpdatedAt = Index( + 'index_ai_chat_threads_conversation_id_updated_at', + 'CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads (conversation_id, updated_at DESC)', + ); + late final AiChatMessageDao aiChatMessageDao = AiChatMessageDao( + this as AiDatabase, + ); + @override + Iterable> get allTables => + allSchemaEntities.whereType>(); + @override + List get allSchemaEntities => [ + aiChatMessages, + aiChatThreads, + indexAiChatMessagesConversationIdCreatedAt, + indexAiChatMessagesThreadIdCreatedAt, + indexAiChatThreadsConversationIdUpdatedAt, + ]; +} + +typedef $AiChatMessagesCreateCompanionBuilder = + AiChatMessagesCompanion Function({ + required String id, + Value threadId, + required String conversationId, + required String role, + required String providerId, + required String content, + required String status, + Value model, + Value errorText, + Value metadata, + required DateTime createdAt, + required DateTime updatedAt, + Value rowid, + }); +typedef $AiChatMessagesUpdateCompanionBuilder = + AiChatMessagesCompanion Function({ + Value id, + Value threadId, + Value conversationId, + Value role, + Value providerId, + Value content, + Value status, + Value model, + Value errorText, + Value metadata, + Value createdAt, + Value updatedAt, + Value rowid, + }); + +class $AiChatMessagesFilterComposer + extends Composer<_$AiDatabase, AiChatMessages> { + $AiChatMessagesFilterComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnFilters get id => $composableBuilder( + column: $table.id, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get threadId => $composableBuilder( + column: $table.threadId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get role => $composableBuilder( + column: $table.role, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get providerId => $composableBuilder( + column: $table.providerId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get content => $composableBuilder( + column: $table.content, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get model => $composableBuilder( + column: $table.model, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get errorText => $composableBuilder( + column: $table.errorText, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get metadata => $composableBuilder( + column: $table.metadata, + builder: (column) => ColumnFilters(column), + ); + + ColumnWithTypeConverterFilters get createdAt => + $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnWithTypeConverterFilters get updatedAt => + $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); +} + +class $AiChatMessagesOrderingComposer + extends Composer<_$AiDatabase, AiChatMessages> { + $AiChatMessagesOrderingComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnOrderings get id => $composableBuilder( + column: $table.id, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get threadId => $composableBuilder( + column: $table.threadId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get role => $composableBuilder( + column: $table.role, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get providerId => $composableBuilder( + column: $table.providerId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get content => $composableBuilder( + column: $table.content, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get model => $composableBuilder( + column: $table.model, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get errorText => $composableBuilder( + column: $table.errorText, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get metadata => $composableBuilder( + column: $table.metadata, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get createdAt => $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get updatedAt => $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnOrderings(column), + ); +} + +class $AiChatMessagesAnnotationComposer + extends Composer<_$AiDatabase, AiChatMessages> { + $AiChatMessagesAnnotationComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + GeneratedColumn get id => + $composableBuilder(column: $table.id, builder: (column) => column); + + GeneratedColumn get threadId => + $composableBuilder(column: $table.threadId, builder: (column) => column); + + GeneratedColumn get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => column, + ); + + GeneratedColumn get role => + $composableBuilder(column: $table.role, builder: (column) => column); + + GeneratedColumn get providerId => $composableBuilder( + column: $table.providerId, + builder: (column) => column, + ); + + GeneratedColumn get content => + $composableBuilder(column: $table.content, builder: (column) => column); + + GeneratedColumn get status => + $composableBuilder(column: $table.status, builder: (column) => column); + + GeneratedColumn get model => + $composableBuilder(column: $table.model, builder: (column) => column); + + GeneratedColumn get errorText => + $composableBuilder(column: $table.errorText, builder: (column) => column); + + GeneratedColumn get metadata => + $composableBuilder(column: $table.metadata, builder: (column) => column); + + GeneratedColumnWithTypeConverter get createdAt => + $composableBuilder(column: $table.createdAt, builder: (column) => column); + + GeneratedColumnWithTypeConverter get updatedAt => + $composableBuilder(column: $table.updatedAt, builder: (column) => column); +} + +class $AiChatMessagesTableManager + extends + RootTableManager< + _$AiDatabase, + AiChatMessages, + AiChatMessage, + $AiChatMessagesFilterComposer, + $AiChatMessagesOrderingComposer, + $AiChatMessagesAnnotationComposer, + $AiChatMessagesCreateCompanionBuilder, + $AiChatMessagesUpdateCompanionBuilder, + ( + AiChatMessage, + BaseReferences<_$AiDatabase, AiChatMessages, AiChatMessage>, + ), + AiChatMessage, + PrefetchHooks Function() + > { + $AiChatMessagesTableManager(_$AiDatabase db, AiChatMessages table) + : super( + TableManagerState( + db: db, + table: table, + createFilteringComposer: () => + $AiChatMessagesFilterComposer($db: db, $table: table), + createOrderingComposer: () => + $AiChatMessagesOrderingComposer($db: db, $table: table), + createComputedFieldComposer: () => + $AiChatMessagesAnnotationComposer($db: db, $table: table), + updateCompanionCallback: + ({ + Value id = const Value.absent(), + Value threadId = const Value.absent(), + Value conversationId = const Value.absent(), + Value role = const Value.absent(), + Value providerId = const Value.absent(), + Value content = const Value.absent(), + Value status = const Value.absent(), + Value model = const Value.absent(), + Value errorText = const Value.absent(), + Value metadata = const Value.absent(), + Value createdAt = const Value.absent(), + Value updatedAt = const Value.absent(), + Value rowid = const Value.absent(), + }) => AiChatMessagesCompanion( + id: id, + threadId: threadId, + conversationId: conversationId, + role: role, + providerId: providerId, + content: content, + status: status, + model: model, + errorText: errorText, + metadata: metadata, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + createCompanionCallback: + ({ + required String id, + Value threadId = const Value.absent(), + required String conversationId, + required String role, + required String providerId, + required String content, + required String status, + Value model = const Value.absent(), + Value errorText = const Value.absent(), + Value metadata = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + Value rowid = const Value.absent(), + }) => AiChatMessagesCompanion.insert( + id: id, + threadId: threadId, + conversationId: conversationId, + role: role, + providerId: providerId, + content: content, + status: status, + model: model, + errorText: errorText, + metadata: metadata, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + withReferenceMapper: (p0) => p0 + .map((e) => (e.readTable(table), BaseReferences(db, table, e))) + .toList(), + prefetchHooksCallback: null, + ), + ); +} + +typedef $AiChatMessagesProcessedTableManager = + ProcessedTableManager< + _$AiDatabase, + AiChatMessages, + AiChatMessage, + $AiChatMessagesFilterComposer, + $AiChatMessagesOrderingComposer, + $AiChatMessagesAnnotationComposer, + $AiChatMessagesCreateCompanionBuilder, + $AiChatMessagesUpdateCompanionBuilder, + ( + AiChatMessage, + BaseReferences<_$AiDatabase, AiChatMessages, AiChatMessage>, + ), + AiChatMessage, + PrefetchHooks Function() + >; +typedef $AiChatThreadsCreateCompanionBuilder = + AiChatThreadsCompanion Function({ + required String id, + required String conversationId, + Value title, + required DateTime createdAt, + required DateTime updatedAt, + Value rowid, + }); +typedef $AiChatThreadsUpdateCompanionBuilder = + AiChatThreadsCompanion Function({ + Value id, + Value conversationId, + Value title, + Value createdAt, + Value updatedAt, + Value rowid, + }); + +class $AiChatThreadsFilterComposer + extends Composer<_$AiDatabase, AiChatThreads> { + $AiChatThreadsFilterComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnFilters get id => $composableBuilder( + column: $table.id, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get title => $composableBuilder( + column: $table.title, + builder: (column) => ColumnFilters(column), + ); + + ColumnWithTypeConverterFilters get createdAt => + $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnWithTypeConverterFilters get updatedAt => + $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); +} + +class $AiChatThreadsOrderingComposer + extends Composer<_$AiDatabase, AiChatThreads> { + $AiChatThreadsOrderingComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnOrderings get id => $composableBuilder( + column: $table.id, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get title => $composableBuilder( + column: $table.title, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get createdAt => $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get updatedAt => $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnOrderings(column), + ); +} + +class $AiChatThreadsAnnotationComposer + extends Composer<_$AiDatabase, AiChatThreads> { + $AiChatThreadsAnnotationComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + GeneratedColumn get id => + $composableBuilder(column: $table.id, builder: (column) => column); + + GeneratedColumn get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => column, + ); + + GeneratedColumn get title => + $composableBuilder(column: $table.title, builder: (column) => column); + + GeneratedColumnWithTypeConverter get createdAt => + $composableBuilder(column: $table.createdAt, builder: (column) => column); + + GeneratedColumnWithTypeConverter get updatedAt => + $composableBuilder(column: $table.updatedAt, builder: (column) => column); +} + +class $AiChatThreadsTableManager + extends + RootTableManager< + _$AiDatabase, + AiChatThreads, + AiChatThread, + $AiChatThreadsFilterComposer, + $AiChatThreadsOrderingComposer, + $AiChatThreadsAnnotationComposer, + $AiChatThreadsCreateCompanionBuilder, + $AiChatThreadsUpdateCompanionBuilder, + ( + AiChatThread, + BaseReferences<_$AiDatabase, AiChatThreads, AiChatThread>, + ), + AiChatThread, + PrefetchHooks Function() + > { + $AiChatThreadsTableManager(_$AiDatabase db, AiChatThreads table) + : super( + TableManagerState( + db: db, + table: table, + createFilteringComposer: () => + $AiChatThreadsFilterComposer($db: db, $table: table), + createOrderingComposer: () => + $AiChatThreadsOrderingComposer($db: db, $table: table), + createComputedFieldComposer: () => + $AiChatThreadsAnnotationComposer($db: db, $table: table), + updateCompanionCallback: + ({ + Value id = const Value.absent(), + Value conversationId = const Value.absent(), + Value title = const Value.absent(), + Value createdAt = const Value.absent(), + Value updatedAt = const Value.absent(), + Value rowid = const Value.absent(), + }) => AiChatThreadsCompanion( + id: id, + conversationId: conversationId, + title: title, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + createCompanionCallback: + ({ + required String id, + required String conversationId, + Value title = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + Value rowid = const Value.absent(), + }) => AiChatThreadsCompanion.insert( + id: id, + conversationId: conversationId, + title: title, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + withReferenceMapper: (p0) => p0 + .map((e) => (e.readTable(table), BaseReferences(db, table, e))) + .toList(), + prefetchHooksCallback: null, + ), + ); +} + +typedef $AiChatThreadsProcessedTableManager = + ProcessedTableManager< + _$AiDatabase, + AiChatThreads, + AiChatThread, + $AiChatThreadsFilterComposer, + $AiChatThreadsOrderingComposer, + $AiChatThreadsAnnotationComposer, + $AiChatThreadsCreateCompanionBuilder, + $AiChatThreadsUpdateCompanionBuilder, + (AiChatThread, BaseReferences<_$AiDatabase, AiChatThreads, AiChatThread>), + AiChatThread, + PrefetchHooks Function() + >; + +class $AiDatabaseManager { + final _$AiDatabase _db; + $AiDatabaseManager(this._db); + $AiChatMessagesTableManager get aiChatMessages => + $AiChatMessagesTableManager(_db, _db.aiChatMessages); + $AiChatThreadsTableManager get aiChatThreads => + $AiChatThreadsTableManager(_db, _db.aiChatThreads); +} diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index 4a1929f83f..ebe54752f4 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -1,40 +1,95 @@ import 'package:drift/drift.dart'; +import 'package:uuid/uuid.dart'; import '../../ai/model/ai_chat_metadata.dart'; -import '../mixin_database.dart'; +import '../ai_database.dart'; part 'ai_chat_message_dao.g.dart'; @DriftAccessor() -class AiChatMessageDao extends DatabaseAccessor +class AiChatMessageDao extends DatabaseAccessor with _$AiChatMessageDaoMixin { AiChatMessageDao(super.db); static const assistantRole = 'assistant'; static const pendingStatus = 'pending'; static const errorStatus = 'error'; + static const _uuid = Uuid(); - Stream> watchConversationMessages( - String conversationId, - ) => + Stream watchLatestThread(String conversationId) => + (select(db.aiChatThreads) + ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.updatedAt), + (tbl) => OrderingTerm.desc(tbl.createdAt), + (tbl) => OrderingTerm.desc(tbl.id), + ]) + ..limit(1)) + .watchSingleOrNull(); + + Future latestThread(String conversationId) => + (select(db.aiChatThreads) + ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.updatedAt), + (tbl) => OrderingTerm.desc(tbl.createdAt), + (tbl) => OrderingTerm.desc(tbl.id), + ]) + ..limit(1)) + .getSingleOrNull(); + + Future threadById(String threadId) => (select( + db.aiChatThreads, + )..where((tbl) => tbl.id.equals(threadId))).getSingleOrNull(); + + Future createThread(String conversationId) async { + final now = DateTime.now(); + final thread = AiChatThread( + id: _uuid.v4(), + conversationId: conversationId, + createdAt: now, + updatedAt: now, + ); + await into(db.aiChatThreads).insert(thread); + return thread; + } + + Future ensureThread({ + required String conversationId, + String? threadId, + }) async { + if (threadId != null) { + final thread = await threadById(threadId); + if (thread == null || thread.conversationId != conversationId) { + throw StateError('AI thread not found'); + } + return thread; + } + + final existing = await latestThread(conversationId); + if (existing != null) return existing; + return createThread(conversationId); + } + + Stream> watchThreadMessages(String threadId) => (select( db.aiChatMessages, ) - ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..where((tbl) => tbl.threadId.equals(threadId)) ..orderBy([ (tbl) => OrderingTerm.asc(tbl.createdAt), (tbl) => OrderingTerm.asc(tbl.id), ])) .watch(); - Stream> watchLatestConversationMessages( - String conversationId, + Stream> watchLatestThreadMessages( + String threadId, int limit, ) => (select( db.aiChatMessages, ) - ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..where((tbl) => tbl.threadId.equals(threadId)) ..orderBy([ (tbl) => OrderingTerm.desc(tbl.createdAt), (tbl) => OrderingTerm.desc(tbl.id), @@ -43,19 +98,19 @@ class AiChatMessageDao extends DatabaseAccessor .watch() .map((items) => items.reversed.toList(growable: false)); - Future> conversationMessages(String conversationId) => + Future> threadMessages(String threadId) => (select( db.aiChatMessages, ) - ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..where((tbl) => tbl.threadId.equals(threadId)) ..orderBy([ (tbl) => OrderingTerm.asc(tbl.createdAt), (tbl) => OrderingTerm.asc(tbl.id), ])) .get(); - Future> beforeConversationMessages({ - required String conversationId, + Future> beforeThreadMessages({ + required String threadId, required AiChatMessage before, required int limit, }) async { @@ -66,7 +121,7 @@ class AiChatMessageDao extends DatabaseAccessor ) ..where( (tbl) => - tbl.conversationId.equals(conversationId) & + tbl.threadId.equals(threadId) & (tbl.createdAt.isSmallerThanValue(beforeCreatedAt) | (tbl.createdAt.equals(beforeCreatedAt) & tbl.id.isSmallerThanValue(before.id))), @@ -80,8 +135,10 @@ class AiChatMessageDao extends DatabaseAccessor return list.reversed.toList(growable: false); } - Future insertMessage(AiChatMessagesCompanion row) => - into(db.aiChatMessages).insertOnConflictUpdate(row); + Future insertMessage(AiChatMessagesCompanion row) async { + await into(db.aiChatMessages).insertOnConflictUpdate(row); + await _touchThread(row.threadId.value); + } Future updateMessageContent( String id, @@ -131,18 +188,45 @@ class AiChatMessageDao extends DatabaseAccessor }); } + Future setMessageMetadataResponse( + String id, + Map responseMetadata, { + required DateTime updatedAt, + }) async { + await transaction(() async { + final message = await (select( + db.aiChatMessages, + )..where((tbl) => tbl.id.equals(id))).getSingleOrNull(); + if (message == null) { + return; + } + final metadata = setAiResponseMetadata( + message.metadata, + responseMetadata, + ); + await (update( + db.aiChatMessages, + )..where((tbl) => tbl.id.equals(id))).write( + AiChatMessagesCompanion( + metadata: Value(metadata), + updatedAt: Value(updatedAt), + ), + ); + }); + } + Future deleteConversationMessages(String conversationId) => (delete( db.aiChatMessages, )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); Future hasPendingAssistantMessage( - String conversationId, { + String threadId, { DateTime? updatedAfter, }) async { final query = selectOnly(db.aiChatMessages) ..addColumns([db.aiChatMessages.id.count()]) ..where( - db.aiChatMessages.conversationId.equals(conversationId) & + db.aiChatMessages.threadId.equals(threadId) & db.aiChatMessages.role.equals(assistantRole) & db.aiChatMessages.status.equals(pendingStatus) & (updatedAfter == null @@ -159,6 +243,7 @@ class AiChatMessageDao extends DatabaseAccessor Future resolveStalePendingAssistantMessages({ required DateTime updatedBefore, String? conversationId, + String? threadId, String errorText = 'Interrupted by app restart', }) { final query = update(db.aiChatMessages) @@ -171,7 +256,10 @@ class AiChatMessageDao extends DatabaseAccessor ) & (conversationId == null ? const Constant(true) - : tbl.conversationId.equals(conversationId)), + : tbl.conversationId.equals(conversationId)) & + (threadId == null + ? const Constant(true) + : tbl.threadId.equals(threadId)), ); return query.write( AiChatMessagesCompanion( @@ -181,4 +269,11 @@ class AiChatMessageDao extends DatabaseAccessor ), ); } + + Future _touchThread(String threadId) => + (update( + db.aiChatThreads, + )..where((tbl) => tbl.id.equals(threadId))).write( + AiChatThreadsCompanion(updatedAt: Value(DateTime.now())), + ); } diff --git a/lib/db/dao/ai_chat_message_dao.g.dart b/lib/db/dao/ai_chat_message_dao.g.dart index 7ffc7133ff..9d95f46028 100644 --- a/lib/db/dao/ai_chat_message_dao.g.dart +++ b/lib/db/dao/ai_chat_message_dao.g.dart @@ -3,7 +3,7 @@ part of 'ai_chat_message_dao.dart'; // ignore_for_file: type=lint -mixin _$AiChatMessageDaoMixin on DatabaseAccessor { +mixin _$AiChatMessageDaoMixin on DatabaseAccessor { AiChatMessageDaoManager get managers => AiChatMessageDaoManager(this); } diff --git a/lib/db/dao/asset_dao.g.dart b/lib/db/dao/asset_dao.g.dart index cc367fcc26..12e7d29377 100644 --- a/lib/db/dao/asset_dao.g.dart +++ b/lib/db/dao/asset_dao.g.dart @@ -39,7 +39,6 @@ mixin _$AssetDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -165,8 +164,6 @@ class AssetDaoManager { $ExpiredMessagesTableManager(_db.attachedDatabase, _db.expiredMessages); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/chain_dao.g.dart b/lib/db/dao/chain_dao.g.dart index 6d337c8b17..0c5edb0e6d 100644 --- a/lib/db/dao/chain_dao.g.dart +++ b/lib/db/dao/chain_dao.g.dart @@ -39,7 +39,6 @@ mixin _$ChainDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -127,8 +126,6 @@ class ChainDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/circle_conversation_dao.g.dart b/lib/db/dao/circle_conversation_dao.g.dart index ebe5a98ba9..2f4472a88c 100644 --- a/lib/db/dao/circle_conversation_dao.g.dart +++ b/lib/db/dao/circle_conversation_dao.g.dart @@ -39,7 +39,6 @@ mixin _$CircleConversationDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -146,8 +145,6 @@ class CircleConversationDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/circle_dao.g.dart b/lib/db/dao/circle_dao.g.dart index ff29b0ce0b..f2d332e379 100644 --- a/lib/db/dao/circle_dao.g.dart +++ b/lib/db/dao/circle_dao.g.dart @@ -39,7 +39,6 @@ mixin _$CircleDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -184,8 +183,6 @@ class CircleDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/conversation_dao.g.dart b/lib/db/dao/conversation_dao.g.dart index 595adcc3ec..eab0696108 100644 --- a/lib/db/dao/conversation_dao.g.dart +++ b/lib/db/dao/conversation_dao.g.dart @@ -39,7 +39,6 @@ mixin _$ConversationDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -801,8 +800,6 @@ class ConversationDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/expired_message_dao.g.dart b/lib/db/dao/expired_message_dao.g.dart index d1cb868dc4..a8c2b5524d 100644 --- a/lib/db/dao/expired_message_dao.g.dart +++ b/lib/db/dao/expired_message_dao.g.dart @@ -39,7 +39,6 @@ mixin _$ExpiredMessageDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -197,8 +196,6 @@ class ExpiredMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/favorite_app_dao.g.dart b/lib/db/dao/favorite_app_dao.g.dart index 4642f1d59d..69968be3c0 100644 --- a/lib/db/dao/favorite_app_dao.g.dart +++ b/lib/db/dao/favorite_app_dao.g.dart @@ -39,7 +39,6 @@ mixin _$FavoriteAppDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -144,8 +143,6 @@ class FavoriteAppDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/flood_message_dao.g.dart b/lib/db/dao/flood_message_dao.g.dart index 52042fb936..9d78277a5b 100644 --- a/lib/db/dao/flood_message_dao.g.dart +++ b/lib/db/dao/flood_message_dao.g.dart @@ -39,7 +39,6 @@ mixin _$FloodMessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -139,8 +138,6 @@ class FloodMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/inscription_collection_dao.g.dart b/lib/db/dao/inscription_collection_dao.g.dart index e9d4cb8d3e..08ed37c7a4 100644 --- a/lib/db/dao/inscription_collection_dao.g.dart +++ b/lib/db/dao/inscription_collection_dao.g.dart @@ -39,7 +39,6 @@ mixin _$InscriptionCollectionDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -128,8 +127,6 @@ class InscriptionCollectionDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/inscription_item_dao.g.dart b/lib/db/dao/inscription_item_dao.g.dart index 6c995896e0..5c45471e24 100644 --- a/lib/db/dao/inscription_item_dao.g.dart +++ b/lib/db/dao/inscription_item_dao.g.dart @@ -42,7 +42,6 @@ mixin _$InscriptionItemDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; Selectable inscriptionByHash(String hash) { @@ -152,8 +151,6 @@ class InscriptionItemDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/message_dao.g.dart b/lib/db/dao/message_dao.g.dart index b9c45209c8..e69968a36f 100644 --- a/lib/db/dao/message_dao.g.dart +++ b/lib/db/dao/message_dao.g.dart @@ -39,7 +39,6 @@ mixin _$MessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -493,8 +492,6 @@ class MessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/participant_dao.g.dart b/lib/db/dao/participant_dao.g.dart index 52e0232ad6..b44cf301e7 100644 --- a/lib/db/dao/participant_dao.g.dart +++ b/lib/db/dao/participant_dao.g.dart @@ -39,7 +39,6 @@ mixin _$ParticipantDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -217,8 +216,6 @@ class ParticipantDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/participant_session_dao.g.dart b/lib/db/dao/participant_session_dao.g.dart index 8f2e9cf93e..fcf4c41ee9 100644 --- a/lib/db/dao/participant_session_dao.g.dart +++ b/lib/db/dao/participant_session_dao.g.dart @@ -39,7 +39,6 @@ mixin _$ParticipantSessionDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -183,8 +182,6 @@ class ParticipantSessionDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/pin_message_dao.g.dart b/lib/db/dao/pin_message_dao.g.dart index 3375cd094b..a3d2dd2725 100644 --- a/lib/db/dao/pin_message_dao.g.dart +++ b/lib/db/dao/pin_message_dao.g.dart @@ -39,7 +39,6 @@ mixin _$PinMessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -162,8 +161,6 @@ class PinMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/property_dao.g.dart b/lib/db/dao/property_dao.g.dart index d83c18cb18..f611989531 100644 --- a/lib/db/dao/property_dao.g.dart +++ b/lib/db/dao/property_dao.g.dart @@ -39,7 +39,6 @@ mixin _$PropertyDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -127,8 +126,6 @@ class PropertyDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/safe_snapshot_dao.g.dart b/lib/db/dao/safe_snapshot_dao.g.dart index 40d9713c2b..85d4854f19 100644 --- a/lib/db/dao/safe_snapshot_dao.g.dart +++ b/lib/db/dao/safe_snapshot_dao.g.dart @@ -41,7 +41,6 @@ mixin _$SafeSnapshotDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; InscriptionCollections get inscriptionCollections => attachedDatabase.inscriptionCollections; InscriptionItems get inscriptionItems => attachedDatabase.inscriptionItems; @@ -222,8 +221,6 @@ class SafeSnapshotDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $InscriptionCollectionsTableManager get inscriptionCollections => $InscriptionCollectionsTableManager( _db.attachedDatabase, diff --git a/lib/db/dao/snapshot_dao.g.dart b/lib/db/dao/snapshot_dao.g.dart index fb32f4242b..aba62b6afe 100644 --- a/lib/db/dao/snapshot_dao.g.dart +++ b/lib/db/dao/snapshot_dao.g.dart @@ -39,7 +39,6 @@ mixin _$SnapshotDaoMixin on DatabaseAccessor { FavoriteApps get favoriteApps => attachedDatabase.favoriteApps; ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -231,8 +230,6 @@ class SnapshotDaoManager { $ExpiredMessagesTableManager(_db.attachedDatabase, _db.expiredMessages); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/sticker_album_dao.g.dart b/lib/db/dao/sticker_album_dao.g.dart index acee16eacb..ae0ab42d06 100644 --- a/lib/db/dao/sticker_album_dao.g.dart +++ b/lib/db/dao/sticker_album_dao.g.dart @@ -39,7 +39,6 @@ mixin _$StickerAlbumDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -127,8 +126,6 @@ class StickerAlbumDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/sticker_dao.g.dart b/lib/db/dao/sticker_dao.g.dart index f2b1d48f31..f8b6a8a11d 100644 --- a/lib/db/dao/sticker_dao.g.dart +++ b/lib/db/dao/sticker_dao.g.dart @@ -39,7 +39,6 @@ mixin _$StickerDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -151,8 +150,6 @@ class StickerDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/sticker_relationship_dao.g.dart b/lib/db/dao/sticker_relationship_dao.g.dart index 078bba0360..efbeb178c6 100644 --- a/lib/db/dao/sticker_relationship_dao.g.dart +++ b/lib/db/dao/sticker_relationship_dao.g.dart @@ -39,7 +39,6 @@ mixin _$StickerRelationshipDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -144,8 +143,6 @@ class StickerRelationshipDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/token_dao.g.dart b/lib/db/dao/token_dao.g.dart index 82b5064992..641feb2cdf 100644 --- a/lib/db/dao/token_dao.g.dart +++ b/lib/db/dao/token_dao.g.dart @@ -40,7 +40,6 @@ mixin _$TokenDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; InscriptionCollections get inscriptionCollections => attachedDatabase.inscriptionCollections; @@ -137,8 +136,6 @@ class TokenDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $InscriptionCollectionsTableManager get inscriptionCollections => diff --git a/lib/db/dao/transcript_message_dao.g.dart b/lib/db/dao/transcript_message_dao.g.dart index ec5faee347..1a6c4cbfb7 100644 --- a/lib/db/dao/transcript_message_dao.g.dart +++ b/lib/db/dao/transcript_message_dao.g.dart @@ -39,7 +39,6 @@ mixin _$TranscriptMessageDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -231,8 +230,6 @@ class TranscriptMessageDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/dao/user_dao.g.dart b/lib/db/dao/user_dao.g.dart index 1939b66a2f..5cf1907be6 100644 --- a/lib/db/dao/user_dao.g.dart +++ b/lib/db/dao/user_dao.g.dart @@ -39,7 +39,6 @@ mixin _$UserDaoMixin on DatabaseAccessor { ExpiredMessages get expiredMessages => attachedDatabase.expiredMessages; Chains get chains => attachedDatabase.chains; Properties get properties => attachedDatabase.properties; - AiChatMessages get aiChatMessages => attachedDatabase.aiChatMessages; SafeSnapshots get safeSnapshots => attachedDatabase.safeSnapshots; Tokens get tokens => attachedDatabase.tokens; InscriptionCollections get inscriptionCollections => @@ -311,8 +310,6 @@ class UserDaoManager { $ChainsTableManager(_db.attachedDatabase, _db.chains); $PropertiesTableManager get properties => $PropertiesTableManager(_db.attachedDatabase, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db.attachedDatabase, _db.aiChatMessages); $SafeSnapshotsTableManager get safeSnapshots => $SafeSnapshotsTableManager(_db.attachedDatabase, _db.safeSnapshots); $TokensTableManager get tokens => diff --git a/lib/db/database.dart b/lib/db/database.dart index 1e76227eed..200f46ce60 100644 --- a/lib/db/database.dart +++ b/lib/db/database.dart @@ -4,6 +4,7 @@ import '../ui/provider/slide_category_provider.dart'; import '../utils/extension/extension.dart'; import '../utils/logger.dart'; import '../utils/property/setting_property.dart'; +import 'ai_database.dart'; import 'dao/ai_chat_message_dao.dart'; import 'dao/app_dao.dart'; import 'dao/asset_dao.dart'; @@ -38,7 +39,7 @@ import 'fts_database.dart'; import 'mixin_database.dart'; class Database { - Database(this.mixinDatabase, this.ftsDatabase) { + Database(this.mixinDatabase, this.ftsDatabase, this.aiDatabase) { settingProperties = SettingPropertyStorage(mixinDatabase.propertyDao); } @@ -46,9 +47,11 @@ class Database { final FtsDatabase ftsDatabase; + final AiDatabase aiDatabase; + AppDao get appDao => mixinDatabase.appDao; - AiChatMessageDao get aiChatMessageDao => mixinDatabase.aiChatMessageDao; + AiChatMessageDao get aiChatMessageDao => aiDatabase.aiChatMessageDao; AssetDao get assetDao => mixinDatabase.assetDao; @@ -117,6 +120,7 @@ class Database { Future dispose() async { await mixinDatabase.close(); await ftsDatabase.close(); + await aiDatabase.close(); // dispose stream, https://github.com/simolus3/moor/issues/290 } diff --git a/lib/db/mixin_database.dart b/lib/db/mixin_database.dart index ab55e76c35..c1ce880652 100644 --- a/lib/db/mixin_database.dart +++ b/lib/db/mixin_database.dart @@ -18,7 +18,6 @@ import 'converter/safe_deposit_type_converter.dart'; import 'converter/safe_withdrawal_type_converter.dart'; import 'converter/user_relationship_converter.dart'; import 'dao/address_dao.dart'; -import 'dao/ai_chat_message_dao.dart'; import 'dao/app_dao.dart'; import 'dao/asset_dao.dart'; import 'dao/chain_dao.dart'; @@ -62,7 +61,6 @@ part 'mixin_database.g.dart'; include: {'moor/mixin.drift', 'moor/dao/common.drift'}, daos: [ AddressDao, - AiChatMessageDao, AppDao, AssetDao, CircleConversationDao, @@ -101,7 +99,7 @@ class MixinDatabase extends _$MixinDatabase { MixinDatabase(super.e); @override - int get schemaVersion => 31; + int get schemaVersion => 28; final eventBus = DataBaseEventBus.instance; @@ -280,25 +278,6 @@ class MixinDatabase extends _$MixinDatabase { if (from <= 27) { await _addColumnIfNotExists(m, tokens, tokens.precision); } - if (from <= 28) { - await m.createTable(aiChatMessages); - await m.createIndex(indexAiChatMessagesConversationIdCreatedAt); - } - if (from <= 29) { - await _addColumnIfNotExists( - m, - aiChatMessages, - aiChatMessages.anchorMessageId, - ); - await _addColumnIfNotExists( - m, - aiChatMessages, - aiChatMessages.anchorCreatedAt, - ); - } - if (from <= 30) { - await _addColumnIfNotExists(m, aiChatMessages, aiChatMessages.metadata); - } }, beforeOpen: (details) async { if (details.hadUpgrade && details.versionBefore! <= 20) { diff --git a/lib/db/mixin_database.g.dart b/lib/db/mixin_database.g.dart index 94de2d2c84..9a4c3f0f7e 100644 --- a/lib/db/mixin_database.g.dart +++ b/lib/db/mixin_database.g.dart @@ -17665,784 +17665,6 @@ class PropertiesCompanion extends UpdateCompanion { } } -class AiChatMessages extends Table - with TableInfo { - @override - final GeneratedDatabase attachedDatabase; - final String? _alias; - AiChatMessages(this.attachedDatabase, [this._alias]); - static const VerificationMeta _idMeta = const VerificationMeta('id'); - late final GeneratedColumn id = GeneratedColumn( - 'id', - aliasedName, - false, - type: DriftSqlType.string, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ); - static const VerificationMeta _conversationIdMeta = const VerificationMeta( - 'conversationId', - ); - late final GeneratedColumn conversationId = GeneratedColumn( - 'conversation_id', - aliasedName, - false, - type: DriftSqlType.string, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ); - static const VerificationMeta _roleMeta = const VerificationMeta('role'); - late final GeneratedColumn role = GeneratedColumn( - 'role', - aliasedName, - false, - type: DriftSqlType.string, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ); - static const VerificationMeta _providerIdMeta = const VerificationMeta( - 'providerId', - ); - late final GeneratedColumn providerId = GeneratedColumn( - 'provider_id', - aliasedName, - false, - type: DriftSqlType.string, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ); - static const VerificationMeta _anchorMessageIdMeta = const VerificationMeta( - 'anchorMessageId', - ); - late final GeneratedColumn anchorMessageId = GeneratedColumn( - 'anchor_message_id', - aliasedName, - true, - type: DriftSqlType.string, - requiredDuringInsert: false, - $customConstraints: '', - ); - late final GeneratedColumnWithTypeConverter anchorCreatedAt = - GeneratedColumn( - 'anchor_created_at', - aliasedName, - true, - type: DriftSqlType.int, - requiredDuringInsert: false, - $customConstraints: '', - ).withConverter(AiChatMessages.$converteranchorCreatedAtn); - static const VerificationMeta _contentMeta = const VerificationMeta( - 'content', - ); - late final GeneratedColumn content = GeneratedColumn( - 'content', - aliasedName, - false, - type: DriftSqlType.string, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ); - static const VerificationMeta _statusMeta = const VerificationMeta('status'); - late final GeneratedColumn status = GeneratedColumn( - 'status', - aliasedName, - false, - type: DriftSqlType.string, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ); - static const VerificationMeta _modelMeta = const VerificationMeta('model'); - late final GeneratedColumn model = GeneratedColumn( - 'model', - aliasedName, - true, - type: DriftSqlType.string, - requiredDuringInsert: false, - $customConstraints: '', - ); - static const VerificationMeta _errorTextMeta = const VerificationMeta( - 'errorText', - ); - late final GeneratedColumn errorText = GeneratedColumn( - 'error_text', - aliasedName, - true, - type: DriftSqlType.string, - requiredDuringInsert: false, - $customConstraints: '', - ); - static const VerificationMeta _metadataMeta = const VerificationMeta( - 'metadata', - ); - late final GeneratedColumn metadata = GeneratedColumn( - 'metadata', - aliasedName, - true, - type: DriftSqlType.string, - requiredDuringInsert: false, - $customConstraints: '', - ); - late final GeneratedColumnWithTypeConverter createdAt = - GeneratedColumn( - 'created_at', - aliasedName, - false, - type: DriftSqlType.int, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ).withConverter(AiChatMessages.$convertercreatedAt); - late final GeneratedColumnWithTypeConverter updatedAt = - GeneratedColumn( - 'updated_at', - aliasedName, - false, - type: DriftSqlType.int, - requiredDuringInsert: true, - $customConstraints: 'NOT NULL', - ).withConverter(AiChatMessages.$converterupdatedAt); - @override - List get $columns => [ - id, - conversationId, - role, - providerId, - anchorMessageId, - anchorCreatedAt, - content, - status, - model, - errorText, - metadata, - createdAt, - updatedAt, - ]; - @override - String get aliasedName => _alias ?? actualTableName; - @override - String get actualTableName => $name; - static const String $name = 'ai_chat_messages'; - @override - VerificationContext validateIntegrity( - Insertable instance, { - bool isInserting = false, - }) { - final context = VerificationContext(); - final data = instance.toColumns(true); - if (data.containsKey('id')) { - context.handle(_idMeta, id.isAcceptableOrUnknown(data['id']!, _idMeta)); - } else if (isInserting) { - context.missing(_idMeta); - } - if (data.containsKey('conversation_id')) { - context.handle( - _conversationIdMeta, - conversationId.isAcceptableOrUnknown( - data['conversation_id']!, - _conversationIdMeta, - ), - ); - } else if (isInserting) { - context.missing(_conversationIdMeta); - } - if (data.containsKey('role')) { - context.handle( - _roleMeta, - role.isAcceptableOrUnknown(data['role']!, _roleMeta), - ); - } else if (isInserting) { - context.missing(_roleMeta); - } - if (data.containsKey('provider_id')) { - context.handle( - _providerIdMeta, - providerId.isAcceptableOrUnknown(data['provider_id']!, _providerIdMeta), - ); - } else if (isInserting) { - context.missing(_providerIdMeta); - } - if (data.containsKey('anchor_message_id')) { - context.handle( - _anchorMessageIdMeta, - anchorMessageId.isAcceptableOrUnknown( - data['anchor_message_id']!, - _anchorMessageIdMeta, - ), - ); - } - if (data.containsKey('content')) { - context.handle( - _contentMeta, - content.isAcceptableOrUnknown(data['content']!, _contentMeta), - ); - } else if (isInserting) { - context.missing(_contentMeta); - } - if (data.containsKey('status')) { - context.handle( - _statusMeta, - status.isAcceptableOrUnknown(data['status']!, _statusMeta), - ); - } else if (isInserting) { - context.missing(_statusMeta); - } - if (data.containsKey('model')) { - context.handle( - _modelMeta, - model.isAcceptableOrUnknown(data['model']!, _modelMeta), - ); - } - if (data.containsKey('error_text')) { - context.handle( - _errorTextMeta, - errorText.isAcceptableOrUnknown(data['error_text']!, _errorTextMeta), - ); - } - if (data.containsKey('metadata')) { - context.handle( - _metadataMeta, - metadata.isAcceptableOrUnknown(data['metadata']!, _metadataMeta), - ); - } - return context; - } - - @override - Set get $primaryKey => {id}; - @override - AiChatMessage map(Map data, {String? tablePrefix}) { - final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; - return AiChatMessage( - id: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}id'], - )!, - conversationId: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}conversation_id'], - )!, - role: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}role'], - )!, - providerId: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}provider_id'], - )!, - anchorMessageId: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}anchor_message_id'], - ), - anchorCreatedAt: AiChatMessages.$converteranchorCreatedAtn.fromSql( - attachedDatabase.typeMapping.read( - DriftSqlType.int, - data['${effectivePrefix}anchor_created_at'], - ), - ), - content: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}content'], - )!, - status: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}status'], - )!, - model: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}model'], - ), - errorText: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}error_text'], - ), - metadata: attachedDatabase.typeMapping.read( - DriftSqlType.string, - data['${effectivePrefix}metadata'], - ), - createdAt: AiChatMessages.$convertercreatedAt.fromSql( - attachedDatabase.typeMapping.read( - DriftSqlType.int, - data['${effectivePrefix}created_at'], - )!, - ), - updatedAt: AiChatMessages.$converterupdatedAt.fromSql( - attachedDatabase.typeMapping.read( - DriftSqlType.int, - data['${effectivePrefix}updated_at'], - )!, - ), - ); - } - - @override - AiChatMessages createAlias(String alias) { - return AiChatMessages(attachedDatabase, alias); - } - - static TypeConverter $converteranchorCreatedAt = - const MillisDateConverter(); - static TypeConverter $converteranchorCreatedAtn = - NullAwareTypeConverter.wrap($converteranchorCreatedAt); - static TypeConverter $convertercreatedAt = - const MillisDateConverter(); - static TypeConverter $converterupdatedAt = - const MillisDateConverter(); - @override - List get customConstraints => const ['PRIMARY KEY(id)']; - @override - bool get dontWriteConstraints => true; -} - -class AiChatMessage extends DataClass implements Insertable { - final String id; - final String conversationId; - final String role; - final String providerId; - final String? anchorMessageId; - final DateTime? anchorCreatedAt; - final String content; - final String status; - final String? model; - final String? errorText; - final String? metadata; - final DateTime createdAt; - final DateTime updatedAt; - const AiChatMessage({ - required this.id, - required this.conversationId, - required this.role, - required this.providerId, - this.anchorMessageId, - this.anchorCreatedAt, - required this.content, - required this.status, - this.model, - this.errorText, - this.metadata, - required this.createdAt, - required this.updatedAt, - }); - @override - Map toColumns(bool nullToAbsent) { - final map = {}; - map['id'] = Variable(id); - map['conversation_id'] = Variable(conversationId); - map['role'] = Variable(role); - map['provider_id'] = Variable(providerId); - if (!nullToAbsent || anchorMessageId != null) { - map['anchor_message_id'] = Variable(anchorMessageId); - } - if (!nullToAbsent || anchorCreatedAt != null) { - map['anchor_created_at'] = Variable( - AiChatMessages.$converteranchorCreatedAtn.toSql(anchorCreatedAt), - ); - } - map['content'] = Variable(content); - map['status'] = Variable(status); - if (!nullToAbsent || model != null) { - map['model'] = Variable(model); - } - if (!nullToAbsent || errorText != null) { - map['error_text'] = Variable(errorText); - } - if (!nullToAbsent || metadata != null) { - map['metadata'] = Variable(metadata); - } - { - map['created_at'] = Variable( - AiChatMessages.$convertercreatedAt.toSql(createdAt), - ); - } - { - map['updated_at'] = Variable( - AiChatMessages.$converterupdatedAt.toSql(updatedAt), - ); - } - return map; - } - - AiChatMessagesCompanion toCompanion(bool nullToAbsent) { - return AiChatMessagesCompanion( - id: Value(id), - conversationId: Value(conversationId), - role: Value(role), - providerId: Value(providerId), - anchorMessageId: anchorMessageId == null && nullToAbsent - ? const Value.absent() - : Value(anchorMessageId), - anchorCreatedAt: anchorCreatedAt == null && nullToAbsent - ? const Value.absent() - : Value(anchorCreatedAt), - content: Value(content), - status: Value(status), - model: model == null && nullToAbsent - ? const Value.absent() - : Value(model), - errorText: errorText == null && nullToAbsent - ? const Value.absent() - : Value(errorText), - metadata: metadata == null && nullToAbsent - ? const Value.absent() - : Value(metadata), - createdAt: Value(createdAt), - updatedAt: Value(updatedAt), - ); - } - - factory AiChatMessage.fromJson( - Map json, { - ValueSerializer? serializer, - }) { - serializer ??= driftRuntimeOptions.defaultSerializer; - return AiChatMessage( - id: serializer.fromJson(json['id']), - conversationId: serializer.fromJson(json['conversation_id']), - role: serializer.fromJson(json['role']), - providerId: serializer.fromJson(json['provider_id']), - anchorMessageId: serializer.fromJson(json['anchor_message_id']), - anchorCreatedAt: serializer.fromJson( - json['anchor_created_at'], - ), - content: serializer.fromJson(json['content']), - status: serializer.fromJson(json['status']), - model: serializer.fromJson(json['model']), - errorText: serializer.fromJson(json['error_text']), - metadata: serializer.fromJson(json['metadata']), - createdAt: serializer.fromJson(json['created_at']), - updatedAt: serializer.fromJson(json['updated_at']), - ); - } - @override - Map toJson({ValueSerializer? serializer}) { - serializer ??= driftRuntimeOptions.defaultSerializer; - return { - 'id': serializer.toJson(id), - 'conversation_id': serializer.toJson(conversationId), - 'role': serializer.toJson(role), - 'provider_id': serializer.toJson(providerId), - 'anchor_message_id': serializer.toJson(anchorMessageId), - 'anchor_created_at': serializer.toJson(anchorCreatedAt), - 'content': serializer.toJson(content), - 'status': serializer.toJson(status), - 'model': serializer.toJson(model), - 'error_text': serializer.toJson(errorText), - 'metadata': serializer.toJson(metadata), - 'created_at': serializer.toJson(createdAt), - 'updated_at': serializer.toJson(updatedAt), - }; - } - - AiChatMessage copyWith({ - String? id, - String? conversationId, - String? role, - String? providerId, - Value anchorMessageId = const Value.absent(), - Value anchorCreatedAt = const Value.absent(), - String? content, - String? status, - Value model = const Value.absent(), - Value errorText = const Value.absent(), - Value metadata = const Value.absent(), - DateTime? createdAt, - DateTime? updatedAt, - }) => AiChatMessage( - id: id ?? this.id, - conversationId: conversationId ?? this.conversationId, - role: role ?? this.role, - providerId: providerId ?? this.providerId, - anchorMessageId: anchorMessageId.present - ? anchorMessageId.value - : this.anchorMessageId, - anchorCreatedAt: anchorCreatedAt.present - ? anchorCreatedAt.value - : this.anchorCreatedAt, - content: content ?? this.content, - status: status ?? this.status, - model: model.present ? model.value : this.model, - errorText: errorText.present ? errorText.value : this.errorText, - metadata: metadata.present ? metadata.value : this.metadata, - createdAt: createdAt ?? this.createdAt, - updatedAt: updatedAt ?? this.updatedAt, - ); - AiChatMessage copyWithCompanion(AiChatMessagesCompanion data) { - return AiChatMessage( - id: data.id.present ? data.id.value : this.id, - conversationId: data.conversationId.present - ? data.conversationId.value - : this.conversationId, - role: data.role.present ? data.role.value : this.role, - providerId: data.providerId.present - ? data.providerId.value - : this.providerId, - anchorMessageId: data.anchorMessageId.present - ? data.anchorMessageId.value - : this.anchorMessageId, - anchorCreatedAt: data.anchorCreatedAt.present - ? data.anchorCreatedAt.value - : this.anchorCreatedAt, - content: data.content.present ? data.content.value : this.content, - status: data.status.present ? data.status.value : this.status, - model: data.model.present ? data.model.value : this.model, - errorText: data.errorText.present ? data.errorText.value : this.errorText, - metadata: data.metadata.present ? data.metadata.value : this.metadata, - createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, - updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, - ); - } - - @override - String toString() { - return (StringBuffer('AiChatMessage(') - ..write('id: $id, ') - ..write('conversationId: $conversationId, ') - ..write('role: $role, ') - ..write('providerId: $providerId, ') - ..write('anchorMessageId: $anchorMessageId, ') - ..write('anchorCreatedAt: $anchorCreatedAt, ') - ..write('content: $content, ') - ..write('status: $status, ') - ..write('model: $model, ') - ..write('errorText: $errorText, ') - ..write('metadata: $metadata, ') - ..write('createdAt: $createdAt, ') - ..write('updatedAt: $updatedAt') - ..write(')')) - .toString(); - } - - @override - int get hashCode => Object.hash( - id, - conversationId, - role, - providerId, - anchorMessageId, - anchorCreatedAt, - content, - status, - model, - errorText, - metadata, - createdAt, - updatedAt, - ); - @override - bool operator ==(Object other) => - identical(this, other) || - (other is AiChatMessage && - other.id == this.id && - other.conversationId == this.conversationId && - other.role == this.role && - other.providerId == this.providerId && - other.anchorMessageId == this.anchorMessageId && - other.anchorCreatedAt == this.anchorCreatedAt && - other.content == this.content && - other.status == this.status && - other.model == this.model && - other.errorText == this.errorText && - other.metadata == this.metadata && - other.createdAt == this.createdAt && - other.updatedAt == this.updatedAt); -} - -class AiChatMessagesCompanion extends UpdateCompanion { - final Value id; - final Value conversationId; - final Value role; - final Value providerId; - final Value anchorMessageId; - final Value anchorCreatedAt; - final Value content; - final Value status; - final Value model; - final Value errorText; - final Value metadata; - final Value createdAt; - final Value updatedAt; - final Value rowid; - const AiChatMessagesCompanion({ - this.id = const Value.absent(), - this.conversationId = const Value.absent(), - this.role = const Value.absent(), - this.providerId = const Value.absent(), - this.anchorMessageId = const Value.absent(), - this.anchorCreatedAt = const Value.absent(), - this.content = const Value.absent(), - this.status = const Value.absent(), - this.model = const Value.absent(), - this.errorText = const Value.absent(), - this.metadata = const Value.absent(), - this.createdAt = const Value.absent(), - this.updatedAt = const Value.absent(), - this.rowid = const Value.absent(), - }); - AiChatMessagesCompanion.insert({ - required String id, - required String conversationId, - required String role, - required String providerId, - this.anchorMessageId = const Value.absent(), - this.anchorCreatedAt = const Value.absent(), - required String content, - required String status, - this.model = const Value.absent(), - this.errorText = const Value.absent(), - this.metadata = const Value.absent(), - required DateTime createdAt, - required DateTime updatedAt, - this.rowid = const Value.absent(), - }) : id = Value(id), - conversationId = Value(conversationId), - role = Value(role), - providerId = Value(providerId), - content = Value(content), - status = Value(status), - createdAt = Value(createdAt), - updatedAt = Value(updatedAt); - static Insertable custom({ - Expression? id, - Expression? conversationId, - Expression? role, - Expression? providerId, - Expression? anchorMessageId, - Expression? anchorCreatedAt, - Expression? content, - Expression? status, - Expression? model, - Expression? errorText, - Expression? metadata, - Expression? createdAt, - Expression? updatedAt, - Expression? rowid, - }) { - return RawValuesInsertable({ - if (id != null) 'id': id, - if (conversationId != null) 'conversation_id': conversationId, - if (role != null) 'role': role, - if (providerId != null) 'provider_id': providerId, - if (anchorMessageId != null) 'anchor_message_id': anchorMessageId, - if (anchorCreatedAt != null) 'anchor_created_at': anchorCreatedAt, - if (content != null) 'content': content, - if (status != null) 'status': status, - if (model != null) 'model': model, - if (errorText != null) 'error_text': errorText, - if (metadata != null) 'metadata': metadata, - if (createdAt != null) 'created_at': createdAt, - if (updatedAt != null) 'updated_at': updatedAt, - if (rowid != null) 'rowid': rowid, - }); - } - - AiChatMessagesCompanion copyWith({ - Value? id, - Value? conversationId, - Value? role, - Value? providerId, - Value? anchorMessageId, - Value? anchorCreatedAt, - Value? content, - Value? status, - Value? model, - Value? errorText, - Value? metadata, - Value? createdAt, - Value? updatedAt, - Value? rowid, - }) { - return AiChatMessagesCompanion( - id: id ?? this.id, - conversationId: conversationId ?? this.conversationId, - role: role ?? this.role, - providerId: providerId ?? this.providerId, - anchorMessageId: anchorMessageId ?? this.anchorMessageId, - anchorCreatedAt: anchorCreatedAt ?? this.anchorCreatedAt, - content: content ?? this.content, - status: status ?? this.status, - model: model ?? this.model, - errorText: errorText ?? this.errorText, - metadata: metadata ?? this.metadata, - createdAt: createdAt ?? this.createdAt, - updatedAt: updatedAt ?? this.updatedAt, - rowid: rowid ?? this.rowid, - ); - } - - @override - Map toColumns(bool nullToAbsent) { - final map = {}; - if (id.present) { - map['id'] = Variable(id.value); - } - if (conversationId.present) { - map['conversation_id'] = Variable(conversationId.value); - } - if (role.present) { - map['role'] = Variable(role.value); - } - if (providerId.present) { - map['provider_id'] = Variable(providerId.value); - } - if (anchorMessageId.present) { - map['anchor_message_id'] = Variable(anchorMessageId.value); - } - if (anchorCreatedAt.present) { - map['anchor_created_at'] = Variable( - AiChatMessages.$converteranchorCreatedAtn.toSql(anchorCreatedAt.value), - ); - } - if (content.present) { - map['content'] = Variable(content.value); - } - if (status.present) { - map['status'] = Variable(status.value); - } - if (model.present) { - map['model'] = Variable(model.value); - } - if (errorText.present) { - map['error_text'] = Variable(errorText.value); - } - if (metadata.present) { - map['metadata'] = Variable(metadata.value); - } - if (createdAt.present) { - map['created_at'] = Variable( - AiChatMessages.$convertercreatedAt.toSql(createdAt.value), - ); - } - if (updatedAt.present) { - map['updated_at'] = Variable( - AiChatMessages.$converterupdatedAt.toSql(updatedAt.value), - ); - } - if (rowid.present) { - map['rowid'] = Variable(rowid.value); - } - return map; - } - - @override - String toString() { - return (StringBuffer('AiChatMessagesCompanion(') - ..write('id: $id, ') - ..write('conversationId: $conversationId, ') - ..write('role: $role, ') - ..write('providerId: $providerId, ') - ..write('anchorMessageId: $anchorMessageId, ') - ..write('anchorCreatedAt: $anchorCreatedAt, ') - ..write('content: $content, ') - ..write('status: $status, ') - ..write('model: $model, ') - ..write('errorText: $errorText, ') - ..write('metadata: $metadata, ') - ..write('createdAt: $createdAt, ') - ..write('updatedAt: $updatedAt, ') - ..write('rowid: $rowid') - ..write(')')) - .toString(); - } -} - class InscriptionCollections extends Table with TableInfo { @override @@ -19599,7 +18821,6 @@ abstract class _$MixinDatabase extends GeneratedDatabase { late final Fiats fiats = Fiats(this); late final FavoriteApps favoriteApps = FavoriteApps(this); late final Properties properties = Properties(this); - late final AiChatMessages aiChatMessages = AiChatMessages(this); late final InscriptionCollections inscriptionCollections = InscriptionCollections(this); late final InscriptionItems inscriptionItems = InscriptionItems(this); @@ -19663,14 +18884,7 @@ abstract class _$MixinDatabase extends GeneratedDatabase { 'index_tokens_collection_hash', 'CREATE INDEX IF NOT EXISTS index_tokens_collection_hash ON tokens (collection_hash)', ); - late final Index indexAiChatMessagesConversationIdCreatedAt = Index( - 'index_ai_chat_messages_conversation_id_created_at', - 'CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages (conversation_id, created_at DESC)', - ); late final AddressDao addressDao = AddressDao(this as MixinDatabase); - late final AiChatMessageDao aiChatMessageDao = AiChatMessageDao( - this as MixinDatabase, - ); late final AppDao appDao = AppDao(this as MixinDatabase); late final AssetDao assetDao = AssetDao(this as MixinDatabase); late final CircleConversationDao circleConversationDao = @@ -20201,7 +19415,6 @@ abstract class _$MixinDatabase extends GeneratedDatabase { fiats, favoriteApps, properties, - aiChatMessages, inscriptionCollections, inscriptionItems, indexConversationsCategoryStatus, @@ -20219,7 +19432,6 @@ abstract class _$MixinDatabase extends GeneratedDatabase { indexMessagesConversationIdQuoteMessageId, indexTokensKernelAssetId, indexTokensCollectionHash, - indexAiChatMessagesConversationIdCreatedAt, ]; @override StreamQueryUpdateRules get streamUpdateRules => const StreamQueryUpdateRules([ @@ -28801,370 +28013,6 @@ typedef $PropertiesProcessedTableManager = Property, PrefetchHooks Function() >; -typedef $AiChatMessagesCreateCompanionBuilder = - AiChatMessagesCompanion Function({ - required String id, - required String conversationId, - required String role, - required String providerId, - Value anchorMessageId, - Value anchorCreatedAt, - required String content, - required String status, - Value model, - Value errorText, - Value metadata, - required DateTime createdAt, - required DateTime updatedAt, - Value rowid, - }); -typedef $AiChatMessagesUpdateCompanionBuilder = - AiChatMessagesCompanion Function({ - Value id, - Value conversationId, - Value role, - Value providerId, - Value anchorMessageId, - Value anchorCreatedAt, - Value content, - Value status, - Value model, - Value errorText, - Value metadata, - Value createdAt, - Value updatedAt, - Value rowid, - }); - -class $AiChatMessagesFilterComposer - extends Composer<_$MixinDatabase, AiChatMessages> { - $AiChatMessagesFilterComposer({ - required super.$db, - required super.$table, - super.joinBuilder, - super.$addJoinBuilderToRootComposer, - super.$removeJoinBuilderFromRootComposer, - }); - ColumnFilters get id => $composableBuilder( - column: $table.id, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get conversationId => $composableBuilder( - column: $table.conversationId, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get role => $composableBuilder( - column: $table.role, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get providerId => $composableBuilder( - column: $table.providerId, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get anchorMessageId => $composableBuilder( - column: $table.anchorMessageId, - builder: (column) => ColumnFilters(column), - ); - - ColumnWithTypeConverterFilters - get anchorCreatedAt => $composableBuilder( - column: $table.anchorCreatedAt, - builder: (column) => ColumnWithTypeConverterFilters(column), - ); - - ColumnFilters get content => $composableBuilder( - column: $table.content, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get status => $composableBuilder( - column: $table.status, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get model => $composableBuilder( - column: $table.model, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get errorText => $composableBuilder( - column: $table.errorText, - builder: (column) => ColumnFilters(column), - ); - - ColumnFilters get metadata => $composableBuilder( - column: $table.metadata, - builder: (column) => ColumnFilters(column), - ); - - ColumnWithTypeConverterFilters get createdAt => - $composableBuilder( - column: $table.createdAt, - builder: (column) => ColumnWithTypeConverterFilters(column), - ); - - ColumnWithTypeConverterFilters get updatedAt => - $composableBuilder( - column: $table.updatedAt, - builder: (column) => ColumnWithTypeConverterFilters(column), - ); -} - -class $AiChatMessagesOrderingComposer - extends Composer<_$MixinDatabase, AiChatMessages> { - $AiChatMessagesOrderingComposer({ - required super.$db, - required super.$table, - super.joinBuilder, - super.$addJoinBuilderToRootComposer, - super.$removeJoinBuilderFromRootComposer, - }); - ColumnOrderings get id => $composableBuilder( - column: $table.id, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get conversationId => $composableBuilder( - column: $table.conversationId, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get role => $composableBuilder( - column: $table.role, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get providerId => $composableBuilder( - column: $table.providerId, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get anchorMessageId => $composableBuilder( - column: $table.anchorMessageId, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get anchorCreatedAt => $composableBuilder( - column: $table.anchorCreatedAt, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get content => $composableBuilder( - column: $table.content, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get status => $composableBuilder( - column: $table.status, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get model => $composableBuilder( - column: $table.model, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get errorText => $composableBuilder( - column: $table.errorText, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get metadata => $composableBuilder( - column: $table.metadata, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get createdAt => $composableBuilder( - column: $table.createdAt, - builder: (column) => ColumnOrderings(column), - ); - - ColumnOrderings get updatedAt => $composableBuilder( - column: $table.updatedAt, - builder: (column) => ColumnOrderings(column), - ); -} - -class $AiChatMessagesAnnotationComposer - extends Composer<_$MixinDatabase, AiChatMessages> { - $AiChatMessagesAnnotationComposer({ - required super.$db, - required super.$table, - super.joinBuilder, - super.$addJoinBuilderToRootComposer, - super.$removeJoinBuilderFromRootComposer, - }); - GeneratedColumn get id => - $composableBuilder(column: $table.id, builder: (column) => column); - - GeneratedColumn get conversationId => $composableBuilder( - column: $table.conversationId, - builder: (column) => column, - ); - - GeneratedColumn get role => - $composableBuilder(column: $table.role, builder: (column) => column); - - GeneratedColumn get providerId => $composableBuilder( - column: $table.providerId, - builder: (column) => column, - ); - - GeneratedColumn get anchorMessageId => $composableBuilder( - column: $table.anchorMessageId, - builder: (column) => column, - ); - - GeneratedColumnWithTypeConverter get anchorCreatedAt => - $composableBuilder( - column: $table.anchorCreatedAt, - builder: (column) => column, - ); - - GeneratedColumn get content => - $composableBuilder(column: $table.content, builder: (column) => column); - - GeneratedColumn get status => - $composableBuilder(column: $table.status, builder: (column) => column); - - GeneratedColumn get model => - $composableBuilder(column: $table.model, builder: (column) => column); - - GeneratedColumn get errorText => - $composableBuilder(column: $table.errorText, builder: (column) => column); - - GeneratedColumn get metadata => - $composableBuilder(column: $table.metadata, builder: (column) => column); - - GeneratedColumnWithTypeConverter get createdAt => - $composableBuilder(column: $table.createdAt, builder: (column) => column); - - GeneratedColumnWithTypeConverter get updatedAt => - $composableBuilder(column: $table.updatedAt, builder: (column) => column); -} - -class $AiChatMessagesTableManager - extends - RootTableManager< - _$MixinDatabase, - AiChatMessages, - AiChatMessage, - $AiChatMessagesFilterComposer, - $AiChatMessagesOrderingComposer, - $AiChatMessagesAnnotationComposer, - $AiChatMessagesCreateCompanionBuilder, - $AiChatMessagesUpdateCompanionBuilder, - ( - AiChatMessage, - BaseReferences<_$MixinDatabase, AiChatMessages, AiChatMessage>, - ), - AiChatMessage, - PrefetchHooks Function() - > { - $AiChatMessagesTableManager(_$MixinDatabase db, AiChatMessages table) - : super( - TableManagerState( - db: db, - table: table, - createFilteringComposer: () => - $AiChatMessagesFilterComposer($db: db, $table: table), - createOrderingComposer: () => - $AiChatMessagesOrderingComposer($db: db, $table: table), - createComputedFieldComposer: () => - $AiChatMessagesAnnotationComposer($db: db, $table: table), - updateCompanionCallback: - ({ - Value id = const Value.absent(), - Value conversationId = const Value.absent(), - Value role = const Value.absent(), - Value providerId = const Value.absent(), - Value anchorMessageId = const Value.absent(), - Value anchorCreatedAt = const Value.absent(), - Value content = const Value.absent(), - Value status = const Value.absent(), - Value model = const Value.absent(), - Value errorText = const Value.absent(), - Value metadata = const Value.absent(), - Value createdAt = const Value.absent(), - Value updatedAt = const Value.absent(), - Value rowid = const Value.absent(), - }) => AiChatMessagesCompanion( - id: id, - conversationId: conversationId, - role: role, - providerId: providerId, - anchorMessageId: anchorMessageId, - anchorCreatedAt: anchorCreatedAt, - content: content, - status: status, - model: model, - errorText: errorText, - metadata: metadata, - createdAt: createdAt, - updatedAt: updatedAt, - rowid: rowid, - ), - createCompanionCallback: - ({ - required String id, - required String conversationId, - required String role, - required String providerId, - Value anchorMessageId = const Value.absent(), - Value anchorCreatedAt = const Value.absent(), - required String content, - required String status, - Value model = const Value.absent(), - Value errorText = const Value.absent(), - Value metadata = const Value.absent(), - required DateTime createdAt, - required DateTime updatedAt, - Value rowid = const Value.absent(), - }) => AiChatMessagesCompanion.insert( - id: id, - conversationId: conversationId, - role: role, - providerId: providerId, - anchorMessageId: anchorMessageId, - anchorCreatedAt: anchorCreatedAt, - content: content, - status: status, - model: model, - errorText: errorText, - metadata: metadata, - createdAt: createdAt, - updatedAt: updatedAt, - rowid: rowid, - ), - withReferenceMapper: (p0) => p0 - .map((e) => (e.readTable(table), BaseReferences(db, table, e))) - .toList(), - prefetchHooksCallback: null, - ), - ); -} - -typedef $AiChatMessagesProcessedTableManager = - ProcessedTableManager< - _$MixinDatabase, - AiChatMessages, - AiChatMessage, - $AiChatMessagesFilterComposer, - $AiChatMessagesOrderingComposer, - $AiChatMessagesAnnotationComposer, - $AiChatMessagesCreateCompanionBuilder, - $AiChatMessagesUpdateCompanionBuilder, - ( - AiChatMessage, - BaseReferences<_$MixinDatabase, AiChatMessages, AiChatMessage>, - ), - AiChatMessage, - PrefetchHooks Function() - >; typedef $InscriptionCollectionsCreateCompanionBuilder = InscriptionCollectionsCompanion Function({ required String collectionHash, @@ -29783,8 +28631,6 @@ class $MixinDatabaseManager { $FavoriteAppsTableManager(_db, _db.favoriteApps); $PropertiesTableManager get properties => $PropertiesTableManager(_db, _db.properties); - $AiChatMessagesTableManager get aiChatMessages => - $AiChatMessagesTableManager(_db, _db.aiChatMessages); $InscriptionCollectionsTableManager get inscriptionCollections => $InscriptionCollectionsTableManager(_db, _db.inscriptionCollections); $InscriptionItemsTableManager get inscriptionItems => diff --git a/lib/db/moor/ai.drift b/lib/db/moor/ai.drift new file mode 100644 index 0000000000..16c7ef466f --- /dev/null +++ b/lib/db/moor/ai.drift @@ -0,0 +1,30 @@ +import '../converter/millis_date_converter.dart'; + +CREATE TABLE ai_chat_messages ( + id TEXT NOT NULL, + thread_id TEXT NOT NULL DEFAULT '', + conversation_id TEXT NOT NULL, + role TEXT NOT NULL, + provider_id TEXT NOT NULL, + content TEXT NOT NULL, + status TEXT NOT NULL, + model TEXT, + error_text TEXT, + metadata TEXT, + created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + PRIMARY KEY(id) +); + +CREATE TABLE ai_chat_threads ( + id TEXT NOT NULL, + conversation_id TEXT NOT NULL, + title TEXT, + created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + PRIMARY KEY(id) +); + +CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages(conversation_id, created_at DESC); +CREATE INDEX IF NOT EXISTS index_ai_chat_messages_thread_id_created_at ON ai_chat_messages(thread_id, created_at DESC); +CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads(conversation_id, updated_at DESC); diff --git a/lib/db/moor/mixin.drift b/lib/db/moor/mixin.drift index dbafc21d8b..6599aea1f2 100644 --- a/lib/db/moor/mixin.drift +++ b/lib/db/moor/mixin.drift @@ -73,23 +73,6 @@ CREATE TABLE chains (chain_id TEXT NOT NULL, name TEXT NOT NULL, symbol TEXT NOT CREATE TABLE properties ("key" TEXT NOT NULL, "group" TEXT NOT NULL MAPPED BY `const PropertyGroupConverter()`, "value" TEXT NOT NULL, PRIMARY KEY("key", "group")); -CREATE TABLE ai_chat_messages ( - id TEXT NOT NULL, - conversation_id TEXT NOT NULL, - role TEXT NOT NULL, - provider_id TEXT NOT NULL, - anchor_message_id TEXT, - anchor_created_at INTEGER MAPPED BY `const MillisDateConverter()`, - content TEXT NOT NULL, - status TEXT NOT NULL, - model TEXT, - error_text TEXT, - metadata TEXT, - created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, - updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, - PRIMARY KEY(id) -); - CREATE TABLE safe_snapshots ( snapshot_id TEXT NOT NULL, type TEXT NOT NULL, @@ -170,4 +153,3 @@ CREATE INDEX IF NOT EXISTS index_message_conversation_id_status_user_id ON messa CREATE INDEX IF NOT EXISTS index_messages_conversation_id_quote_message_id ON messages(conversation_id, quote_message_id); CREATE INDEX IF NOT EXISTS index_tokens_kernel_asset_id ON tokens(kernel_asset_id); CREATE INDEX IF NOT EXISTS index_tokens_collection_hash ON tokens(collection_hash); -CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages(conversation_id, created_at DESC); diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 07449513fc..4bb98a49e1 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -25,6 +25,7 @@ import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; import '../../../constants/icon_fonts.dart'; import '../../../constants/resources.dart'; +import '../../../db/ai_database.dart'; import '../../../db/database_event_bus.dart'; import '../../../db/mixin_database.dart' hide Offset; import '../../../enum/encrypt_category.dart'; @@ -120,14 +121,22 @@ class _InputContainer extends HookConsumerWidget { selectedModel: aiModeState.model, ); final aiModeEnabled = aiModeState.enabled; + final activeAiThread = useMemoizedStream( + () => conversationId == null + ? Stream.value(null) + : context.database.aiChatMessageDao.watchLatestThread( + conversationId, + ), + keys: [conversationId], + ).data; final aiMessages = useMemoizedStream( - () => conversationId == null + () => activeAiThread == null ? Stream.value(const []) - : context.database.aiChatMessageDao.watchConversationMessages( - conversationId, + : context.database.aiChatMessageDao.watchThreadMessages( + activeAiThread.id, ), - keys: [conversationId], + keys: [activeAiThread?.id], initialData: const [], ).data ?? const []; @@ -436,6 +445,7 @@ class _InputContainer extends HookConsumerWidget { textEditingValueStream: textEditingValueStream, aiModeEnabled: aiModeEnabled, aiRequestInFlight: aiRequestInFlight, + aiThreadId: activeAiThread?.id, ), ], ), @@ -454,6 +464,7 @@ class _InputContainer extends HookConsumerWidget { class _AnimatedSendOrVoiceButton extends HookConsumerWidget { const _AnimatedSendOrVoiceButton({ required this.conversationId, + required this.aiThreadId, required this.textEditingValueStream, required this.textEditingController, required this.aiModeEnabled, @@ -461,6 +472,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { }); final String? conversationId; + final String? aiThreadId; final Stream textEditingValueStream; final TextEditingController textEditingController; final bool aiModeEnabled; @@ -485,7 +497,9 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { onTap: () { final currentConversationId = conversationId; if (currentConversationId == null) return; - AiChatController(context.database).stop(currentConversationId); + AiChatController( + context.database, + ).stop(currentConversationId, threadId: aiThreadId); }, ); } @@ -506,6 +520,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { context, textEditingController, conversationId: conversationId, + aiThreadId: aiThreadId, ), ), ), @@ -688,6 +703,7 @@ Future _sendMessage( BuildContext context, TextEditingController textEditingController, { required String? conversationId, + String? aiThreadId, bool silent = false, }) async { final text = textEditingController.value.text.trim(); @@ -729,6 +745,7 @@ Future _sendMessage( try { await AiChatController(context.database).send( conversationId: conversationId, + threadId: aiThreadId, input: inlineAiInput, language: _currentLanguageTag(context), provider: provider, diff --git a/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart b/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart index 1c1fbcdec2..12ad2e34ee 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant/message_list.dart @@ -3,7 +3,7 @@ import 'dart:async'; import 'package:flutter/material.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; -import '../../../../db/mixin_database.dart'; +import '../../../../db/ai_database.dart'; import '../../../../utils/extension/extension.dart'; import '../../../../widgets/ai/ai_message_card.dart'; import '../../../../widgets/clamping_custom_scroll_view/clamping_custom_scroll_view.dart'; @@ -15,11 +15,13 @@ import 'constants.dart'; class AiAssistantMessageList extends HookWidget { const AiAssistantMessageList({ required this.conversationId, + required this.threadId, required this.latestMessages, super.key, }); final String conversationId; + final String? threadId; final List latestMessages; @override @@ -32,8 +34,8 @@ class AiAssistantMessageList extends HookWidget { [olderMessages.value, latestMessages], ); final centerKey = useMemoized( - () => ValueKey('ai-list-center-$conversationId'), - [conversationId], + () => ValueKey('ai-list-center-$conversationId-$threadId'), + [conversationId, threadId], ); final topKey = useMemoized( () => GlobalKey(debugLabel: 'ai list top'), @@ -71,14 +73,16 @@ class AiAssistantMessageList extends HookWidget { if (isLoadingOlder.value || isOldest.value || messages.isEmpty) { return; } + final currentThreadId = threadId; + if (currentThreadId == null) return; final before = messages.first; isLoadingOlder.value = true; try { final list = await context.database.aiChatMessageDao - .beforeConversationMessages( - conversationId: conversationId, + .beforeThreadMessages( + threadId: currentThreadId, before: before, limit: aiAssistantMessagePageLimit, ); @@ -103,7 +107,7 @@ class AiAssistantMessageList extends HookWidget { lastUserMessageIdRef.value = null; previousLatestMessagesRef.value = const []; return null; - }, [conversationId]); + }, [conversationId, threadId]); useEffect(() { final previousLatestMessages = previousLatestMessagesRef.value; @@ -197,7 +201,7 @@ class AiAssistantMessageList extends HookWidget { return false; }, child: MessageDayTimeViewportWidget.chatPage( - key: ValueKey(conversationId), + key: ValueKey('$conversationId-$threadId'), bottomKey: bottomKey, center: null, topKey: topKey, diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart index 7133a8dd2b..1ec25a172b 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant_page.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -5,7 +5,7 @@ import 'package:hooks_riverpod/hooks_riverpod.dart'; import '../../../ai/ai_chat_controller.dart'; import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; -import '../../../db/mixin_database.dart'; +import '../../../db/ai_database.dart'; import '../../../utils/extension/extension.dart'; import '../../../utils/hook.dart'; import '../../../widgets/app_bar.dart'; @@ -41,14 +41,21 @@ class AiAssistantPage extends HookConsumerWidget { providerId: aiModeState.providerId, selectedModel: aiModeState.model, ); + final activeThread = useMemoizedStream( + () => context.database.aiChatMessageDao.watchLatestThread( + conversationId, + ), + keys: [conversationId], + ).data; final latestMessages = useMemoizedStream( - () => - context.database.aiChatMessageDao.watchLatestConversationMessages( - conversationId, - aiAssistantMessagePageLimit, - ), - keys: [conversationId], + () => activeThread == null + ? Stream.value(const []) + : context.database.aiChatMessageDao.watchLatestThreadMessages( + activeThread.id, + aiAssistantMessagePageLimit, + ), + keys: [activeThread?.id], initialData: const [], ).data ?? const []; @@ -83,6 +90,7 @@ class AiAssistantPage extends HookConsumerWidget { try { await AiChatController(context.database).send( conversationId: conversationId, + threadId: activeThread?.id, input: text, language: currentLanguageTag(context), provider: aiProvider, @@ -101,6 +109,7 @@ class AiAssistantPage extends HookConsumerWidget { Expanded( child: AiAssistantMessageList( conversationId: conversationId, + threadId: activeThread?.id, latestMessages: latestMessages, ), ), @@ -112,8 +121,9 @@ class AiAssistantPage extends HookConsumerWidget { enabledAiProviders: enabledAiProviders, requestInFlight: requestInFlight, onSend: send, - onStop: () => - AiChatController(context.database).stop(conversationId), + onStop: () => AiChatController( + context.database, + ).stop(conversationId, threadId: activeThread?.id), onProviderSelected: (value) => aiModeNotifier.updateProvider( providerId: value.id, model: value.model, diff --git a/lib/ui/provider/database_provider.dart b/lib/ui/provider/database_provider.dart index 09fd1b2346..f1d0cbb767 100644 --- a/lib/ui/provider/database_provider.dart +++ b/lib/ui/provider/database_provider.dart @@ -1,6 +1,7 @@ import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:mixin_logger/mixin_logger.dart'; +import '../../db/ai_database.dart'; import '../../db/database.dart'; import '../../db/fts_database.dart'; import '../../db/mixin_database.dart'; @@ -52,6 +53,7 @@ class DatabaseOpener extends DistinctStateNotifier> { final db = Database( mixinDatabase, await FtsDatabase.connect(identityNumber, fromMainIsolate: true), + await AiDatabase.connect(identityNumber, fromMainIsolate: true), ); // Do a database query, to ensure database has properly initialized. await mixinDatabase.doInitVerify(); diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 75feab79ed..f933528dd6 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -6,7 +6,7 @@ import 'package:intl/intl.dart'; import 'package:super_context_menu/super_context_menu.dart'; import '../../ai/model/ai_chat_metadata.dart'; -import '../../db/mixin_database.dart' hide Offset; +import '../../db/ai_database.dart'; import '../../utils/datetime_format_utils.dart'; import '../../utils/extension/extension.dart'; import '../../utils/platform.dart'; @@ -123,6 +123,7 @@ class _AiResponseMessageCard extends StatelessWidget { const SizedBox(height: 4), _AiResponseFooter( model: message.model, + metadata: message.metadata, dateTime: message.createdAt, ), ], @@ -398,10 +399,12 @@ SelectedContent? _findSelectedContent(BuildContext context) { class _AiResponseFooter extends StatelessWidget { const _AiResponseFooter({ required this.model, + required this.metadata, required this.dateTime, }); final String? model; + final String? metadata; final DateTime dateTime; @override @@ -416,18 +419,25 @@ class _AiResponseFooter extends StatelessWidget { ); final dateTimeText = DateFormat.Hm().format(dateTime.toLocal()); final trimmedModel = model?.trim(); + final responseMeta = aiMetadataResponse(metadata); + final elapsedMs = (responseMeta['elapsedMs'] as num?)?.round(); + final totalTokens = _totalTokens(responseMeta); return SelectionContainer.disabled( child: SizedBox( width: double.infinity, - child: Row( + child: Wrap( + spacing: 12, + runSpacing: 2, children: [ const SizedBox(width: 4), Text(dateTimeText, style: textStyle), - if (trimmedModel != null && trimmedModel.isNotEmpty) ...[ - const SizedBox(width: 12), + if (trimmedModel != null && trimmedModel.isNotEmpty) Text(trimmedModel, style: textStyle), - ], + if (elapsedMs != null && elapsedMs > 0) + Text(_formatElapsed(elapsedMs), style: textStyle), + if (totalTokens != null && totalTokens > 0) + Text(_formatTokens(totalTokens), style: textStyle), ], ), ), @@ -435,6 +445,33 @@ class _AiResponseFooter extends StatelessWidget { } } +num? _totalTokens(Map responseMeta) => + _usageValue(responseMeta, 'totalTokens') ?? + ((_usageValue(responseMeta, 'inputTokens') ?? 0) + + (_usageValue(responseMeta, 'outputTokens') ?? 0)); + +num? _usageValue(Map responseMeta, String key) { + final usage = responseMeta['usage']; + if (usage is Map) { + return usage[key] as num?; + } + if (usage is Map) { + return usage[key] as num?; + } + return null; +} + +String _formatElapsed(int elapsedMs) { + if (elapsedMs < 1000) { + return '${elapsedMs}ms'; + } + final seconds = elapsedMs / Duration.millisecondsPerSecond; + return '${seconds.toStringAsFixed(seconds >= 10 ? 0 : 1)}s'; +} + +String _formatTokens(num tokens) => + '${NumberFormat.decimalPattern().format(tokens.round())} tokens'; + String _displayText(AiChatMessage message) { final content = message.content.trim(); if (content.isNotEmpty) return content; diff --git a/lib/workers/device_transfer.dart b/lib/workers/device_transfer.dart index 9ea27ed40c..644c9e8962 100644 --- a/lib/workers/device_transfer.dart +++ b/lib/workers/device_transfer.dart @@ -18,6 +18,7 @@ import '../blaze/blaze_message.dart'; import '../blaze/vo/plain_json_message.dart'; import '../constants/constants.dart'; import '../crypto/uuid/uuid.dart'; +import '../db/ai_database.dart'; import '../db/database.dart'; import '../db/fts_database.dart'; import '../db/mixin_database.dart'; @@ -160,6 +161,7 @@ Future _deviceTransferIsolateEntryPoint( final database = Database( await connectToDatabase(params.identityNumber, readCount: 1), await FtsDatabase.connect(params.identityNumber), + await AiDatabase.connect(params.identityNumber), ); final deviceTransfer = await DeviceTransfer.create( database: database, diff --git a/lib/workers/message_worker_isolate.dart b/lib/workers/message_worker_isolate.dart index 04dfebcd2e..e15b183de2 100644 --- a/lib/workers/message_worker_isolate.dart +++ b/lib/workers/message_worker_isolate.dart @@ -16,6 +16,7 @@ import 'package:stream_channel/isolate_channel.dart'; import '../blaze/blaze.dart'; import '../crypto/signal/signal_protocol.dart'; +import '../db/ai_database.dart'; import '../db/database.dart'; import '../db/database_event_bus.dart'; import '../db/fts_database.dart'; @@ -149,6 +150,7 @@ class _MessageProcessRunner { database = Database( await connectToDatabase(identityNumber, readCount: 4), await FtsDatabase.connect(identityNumber), + await AiDatabase.connect(identityNumber), ); client = createClient( diff --git a/test/ai/ai_chat_metadata_test.dart b/test/ai/ai_chat_metadata_test.dart new file mode 100644 index 0000000000..da6739c685 --- /dev/null +++ b/test/ai/ai_chat_metadata_test.dart @@ -0,0 +1,62 @@ +import 'dart:convert'; + +import 'package:flutter_app/ai/model/ai_chat_metadata.dart'; +import 'package:flutter_app/ai/model/ai_provider_config.dart'; +import 'package:flutter_app/ai/model/ai_provider_type.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('AI chat metadata', () { + test('keeps provider and tool events when response metadata is set', () { + final initialMetadata = createAiMessageMetadata( + AiProviderConfig( + id: 'provider-id', + name: 'Provider', + type: AiProviderType.openaiCompatible, + baseUrl: 'https://api.example.com/v1', + apiKey: 'key', + model: 'test-model', + ), + ); + final withToolEvent = appendAiToolEventToMetadata( + initialMetadata, + createAiToolCallEvent( + id: 'tool-id', + name: 'read_conversation_chunk', + arguments: const {'limit': 20}, + ), + ); + + final updated = setAiResponseMetadata( + withToolEvent, + createAiResponseMetadata( + elapsedMs: 1234, + promptMessageCount: 7, + toolCount: 4, + outputCharacters: 42, + response: const { + 'finishReason': 'stop', + 'usage': { + 'inputTokens': 100, + 'outputTokens': 24, + 'totalTokens': 124, + }, + }, + ), + ); + + final decoded = jsonDecode(updated) as Map; + expect(decoded['provider'], isA>()); + expect(aiMetadataToolEvents(updated), hasLength(1)); + expect(aiMetadataResponse(updated), containsPair('elapsedMs', 1234)); + expect( + aiMetadataResponse(updated), + containsPair('promptMessageCount', 7), + ); + expect( + aiMetadataResponse(updated)['usage'], + containsPair('totalTokens', 124), + ); + }); + }); +} diff --git a/test/ai/ai_chat_thread_test.dart b/test/ai/ai_chat_thread_test.dart new file mode 100644 index 0000000000..13566ede73 --- /dev/null +++ b/test/ai/ai_chat_thread_test.dart @@ -0,0 +1,143 @@ +import 'package:drift/drift.dart'; +import 'package:drift/native.dart'; +import 'package:flutter_app/ai/ai_chat_prompt_builder.dart'; +import 'package:flutter_app/ai/model/ai_prompt_message.dart'; +import 'package:flutter_app/db/ai_database.dart'; +import 'package:flutter_app/db/database.dart'; +import 'package:flutter_app/db/fts_database.dart'; +import 'package:flutter_app/db/mixin_database.dart'; +import 'package:flutter_test/flutter_test.dart'; + +void main() { + group('AI chat threads', () { + late MixinDatabase mixinDatabase; + late FtsDatabase ftsDatabase; + late AiDatabase aiDatabase; + late Database database; + + setUp(() { + mixinDatabase = MixinDatabase(NativeDatabase.memory()); + ftsDatabase = FtsDatabase(NativeDatabase.memory()); + aiDatabase = AiDatabase(NativeDatabase.memory()); + database = Database(mixinDatabase, ftsDatabase, aiDatabase); + }); + + tearDown(() => database.dispose()); + + test('scopes messages and pending state by thread', () async { + const conversationId = 'conversation-id'; + final firstThread = await database.aiChatMessageDao.createThread( + conversationId, + ); + final secondThread = await database.aiChatMessageDao.createThread( + conversationId, + ); + final now = DateTime.now(); + + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: 'first-thread-message', + threadId: Value(firstThread.id), + conversationId: conversationId, + role: 'assistant', + providerId: 'provider-id', + content: 'pending in first thread', + status: 'pending', + createdAt: now, + updatedAt: now, + ), + ); + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: 'second-thread-message', + threadId: Value(secondThread.id), + conversationId: conversationId, + role: 'user', + providerId: 'provider-id', + content: 'done in second thread', + status: 'done', + createdAt: now.add(const Duration(milliseconds: 1)), + updatedAt: now.add(const Duration(milliseconds: 1)), + ), + ); + + final firstMessages = await database.aiChatMessageDao.threadMessages( + firstThread.id, + ); + final secondMessages = await database.aiChatMessageDao.threadMessages( + secondThread.id, + ); + + expect(firstMessages.map((item) => item.id), ['first-thread-message']); + expect(secondMessages.map((item) => item.id), ['second-thread-message']); + expect( + await database.aiChatMessageDao.hasPendingAssistantMessage( + firstThread.id, + ), + isTrue, + ); + expect( + await database.aiChatMessageDao.hasPendingAssistantMessage( + secondThread.id, + ), + isFalse, + ); + }); + + test('prompt history excludes the current user message', () async { + const conversationId = 'conversation-id'; + final thread = await database.aiChatMessageDao.createThread( + conversationId, + ); + final now = DateTime.now(); + + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: 'previous-message', + threadId: Value(thread.id), + conversationId: conversationId, + role: 'assistant', + providerId: 'provider-id', + content: 'previous answer', + status: 'done', + createdAt: now, + updatedAt: now, + ), + ); + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: 'current-message', + threadId: Value(thread.id), + conversationId: conversationId, + role: 'user', + providerId: 'provider-id', + content: 'current question', + status: 'done', + createdAt: now.add(const Duration(milliseconds: 1)), + updatedAt: now.add(const Duration(milliseconds: 1)), + ), + ); + + final messages = await AiChatPromptBuilder(database).buildPromptMessages( + conversationId, + thread.id, + 'current question', + 'en', + currentMessageId: 'current-message', + ); + + expect( + messages.where( + (item) => + item.role.value == AiPromptRole.user.value && + item.content.contains('current question'), + ), + hasLength(1), + ); + expect( + messages.where((item) => item.content == 'previous answer'), + hasLength(1), + ); + }); + }); +} diff --git a/test/utils/device_transfer_test.dart b/test/utils/device_transfer_test.dart index 885bbd155a..d9244ea371 100644 --- a/test/utils/device_transfer_test.dart +++ b/test/utils/device_transfer_test.dart @@ -8,6 +8,7 @@ import 'dart:io'; import 'package:ansicolor/ansicolor.dart'; import 'package:drift/drift.dart'; import 'package:drift/native.dart'; +import 'package:flutter_app/db/ai_database.dart'; import 'package:flutter_app/db/database.dart'; import 'package:flutter_app/db/fts_database.dart'; import 'package:flutter_app/db/mixin_database.dart'; @@ -169,6 +170,7 @@ void main() { receiverDatabase = Database( MixinDatabase(NativeDatabase.memory()), FtsDatabase(NativeDatabase.memory()), + AiDatabase(NativeDatabase.memory()), ); final userId = const Uuid().v4(); @@ -205,6 +207,7 @@ void main() { senderDatabase = Database( MixinDatabase(NativeDatabase.memory()), FtsDatabase(NativeDatabase.memory()), + AiDatabase(NativeDatabase.memory()), )..addTestData(userId); final senderDeviceId = const Uuid().v4(); From 16727134535927f44ffbf2f99a17a5d77a879623 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 21:12:30 +0800 Subject: [PATCH 32/52] feat: add multi-thread support for AI chat with metadata management --- AGENTS.md | 4 + lib/ai/ai_chat_controller.dart | 5 +- lib/db/ai_database.dart | 79 ++- lib/db/ai_database.g.dart | 592 +++++++++++++++++- lib/db/dao/ai_chat_message_dao.dart | 167 ++++- lib/db/moor/ai.drift | 11 +- lib/ui/home/chat/chat_page.dart | 7 + .../ai_assistant/constants.dart | 6 + .../chat_slide_page/ai_assistant_page.dart | 289 ++++++++- lib/widgets/app_bar.dart | 3 + test/ai/ai_chat_thread_test.dart | 61 +- 11 files changed, 1176 insertions(+), 48 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 009702fd4d..591dcdee82 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -141,6 +141,10 @@ SENTRY_DSN=... - Keep generated, third-party, and platform registrant files untouched unless the task explicitly requires them. - Keep changes scoped to the requested behavior; do not refactor unrelated areas. - Reuse existing UI components from `lib/widgets` and patterns from nearby screens. +- For context menus on list items, messages, and other right-click surfaces, use + the native `super_context_menu` flow (`CustomContextMenuWidget`, + `MenuAction`, `MenusWithSeparator`) used elsewhere in the app instead of + custom popup/menu widgets. - Reuse existing DB access through DAOs and providers instead of bypassing with ad hoc SQL unless Drift APIs cannot express the query. - For user-facing text, use `Localization` and ARB files rather than hard-coded strings. - For async work, preserve current error propagation style and do not swallow exceptions without a concrete recovery path. diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 77134fccd2..7f44aae5f5 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -94,7 +94,7 @@ class AiChatController { } } - Future send({ + Future send({ required String conversationId, required String input, required String language, @@ -234,6 +234,7 @@ class AiChatController { 'threadId=${thread.id} ' 'assistantMessageId=$assistantMessageId output=${_previewText(result)}', ); + return thread.id; } catch (error, stacktrace) { if (cancelToken.isCancelled) { d( @@ -248,7 +249,7 @@ class AiChatController { updatedAt: DateTime.now(), errorText: 'Stopped', ); - return; + return thread.id; } e('AI chat error: $error, $stacktrace'); await database.aiChatMessageDao.updateMessageStatus( diff --git a/lib/db/ai_database.dart b/lib/db/ai_database.dart index d08b0dc93a..22f6ac3b08 100644 --- a/lib/db/ai_database.dart +++ b/lib/db/ai_database.dart @@ -27,5 +27,82 @@ class AiDatabase extends _$AiDatabase { } @override - int get schemaVersion => 1; + int get schemaVersion => 2; + + @override + MigrationStrategy get migration => MigrationStrategy( + onUpgrade: (m, from, to) async { + if (from <= 1) { + await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.summary); + await _addColumnIfNotExists( + m, + aiChatThreads, + aiChatThreads.lastMessagePreview, + ); + await _addColumnIfNotExists( + m, + aiChatThreads, + aiChatThreads.messageCount, + ); + await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.status); + await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.pinnedAt); + await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.archivedAt); + await _addColumnIfNotExists( + m, + aiChatThreads, + aiChatThreads.lastMessageAt, + ); + await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.metadata); + await _backfillThreadStats(); + await customStatement( + 'DROP INDEX IF EXISTS index_ai_chat_threads_conversation_id_updated_at', + ); + await m.createIndex(indexAiChatThreadsConversationIdUpdatedAt); + await m.createIndex(indexAiChatThreadsConversationIdLastMessageAt); + } + }, + ); + + Future _addColumnIfNotExists( + Migrator m, + TableInfo table, + GeneratedColumn column, + ) async { + if (!await _checkColumnExists(table.actualTableName, column.name)) { + await m.addColumn(table, column); + } + } + + Future _checkColumnExists(String tableName, String columnName) async { + final queryRow = await customSelect( + "SELECT COUNT(*) AS CNTREC FROM pragma_table_info('$tableName') WHERE name='$columnName'", + ).getSingle(); + return queryRow.read('CNTREC'); + } + + Future _backfillThreadStats() async { + await customStatement(''' +UPDATE ai_chat_threads +SET + message_count = ( + SELECT COUNT(*) + FROM ai_chat_messages + WHERE ai_chat_messages.thread_id = ai_chat_threads.id + ), + last_message_at = ( + SELECT created_at + FROM ai_chat_messages + WHERE ai_chat_messages.thread_id = ai_chat_threads.id + ORDER BY created_at DESC, id DESC + LIMIT 1 + ), + last_message_preview = ( + SELECT substr(trim(replace(replace(content, char(10), ' '), char(13), ' ')), 1, 160) + FROM ai_chat_messages + WHERE ai_chat_messages.thread_id = ai_chat_threads.id + ORDER BY created_at DESC, id DESC + LIMIT 1 + ) +'''); + } } diff --git a/lib/db/ai_database.g.dart b/lib/db/ai_database.g.dart index b1f0ec4a53..03c6a64c7b 100644 --- a/lib/db/ai_database.g.dart +++ b/lib/db/ai_database.g.dart @@ -748,6 +748,88 @@ class AiChatThreads extends Table with TableInfo { requiredDuringInsert: false, $customConstraints: '', ); + static const VerificationMeta _summaryMeta = const VerificationMeta( + 'summary', + ); + late final GeneratedColumn summary = GeneratedColumn( + 'summary', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + static const VerificationMeta _lastMessagePreviewMeta = + const VerificationMeta('lastMessagePreview'); + late final GeneratedColumn lastMessagePreview = + GeneratedColumn( + 'last_message_preview', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + static const VerificationMeta _messageCountMeta = const VerificationMeta( + 'messageCount', + ); + late final GeneratedColumn messageCount = GeneratedColumn( + 'message_count', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: false, + $customConstraints: 'NOT NULL DEFAULT 0', + defaultValue: const CustomExpression('0'), + ); + static const VerificationMeta _statusMeta = const VerificationMeta('status'); + late final GeneratedColumn status = GeneratedColumn( + 'status', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: 'NOT NULL DEFAULT \'active\'', + defaultValue: const CustomExpression('\'active\''), + ); + late final GeneratedColumnWithTypeConverter pinnedAt = + GeneratedColumn( + 'pinned_at', + aliasedName, + true, + type: DriftSqlType.int, + requiredDuringInsert: false, + $customConstraints: '', + ).withConverter(AiChatThreads.$converterpinnedAtn); + late final GeneratedColumnWithTypeConverter archivedAt = + GeneratedColumn( + 'archived_at', + aliasedName, + true, + type: DriftSqlType.int, + requiredDuringInsert: false, + $customConstraints: '', + ).withConverter(AiChatThreads.$converterarchivedAtn); + late final GeneratedColumnWithTypeConverter lastMessageAt = + GeneratedColumn( + 'last_message_at', + aliasedName, + true, + type: DriftSqlType.int, + requiredDuringInsert: false, + $customConstraints: '', + ).withConverter(AiChatThreads.$converterlastMessageAtn); + static const VerificationMeta _metadataMeta = const VerificationMeta( + 'metadata', + ); + late final GeneratedColumn metadata = GeneratedColumn( + 'metadata', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); late final GeneratedColumnWithTypeConverter createdAt = GeneratedColumn( 'created_at', @@ -771,6 +853,14 @@ class AiChatThreads extends Table with TableInfo { id, conversationId, title, + summary, + lastMessagePreview, + messageCount, + status, + pinnedAt, + archivedAt, + lastMessageAt, + metadata, createdAt, updatedAt, ]; @@ -808,6 +898,42 @@ class AiChatThreads extends Table with TableInfo { title.isAcceptableOrUnknown(data['title']!, _titleMeta), ); } + if (data.containsKey('summary')) { + context.handle( + _summaryMeta, + summary.isAcceptableOrUnknown(data['summary']!, _summaryMeta), + ); + } + if (data.containsKey('last_message_preview')) { + context.handle( + _lastMessagePreviewMeta, + lastMessagePreview.isAcceptableOrUnknown( + data['last_message_preview']!, + _lastMessagePreviewMeta, + ), + ); + } + if (data.containsKey('message_count')) { + context.handle( + _messageCountMeta, + messageCount.isAcceptableOrUnknown( + data['message_count']!, + _messageCountMeta, + ), + ); + } + if (data.containsKey('status')) { + context.handle( + _statusMeta, + status.isAcceptableOrUnknown(data['status']!, _statusMeta), + ); + } + if (data.containsKey('metadata')) { + context.handle( + _metadataMeta, + metadata.isAcceptableOrUnknown(data['metadata']!, _metadataMeta), + ); + } return context; } @@ -829,6 +955,44 @@ class AiChatThreads extends Table with TableInfo { DriftSqlType.string, data['${effectivePrefix}title'], ), + summary: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}summary'], + ), + lastMessagePreview: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}last_message_preview'], + ), + messageCount: attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}message_count'], + )!, + status: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}status'], + )!, + pinnedAt: AiChatThreads.$converterpinnedAtn.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}pinned_at'], + ), + ), + archivedAt: AiChatThreads.$converterarchivedAtn.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}archived_at'], + ), + ), + lastMessageAt: AiChatThreads.$converterlastMessageAtn.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}last_message_at'], + ), + ), + metadata: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}metadata'], + ), createdAt: AiChatThreads.$convertercreatedAt.fromSql( attachedDatabase.typeMapping.read( DriftSqlType.int, @@ -849,6 +1013,18 @@ class AiChatThreads extends Table with TableInfo { return AiChatThreads(attachedDatabase, alias); } + static TypeConverter $converterpinnedAt = + const MillisDateConverter(); + static TypeConverter $converterpinnedAtn = + NullAwareTypeConverter.wrap($converterpinnedAt); + static TypeConverter $converterarchivedAt = + const MillisDateConverter(); + static TypeConverter $converterarchivedAtn = + NullAwareTypeConverter.wrap($converterarchivedAt); + static TypeConverter $converterlastMessageAt = + const MillisDateConverter(); + static TypeConverter $converterlastMessageAtn = + NullAwareTypeConverter.wrap($converterlastMessageAt); static TypeConverter $convertercreatedAt = const MillisDateConverter(); static TypeConverter $converterupdatedAt = @@ -863,12 +1039,28 @@ class AiChatThread extends DataClass implements Insertable { final String id; final String conversationId; final String? title; + final String? summary; + final String? lastMessagePreview; + final int messageCount; + final String status; + final DateTime? pinnedAt; + final DateTime? archivedAt; + final DateTime? lastMessageAt; + final String? metadata; final DateTime createdAt; final DateTime updatedAt; const AiChatThread({ required this.id, required this.conversationId, this.title, + this.summary, + this.lastMessagePreview, + required this.messageCount, + required this.status, + this.pinnedAt, + this.archivedAt, + this.lastMessageAt, + this.metadata, required this.createdAt, required this.updatedAt, }); @@ -880,6 +1072,32 @@ class AiChatThread extends DataClass implements Insertable { if (!nullToAbsent || title != null) { map['title'] = Variable(title); } + if (!nullToAbsent || summary != null) { + map['summary'] = Variable(summary); + } + if (!nullToAbsent || lastMessagePreview != null) { + map['last_message_preview'] = Variable(lastMessagePreview); + } + map['message_count'] = Variable(messageCount); + map['status'] = Variable(status); + if (!nullToAbsent || pinnedAt != null) { + map['pinned_at'] = Variable( + AiChatThreads.$converterpinnedAtn.toSql(pinnedAt), + ); + } + if (!nullToAbsent || archivedAt != null) { + map['archived_at'] = Variable( + AiChatThreads.$converterarchivedAtn.toSql(archivedAt), + ); + } + if (!nullToAbsent || lastMessageAt != null) { + map['last_message_at'] = Variable( + AiChatThreads.$converterlastMessageAtn.toSql(lastMessageAt), + ); + } + if (!nullToAbsent || metadata != null) { + map['metadata'] = Variable(metadata); + } { map['created_at'] = Variable( AiChatThreads.$convertercreatedAt.toSql(createdAt), @@ -900,6 +1118,26 @@ class AiChatThread extends DataClass implements Insertable { title: title == null && nullToAbsent ? const Value.absent() : Value(title), + summary: summary == null && nullToAbsent + ? const Value.absent() + : Value(summary), + lastMessagePreview: lastMessagePreview == null && nullToAbsent + ? const Value.absent() + : Value(lastMessagePreview), + messageCount: Value(messageCount), + status: Value(status), + pinnedAt: pinnedAt == null && nullToAbsent + ? const Value.absent() + : Value(pinnedAt), + archivedAt: archivedAt == null && nullToAbsent + ? const Value.absent() + : Value(archivedAt), + lastMessageAt: lastMessageAt == null && nullToAbsent + ? const Value.absent() + : Value(lastMessageAt), + metadata: metadata == null && nullToAbsent + ? const Value.absent() + : Value(metadata), createdAt: Value(createdAt), updatedAt: Value(updatedAt), ); @@ -914,6 +1152,16 @@ class AiChatThread extends DataClass implements Insertable { id: serializer.fromJson(json['id']), conversationId: serializer.fromJson(json['conversation_id']), title: serializer.fromJson(json['title']), + summary: serializer.fromJson(json['summary']), + lastMessagePreview: serializer.fromJson( + json['last_message_preview'], + ), + messageCount: serializer.fromJson(json['message_count']), + status: serializer.fromJson(json['status']), + pinnedAt: serializer.fromJson(json['pinned_at']), + archivedAt: serializer.fromJson(json['archived_at']), + lastMessageAt: serializer.fromJson(json['last_message_at']), + metadata: serializer.fromJson(json['metadata']), createdAt: serializer.fromJson(json['created_at']), updatedAt: serializer.fromJson(json['updated_at']), ); @@ -925,6 +1173,14 @@ class AiChatThread extends DataClass implements Insertable { 'id': serializer.toJson(id), 'conversation_id': serializer.toJson(conversationId), 'title': serializer.toJson(title), + 'summary': serializer.toJson(summary), + 'last_message_preview': serializer.toJson(lastMessagePreview), + 'message_count': serializer.toJson(messageCount), + 'status': serializer.toJson(status), + 'pinned_at': serializer.toJson(pinnedAt), + 'archived_at': serializer.toJson(archivedAt), + 'last_message_at': serializer.toJson(lastMessageAt), + 'metadata': serializer.toJson(metadata), 'created_at': serializer.toJson(createdAt), 'updated_at': serializer.toJson(updatedAt), }; @@ -934,12 +1190,32 @@ class AiChatThread extends DataClass implements Insertable { String? id, String? conversationId, Value title = const Value.absent(), + Value summary = const Value.absent(), + Value lastMessagePreview = const Value.absent(), + int? messageCount, + String? status, + Value pinnedAt = const Value.absent(), + Value archivedAt = const Value.absent(), + Value lastMessageAt = const Value.absent(), + Value metadata = const Value.absent(), DateTime? createdAt, DateTime? updatedAt, }) => AiChatThread( id: id ?? this.id, conversationId: conversationId ?? this.conversationId, title: title.present ? title.value : this.title, + summary: summary.present ? summary.value : this.summary, + lastMessagePreview: lastMessagePreview.present + ? lastMessagePreview.value + : this.lastMessagePreview, + messageCount: messageCount ?? this.messageCount, + status: status ?? this.status, + pinnedAt: pinnedAt.present ? pinnedAt.value : this.pinnedAt, + archivedAt: archivedAt.present ? archivedAt.value : this.archivedAt, + lastMessageAt: lastMessageAt.present + ? lastMessageAt.value + : this.lastMessageAt, + metadata: metadata.present ? metadata.value : this.metadata, createdAt: createdAt ?? this.createdAt, updatedAt: updatedAt ?? this.updatedAt, ); @@ -950,6 +1226,22 @@ class AiChatThread extends DataClass implements Insertable { ? data.conversationId.value : this.conversationId, title: data.title.present ? data.title.value : this.title, + summary: data.summary.present ? data.summary.value : this.summary, + lastMessagePreview: data.lastMessagePreview.present + ? data.lastMessagePreview.value + : this.lastMessagePreview, + messageCount: data.messageCount.present + ? data.messageCount.value + : this.messageCount, + status: data.status.present ? data.status.value : this.status, + pinnedAt: data.pinnedAt.present ? data.pinnedAt.value : this.pinnedAt, + archivedAt: data.archivedAt.present + ? data.archivedAt.value + : this.archivedAt, + lastMessageAt: data.lastMessageAt.present + ? data.lastMessageAt.value + : this.lastMessageAt, + metadata: data.metadata.present ? data.metadata.value : this.metadata, createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, ); @@ -961,6 +1253,14 @@ class AiChatThread extends DataClass implements Insertable { ..write('id: $id, ') ..write('conversationId: $conversationId, ') ..write('title: $title, ') + ..write('summary: $summary, ') + ..write('lastMessagePreview: $lastMessagePreview, ') + ..write('messageCount: $messageCount, ') + ..write('status: $status, ') + ..write('pinnedAt: $pinnedAt, ') + ..write('archivedAt: $archivedAt, ') + ..write('lastMessageAt: $lastMessageAt, ') + ..write('metadata: $metadata, ') ..write('createdAt: $createdAt, ') ..write('updatedAt: $updatedAt') ..write(')')) @@ -968,8 +1268,21 @@ class AiChatThread extends DataClass implements Insertable { } @override - int get hashCode => - Object.hash(id, conversationId, title, createdAt, updatedAt); + int get hashCode => Object.hash( + id, + conversationId, + title, + summary, + lastMessagePreview, + messageCount, + status, + pinnedAt, + archivedAt, + lastMessageAt, + metadata, + createdAt, + updatedAt, + ); @override bool operator ==(Object other) => identical(this, other) || @@ -977,6 +1290,14 @@ class AiChatThread extends DataClass implements Insertable { other.id == this.id && other.conversationId == this.conversationId && other.title == this.title && + other.summary == this.summary && + other.lastMessagePreview == this.lastMessagePreview && + other.messageCount == this.messageCount && + other.status == this.status && + other.pinnedAt == this.pinnedAt && + other.archivedAt == this.archivedAt && + other.lastMessageAt == this.lastMessageAt && + other.metadata == this.metadata && other.createdAt == this.createdAt && other.updatedAt == this.updatedAt); } @@ -985,6 +1306,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { final Value id; final Value conversationId; final Value title; + final Value summary; + final Value lastMessagePreview; + final Value messageCount; + final Value status; + final Value pinnedAt; + final Value archivedAt; + final Value lastMessageAt; + final Value metadata; final Value createdAt; final Value updatedAt; final Value rowid; @@ -992,6 +1321,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { this.id = const Value.absent(), this.conversationId = const Value.absent(), this.title = const Value.absent(), + this.summary = const Value.absent(), + this.lastMessagePreview = const Value.absent(), + this.messageCount = const Value.absent(), + this.status = const Value.absent(), + this.pinnedAt = const Value.absent(), + this.archivedAt = const Value.absent(), + this.lastMessageAt = const Value.absent(), + this.metadata = const Value.absent(), this.createdAt = const Value.absent(), this.updatedAt = const Value.absent(), this.rowid = const Value.absent(), @@ -1000,6 +1337,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { required String id, required String conversationId, this.title = const Value.absent(), + this.summary = const Value.absent(), + this.lastMessagePreview = const Value.absent(), + this.messageCount = const Value.absent(), + this.status = const Value.absent(), + this.pinnedAt = const Value.absent(), + this.archivedAt = const Value.absent(), + this.lastMessageAt = const Value.absent(), + this.metadata = const Value.absent(), required DateTime createdAt, required DateTime updatedAt, this.rowid = const Value.absent(), @@ -1011,6 +1356,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { Expression? id, Expression? conversationId, Expression? title, + Expression? summary, + Expression? lastMessagePreview, + Expression? messageCount, + Expression? status, + Expression? pinnedAt, + Expression? archivedAt, + Expression? lastMessageAt, + Expression? metadata, Expression? createdAt, Expression? updatedAt, Expression? rowid, @@ -1019,6 +1372,15 @@ class AiChatThreadsCompanion extends UpdateCompanion { if (id != null) 'id': id, if (conversationId != null) 'conversation_id': conversationId, if (title != null) 'title': title, + if (summary != null) 'summary': summary, + if (lastMessagePreview != null) + 'last_message_preview': lastMessagePreview, + if (messageCount != null) 'message_count': messageCount, + if (status != null) 'status': status, + if (pinnedAt != null) 'pinned_at': pinnedAt, + if (archivedAt != null) 'archived_at': archivedAt, + if (lastMessageAt != null) 'last_message_at': lastMessageAt, + if (metadata != null) 'metadata': metadata, if (createdAt != null) 'created_at': createdAt, if (updatedAt != null) 'updated_at': updatedAt, if (rowid != null) 'rowid': rowid, @@ -1029,6 +1391,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { Value? id, Value? conversationId, Value? title, + Value? summary, + Value? lastMessagePreview, + Value? messageCount, + Value? status, + Value? pinnedAt, + Value? archivedAt, + Value? lastMessageAt, + Value? metadata, Value? createdAt, Value? updatedAt, Value? rowid, @@ -1037,6 +1407,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { id: id ?? this.id, conversationId: conversationId ?? this.conversationId, title: title ?? this.title, + summary: summary ?? this.summary, + lastMessagePreview: lastMessagePreview ?? this.lastMessagePreview, + messageCount: messageCount ?? this.messageCount, + status: status ?? this.status, + pinnedAt: pinnedAt ?? this.pinnedAt, + archivedAt: archivedAt ?? this.archivedAt, + lastMessageAt: lastMessageAt ?? this.lastMessageAt, + metadata: metadata ?? this.metadata, createdAt: createdAt ?? this.createdAt, updatedAt: updatedAt ?? this.updatedAt, rowid: rowid ?? this.rowid, @@ -1055,6 +1433,36 @@ class AiChatThreadsCompanion extends UpdateCompanion { if (title.present) { map['title'] = Variable(title.value); } + if (summary.present) { + map['summary'] = Variable(summary.value); + } + if (lastMessagePreview.present) { + map['last_message_preview'] = Variable(lastMessagePreview.value); + } + if (messageCount.present) { + map['message_count'] = Variable(messageCount.value); + } + if (status.present) { + map['status'] = Variable(status.value); + } + if (pinnedAt.present) { + map['pinned_at'] = Variable( + AiChatThreads.$converterpinnedAtn.toSql(pinnedAt.value), + ); + } + if (archivedAt.present) { + map['archived_at'] = Variable( + AiChatThreads.$converterarchivedAtn.toSql(archivedAt.value), + ); + } + if (lastMessageAt.present) { + map['last_message_at'] = Variable( + AiChatThreads.$converterlastMessageAtn.toSql(lastMessageAt.value), + ); + } + if (metadata.present) { + map['metadata'] = Variable(metadata.value); + } if (createdAt.present) { map['created_at'] = Variable( AiChatThreads.$convertercreatedAt.toSql(createdAt.value), @@ -1077,6 +1485,14 @@ class AiChatThreadsCompanion extends UpdateCompanion { ..write('id: $id, ') ..write('conversationId: $conversationId, ') ..write('title: $title, ') + ..write('summary: $summary, ') + ..write('lastMessagePreview: $lastMessagePreview, ') + ..write('messageCount: $messageCount, ') + ..write('status: $status, ') + ..write('pinnedAt: $pinnedAt, ') + ..write('archivedAt: $archivedAt, ') + ..write('lastMessageAt: $lastMessageAt, ') + ..write('metadata: $metadata, ') ..write('createdAt: $createdAt, ') ..write('updatedAt: $updatedAt, ') ..write('rowid: $rowid') @@ -1100,7 +1516,11 @@ abstract class _$AiDatabase extends GeneratedDatabase { ); late final Index indexAiChatThreadsConversationIdUpdatedAt = Index( 'index_ai_chat_threads_conversation_id_updated_at', - 'CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads (conversation_id, updated_at DESC)', + 'CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads (conversation_id, status, updated_at DESC)', + ); + late final Index indexAiChatThreadsConversationIdLastMessageAt = Index( + 'index_ai_chat_threads_conversation_id_last_message_at', + 'CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_last_message_at ON ai_chat_threads (conversation_id, status, last_message_at DESC)', ); late final AiChatMessageDao aiChatMessageDao = AiChatMessageDao( this as AiDatabase, @@ -1115,6 +1535,7 @@ abstract class _$AiDatabase extends GeneratedDatabase { indexAiChatMessagesConversationIdCreatedAt, indexAiChatMessagesThreadIdCreatedAt, indexAiChatThreadsConversationIdUpdatedAt, + indexAiChatThreadsConversationIdLastMessageAt, ]; } @@ -1462,6 +1883,14 @@ typedef $AiChatThreadsCreateCompanionBuilder = required String id, required String conversationId, Value title, + Value summary, + Value lastMessagePreview, + Value messageCount, + Value status, + Value pinnedAt, + Value archivedAt, + Value lastMessageAt, + Value metadata, required DateTime createdAt, required DateTime updatedAt, Value rowid, @@ -1471,6 +1900,14 @@ typedef $AiChatThreadsUpdateCompanionBuilder = Value id, Value conversationId, Value title, + Value summary, + Value lastMessagePreview, + Value messageCount, + Value status, + Value pinnedAt, + Value archivedAt, + Value lastMessageAt, + Value metadata, Value createdAt, Value updatedAt, Value rowid, @@ -1500,6 +1937,49 @@ class $AiChatThreadsFilterComposer builder: (column) => ColumnFilters(column), ); + ColumnFilters get summary => $composableBuilder( + column: $table.summary, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get lastMessagePreview => $composableBuilder( + column: $table.lastMessagePreview, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get messageCount => $composableBuilder( + column: $table.messageCount, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnFilters(column), + ); + + ColumnWithTypeConverterFilters get pinnedAt => + $composableBuilder( + column: $table.pinnedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnWithTypeConverterFilters get archivedAt => + $composableBuilder( + column: $table.archivedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnWithTypeConverterFilters get lastMessageAt => + $composableBuilder( + column: $table.lastMessageAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnFilters get metadata => $composableBuilder( + column: $table.metadata, + builder: (column) => ColumnFilters(column), + ); + ColumnWithTypeConverterFilters get createdAt => $composableBuilder( column: $table.createdAt, @@ -1537,6 +2017,46 @@ class $AiChatThreadsOrderingComposer builder: (column) => ColumnOrderings(column), ); + ColumnOrderings get summary => $composableBuilder( + column: $table.summary, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get lastMessagePreview => $composableBuilder( + column: $table.lastMessagePreview, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get messageCount => $composableBuilder( + column: $table.messageCount, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get pinnedAt => $composableBuilder( + column: $table.pinnedAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get archivedAt => $composableBuilder( + column: $table.archivedAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get lastMessageAt => $composableBuilder( + column: $table.lastMessageAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get metadata => $composableBuilder( + column: $table.metadata, + builder: (column) => ColumnOrderings(column), + ); + ColumnOrderings get createdAt => $composableBuilder( column: $table.createdAt, builder: (column) => ColumnOrderings(column), @@ -1568,6 +2088,40 @@ class $AiChatThreadsAnnotationComposer GeneratedColumn get title => $composableBuilder(column: $table.title, builder: (column) => column); + GeneratedColumn get summary => + $composableBuilder(column: $table.summary, builder: (column) => column); + + GeneratedColumn get lastMessagePreview => $composableBuilder( + column: $table.lastMessagePreview, + builder: (column) => column, + ); + + GeneratedColumn get messageCount => $composableBuilder( + column: $table.messageCount, + builder: (column) => column, + ); + + GeneratedColumn get status => + $composableBuilder(column: $table.status, builder: (column) => column); + + GeneratedColumnWithTypeConverter get pinnedAt => + $composableBuilder(column: $table.pinnedAt, builder: (column) => column); + + GeneratedColumnWithTypeConverter get archivedAt => + $composableBuilder( + column: $table.archivedAt, + builder: (column) => column, + ); + + GeneratedColumnWithTypeConverter get lastMessageAt => + $composableBuilder( + column: $table.lastMessageAt, + builder: (column) => column, + ); + + GeneratedColumn get metadata => + $composableBuilder(column: $table.metadata, builder: (column) => column); + GeneratedColumnWithTypeConverter get createdAt => $composableBuilder(column: $table.createdAt, builder: (column) => column); @@ -1609,6 +2163,14 @@ class $AiChatThreadsTableManager Value id = const Value.absent(), Value conversationId = const Value.absent(), Value title = const Value.absent(), + Value summary = const Value.absent(), + Value lastMessagePreview = const Value.absent(), + Value messageCount = const Value.absent(), + Value status = const Value.absent(), + Value pinnedAt = const Value.absent(), + Value archivedAt = const Value.absent(), + Value lastMessageAt = const Value.absent(), + Value metadata = const Value.absent(), Value createdAt = const Value.absent(), Value updatedAt = const Value.absent(), Value rowid = const Value.absent(), @@ -1616,6 +2178,14 @@ class $AiChatThreadsTableManager id: id, conversationId: conversationId, title: title, + summary: summary, + lastMessagePreview: lastMessagePreview, + messageCount: messageCount, + status: status, + pinnedAt: pinnedAt, + archivedAt: archivedAt, + lastMessageAt: lastMessageAt, + metadata: metadata, createdAt: createdAt, updatedAt: updatedAt, rowid: rowid, @@ -1625,6 +2195,14 @@ class $AiChatThreadsTableManager required String id, required String conversationId, Value title = const Value.absent(), + Value summary = const Value.absent(), + Value lastMessagePreview = const Value.absent(), + Value messageCount = const Value.absent(), + Value status = const Value.absent(), + Value pinnedAt = const Value.absent(), + Value archivedAt = const Value.absent(), + Value lastMessageAt = const Value.absent(), + Value metadata = const Value.absent(), required DateTime createdAt, required DateTime updatedAt, Value rowid = const Value.absent(), @@ -1632,6 +2210,14 @@ class $AiChatThreadsTableManager id: id, conversationId: conversationId, title: title, + summary: summary, + lastMessagePreview: lastMessagePreview, + messageCount: messageCount, + status: status, + pinnedAt: pinnedAt, + archivedAt: archivedAt, + lastMessageAt: lastMessageAt, + metadata: metadata, createdAt: createdAt, updatedAt: updatedAt, rowid: rowid, diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index ebe54752f4..8e2d2afd9f 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -14,12 +14,35 @@ class AiChatMessageDao extends DatabaseAccessor static const assistantRole = 'assistant'; static const pendingStatus = 'pending'; static const errorStatus = 'error'; + static const activeThreadStatus = 'active'; static const _uuid = Uuid(); + Stream> watchThreads(String conversationId) => + (select(db.aiChatThreads) + ..where( + (tbl) => + tbl.conversationId.equals(conversationId) & + tbl.status.equals(activeThreadStatus), + ) + ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.pinnedAt), + (tbl) => OrderingTerm.desc(tbl.lastMessageAt), + (tbl) => OrderingTerm.desc(tbl.updatedAt), + (tbl) => OrderingTerm.desc(tbl.createdAt), + (tbl) => OrderingTerm.desc(tbl.id), + ])) + .watch(); + Stream watchLatestThread(String conversationId) => (select(db.aiChatThreads) - ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..where( + (tbl) => + tbl.conversationId.equals(conversationId) & + tbl.status.equals(activeThreadStatus), + ) ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.pinnedAt), + (tbl) => OrderingTerm.desc(tbl.lastMessageAt), (tbl) => OrderingTerm.desc(tbl.updatedAt), (tbl) => OrderingTerm.desc(tbl.createdAt), (tbl) => OrderingTerm.desc(tbl.id), @@ -29,8 +52,14 @@ class AiChatMessageDao extends DatabaseAccessor Future latestThread(String conversationId) => (select(db.aiChatThreads) - ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..where( + (tbl) => + tbl.conversationId.equals(conversationId) & + tbl.status.equals(activeThreadStatus), + ) ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.pinnedAt), + (tbl) => OrderingTerm.desc(tbl.lastMessageAt), (tbl) => OrderingTerm.desc(tbl.updatedAt), (tbl) => OrderingTerm.desc(tbl.createdAt), (tbl) => OrderingTerm.desc(tbl.id), @@ -47,6 +76,8 @@ class AiChatMessageDao extends DatabaseAccessor final thread = AiChatThread( id: _uuid.v4(), conversationId: conversationId, + messageCount: 0, + status: activeThreadStatus, createdAt: now, updatedAt: now, ); @@ -54,13 +85,36 @@ class AiChatMessageDao extends DatabaseAccessor return thread; } + Future updateThreadTitle(String threadId, String? title) => + (update( + db.aiChatThreads, + )..where((tbl) => tbl.id.equals(threadId))).write( + AiChatThreadsCompanion( + title: Value(title?.trim().isEmpty ?? true ? null : title!.trim()), + updatedAt: Value(DateTime.now()), + ), + ); + + Future deleteThread(String threadId) async { + await transaction(() async { + await (delete( + db.aiChatMessages, + )..where((tbl) => tbl.threadId.equals(threadId))).go(); + await (delete( + db.aiChatThreads, + )..where((tbl) => tbl.id.equals(threadId))).go(); + }); + } + Future ensureThread({ required String conversationId, String? threadId, }) async { if (threadId != null) { final thread = await threadById(threadId); - if (thread == null || thread.conversationId != conversationId) { + if (thread == null || + thread.conversationId != conversationId || + thread.status != activeThreadStatus) { throw StateError('AI thread not found'); } return thread; @@ -137,32 +191,44 @@ class AiChatMessageDao extends DatabaseAccessor Future insertMessage(AiChatMessagesCompanion row) async { await into(db.aiChatMessages).insertOnConflictUpdate(row); - await _touchThread(row.threadId.value); + await refreshThreadStats(row.threadId.value); } Future updateMessageContent( String id, String content, { required DateTime updatedAt, - }) => (update(db.aiChatMessages)..where((tbl) => tbl.id.equals(id))).write( - AiChatMessagesCompanion( - content: Value(content), - updatedAt: Value(updatedAt), - ), - ); + }) async { + final threadId = await _messageThreadId(id); + await (update(db.aiChatMessages)..where((tbl) => tbl.id.equals(id))).write( + AiChatMessagesCompanion( + content: Value(content), + updatedAt: Value(updatedAt), + ), + ); + if (threadId != null) { + await refreshThreadStats(threadId); + } + } Future updateMessageStatus( String id, String status, { required DateTime updatedAt, String? errorText, - }) => (update(db.aiChatMessages)..where((tbl) => tbl.id.equals(id))).write( - AiChatMessagesCompanion( - status: Value(status), - errorText: Value(errorText), - updatedAt: Value(updatedAt), - ), - ); + }) async { + final threadId = await _messageThreadId(id); + await (update(db.aiChatMessages)..where((tbl) => tbl.id.equals(id))).write( + AiChatMessagesCompanion( + status: Value(status), + errorText: Value(errorText), + updatedAt: Value(updatedAt), + ), + ); + if (threadId != null) { + await refreshThreadStats(threadId); + } + } Future appendMessageMetadataToolEvent( String id, @@ -215,9 +281,16 @@ class AiChatMessageDao extends DatabaseAccessor }); } - Future deleteConversationMessages(String conversationId) => (delete( - db.aiChatMessages, - )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); + Future deleteConversationMessages(String conversationId) async { + await transaction(() async { + await (delete( + db.aiChatMessages, + )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); + await (delete( + db.aiChatThreads, + )..where((tbl) => tbl.conversationId.equals(conversationId))).go(); + }); + } Future hasPendingAssistantMessage( String threadId, { @@ -270,10 +343,52 @@ class AiChatMessageDao extends DatabaseAccessor ); } - Future _touchThread(String threadId) => - (update( - db.aiChatThreads, - )..where((tbl) => tbl.id.equals(threadId))).write( - AiChatThreadsCompanion(updatedAt: Value(DateTime.now())), - ); + Future refreshThreadStats(String threadId) async { + final countExpression = db.aiChatMessages.id.count(); + final countQuery = selectOnly(db.aiChatMessages) + ..addColumns([countExpression]) + ..where(db.aiChatMessages.threadId.equals(threadId)); + final countRow = await countQuery.getSingleOrNull(); + final messageCount = countRow?.read(countExpression) ?? 0; + final latestMessage = + await (select(db.aiChatMessages) + ..where((tbl) => tbl.threadId.equals(threadId)) + ..orderBy([ + (tbl) => OrderingTerm.desc(tbl.createdAt), + (tbl) => OrderingTerm.desc(tbl.id), + ]) + ..limit(1)) + .getSingleOrNull(); + + await (update( + db.aiChatThreads, + )..where((tbl) => tbl.id.equals(threadId))).write( + AiChatThreadsCompanion( + lastMessagePreview: Value(_previewMessageContent(latestMessage)), + messageCount: Value(messageCount), + lastMessageAt: Value(latestMessage?.createdAt), + updatedAt: Value(DateTime.now()), + ), + ); + } + + Future _messageThreadId(String messageId) async { + final row = + await (select(db.aiChatMessages) + ..where((tbl) => tbl.id.equals(messageId)) + ..limit(1)) + .getSingleOrNull(); + return row?.threadId; + } + + String? _previewMessageContent(AiChatMessage? message) { + final content = message?.content.replaceAll(RegExp(r'\s+'), ' ').trim(); + if (content == null || content.isEmpty) { + return null; + } + if (content.length <= 160) { + return content; + } + return content.substring(0, 160); + } } diff --git a/lib/db/moor/ai.drift b/lib/db/moor/ai.drift index 16c7ef466f..829ddaa6ef 100644 --- a/lib/db/moor/ai.drift +++ b/lib/db/moor/ai.drift @@ -20,6 +20,14 @@ CREATE TABLE ai_chat_threads ( id TEXT NOT NULL, conversation_id TEXT NOT NULL, title TEXT, + summary TEXT, + last_message_preview TEXT, + message_count INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', + pinned_at INTEGER MAPPED BY `const MillisDateConverter()`, + archived_at INTEGER MAPPED BY `const MillisDateConverter()`, + last_message_at INTEGER MAPPED BY `const MillisDateConverter()`, + metadata TEXT, created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, PRIMARY KEY(id) @@ -27,4 +35,5 @@ CREATE TABLE ai_chat_threads ( CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages(conversation_id, created_at DESC); CREATE INDEX IF NOT EXISTS index_ai_chat_messages_thread_id_created_at ON ai_chat_messages(thread_id, created_at DESC); -CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads(conversation_id, updated_at DESC); +CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads(conversation_id, status, updated_at DESC); +CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_last_message_at ON ai_chat_threads(conversation_id, status, last_message_at DESC); diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index 8697933d2b..de550e9456 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -73,6 +73,7 @@ class ChatSideCubit extends AbstractResponsiveNavigatorCubit { static const groupsInCommon = 'groupsInCommon'; static const disappearMessages = 'disappearMessages'; static const aiAssistantPage = 'aiAssistantPage'; + static const aiAssistantThreadsPage = 'aiAssistantThreadsPage'; @override MaterialPage route(String name, Object? arguments) { @@ -137,6 +138,12 @@ class ChatSideCubit extends AbstractResponsiveNavigatorCubit { name: aiAssistantPage, child: _ChatSidePageBuilder(AiAssistantPage.new), ); + case aiAssistantThreadsPage: + return const MaterialPage( + key: ValueKey(aiAssistantThreadsPage), + name: aiAssistantThreadsPage, + child: _ChatSidePageBuilder(AiAssistantThreadsPage.new), + ); default: throw ArgumentError('Invalid route'); } diff --git a/lib/ui/home/chat_slide_page/ai_assistant/constants.dart b/lib/ui/home/chat_slide_page/ai_assistant/constants.dart index e35e29800b..d0039e8fdb 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant/constants.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant/constants.dart @@ -2,5 +2,11 @@ const aiAssistantTitle = 'AI Assistant'; const aiAssistantEmpty = 'Ask AI about this conversation'; const aiAssistantInputHint = 'Ask about this conversation'; const aiAssistantUnavailable = 'Add a usable AI model in Settings first'; +const aiAssistantNewThread = 'New Chat'; +const aiAssistantThreads = 'AI Chats'; +const aiAssistantUntitledThread = 'New chat'; +const aiAssistantDeleteThread = 'Delete Chat'; +const aiAssistantDeleteThreadDescription = + 'This removes the AI messages in this chat.'; const aiAssistantStickToBottomDistance = 96.0; const aiAssistantMessagePageLimit = 80; diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart index 1ec25a172b..5cfb883168 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant_page.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -1,22 +1,36 @@ import 'package:flutter/material.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; +import 'package:super_context_menu/super_context_menu.dart'; import '../../../ai/ai_chat_controller.dart'; import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; +import '../../../constants/icon_fonts.dart'; +import '../../../constants/resources.dart'; import '../../../db/ai_database.dart'; import '../../../utils/extension/extension.dart'; import '../../../utils/hook.dart'; +import '../../../widgets/action_button.dart'; import '../../../widgets/app_bar.dart'; +import '../../../widgets/buttons.dart'; +import '../../../widgets/dialog.dart'; +import '../../../widgets/empty.dart'; +import '../../../widgets/menu.dart'; import '../../../widgets/toast.dart'; import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; +import '../chat/chat_page.dart'; import 'ai_assistant/composer.dart'; import 'ai_assistant/constants.dart'; import 'ai_assistant/helpers.dart'; import 'ai_assistant/message_list.dart'; +const _newAiAssistantThreadId = ''; + +final aiAssistantThreadIdProvider = StateProvider.autoDispose + .family((ref, conversationId) => null); + class AiAssistantPage extends HookConsumerWidget { const AiAssistantPage(this.conversationState, {super.key}); @@ -41,21 +55,37 @@ class AiAssistantPage extends HookConsumerWidget { providerId: aiModeState.providerId, selectedModel: aiModeState.model, ); - final activeThread = useMemoizedStream( - () => context.database.aiChatMessageDao.watchLatestThread( - conversationId, - ), - keys: [conversationId], - ).data; + final threads = + useMemoizedStream( + () => context.database.aiChatMessageDao.watchThreads(conversationId), + keys: [conversationId], + initialData: const [], + ).data ?? + const []; + final activeThreadId = ref.watch( + aiAssistantThreadIdProvider(conversationId), + ); + final activeThreadNotifier = ref.read( + aiAssistantThreadIdProvider(conversationId).notifier, + ); + final isNewThreadPage = + activeThreadId == _newAiAssistantThreadId || threads.isEmpty; + final activeThread = threads.firstWhereOrNull( + (item) => item.id == activeThreadId, + ); + final fallbackThread = threads.firstOrNull; + final currentThread = isNewThreadPage + ? null + : activeThread ?? fallbackThread; final latestMessages = useMemoizedStream( - () => activeThread == null + () => currentThread == null ? Stream.value(const []) : context.database.aiChatMessageDao.watchLatestThreadMessages( - activeThread.id, + currentThread.id, aiAssistantMessagePageLimit, ), - keys: [activeThread?.id], + keys: [currentThread?.id], initialData: const [], ).data ?? const []; @@ -88,28 +118,49 @@ class AiAssistantPage extends HookConsumerWidget { } try { - await AiChatController(context.database).send( + final threadId = await AiChatController(context.database).send( conversationId: conversationId, - threadId: activeThread?.id, + threadId: currentThread?.id, input: text, language: currentLanguageTag(context), provider: aiProvider, onInputAccepted: textEditingController.clear, ); + activeThreadNotifier.state = threadId; } catch (error, _) { showToastFailed(error); } } + void openNewThreadPage() { + if (isNewThreadPage) return; + activeThreadNotifier.state = _newAiAssistantThreadId; + } + return Scaffold( backgroundColor: context.theme.primary, - appBar: const MixinAppBar(title: Text(aiAssistantTitle)), + appBar: MixinAppBar( + leadingWidth: 104, + leading: _AiAssistantLeadingActions( + addEnabled: !isNewThreadPage, + onOpenThreads: () => context.read().pushPage( + ChatSideCubit.aiAssistantThreadsPage, + ), + onNewThread: openNewThreadPage, + ), + title: Text(_threadTitle(currentThread, threads)), + actions: [ + MixinCloseButton( + onTap: () => context.read().onPopPage(), + ), + ], + ), body: Column( children: [ Expanded( child: AiAssistantMessageList( conversationId: conversationId, - threadId: activeThread?.id, + threadId: currentThread?.id, latestMessages: latestMessages, ), ), @@ -123,7 +174,7 @@ class AiAssistantPage extends HookConsumerWidget { onSend: send, onStop: () => AiChatController( context.database, - ).stop(conversationId, threadId: activeThread?.id), + ).stop(conversationId, threadId: currentThread?.id), onProviderSelected: (value) => aiModeNotifier.updateProvider( providerId: value.id, model: value.model, @@ -135,3 +186,213 @@ class AiAssistantPage extends HookConsumerWidget { ); } } + +class AiAssistantThreadsPage extends HookConsumerWidget { + const AiAssistantThreadsPage(this.conversationState, {super.key}); + + final ConversationState conversationState; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final conversationId = conversationState.conversationId; + final threads = + useMemoizedStream( + () => context.database.aiChatMessageDao.watchThreads(conversationId), + keys: [conversationId], + initialData: const [], + ).data ?? + const []; + final activeThreadId = ref.watch( + aiAssistantThreadIdProvider(conversationId), + ); + final activeThreadNotifier = ref.read( + aiAssistantThreadIdProvider(conversationId).notifier, + ); + final hasSelectedThread = threads.any((item) => item.id == activeThreadId); + + return Scaffold( + backgroundColor: context.theme.primary, + appBar: MixinAppBar( + title: const Text(aiAssistantThreads), + actions: [ + MixinCloseButton( + onTap: () => context.read().onPopPage(), + ), + ], + ), + body: threads.isEmpty + ? const Empty(text: aiAssistantEmpty) + : ListView.separated( + padding: const EdgeInsets.fromLTRB(10, 8, 10, 20), + itemBuilder: (context, index) { + final thread = threads[index]; + final selected = + activeThreadId == thread.id || + ((activeThreadId == null || !hasSelectedThread) && + index == 0); + return _AiAssistantThreadTile( + thread: thread, + title: _threadTitle(thread, threads), + selected: selected, + onDelete: () async { + final result = await showConfirmMixinDialog( + context, + aiAssistantDeleteThread, + description: aiAssistantDeleteThreadDescription, + ); + if (result != DialogEvent.positive) return; + await context.database.aiChatMessageDao.deleteThread( + thread.id, + ); + if (activeThreadId == thread.id) { + activeThreadNotifier.state = null; + } + }, + onTap: () { + activeThreadNotifier.state = thread.id; + context.read().pop(); + }, + ); + }, + separatorBuilder: (context, index) => const SizedBox(height: 6), + itemCount: threads.length, + ), + ); + } +} + +class _AiAssistantLeadingActions extends StatelessWidget { + const _AiAssistantLeadingActions({ + required this.addEnabled, + required this.onOpenThreads, + required this.onNewThread, + }); + + final bool addEnabled; + final VoidCallback onOpenThreads; + final VoidCallback onNewThread; + + @override + Widget build(BuildContext context) => Row( + mainAxisSize: MainAxisSize.min, + children: [ + ActionButton( + onTap: onOpenThreads, + child: Icon(Icons.history_rounded, color: context.theme.icon, size: 22), + ), + ActionButton( + name: Resources.assetsImagesIcAddSvg, + color: addEnabled ? context.theme.icon : context.theme.secondaryText, + interactive: addEnabled, + onTap: onNewThread, + ), + ], + ); +} + +class _AiAssistantThreadTile extends StatelessWidget { + const _AiAssistantThreadTile({ + required this.thread, + required this.title, + required this.selected, + required this.onTap, + required this.onDelete, + }); + + final AiChatThread thread; + final String title; + final bool selected; + final VoidCallback onTap; + final VoidCallback onDelete; + + @override + Widget build(BuildContext context) { + final preview = thread.lastMessagePreview?.trim(); + final backgroundColor = selected + ? context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.06), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ) + : Colors.transparent; + + return CustomContextMenuWidget( + desktopMenuWidgetBuilder: CustomDesktopMenuWidgetBuilder(), + menuProvider: (_) => MenusWithSeparator( + childrens: [ + [ + MenuAction( + attributes: const MenuActionAttributes(destructive: true), + image: MenuImage.icon(IconFonts.delete), + title: aiAssistantDeleteThread, + callback: onDelete, + ), + ], + ], + ), + child: Material( + color: backgroundColor, + borderRadius: const BorderRadius.all(Radius.circular(8)), + child: InkWell( + borderRadius: const BorderRadius.all(Radius.circular(8)), + onTap: onTap, + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 14, vertical: 12), + child: Row( + children: [ + Expanded( + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + title, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.text, + fontSize: 15, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 4), + Text( + preview?.isNotEmpty == true + ? preview! + : aiAssistantEmpty, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + ), + ), + ], + ), + ), + const SizedBox(width: 12), + Text( + '${thread.messageCount}', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + ), + ), + ], + ), + ), + ), + ), + ); + } +} + +String _threadTitle(AiChatThread? thread, List threads) { + if (thread == null) { + return aiAssistantNewThread; + } + final title = thread.title?.trim(); + if (title != null && title.isNotEmpty) { + return title; + } + final index = threads.indexWhere((item) => item.id == thread.id); + return index < 0 ? aiAssistantUntitledThread : 'Chat ${index + 1}'; +} diff --git a/lib/widgets/app_bar.dart b/lib/widgets/app_bar.dart index 0445c15285..bf804787b4 100644 --- a/lib/widgets/app_bar.dart +++ b/lib/widgets/app_bar.dart @@ -11,12 +11,14 @@ class MixinAppBar extends StatelessWidget implements PreferredSizeWidget { this.actions = const [], this.backgroundColor, this.leading, + this.leadingWidth, }); final Widget? title; final List actions; final Color? backgroundColor; final Widget? leading; + final double? leadingWidth; @override Widget build(BuildContext context) { @@ -49,6 +51,7 @@ class MixinAppBar extends StatelessWidget implements PreferredSizeWidget { elevation: 0, centerTitle: true, backgroundColor: backgroundColor ?? context.theme.primary, + leadingWidth: leadingWidth, leading: MoveWindowBarrier( child: Builder( builder: (context) => diff --git a/test/ai/ai_chat_thread_test.dart b/test/ai/ai_chat_thread_test.dart index 13566ede73..763f158188 100644 --- a/test/ai/ai_chat_thread_test.dart +++ b/test/ai/ai_chat_thread_test.dart @@ -3,10 +3,12 @@ import 'package:drift/native.dart'; import 'package:flutter_app/ai/ai_chat_prompt_builder.dart'; import 'package:flutter_app/ai/model/ai_prompt_message.dart'; import 'package:flutter_app/db/ai_database.dart'; +import 'package:flutter_app/db/dao/ai_chat_message_dao.dart'; import 'package:flutter_app/db/database.dart'; import 'package:flutter_app/db/fts_database.dart'; import 'package:flutter_app/db/mixin_database.dart'; import 'package:flutter_test/flutter_test.dart'; +import 'package:sqlite3/sqlite3.dart' as sqlite; void main() { group('AI chat threads', () { @@ -14,15 +16,21 @@ void main() { late FtsDatabase ftsDatabase; late AiDatabase aiDatabase; late Database database; + late bool disposeDatabase; setUp(() { mixinDatabase = MixinDatabase(NativeDatabase.memory()); ftsDatabase = FtsDatabase(NativeDatabase.memory()); aiDatabase = AiDatabase(NativeDatabase.memory()); database = Database(mixinDatabase, ftsDatabase, aiDatabase); + disposeDatabase = true; }); - tearDown(() => database.dispose()); + tearDown(() async { + if (disposeDatabase) { + await database.dispose(); + } + }); test('scopes messages and pending state by thread', () async { const conversationId = 'conversation-id'; @@ -84,6 +92,57 @@ void main() { ); }); + test('maintains thread list metadata from messages', () async { + const conversationId = 'conversation-id'; + final thread = await database.aiChatMessageDao.createThread( + conversationId, + ); + final now = DateTime.now(); + + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: 'user-message', + threadId: Value(thread.id), + conversationId: conversationId, + role: 'user', + providerId: 'provider-id', + content: 'hello', + status: 'done', + createdAt: now, + updatedAt: now, + ), + ); + await database.aiChatMessageDao.insertMessage( + AiChatMessagesCompanion.insert( + id: 'assistant-message', + threadId: Value(thread.id), + conversationId: conversationId, + role: 'assistant', + providerId: 'provider-id', + content: '', + status: 'pending', + createdAt: now.add(const Duration(milliseconds: 1)), + updatedAt: now.add(const Duration(milliseconds: 1)), + ), + ); + await database.aiChatMessageDao.updateMessageContent( + 'assistant-message', + 'assistant answer', + updatedAt: now.add(const Duration(milliseconds: 2)), + ); + + final updatedThread = await database.aiChatMessageDao.threadById( + thread.id, + ); + + expect(updatedThread?.messageCount, 2); + expect(updatedThread?.lastMessagePreview, 'assistant answer'); + expect( + updatedThread?.lastMessageAt?.millisecondsSinceEpoch, + now.add(const Duration(milliseconds: 1)).millisecondsSinceEpoch, + ); + }); + test('prompt history excludes the current user message', () async { const conversationId = 'conversation-id'; final thread = await database.aiChatMessageDao.createThread( From b4344e0213ed133ee5547f6d6f03d711d868257a Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 21:29:36 +0800 Subject: [PATCH 33/52] feat: enhance AI assistant UI and add response details dialog --- lib/ui/home/chat/input_container.dart | 2 +- .../chat_slide_page/ai_assistant_page.dart | 31 +-- lib/widgets/ai/ai_message_card.dart | 215 +++++++++++++++--- 3 files changed, 195 insertions(+), 53 deletions(-) diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 4bb98a49e1..104aec8e96 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -490,7 +490,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { ).data ?? false; - if (aiRequestInFlight) { + if (aiModeEnabled && aiRequestInFlight) { return ActionButton( name: Resources.assetsImagesRecordStopSvg, color: context.theme.accent, diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart index 5cfb883168..1f0b1eb26e 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant_page.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -140,19 +140,19 @@ class AiAssistantPage extends HookConsumerWidget { return Scaffold( backgroundColor: context.theme.primary, appBar: MixinAppBar( - leadingWidth: 104, - leading: _AiAssistantLeadingActions( - addEnabled: !isNewThreadPage, - onOpenThreads: () => context.read().pushPage( - ChatSideCubit.aiAssistantThreadsPage, - ), - onNewThread: openNewThreadPage, - ), title: Text(_threadTitle(currentThread, threads)), actions: [ - MixinCloseButton( - onTap: () => context.read().onPopPage(), + _AiAssistantActions( + addEnabled: !isNewThreadPage, + onOpenThreads: () => context.read().pushPage( + ChatSideCubit.aiAssistantThreadsPage, + ), + onNewThread: openNewThreadPage, ), + if (!Navigator.of(context).canPop()) + MixinCloseButton( + onTap: () => context.read().onPopPage(), + ), ], ), body: Column( @@ -215,9 +215,10 @@ class AiAssistantThreadsPage extends HookConsumerWidget { appBar: MixinAppBar( title: const Text(aiAssistantThreads), actions: [ - MixinCloseButton( - onTap: () => context.read().onPopPage(), - ), + if (!Navigator.of(context).canPop()) + MixinCloseButton( + onTap: () => context.read().onPopPage(), + ), ], ), body: threads.isEmpty @@ -261,8 +262,8 @@ class AiAssistantThreadsPage extends HookConsumerWidget { } } -class _AiAssistantLeadingActions extends StatelessWidget { - const _AiAssistantLeadingActions({ +class _AiAssistantActions extends StatelessWidget { + const _AiAssistantActions({ required this.addEnabled, required this.onOpenThreads, required this.onNewThread, diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index f933528dd6..42f799e0b3 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -10,6 +10,7 @@ import '../../db/ai_database.dart'; import '../../utils/datetime_format_utils.dart'; import '../../utils/extension/extension.dart'; import '../../utils/platform.dart'; +import '../dialog.dart'; import '../markdown.dart'; import '../menu.dart'; import '../message/item/text/selectable.dart'; @@ -20,6 +21,7 @@ import '../message/message_style.dart'; import '../qr_code.dart'; const _copyAiMessageTitle = 'Copy AI Message'; +const _showAiResponseDetailsTitle = 'AI Response Details'; class AiMessageCard extends StatelessWidget { const AiMessageCard({ @@ -99,7 +101,7 @@ class _AiUserMessageCard extends StatelessWidget { ); } -class _AiResponseMessageCard extends StatelessWidget { +class _AiResponseMessageCard extends StatefulWidget { const _AiResponseMessageCard({ required this.message, required this.mergedWithPrev, @@ -108,25 +110,40 @@ class _AiResponseMessageCard extends StatelessWidget { final AiChatMessage message; final bool mergedWithPrev; + @override + State<_AiResponseMessageCard> createState() => _AiResponseMessageCardState(); +} + +class _AiResponseMessageCardState extends State<_AiResponseMessageCard> { + bool _hovering = false; + @override Widget build(BuildContext context) => Padding( padding: EdgeInsets.only( - top: mergedWithPrev ? 6 : 18, + top: widget.mergedWithPrev ? 6 : 18, bottom: 6, ), - child: _AiMessageMenu( - message: message, - child: Column( - spacing: 6, - children: [ - _AiResponseMessageBody(message: message), - const SizedBox(height: 4), - _AiResponseFooter( - model: message.model, - metadata: message.metadata, - dateTime: message.createdAt, - ), - ], + child: MouseRegion( + onEnter: (_) { + if (!_hovering) setState(() => _hovering = true); + }, + onExit: (_) { + if (_hovering) setState(() => _hovering = false); + }, + child: _AiMessageMenu( + message: widget.message, + child: Column( + spacing: 6, + children: [ + _AiResponseMessageBody(message: widget.message), + const SizedBox(height: 4), + _AiResponseFooter( + model: widget.message.model, + dateTime: widget.message.createdAt, + showModel: _hovering, + ), + ], + ), ), ), ); @@ -350,6 +367,12 @@ class _AiMessageMenu extends StatelessWidget { ), ], [ + if (message.role != 'user') + MenuAction( + image: MenuImage.icon(Icons.info_outline), + title: _showAiResponseDetailsTitle, + callback: () => _showAiResponseDetails(context, message), + ), MenuAction( image: MenuImage.icon(Icons.data_object), title: _copyAiMessageTitle, @@ -396,16 +419,130 @@ SelectedContent? _findSelectedContent(BuildContext context) { return null; } +void _showAiResponseDetails(BuildContext context, AiChatMessage message) { + final rootMeta = decodeAiMessageMetadata(message.metadata); + final providerMeta = rootMeta['provider']; + final responseMeta = aiMetadataResponse(message.metadata); + final usage = responseMeta['usage']; + final elapsedMs = (responseMeta['elapsedMs'] as num?)?.round(); + final totalTokens = _totalTokens(responseMeta); + final inputTokens = _usageValue(responseMeta, 'inputTokens'); + final outputTokens = _usageValue(responseMeta, 'outputTokens'); + final promptMessageCount = responseMeta['promptMessageCount'] as num?; + final toolCount = responseMeta['toolCount'] as num?; + final outputCharacters = responseMeta['outputCharacters'] as num?; + final providerType = providerMeta is Map ? providerMeta['type'] : null; + final providerModel = providerMeta is Map ? providerMeta['model'] : null; + final model = (message.model?.trim().isNotEmpty ?? false) + ? message.model!.trim() + : providerModel is String + ? providerModel + : null; + final completedAt = _formatIsoDateTime(responseMeta['completedAt']); + final createdAt = DateFormat( + 'yyyy-MM-dd HH:mm:ss', + ).format(message.createdAt.toLocal()); + final details = >[ + MapEntry('Created', createdAt), + if (completedAt != null) MapEntry('Completed', completedAt), + MapEntry('Status', message.status), + if (model != null && model.isNotEmpty) MapEntry('Model', model), + if (providerType is String && providerType.isNotEmpty) + MapEntry('Provider', providerType), + if (elapsedMs != null && elapsedMs > 0) + MapEntry('Elapsed', _formatElapsed(elapsedMs)), + if (totalTokens != null && totalTokens > 0) + MapEntry('Total tokens', _formatFullTokens(totalTokens)), + if (inputTokens != null && inputTokens > 0) + MapEntry('Input tokens', _formatFullTokens(inputTokens)), + if (outputTokens != null && outputTokens > 0) + MapEntry('Output tokens', _formatFullTokens(outputTokens)), + if (promptMessageCount != null && promptMessageCount > 0) + MapEntry('Prompt messages', '${promptMessageCount.round()}'), + if (toolCount != null && toolCount > 0) + MapEntry('Tools', '${toolCount.round()}'), + if (outputCharacters != null && outputCharacters > 0) + MapEntry('Output chars', '${outputCharacters.round()}'), + if (usage is Map && usage.isEmpty) const MapEntry('Usage', 'Empty'), + ]; + final detailsText = details + .map((entry) => '${entry.key}: ${entry.value}') + .join('\n'); + + showMixinDialog( + context: context, + constraints: const BoxConstraints(maxWidth: 420), + child: Builder( + builder: (context) => AlertDialogLayout( + minWidth: 360, + minHeight: 0, + titleMarginBottom: 20, + title: const Text(_showAiResponseDetailsTitle), + content: DefaultTextStyle.merge( + style: TextStyle( + color: context.theme.text, + fontSize: 13, + fontWeight: FontWeight.normal, + ), + child: Column( + mainAxisSize: MainAxisSize.min, + children: details + .map( + (entry) => Padding( + padding: const EdgeInsets.symmetric(vertical: 4), + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + SizedBox( + width: 112, + child: Text( + entry.key, + style: TextStyle( + color: context.theme.secondaryText, + ), + ), + ), + Expanded( + child: SelectableText( + entry.value, + style: TextStyle(color: context.theme.text), + ), + ), + ], + ), + ), + ) + .toList(growable: false), + ), + ), + actions: [ + MixinButton( + backgroundTransparent: true, + onTap: () { + Clipboard.setData(ClipboardData(text: detailsText)); + }, + child: Text(context.l10n.copy), + ), + MixinButton( + onTap: () => Navigator.pop(context), + child: Text(context.l10n.close), + ), + ], + ), + ), + ); +} + class _AiResponseFooter extends StatelessWidget { const _AiResponseFooter({ required this.model, - required this.metadata, required this.dateTime, + required this.showModel, }); final String? model; - final String? metadata; final DateTime dateTime; + final bool showModel; @override Widget build(BuildContext context) { @@ -419,26 +556,23 @@ class _AiResponseFooter extends StatelessWidget { ); final dateTimeText = DateFormat.Hm().format(dateTime.toLocal()); final trimmedModel = model?.trim(); - final responseMeta = aiMetadataResponse(metadata); - final elapsedMs = (responseMeta['elapsedMs'] as num?)?.round(); - final totalTokens = _totalTokens(responseMeta); + final text = [ + dateTimeText, + if (showModel && trimmedModel != null && trimmedModel.isNotEmpty) + trimmedModel, + ].join(' · '); return SelectionContainer.disabled( - child: SizedBox( - width: double.infinity, - child: Wrap( - spacing: 12, - runSpacing: 2, - children: [ - const SizedBox(width: 4), - Text(dateTimeText, style: textStyle), - if (trimmedModel != null && trimmedModel.isNotEmpty) - Text(trimmedModel, style: textStyle), - if (elapsedMs != null && elapsedMs > 0) - Text(_formatElapsed(elapsedMs), style: textStyle), - if (totalTokens != null && totalTokens > 0) - Text(_formatTokens(totalTokens), style: textStyle), - ], + child: Padding( + padding: const EdgeInsets.only(left: 4), + child: SizedBox( + width: double.infinity, + child: Text( + text, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: textStyle, + ), ), ), ); @@ -469,8 +603,15 @@ String _formatElapsed(int elapsedMs) { return '${seconds.toStringAsFixed(seconds >= 10 ? 0 : 1)}s'; } -String _formatTokens(num tokens) => - '${NumberFormat.decimalPattern().format(tokens.round())} tokens'; +String _formatFullTokens(num tokens) => + NumberFormat.decimalPattern().format(tokens.round()); + +String? _formatIsoDateTime(Object? value) { + if (value is! String || value.isEmpty) return null; + final dateTime = DateTime.tryParse(value); + if (dateTime == null) return null; + return DateFormat('yyyy-MM-dd HH:mm:ss').format(dateTime.toLocal()); +} String _displayText(AiChatMessage message) { final content = message.content.trim(); From 61e5b0e753362a878af244ba0fc918f441598b6f Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 21:30:40 +0800 Subject: [PATCH 34/52] refactor: remove unused imports in AI chat thread tests --- test/ai/ai_chat_thread_test.dart | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/ai/ai_chat_thread_test.dart b/test/ai/ai_chat_thread_test.dart index 763f158188..d70d197d79 100644 --- a/test/ai/ai_chat_thread_test.dart +++ b/test/ai/ai_chat_thread_test.dart @@ -3,12 +3,10 @@ import 'package:drift/native.dart'; import 'package:flutter_app/ai/ai_chat_prompt_builder.dart'; import 'package:flutter_app/ai/model/ai_prompt_message.dart'; import 'package:flutter_app/db/ai_database.dart'; -import 'package:flutter_app/db/dao/ai_chat_message_dao.dart'; import 'package:flutter_app/db/database.dart'; import 'package:flutter_app/db/fts_database.dart'; import 'package:flutter_app/db/mixin_database.dart'; import 'package:flutter_test/flutter_test.dart'; -import 'package:sqlite3/sqlite3.dart' as sqlite; void main() { group('AI chat threads', () { From 1d5e6964675a9292a203f7763d3570c0d373ae67 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Tue, 28 Apr 2026 21:41:09 +0800 Subject: [PATCH 35/52] refactor(db): simplify migration strategy and remove unused methods --- lib/db/ai_database.dart | 76 ++--------------------------------------- 1 file changed, 2 insertions(+), 74 deletions(-) diff --git a/lib/db/ai_database.dart b/lib/db/ai_database.dart index 22f6ac3b08..da6b7098b5 100644 --- a/lib/db/ai_database.dart +++ b/lib/db/ai_database.dart @@ -27,82 +27,10 @@ class AiDatabase extends _$AiDatabase { } @override - int get schemaVersion => 2; + int get schemaVersion => 1; @override MigrationStrategy get migration => MigrationStrategy( - onUpgrade: (m, from, to) async { - if (from <= 1) { - await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.summary); - await _addColumnIfNotExists( - m, - aiChatThreads, - aiChatThreads.lastMessagePreview, - ); - await _addColumnIfNotExists( - m, - aiChatThreads, - aiChatThreads.messageCount, - ); - await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.status); - await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.pinnedAt); - await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.archivedAt); - await _addColumnIfNotExists( - m, - aiChatThreads, - aiChatThreads.lastMessageAt, - ); - await _addColumnIfNotExists(m, aiChatThreads, aiChatThreads.metadata); - await _backfillThreadStats(); - await customStatement( - 'DROP INDEX IF EXISTS index_ai_chat_threads_conversation_id_updated_at', - ); - await m.createIndex(indexAiChatThreadsConversationIdUpdatedAt); - await m.createIndex(indexAiChatThreadsConversationIdLastMessageAt); - } - }, + onUpgrade: (m, from, to) async {}, ); - - Future _addColumnIfNotExists( - Migrator m, - TableInfo table, - GeneratedColumn column, - ) async { - if (!await _checkColumnExists(table.actualTableName, column.name)) { - await m.addColumn(table, column); - } - } - - Future _checkColumnExists(String tableName, String columnName) async { - final queryRow = await customSelect( - "SELECT COUNT(*) AS CNTREC FROM pragma_table_info('$tableName') WHERE name='$columnName'", - ).getSingle(); - return queryRow.read('CNTREC'); - } - - Future _backfillThreadStats() async { - await customStatement(''' -UPDATE ai_chat_threads -SET - message_count = ( - SELECT COUNT(*) - FROM ai_chat_messages - WHERE ai_chat_messages.thread_id = ai_chat_threads.id - ), - last_message_at = ( - SELECT created_at - FROM ai_chat_messages - WHERE ai_chat_messages.thread_id = ai_chat_threads.id - ORDER BY created_at DESC, id DESC - LIMIT 1 - ), - last_message_preview = ( - SELECT substr(trim(replace(replace(content, char(10), ' '), char(13), ' ')), 1, 160) - FROM ai_chat_messages - WHERE ai_chat_messages.thread_id = ai_chat_threads.id - ORDER BY created_at DESC, id DESC - LIMIT 1 - ) -'''); - } } From 367f924fde5095bc93d32b62c70d73dfd0a99b4e Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:49:03 +0800 Subject: [PATCH 36/52] feat(ai): add support for attaching messages as AI context --- lib/ai/ai_chat_controller.dart | 9 + lib/ai/ai_chat_prompt_builder.dart | 34 ++++ lib/ai/ai_message_context.dart | 52 ++++++ lib/ai/model/ai_chat_metadata.dart | 27 +++ lib/ui/home/chat/chat_page.dart | 6 +- lib/ui/home/chat/chat_side_route_names.dart | 1 + lib/ui/home/chat/input_container.dart | 77 +++++++- lib/ui/home/chat/selection_bottom_bar.dart | 64 +++++-- .../ai_assistant/composer.dart | 14 ++ .../ai_assistant/constants.dart | 1 + .../chat_slide_page/ai_assistant_page.dart | 30 +++- .../ai_context_attachment_provider.dart | 54 ++++++ lib/widgets/ai/ai_context_attachment_bar.dart | 169 ++++++++++++++++++ lib/widgets/ai/ai_message_card.dart | 105 ++++++++++- lib/widgets/markdown.dart | 26 ++- lib/widgets/message/message.dart | 34 +++- pubspec.lock | 4 +- pubspec.yaml | 2 +- test/ai/ai_chat_metadata_test.dart | 20 +++ 19 files changed, 697 insertions(+), 32 deletions(-) create mode 100644 lib/ai/ai_message_context.dart create mode 100644 lib/ui/home/chat/chat_side_route_names.dart create mode 100644 lib/ui/provider/ai_context_attachment_provider.dart create mode 100644 lib/widgets/ai/ai_context_attachment_bar.dart diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 7f44aae5f5..79c6344404 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -8,7 +8,9 @@ import 'package:uuid/uuid.dart'; import '../db/ai_database.dart'; import '../db/dao/ai_chat_message_dao.dart'; import '../db/database.dart'; +import '../db/mixin_database.dart'; import 'ai_chat_prompt_builder.dart'; +import 'ai_message_context.dart'; import 'ai_provider_requester.dart'; import 'model/ai_chat_metadata.dart'; import 'model/ai_prompt_message.dart'; @@ -100,6 +102,7 @@ class AiChatController { required String language, String? threadId, AiProviderConfig? provider, + List attachedMessages = const [], void Function()? onInputAccepted, }) async { final thread = await database.aiChatMessageDao.ensureThread( @@ -147,6 +150,11 @@ class AiChatController { content: input, status: _kAiStatusDone, model: Value(config.model), + metadata: Value( + createAiUserMessageMetadata( + attachedMessages.map(aiMessageContextMetadata).toList(), + ), + ), createdAt: now, updatedAt: now, ), @@ -187,6 +195,7 @@ class AiChatController { input, language, currentMessageId: userMessageId, + attachedMessages: attachedMessages, ); Map? responseMetadata; final result = await _requestText( diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index 0e13e070e9..a0cc04b5f8 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -2,6 +2,7 @@ import 'package:mixin_logger/mixin_logger.dart'; import '../db/database.dart'; import '../db/mixin_database.dart'; +import 'ai_message_context.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_prompt_template.dart'; @@ -20,6 +21,7 @@ class AiChatPromptBuilder { String input, String language, { String? currentMessageId, + List attachedMessages = const [], }) async { final now = DateTime.now(); final recentMessages = await database.messageDao @@ -58,6 +60,12 @@ class AiChatPromptBuilder { language: language, now: now, ); + _appendAttachedMessages( + promptMessages, + attachedMessages: attachedMessages, + language: language, + now: now, + ); final history = aiMessages .where( @@ -90,6 +98,7 @@ class AiChatPromptBuilder { 'AI prompt built: conversationId=$conversationId ' 'threadId=$threadId ' 'recent=${recentMessages.length} ' + 'attached=${attachedMessages.length} ' 'history=${history.length} promptMessages=${promptMessages.length}', ); return promptMessages; @@ -225,6 +234,31 @@ class AiChatPromptBuilder { } } + void _appendAttachedMessages( + List promptMessages, { + required List attachedMessages, + required String language, + required DateTime now, + }) { + if (attachedMessages.isEmpty) { + return; + } + + final lines = attachedMessages.map(aiMessageContextLine).join('\n'); + promptMessages.addAll( + _promptMessages( + role: AiPromptRole.system, + content: + 'User-attached messages for the next request. Treat these as ' + 'the primary quoted context, especially when the user says ' + '"this message", "these messages", or asks for a specific ' + 'message to be handled. Answer in $language unless the user ' + 'explicitly asks for another language. Current time: ' + '${now.toIso8601String()}.\n$lines', + ), + ); + } + String _conversationContextLine({ required DateTime createdAt, required String sender, diff --git a/lib/ai/ai_message_context.dart b/lib/ai/ai_message_context.dart new file mode 100644 index 0000000000..6474a6b602 --- /dev/null +++ b/lib/ai/ai_message_context.dart @@ -0,0 +1,52 @@ +import '../db/extension/message_category.dart'; +import '../db/mixin_database.dart'; +import '../utils/message_optimize.dart'; + +String aiMessageContextText(MessageItem message) { + final content = message.content?.trim(); + if ((message.type.isText || message.type.isPost) && + content != null && + content.isNotEmpty) { + return content; + } + + final caption = message.caption?.trim(); + if (caption != null && caption.isNotEmpty) { + return caption; + } + + final mediaName = message.mediaName?.trim(); + if (mediaName != null && mediaName.isNotEmpty) { + return '[${message.type}] $mediaName'; + } + + return messagePreviewOptimize( + message.status, + message.type, + message.content, + ) ?? + '[${message.type}]'; +} + +String aiMessageContextLine(MessageItem message) => + '[${message.createdAt.toIso8601String()}] ' + '${message.userFullName ?? message.userId} ' + '(message_id=${message.messageId}): ${aiMessageContextText(message)}'; + +String aiMessageContextPreview(MessageItem message, {int maxLength = 96}) { + final text = aiMessageContextText(message).replaceAll(RegExp(r'\s+'), ' '); + if (text.length <= maxLength) { + return text; + } + return '${text.substring(0, maxLength)}...'; +} + +Map aiMessageContextMetadata(MessageItem message) => { + 'messageId': message.messageId, + 'conversationId': message.conversationId, + 'senderId': message.userId, + 'senderName': message.userFullName ?? message.userId, + 'type': message.type, + 'createdAt': message.createdAt.toUtc().toIso8601String(), + 'preview': aiMessageContextPreview(message, maxLength: 180), +}; diff --git a/lib/ai/model/ai_chat_metadata.dart b/lib/ai/model/ai_chat_metadata.dart index 632b8c7637..5fedc8aee2 100644 --- a/lib/ai/model/ai_chat_metadata.dart +++ b/lib/ai/model/ai_chat_metadata.dart @@ -4,6 +4,7 @@ import 'ai_provider_config.dart'; const aiMetadataToolEventsKey = 'toolEvents'; const aiMetadataResponseKey = 'response'; +const aiMetadataAttachmentsKey = 'attachments'; const aiToolEventTypeCall = 'tool_call'; const aiToolEventTypeResult = 'tool_result'; @@ -16,6 +17,17 @@ String createAiMessageMetadata(AiProviderConfig provider) => jsonEncode({ aiMetadataToolEventsKey: const >[], }); +String? createAiUserMessageMetadata( + List> attachments, +) { + if (attachments.isEmpty) { + return null; + } + return jsonEncode({ + aiMetadataAttachmentsKey: attachments, + }); +} + Map decodeAiMessageMetadata(String? metadata) { if (metadata == null || metadata.trim().isEmpty) { return {}; @@ -34,6 +46,21 @@ Map decodeAiMessageMetadata(String? metadata) { return {}; } +List> aiMetadataAttachments(String? metadata) { + final attachments = decodeAiMessageMetadata( + metadata, + )[aiMetadataAttachmentsKey]; + if (attachments is! List) { + return const >[]; + } + return attachments + .whereType() + .map( + (attachment) => attachment.map((key, value) => MapEntry('$key', value)), + ) + .toList(growable: false); +} + String appendAiToolEventToMetadata( String? metadata, Map event, diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index de550e9456..bd8244237b 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -56,6 +56,7 @@ import '../home.dart'; import '../hook/pin_message.dart'; import '../route/responsive_navigator.dart'; import 'chat_bar.dart'; +import 'chat_side_route_names.dart'; import 'files_preview.dart'; import 'input_container.dart'; import 'selection_bottom_bar.dart'; @@ -72,7 +73,7 @@ class ChatSideCubit extends AbstractResponsiveNavigatorCubit { static const sharedApps = 'sharedApps'; static const groupsInCommon = 'groupsInCommon'; static const disappearMessages = 'disappearMessages'; - static const aiAssistantPage = 'aiAssistantPage'; + static const aiAssistantPage = chatSideAiAssistantPage; static const aiAssistantThreadsPage = 'aiAssistantThreadsPage'; @override @@ -267,6 +268,9 @@ class ChatPage extends HookConsumerWidget { providers: [ BlocProvider.value(value: blinkCubit), BlocProvider.value(value: chatSideCubit), + BlocProvider.value( + value: chatSideCubit, + ), BlocProvider.value(value: searchConversationKeywordCubit), BlocProvider( create: (context) => MessageBloc( diff --git a/lib/ui/home/chat/chat_side_route_names.dart b/lib/ui/home/chat/chat_side_route_names.dart new file mode 100644 index 0000000000..fb25fe9ee0 --- /dev/null +++ b/lib/ui/home/chat/chat_side_route_names.dart @@ -0,0 +1 @@ +const chatSideAiAssistantPage = 'aiAssistantPage'; diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 104aec8e96..b99bebb2c6 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -38,6 +38,7 @@ import '../../../utils/reg_exp_utils.dart'; import '../../../utils/system/clipboard.dart'; import '../../../widgets/action_button.dart'; import '../../../widgets/actions/actions.dart'; +import '../../../widgets/ai/ai_context_attachment_bar.dart'; import '../../../widgets/high_light_text.dart'; import '../../../widgets/hover_overlay.dart'; import '../../../widgets/mention_panel.dart'; @@ -48,12 +49,15 @@ import '../../../widgets/sticker_page/sticker_page.dart'; import '../../../widgets/toast.dart'; import '../../../widgets/user_selector/conversation_selector.dart'; import '../../provider/abstract_responsive_navigator.dart'; +import '../../provider/ai_context_attachment_provider.dart'; import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; import '../../provider/mention_cache_provider.dart'; import '../../provider/mention_provider.dart'; import '../../provider/quote_message_provider.dart'; import '../../provider/recall_message_reedit_provider.dart'; +import '../bloc/blink_cubit.dart'; +import '../bloc/message_bloc.dart'; import 'ai_draft_assist_panel.dart'; import 'chat_page.dart'; import 'files_preview.dart'; @@ -121,6 +125,12 @@ class _InputContainer extends HookConsumerWidget { selectedModel: aiModeState.model, ); final aiModeEnabled = aiModeState.enabled; + final attachedMessages = conversationId == null + ? const [] + : ref.watch(aiContextAttachmentProvider(conversationId)); + final attachedMessagesNotifier = conversationId == null + ? null + : ref.read(aiContextAttachmentProvider(conversationId).notifier); final activeAiThread = useMemoizedStream( () => conversationId == null ? Stream.value(null) @@ -329,6 +339,15 @@ class _InputContainer extends HookConsumerWidget { provider: aiProvider, ), const SizedBox(height: 8), + AiContextAttachmentBar( + messages: attachedMessages, + onTap: (message) => + _jumpToAttachedMessage(context, message), + onRemove: (messageId) => + attachedMessagesNotifier?.remove(messageId), + ), + if (attachedMessages.isNotEmpty) + const SizedBox(height: 8), ], if (!aiModeEnabled && !aiDraftAssistState.value.isIdle) ...[ @@ -405,6 +424,7 @@ class _InputContainer extends HookConsumerWidget { aiModeEnabled: aiModeEnabled, providerName: aiProvider?.name, modelName: aiProvider?.model, + aiThreadId: activeAiThread?.id, aiRequestInFlight: aiRequestInFlight, aiDraftAssistState: aiDraftAssistState.value, ), @@ -671,6 +691,11 @@ String _currentLanguageTag(BuildContext context) { return '${locale.languageCode}-$countryCode'; } +void _jumpToAttachedMessage(BuildContext context, MessageItem message) { + context.read().scrollTo(message.messageId); + context.read().blinkByMessageId(message.messageId); +} + void showMaxLengthReachedToast(BuildContext context) => showToastFailed(ToastError(context.l10n.contentTooLong)); @@ -733,6 +758,49 @@ Future _sendMessage( final inlineAiInput = text.startsWith('/ai ') ? text.substring(4).trim() : null; + final attachedMessages = context.providerContainer.read( + aiContextAttachmentProvider(conversationId), + ); + final attachedMessagesNotifier = context.providerContainer.read( + aiContextAttachmentProvider(conversationId).notifier, + ); + + final aiModeState = context.providerContainer.read( + aiInputModeProvider(conversationId), + ); + if (aiModeState.enabled) { + final provider = _resolveAiModeProvider( + selectedAiProvider: context.database.settingProperties.selectedAiProvider, + enabledAiProviders: context.database.settingProperties.aiProviders + .whereType() + .where((element) => element.enabled) + .toList(), + providerId: aiModeState.providerId, + selectedModel: aiModeState.model, + ); + if (provider == null || provider.model.trim().isEmpty) { + showToastFailed(ToastError('Please add an AI provider first')); + return; + } + try { + await AiChatController(context.database).send( + conversationId: conversationId, + threadId: aiThreadId, + input: text, + language: _currentLanguageTag(context), + provider: provider, + attachedMessages: attachedMessages, + onInputAccepted: () { + textEditingController.text = ''; + attachedMessagesNotifier.clear(); + }, + ); + } catch (error, _) { + showToastFailed(error); + } + return; + } + if (inlineAiInput != null && inlineAiInput.isNotEmpty) { final provider = context.database.settingProperties.selectedAiProvider; if (provider == null || provider.model.trim().isEmpty) { @@ -749,7 +817,11 @@ Future _sendMessage( input: inlineAiInput, language: _currentLanguageTag(context), provider: provider, - onInputAccepted: () => textEditingController.text = '', + attachedMessages: attachedMessages, + onInputAccepted: () { + textEditingController.text = ''; + attachedMessagesNotifier.clear(); + }, ); } catch (error, _) { showToastFailed(error); @@ -804,6 +876,7 @@ class _SendTextField extends HookConsumerWidget { required this.aiModeEnabled, required this.providerName, required this.modelName, + required this.aiThreadId, required this.aiRequestInFlight, required this.aiDraftAssistState, }); @@ -815,6 +888,7 @@ class _SendTextField extends HookConsumerWidget { final bool aiModeEnabled; final String? providerName; final String? modelName; + final String? aiThreadId; final bool aiRequestInFlight; final AiDraftAssistViewState aiDraftAssistState; @@ -915,6 +989,7 @@ class _SendTextField extends HookConsumerWidget { context, textEditingController, conversationId: ref.read(currentConversationIdProvider), + aiThreadId: aiThreadId, ), ), ), diff --git a/lib/ui/home/chat/selection_bottom_bar.dart b/lib/ui/home/chat/selection_bottom_bar.dart index dd711a2a02..399d53f21f 100644 --- a/lib/ui/home/chat/selection_bottom_bar.dart +++ b/lib/ui/home/chat/selection_bottom_bar.dart @@ -1,9 +1,12 @@ +import 'dart:async'; + import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import 'package:flutter_svg/svg.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:intl/intl.dart'; +import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/resources.dart'; import '../../../utils/extension/extension.dart'; import '../../../utils/logger.dart'; @@ -12,8 +15,12 @@ import '../../../widgets/dialog.dart'; import '../../../widgets/interactive_decorated_box.dart'; import '../../../widgets/toast.dart'; import '../../../widgets/user_selector/conversation_selector.dart'; +import '../../provider/ai_context_attachment_provider.dart'; import '../../provider/conversation_provider.dart'; import '../../provider/message_selection_provider.dart'; +import '../chat_slide_page/ai_assistant/constants.dart'; +import '../route/responsive_navigator.dart'; +import 'chat_side_route_names.dart'; class SelectionBottomBar extends HookConsumerWidget { const SelectionBottomBar({super.key}); @@ -27,6 +34,9 @@ class SelectionBottomBar extends HookConsumerWidget { final canCombineForward = ref.watch( messageSelectionProvider.select((value) => value.canCombineForward), ); + final canAttachToAi = context.database.settingProperties.aiProviders + .whereType() + .any((item) => item.enabled && item.model.trim().isNotEmpty); return SizedBox( height: 80, @@ -127,6 +137,31 @@ class SelectionBottomBar extends HookConsumerWidget { })(), ), ), + _Button( + label: aiAssistantAttachToAi, + icon: Icons.auto_awesome_rounded, + enable: canAttachToAi, + onTap: () async { + final conversationId = ref.read(currentConversationIdProvider); + if (conversationId == null) return; + final selection = ref.read(messageSelectionProvider); + final messages = await context.database.messageDao + .messageItemByMessageIds( + selection.selectedMessageIds.toList(), + ) + .get(); + if (messages.isEmpty) return; + ref + .read(aiContextAttachmentProvider(conversationId).notifier) + .attachMessages(messages); + selection.clearSelection(); + unawaited( + context.read().replace( + chatSideAiAssistantPage, + ), + ); + }, + ), _Button( label: context.l10n.delete, iconAssetName: Resources.assetsImagesContextMenuDeleteSvg, @@ -174,13 +209,15 @@ class SelectionBottomBar extends HookConsumerWidget { class _Button extends StatelessWidget { const _Button({ required this.label, - required this.iconAssetName, required this.onTap, + this.iconAssetName, + this.icon, this.enable = true, - }); + }) : assert(iconAssetName != null || icon != null); final String label; - final String iconAssetName; + final String? iconAssetName; + final IconData? icon; final VoidCallback onTap; final bool enable; @@ -192,15 +229,18 @@ class _Button extends StatelessWidget { child: Column( mainAxisSize: MainAxisSize.min, children: [ - SvgPicture.asset( - iconAssetName, - width: 24, - height: 24, - colorFilter: ColorFilter.mode( - context.theme.icon, - BlendMode.srcIn, - ), - ), + if (iconAssetName != null) + SvgPicture.asset( + iconAssetName!, + width: 24, + height: 24, + colorFilter: ColorFilter.mode( + context.theme.icon, + BlendMode.srcIn, + ), + ) + else + Icon(icon, size: 24, color: context.theme.icon), const SizedBox(height: 8), Text( label, diff --git a/lib/ui/home/chat_slide_page/ai_assistant/composer.dart b/lib/ui/home/chat_slide_page/ai_assistant/composer.dart index 9abacb5bd5..28f8ce311e 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant/composer.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant/composer.dart @@ -6,9 +6,11 @@ import 'package:flutter/services.dart'; import '../../../../ai/model/ai_provider_config.dart'; import '../../../../constants/constants.dart'; import '../../../../constants/resources.dart'; +import '../../../../db/mixin_database.dart'; import '../../../../utils/extension/extension.dart'; import '../../../../widgets/action_button.dart'; import '../../../../widgets/actions/actions.dart'; +import '../../../../widgets/ai/ai_context_attachment_bar.dart'; import '../../../../widgets/high_light_text.dart'; import '../../../../widgets/menu.dart'; import 'constants.dart'; @@ -18,10 +20,13 @@ class AiAssistantComposer extends StatelessWidget { required this.focusNode, required this.textEditingController, required this.enabled, + required this.attachedMessages, required this.enabledAiProviders, required this.requestInFlight, required this.onSend, required this.onStop, + required this.onTapAttachment, + required this.onRemoveAttachment, required this.onProviderSelected, required this.onModelSelected, this.provider, @@ -32,10 +37,13 @@ class AiAssistantComposer extends StatelessWidget { final TextEditingController textEditingController; final bool enabled; final AiProviderConfig? provider; + final List attachedMessages; final List enabledAiProviders; final bool requestInFlight; final VoidCallback onSend; final VoidCallback onStop; + final ValueChanged onTapAttachment; + final ValueChanged onRemoveAttachment; final ValueChanged onProviderSelected; final ValueChanged onModelSelected; @@ -53,6 +61,12 @@ class AiAssistantComposer extends StatelessWidget { mainAxisSize: MainAxisSize.min, children: [ if (provider != null) ...[ + AiContextAttachmentBar( + messages: attachedMessages, + onTap: onTapAttachment, + onRemove: onRemoveAttachment, + ), + if (attachedMessages.isNotEmpty) const SizedBox(height: 8), _AiAssistantModeBar( provider: provider!, enabledAiProviders: enabledAiProviders, diff --git a/lib/ui/home/chat_slide_page/ai_assistant/constants.dart b/lib/ui/home/chat_slide_page/ai_assistant/constants.dart index d0039e8fdb..8cb59cb5f3 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant/constants.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant/constants.dart @@ -8,5 +8,6 @@ const aiAssistantUntitledThread = 'New chat'; const aiAssistantDeleteThread = 'Delete Chat'; const aiAssistantDeleteThreadDescription = 'This removes the AI messages in this chat.'; +const aiAssistantAttachToAi = 'Attach to AI'; const aiAssistantStickToBottomDistance = 96.0; const aiAssistantMessagePageLimit = 80; diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart index 1f0b1eb26e..ddb7a858e6 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant_page.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -9,6 +9,7 @@ import '../../../constants/constants.dart'; import '../../../constants/icon_fonts.dart'; import '../../../constants/resources.dart'; import '../../../db/ai_database.dart'; +import '../../../db/mixin_database.dart'; import '../../../utils/extension/extension.dart'; import '../../../utils/hook.dart'; import '../../../widgets/action_button.dart'; @@ -18,8 +19,11 @@ import '../../../widgets/dialog.dart'; import '../../../widgets/empty.dart'; import '../../../widgets/menu.dart'; import '../../../widgets/toast.dart'; +import '../../provider/ai_context_attachment_provider.dart'; import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; +import '../bloc/blink_cubit.dart'; +import '../bloc/message_bloc.dart'; import '../chat/chat_page.dart'; import 'ai_assistant/composer.dart'; import 'ai_assistant/constants.dart'; @@ -41,6 +45,12 @@ class AiAssistantPage extends HookConsumerWidget { useListenable(context.database.settingProperties); final conversationId = conversationState.conversationId; + final attachedMessages = ref.watch( + aiContextAttachmentProvider(conversationId), + ); + final attachedMessagesNotifier = ref.read( + aiContextAttachmentProvider(conversationId).notifier, + ); final aiModeState = ref.watch(aiInputModeProvider(conversationId)); final aiModeNotifier = ref.read( aiInputModeProvider(conversationId).notifier, @@ -124,7 +134,11 @@ class AiAssistantPage extends HookConsumerWidget { input: text, language: currentLanguageTag(context), provider: aiProvider, - onInputAccepted: textEditingController.clear, + attachedMessages: attachedMessages, + onInputAccepted: () { + textEditingController.clear(); + attachedMessagesNotifier.clear(); + }, ); activeThreadNotifier.state = threadId; } catch (error, _) { @@ -169,12 +183,16 @@ class AiAssistantPage extends HookConsumerWidget { textEditingController: textEditingController, enabled: aiProvider != null, provider: aiProvider, + attachedMessages: attachedMessages, enabledAiProviders: enabledAiProviders, requestInFlight: requestInFlight, onSend: send, onStop: () => AiChatController( context.database, ).stop(conversationId, threadId: currentThread?.id), + onTapAttachment: (message) => + _jumpToAttachedMessage(context, message), + onRemoveAttachment: attachedMessagesNotifier.remove, onProviderSelected: (value) => aiModeNotifier.updateProvider( providerId: value.id, model: value.model, @@ -187,6 +205,16 @@ class AiAssistantPage extends HookConsumerWidget { } } +void _jumpToAttachedMessage(BuildContext context, MessageItem message) { + context.read().scrollTo(message.messageId); + context.read().blinkByMessageId(message.messageId); + + final chatSideCubit = context.read(); + if (chatSideCubit.state.routeMode) { + chatSideCubit.pop(); + } +} + class AiAssistantThreadsPage extends HookConsumerWidget { const AiAssistantThreadsPage(this.conversationState, {super.key}); diff --git a/lib/ui/provider/ai_context_attachment_provider.dart b/lib/ui/provider/ai_context_attachment_provider.dart new file mode 100644 index 0000000000..8ef080d2ff --- /dev/null +++ b/lib/ui/provider/ai_context_attachment_provider.dart @@ -0,0 +1,54 @@ +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../db/mixin_database.dart'; +import 'conversation_provider.dart'; + +class AiContextAttachmentNotifier extends StateNotifier> { + AiContextAttachmentNotifier(this.conversationId) : super(const []); + + final String conversationId; + + void attachMessages(Iterable messages) { + final next = { + for (final message in state) message.messageId: message, + }; + for (final message in messages) { + if (message.conversationId != conversationId) { + continue; + } + next[message.messageId] = message; + } + state = next.values.toList(growable: false) + ..sort((a, b) { + final result = a.createdAt.compareTo(b.createdAt); + if (result != 0) return result; + return a.messageId.compareTo(b.messageId); + }); + } + + void remove(String messageId) { + state = state + .where((message) => message.messageId != messageId) + .toList(growable: false); + } + + void clear() { + if (state.isEmpty) return; + state = const []; + } +} + +final aiContextAttachmentProvider = StateNotifierProvider.autoDispose + .family, String>( + (ref, conversationId) { + final keepAlive = ref.keepAlive(); + ref.listen(currentConversationIdProvider, (previous, next) { + if (next == conversationId) { + return; + } + keepAlive.close(); + ref.invalidateSelf(); + }); + return AiContextAttachmentNotifier(conversationId); + }, + ); diff --git a/lib/widgets/ai/ai_context_attachment_bar.dart b/lib/widgets/ai/ai_context_attachment_bar.dart new file mode 100644 index 0000000000..051b51ac4a --- /dev/null +++ b/lib/widgets/ai/ai_context_attachment_bar.dart @@ -0,0 +1,169 @@ +import 'package:flutter/material.dart'; +import 'package:mixin_bot_sdk_dart/mixin_bot_sdk_dart.dart'; + +import '../../ai/ai_message_context.dart'; +import '../../db/mixin_database.dart'; +import '../../utils/extension/extension.dart'; +import '../action_button.dart'; + +class AiContextAttachmentBar extends StatelessWidget { + const AiContextAttachmentBar({ + required this.messages, + required this.onRemove, + required this.onTap, + super.key, + }); + + final List messages; + final ValueChanged onRemove; + final ValueChanged onTap; + + @override + Widget build(BuildContext context) { + if (messages.isEmpty) { + return const SizedBox.shrink(); + } + + final showSenderName = messages.any( + (message) => message.conversionCategory == ConversationCategory.group, + ); + + return SizedBox( + height: showSenderName ? 46 : 30, + child: ListView.separated( + scrollDirection: Axis.horizontal, + padding: const EdgeInsets.symmetric(horizontal: 2), + itemBuilder: (context, index) => _AttachmentChip( + message: messages[index], + showSenderName: showSenderName, + onRemove: onRemove, + onTap: onTap, + ), + separatorBuilder: (context, index) => const SizedBox(width: 6), + itemCount: messages.length, + ), + ); + } +} + +class _AttachmentChip extends StatefulWidget { + const _AttachmentChip({ + required this.message, + required this.showSenderName, + required this.onRemove, + required this.onTap, + }); + + final MessageItem message; + final bool showSenderName; + final ValueChanged onRemove; + final ValueChanged onTap; + + @override + State<_AttachmentChip> createState() => _AttachmentChipState(); +} + +class _AttachmentChipState extends State<_AttachmentChip> { + bool _hovering = false; + + @override + Widget build(BuildContext context) { + final preview = aiMessageContextPreview(widget.message, maxLength: 72); + final senderName = widget.message.userFullName ?? widget.message.userId; + final backgroundColor = context.dynamicColor( + const Color.fromRGBO(245, 247, 250, 1), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + final borderColor = _hovering + ? context.theme.accent.withValues(alpha: 0.35) + : context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.06), + darkColor: const Color.fromRGBO(255, 255, 255, 0.08), + ); + + return MouseRegion( + cursor: SystemMouseCursors.click, + onEnter: (_) => setState(() => _hovering = true), + onExit: (_) => setState(() => _hovering = false), + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 200), + child: GestureDetector( + behavior: HitTestBehavior.opaque, + onTap: () => widget.onTap(widget.message), + child: DecoratedBox( + decoration: BoxDecoration( + color: backgroundColor, + border: Border.all(color: borderColor), + borderRadius: const BorderRadius.all(Radius.circular(8)), + ), + child: Padding( + padding: const EdgeInsets.only(left: 8, right: 4), + child: Row( + mainAxisSize: MainAxisSize.min, + children: [ + Flexible( + child: widget.showSenderName + ? Column( + mainAxisAlignment: MainAxisAlignment.center, + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisSize: MainAxisSize.min, + children: [ + Text( + senderName, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 11, + fontWeight: FontWeight.w500, + ), + ), + const SizedBox(height: 2), + _PreviewText(preview), + ], + ) + : _PreviewText(preview), + ), + const SizedBox(width: 4), + AnimatedOpacity( + duration: const Duration(milliseconds: 120), + opacity: _hovering ? 1 : 0, + child: ActionButton( + padding: EdgeInsets.zero, + size: 18, + interactive: _hovering, + onTap: () => widget.onRemove(widget.message.messageId), + child: Icon( + Icons.close_rounded, + color: context.theme.secondaryText, + size: 14, + ), + ), + ), + ], + ), + ), + ), + ), + ), + ); + } +} + +class _PreviewText extends StatelessWidget { + const _PreviewText(this.text); + + final String text; + + @override + Widget build(BuildContext context) => Text( + text, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.text, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ); +} diff --git a/lib/widgets/ai/ai_message_card.dart b/lib/widgets/ai/ai_message_card.dart index 42f799e0b3..ddaee97419 100644 --- a/lib/widgets/ai/ai_message_card.dart +++ b/lib/widgets/ai/ai_message_card.dart @@ -155,10 +155,95 @@ class _AiUserMessageBody extends StatelessWidget { final AiChatMessage message; @override - Widget build(BuildContext context) => _AiSelectableText( - text: _displayText(message), - style: _aiMessageTextStyle(context, message), - ); + Widget build(BuildContext context) { + final attachments = aiMetadataAttachments(message.metadata); + final body = _AiSelectableText( + text: _displayText(message), + style: _aiMessageTextStyle(context, message), + ); + if (attachments.isEmpty) { + return body; + } + + return Column( + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisSize: MainAxisSize.min, + children: [ + _AiAttachedContextSummary(attachments: attachments), + const SizedBox(height: 6), + body, + ], + ); + } +} + +class _AiAttachedContextSummary extends StatelessWidget { + const _AiAttachedContextSummary({required this.attachments}); + + final List> attachments; + + @override + Widget build(BuildContext context) { + final color = context.dynamicColor( + const Color.fromRGBO(0, 0, 0, 0.08), + darkColor: const Color.fromRGBO(255, 255, 255, 0.1), + ); + + return SelectionContainer.disabled( + child: DecoratedBox( + decoration: BoxDecoration( + color: color, + borderRadius: const BorderRadius.all(Radius.circular(6)), + ), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8, vertical: 6), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + mainAxisSize: MainAxisSize.min, + children: [ + Row( + mainAxisSize: MainAxisSize.min, + children: [ + Icon( + Icons.auto_awesome_rounded, + size: 13, + color: context.theme.secondaryText, + ), + const SizedBox(width: 5), + Flexible( + child: Text( + 'AI context · ${attachments.length}', + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + fontWeight: FontWeight.w500, + ), + ), + ), + ], + ), + const SizedBox(height: 4), + for (final attachment in attachments.take(3)) + Padding( + padding: const EdgeInsets.only(top: 2), + child: Text( + _attachmentSummaryText(attachment), + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + ), + ), + ), + ], + ), + ), + ), + ); + } } class _AiResponseMessageBody extends StatelessWidget { @@ -623,6 +708,18 @@ String _displayText(AiChatMessage message) { return message.errorText ?? 'No response'; } +String _attachmentSummaryText(Map attachment) { + final sender = attachment['senderName'] as String?; + final preview = attachment['preview'] as String?; + if (sender == null || sender.isEmpty) { + return preview ?? ''; + } + if (preview == null || preview.isEmpty) { + return sender; + } + return '$sender: $preview'; +} + String _pendingAssistantText(AiChatMessage message) { final activeToolName = _activeToolName(message.metadata); if (activeToolName != null) { diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index 4a80721c91..74483bcd2e 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -11,6 +11,7 @@ import 'package:mixin_markdown_widget/mixin_markdown_widget.dart'; import '../ui/provider/setting_provider.dart'; import '../utils/extension/extension.dart'; import '../utils/uri_utils.dart'; +import 'message/message_style.dart'; import 'mixin_image.dart'; const _kMarkdownControllerCacheLimit = 120; @@ -513,20 +514,26 @@ MarkdownThemeData _createMarkdownTheme( ); final textColor = context.theme.text; final accentColor = context.theme.accent; + final chatBodyFontSize = + MessageStyle.defaultStyle.primaryFontSize + chatFontSizeDelta; + final baseBodyFontSize = base.bodyStyle.fontSize ?? chatBodyFontSize; + final fontSizeScale = baseBodyFontSize == 0 + ? 1.0 + : chatBodyFontSize / baseBodyFontSize; TextStyle applyTextColor(TextStyle style) => style.copyWith(color: textColor); - TextStyle applyFontSizeDelta(TextStyle style) { + TextStyle scaleFontSize(TextStyle style) { final fontSize = style.fontSize; if (fontSize == null) return style; - return style.copyWith(fontSize: fontSize + chatFontSizeDelta); + return style.copyWith(fontSize: fontSize * fontSizeScale); } TextStyle applyTextStyle(TextStyle style) => - applyTextColor(applyFontSizeDelta(style)); + applyTextColor(scaleFontSize(style)); return base.copyWith( bodyStyle: applyTextStyle(base.bodyStyle), - quoteStyle: applyFontSizeDelta( + quoteStyle: scaleFontSize( base.quoteStyle.copyWith( color: base.quoteStyle.color ?? textColor.withValues(alpha: 0.82), ), @@ -535,16 +542,18 @@ MarkdownThemeData _createMarkdownTheme( color: accentColor, decorationColor: accentColor, fontSize: - (base.linkStyle.fontSize ?? base.bodyStyle.fontSize ?? 16) + - chatFontSizeDelta, + (base.linkStyle.fontSize ?? + base.bodyStyle.fontSize ?? + chatBodyFontSize) * + fontSizeScale, + decoration: .none, ), inlineCodeStyle: applyTextStyle(base.inlineCodeStyle), codeBlockStyle: applyTextStyle(base.codeBlockStyle), tableHeaderStyle: applyTextStyle(base.tableHeaderStyle), heading1Style: applyTextStyle( - applyFontSizeDelta( + scaleFontSize( base.heading1Style.copyWith( - fontSize: 32, height: 40 / 32, fontWeight: FontWeight.bold, ), @@ -558,5 +567,6 @@ MarkdownThemeData _createMarkdownTheme( quoteBorderColor: accentColor.withValues(alpha: 0.4), selectionColor: accentColor.withValues(alpha: 0.24), showHeading1Divider: false, + quoteBackgroundColor: Colors.transparent, ); } diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index cd85345341..568a73da62 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -29,6 +29,10 @@ import '../../db/mixin_database.dart' hide Message, Offset; import '../../enum/media_status.dart'; import '../../enum/message_category.dart'; import '../../ui/home/bloc/blink_cubit.dart'; +import '../../ui/home/chat/chat_side_route_names.dart'; +import '../../ui/home/chat_slide_page/ai_assistant/constants.dart'; +import '../../ui/home/route/responsive_navigator.dart'; +import '../../ui/provider/ai_context_attachment_provider.dart'; import '../../ui/provider/conversation_provider.dart'; import '../../ui/provider/is_bot_group_provider.dart'; import '../../ui/provider/message_selection_provider.dart'; @@ -160,6 +164,23 @@ void _quickReply(BuildContext context) { }); } +void _attachMessagesToAi( + BuildContext context, + WidgetRef ref, + List messages, +) { + if (messages.isEmpty) return; + final conversationId = messages.first.conversationId; + ref + .read(aiContextAttachmentProvider(conversationId).notifier) + .attachMessages(messages); + unawaited( + context.read().replace( + chatSideAiAssistantPage, + ), + ); +} + SelectedContent? _findSelectedContent(BuildContext context) { SelectableRegionState? findSelectableRegionState(BuildContext context) { if (context is! Element) { @@ -246,8 +267,8 @@ class MessageItemWidget extends HookConsumerWidget { !(sameUserNext && sameDayNext) && (!showAvatar || isCurrentUser); final datetime = isSameDay(prevDateTime ?? prev?.createdAt, message.createdAt) - ? null - : message.createdAt; + ? null + : message.createdAt; String? userName; String? userId; String? userAvatarUrl; @@ -670,6 +691,15 @@ class MessageItemWidget extends HookConsumerWidget { } final aiActions = [ + if (hasEnabledAiProvider && + !isTranscriptPage && + !isPinnedPage) + MenuAction( + image: MenuImage.icon(Icons.auto_awesome_rounded), + title: aiAssistantAttachToAi, + callback: () => + _attachMessagesToAi(context, ref, [message]), + ), if (aiText != null) MenuAction( image: MenuImage.icon(Icons.translate), diff --git a/pubspec.lock b/pubspec.lock index bcec75369f..854d67413c 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -1328,10 +1328,10 @@ packages: dependency: "direct main" description: name: mixin_markdown_widget - sha256: ea1fd34d1eeb837e6be54641458800739080e4ee5a9ecc100536c82f0b69242f + sha256: "58b366d61d55fe852a91a7b3fb102481a76a7a8f208e0bd948ef823501ba29d9" url: "https://pub.dev" source: hosted - version: "0.2.0" + version: "0.2.1" msix: dependency: "direct dev" description: diff --git a/pubspec.yaml b/pubspec.yaml index ebbeef908e..6621dfa4f2 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -99,7 +99,7 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - mixin_markdown_widget: ^0.2.0 + mixin_markdown_widget: ^0.2.1 mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 diff --git a/test/ai/ai_chat_metadata_test.dart b/test/ai/ai_chat_metadata_test.dart index da6739c685..7003d74e48 100644 --- a/test/ai/ai_chat_metadata_test.dart +++ b/test/ai/ai_chat_metadata_test.dart @@ -58,5 +58,25 @@ void main() { containsPair('totalTokens', 124), ); }); + + test('stores user message attachments', () { + final metadata = createAiUserMessageMetadata([ + const { + 'messageId': 'message-id', + 'senderName': 'Alice', + 'preview': 'Please review this', + }, + ]); + + expect(metadata, isNotNull); + expect( + aiMetadataAttachments(metadata), + [ + containsPair('messageId', 'message-id'), + ], + ); + expect(aiMetadataToolEvents(metadata), isEmpty); + expect(aiMetadataResponse(metadata), isEmpty); + }); }); } From 25cf53e9c50f19b58a03e4a36b55e70682f90c05 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:55:55 +0800 Subject: [PATCH 37/52] fix: set default httpLogLevel to null in createClient --- lib/utils/mixin_api_client.dart | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/utils/mixin_api_client.dart b/lib/utils/mixin_api_client.dart index 3b9d972096..ce149fc71b 100644 --- a/lib/utils/mixin_api_client.dart +++ b/lib/utils/mixin_api_client.dart @@ -38,7 +38,7 @@ Client createClient({ sendTimeout: tenSecond, followRedirects: false, ), - // httpLogLevel: HttpLogLevel.none, + httpLogLevel: null, jsonDecodeCallback: jsonDecode, interceptors: [ ...interceptors, From 6ca2009fa6a652a0d51d6309f990dbe098a9011d Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:02:07 +0800 Subject: [PATCH 38/52] refactor: simplify markdown controller cache and remove warmup logic --- lib/widgets/markdown.dart | 186 ++++++-------------------------------- 1 file changed, 26 insertions(+), 160 deletions(-) diff --git a/lib/widgets/markdown.dart b/lib/widgets/markdown.dart index 74483bcd2e..eb0c7928ec 100644 --- a/lib/widgets/markdown.dart +++ b/lib/widgets/markdown.dart @@ -1,9 +1,6 @@ -import 'dart:async'; -import 'dart:collection'; import 'dart:io'; import 'package:flutter/material.dart'; -import 'package:flutter/scheduler.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:mixin_markdown_widget/mixin_markdown_widget.dart'; @@ -15,7 +12,6 @@ import 'message/message_style.dart'; import 'mixin_image.dart'; const _kMarkdownControllerCacheLimit = 120; -const _kMarkdownWarmupPerFrame = 2; String buildMarkdownCacheKey({ required String namespace, @@ -26,19 +22,20 @@ final markdownControllerCache = MarkdownControllerCache(); class MarkdownControllerCache { final _entries = {}; - final _pending = >{}; - final _queuedKeys = {}; - final _warmupQueue = ListQueue<({String key, String data})>(); - bool _warmupScheduled = false; - - MarkdownController? acquire( + MarkdownController acquire( String key, String data, { bool streaming = false, }) { - final entry = _entries[key]; - if (entry == null) return null; + var entry = _entries[key]; + if (entry == null) { + entry = _MarkdownCacheEntry( + data: data, + controller: MarkdownController(data: data), + ); + _entries[key] = entry; + } if (entry.data != data) { _updateEntryData(entry, data, streaming: streaming); } else if (!streaming) { @@ -46,6 +43,7 @@ class MarkdownControllerCache { } _touch(key, entry); entry.retainCount += 1; + _evictIfNeeded(); return entry.controller; } @@ -57,81 +55,6 @@ class MarkdownControllerCache { } } - Future warmup(String key, String data) { - final entry = _entries[key]; - if (entry != null) { - if (entry.data == data) { - _touch(key, entry); - entry.controller.commitStream(); - return Future.value(); - } - _updateEntryData(entry, data, streaming: false); - _touch(key, entry); - return Future.value(); - } - - final pending = _pending[key]; - if (pending != null) return pending.future; - - final completer = Completer(); - _pending[key] = completer; - if (_queuedKeys.add(key)) { - _warmupQueue.add((key: key, data: data)); - _scheduleWarmup(); - } - return completer.future; - } - - void warmupAll(Iterable<({String key, String data})> entries) { - for (final entry in entries) { - unawaited(warmup(entry.key, entry.data)); - } - } - - void _scheduleWarmup() { - if (_warmupScheduled) return; - _warmupScheduled = true; - SchedulerBinding.instance.addPostFrameCallback((_) { - _warmupScheduled = false; - _drainWarmupQueue(); - }); - } - - void _drainWarmupQueue() { - var count = 0; - while (_warmupQueue.isNotEmpty && count < _kMarkdownWarmupPerFrame) { - final task = _warmupQueue.removeFirst(); - _queuedKeys.remove(task.key); - final completer = _pending.remove(task.key); - - try { - final existing = _entries[task.key]; - if (existing != null) { - if (existing.data != task.data) { - _updateEntryData(existing, task.data, streaming: false); - } else { - existing.controller.commitStream(); - } - _touch(task.key, existing); - } else { - _entries[task.key] = _MarkdownCacheEntry( - data: task.data, - controller: MarkdownController(data: task.data), - ); - _evictIfNeeded(); - } - completer?.complete(); - } catch (error, stackTrace) { - completer?.completeError(error, stackTrace); - } - count += 1; - } - - if (_warmupQueue.isNotEmpty) { - _scheduleWarmup(); - } - } - void _touch(String key, _MarkdownCacheEntry entry) { _entries.remove(key); _entries[key] = entry; @@ -300,57 +223,24 @@ class _MarkdownView extends HookWidget { return _buildMarkdownWidget(data: data); } - final controller = - useState<({String key, String data, MarkdownController controller})?>( - null, - ); - - useEffect(() { - var disposed = false; - MarkdownController? retained; - - bool bindCachedController() { - final cached = markdownControllerCache.acquire( - cacheKey!, - data, - streaming: streaming, - ); - if (cached == null || disposed) return false; - retained = cached; - controller.value = (key: cacheKey!, data: data, controller: cached); - return true; - } - - if (!bindCachedController()) { - unawaited( - markdownControllerCache.warmup(cacheKey!, data).then((_) { - if (disposed) return; - bindCachedController(); - }), - ); - } - - return () { - disposed = true; - final current = retained; - if (current != null) { - markdownControllerCache.release(cacheKey!, current); - } - }; - }, [cacheKey, data, streaming]); - - final cachedController = controller.value; - if (cachedController != null && - cachedController.key == cacheKey && - cachedController.data == data) { - return _buildMarkdownWidget(controller: cachedController.controller); - } + final key = cacheKey!; + final controller = useMemoized( + () => markdownControllerCache.acquire( + key, + data, + streaming: streaming, + ), + [key, data, streaming], + ); - return _MarkdownFallback( - data: data, - theme: theme, - padding: padding, + useEffect( + () => () { + markdownControllerCache.release(key, controller); + }, + [key, controller], ); + + return _buildMarkdownWidget(controller: controller); } Widget _buildMarkdownWidget({ @@ -370,30 +260,6 @@ class _MarkdownView extends HookWidget { ); } -class _MarkdownFallback extends StatelessWidget { - const _MarkdownFallback({ - required this.data, - required this.theme, - this.padding, - }); - - final String data; - final MarkdownThemeData theme; - final EdgeInsetsGeometry? padding; - - @override - Widget build(BuildContext context) { - final effectivePadding = padding ?? EdgeInsets.zero; - return Padding( - padding: effectivePadding, - child: Text( - data, - style: theme.bodyStyle, - ), - ); - } -} - Widget _buildMarkdownImage( BuildContext context, ImageBlock block, From 242bc8e80020690b9ac8a9cc4b59611e546edd97 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:24:14 +0800 Subject: [PATCH 39/52] feat: add AI-powered unread message summarization feature --- lib/ai/model/ai_prompt_template.dart | 34 +++++ lib/ui/home/chat/chat_page.dart | 5 +- .../ai_assistant/unread_summary.dart | 128 ++++++++++++++++++ lib/ui/home/conversation/menu_wrapper.dart | 15 ++ lib/widgets/message/message.dart | 80 +++++++++-- 5 files changed, 243 insertions(+), 19 deletions(-) create mode 100644 lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart index 89968b5e3e..17f81d5d43 100644 --- a/lib/ai/model/ai_prompt_template.dart +++ b/lib/ai/model/ai_prompt_template.dart @@ -43,6 +43,16 @@ enum AiPromptVariable { 'messages', 'Messages', 'Conversation message lines assembled by the app.', + ), + unreadStartAt( + 'unreadStartAt', + 'Unread Start At', + 'ISO 8601 timestamp of the first unread message.', + ), + firstUnreadMessageId( + 'firstUnreadMessageId', + 'First Unread Message ID', + 'Message ID of the first unread message.', ) ; @@ -57,6 +67,7 @@ enum AiPromptVariable { enum AiPromptTemplateKey { chatSystem, + summarizeUnreadMessages, assistSystem, messageTranslate, messageExplain, @@ -107,6 +118,29 @@ const aiPromptTemplateDefinitions = [ AiPromptVariable.language, ], ), + AiPromptTemplateDefinition( + key: AiPromptTemplateKey.summarizeUnreadMessages, + group: AiPromptTemplateGroup.conversation, + title: 'Summarize Unread Messages Prompt', + description: 'User prompt for summarizing new unread messages.', + defaultValue: + 'Summarize the new information in this conversation since the unread ' + 'section started.\n\n' + 'Unread section start:\n' + '- start_at: {{unreadStartAt}}\n' + '- first_unread_message_id: {{firstUnreadMessageId}}\n\n' + 'Use the conversation tools to inspect messages created at or after ' + 'start_at. Focus only on new information, decisions, questions, ' + 'requests, mentions of the user, links, files, media references, and ' + 'action items.', + variables: [ + AiPromptVariable.conversationId, + AiPromptVariable.currentIsoDateTime, + AiPromptVariable.language, + AiPromptVariable.unreadStartAt, + AiPromptVariable.firstUnreadMessageId, + ], + ), AiPromptTemplateDefinition( key: AiPromptTemplateKey.assistSystem, group: AiPromptTemplateGroup.conversation, diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index bd8244237b..24ca0fe4ba 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -192,10 +192,7 @@ class _ChatSidePageBuilder extends HookConsumerWidget { @override Widget build(BuildContext context, WidgetRef ref) { - final conversationId = useMemoized( - () => ref.read(lastConversationIdProvider), - [], - ); + final conversationId = ref.watch(lastConversationIdProvider); final filter = useCallback( (state) => state?.conversationId == conversationId, diff --git a/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart b/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart new file mode 100644 index 0000000000..b360e6214c --- /dev/null +++ b/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart @@ -0,0 +1,128 @@ +import 'package:flutter/widgets.dart'; + +import '../../../../ai/ai_chat_controller.dart'; +import '../../../../ai/model/ai_prompt_template.dart'; +import '../../../../db/dao/conversation_dao.dart'; +import '../../../../db/database.dart'; +import '../../../../db/mixin_database.dart'; +import '../../../../utils/extension/extension.dart'; +import '../../../../widgets/toast.dart'; +import '../../../provider/conversation_provider.dart'; +import '../../chat/chat_page.dart'; +import '../ai_assistant_page.dart'; +import 'helpers.dart'; + +bool hasAvailableAiModel(BuildContext context) => + context.database.settingProperties.selectedAiProvider?.model + .trim() + .isNotEmpty == + true; + +Future summarizeUnreadMessagesWithAi({ + required BuildContext context, + required String conversationId, + required String? lastReadMessageId, + ConversationItem? conversation, + bool selectConversation = false, +}) async { + final database = context.database; + final providerContainer = context.providerContainer; + final language = currentLanguageTag(context); + final provider = database.settingProperties.selectedAiProvider; + if (provider == null || provider.model.trim().isEmpty) { + showToastFailed(ToastError('Please add an AI provider first')); + return; + } + + try { + final firstUnreadMessage = await _firstUnreadMessage( + database, + conversationId: conversationId, + lastReadMessageId: lastReadMessageId, + ); + if (firstUnreadMessage == null) return; + if (!context.mounted) return; + + if (selectConversation) { + await ConversationStateNotifier.selectConversation( + context, + conversationId, + conversation: conversation, + initialChatSidePage: ChatSideCubit.aiAssistantPage, + ); + } else { + await context.read().replace( + ChatSideCubit.aiAssistantPage, + ); + } + + final thread = await database.aiChatMessageDao.createThread( + conversationId, + ); + providerContainer + .read(aiAssistantThreadIdProvider(conversationId).notifier) + .state = + thread.id; + + await AiChatController(database).send( + conversationId: conversationId, + threadId: thread.id, + input: _unreadSummaryPrompt( + database: database, + conversationId: conversationId, + firstUnreadMessage: firstUnreadMessage, + language: language, + ), + language: language, + provider: provider, + ); + } catch (error, _) { + showToastFailed(error); + } +} + +Future _firstUnreadMessage( + Database database, { + required String conversationId, + required String? lastReadMessageId, +}) async { + final messageDao = database.messageDao; + if (lastReadMessageId == null || lastReadMessageId.isEmpty) { + return messageDao + .messagesByConversationIdAndCreatedAtRange(conversationId, limit: 1) + .getSingleOrNull(); + } + + final orderInfo = await messageDao.messageOrderInfo(lastReadMessageId); + if (orderInfo == null) { + return messageDao + .messagesByConversationIdAndCreatedAtRange(conversationId, limit: 1) + .getSingleOrNull(); + } + return messageDao + .afterMessagesByConversationId(orderInfo, conversationId, 1) + .getSingleOrNull(); +} + +String _unreadSummaryPrompt({ + required Database database, + required String conversationId, + required MessageItem firstUnreadMessage, + required String language, +}) { + final startAt = firstUnreadMessage.createdAt.toIso8601String(); + return renderAiPromptTemplate( + database.settingProperties.aiPromptTemplate( + AiPromptTemplateKey.summarizeUnreadMessages, + ), + { + ...buildAiPromptTemplateVariables( + conversationId: conversationId, + language: language, + ), + AiPromptVariable.unreadStartAt.placeholder: startAt, + AiPromptVariable.firstUnreadMessageId.placeholder: + firstUnreadMessage.messageId, + }, + ); +} diff --git a/lib/ui/home/conversation/menu_wrapper.dart b/lib/ui/home/conversation/menu_wrapper.dart index c0e516aef8..21f22ddf34 100644 --- a/lib/ui/home/conversation/menu_wrapper.dart +++ b/lib/ui/home/conversation/menu_wrapper.dart @@ -13,6 +13,7 @@ import '../../../widgets/menu.dart'; import '../../../widgets/toast.dart'; import '../../provider/conversation_provider.dart'; import '../../provider/slide_category_provider.dart'; +import '../chat_slide_page/ai_assistant/unread_summary.dart'; class ConversationMenuWrapper extends HookConsumerWidget { const ConversationMenuWrapper({ @@ -40,6 +41,8 @@ class ConversationMenuWrapper extends HookConsumerWidget { final isGroupConversation = conversation?.isGroupConversation ?? searchConversation!.isGroupConversation; + final lastReadMessageId = conversation?.lastReadMessageId; + final hasUnreadMessages = (conversation?.unseenMessageCount ?? 0) > 0; return CustomContextMenuWidget( desktopMenuWidgetBuilder: CustomDesktopMenuWidgetBuilder(), @@ -58,6 +61,18 @@ class ConversationMenuWrapper extends HookConsumerWidget { return MenusWithSeparator( childrens: [ [ + if (hasUnreadMessages && hasAvailableAiModel(context)) + MenuAction( + image: MenuImage.icon(Icons.auto_awesome_rounded), + title: 'Summarize unread messages', + callback: () => summarizeUnreadMessagesWithAi( + context: context, + conversationId: conversationId, + lastReadMessageId: lastReadMessageId, + conversation: conversation, + selectConversation: true, + ), + ), if (pinTime != null) MenuAction( image: MenuImage.icon(IconFonts.unPin), diff --git a/lib/widgets/message/message.dart b/lib/widgets/message/message.dart index 568a73da62..3bb06c4ac8 100644 --- a/lib/widgets/message/message.dart +++ b/lib/widgets/message/message.dart @@ -31,6 +31,7 @@ import '../../enum/message_category.dart'; import '../../ui/home/bloc/blink_cubit.dart'; import '../../ui/home/chat/chat_side_route_names.dart'; import '../../ui/home/chat_slide_page/ai_assistant/constants.dart'; +import '../../ui/home/chat_slide_page/ai_assistant/unread_summary.dart'; import '../../ui/home/route/responsive_navigator.dart'; import '../../ui/provider/ai_context_attachment_provider.dart'; import '../../ui/provider/conversation_provider.dart'; @@ -850,7 +851,10 @@ class MessageItemWidget extends HookConsumerWidget { ), ), if (message.messageId == lastReadMessageId && next != null) - const _UnreadMessageBar(), + _UnreadMessageBar( + conversationId: message.conversationId, + lastReadMessageId: lastReadMessageId, + ), ], ); @@ -1132,23 +1136,69 @@ class _MessageBubbleMargin extends HookConsumerWidget { } } -class _UnreadMessageBar extends StatelessWidget { - const _UnreadMessageBar(); +class _UnreadMessageBar extends HookConsumerWidget { + const _UnreadMessageBar({ + required this.conversationId, + required this.lastReadMessageId, + }); + + final String conversationId; + final String? lastReadMessageId; @override - Widget build(BuildContext context) => Container( - color: context.theme.background, - padding: const EdgeInsets.symmetric(vertical: 4), - margin: const EdgeInsets.symmetric(vertical: 6), - alignment: Alignment.center, - child: Text( - context.l10n.unreadMessages, - style: TextStyle( - color: context.theme.secondaryText, - fontSize: context.messageStyle.secondaryFontSize, + Widget build(BuildContext context, WidgetRef ref) { + useListenable(context.database.settingProperties); + + final hasAiModel = hasAvailableAiModel(context); + return Container( + color: context.theme.background, + padding: const EdgeInsets.symmetric(vertical: 4), + margin: const EdgeInsets.symmetric(vertical: 6), + child: Row( + children: [ + const SizedBox(width: 44), + Expanded( + child: Text( + context.l10n.unreadMessages, + textAlign: TextAlign.center, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: context.messageStyle.secondaryFontSize, + ), + ), + ), + SizedBox( + width: 44, + child: hasAiModel + ? Align( + child: Tooltip( + message: 'Summarize unread messages', + child: InteractiveDecoratedBox( + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(14), + ), + onTap: () => summarizeUnreadMessagesWithAi( + context: context, + conversationId: conversationId, + lastReadMessageId: lastReadMessageId, + ), + child: Padding( + padding: const EdgeInsets.all(4), + child: Icon( + Icons.auto_awesome_rounded, + size: 16, + color: context.theme.accent, + ), + ), + ), + ), + ) + : null, + ), + ], ), - ), - ); + ); + } } class _MessageSelectionWrapper extends HookConsumerWidget { From f8e03d12d358830c2271a2cebceafd00e2902b7c Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Wed, 29 Apr 2026 18:33:07 +0800 Subject: [PATCH 40/52] feat(chat): make chat side page width dynamic based on page type --- lib/ui/home/chat/chat_page.dart | 23 ++++++++++++++++++----- lib/ui/home/home.dart | 2 ++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/lib/ui/home/chat/chat_page.dart b/lib/ui/home/chat/chat_page.dart index 24ca0fe4ba..2f3fc1d013 100644 --- a/lib/ui/home/chat/chat_page.dart +++ b/lib/ui/home/chat/chat_page.dart @@ -238,6 +238,7 @@ class ChatPage extends HookConsumerWidget { useBlocState( bloc: chatSideCubit, ); + final chatSidePageWidth = _chatSidePageWidth(navigatorState.pages); ref.listen(hasSelectedMessageProvider, (previous, hasSelectedMessage) { if (!hasSelectedMessage) return; @@ -288,7 +289,7 @@ class ChatPage extends HookConsumerWidget { builder: (context, boxConstraints) { final routeMode = boxConstraints.maxWidth < - (kResponsiveNavigationMinWidth + kChatSidePageWidth); + (kResponsiveNavigationMinWidth + chatSidePageWidth); chatSideCubit.updateRouteMode(routeMode); return _ChatMenuHandler( @@ -329,6 +330,17 @@ class ChatPage extends HookConsumerWidget { } } +double _chatSidePageWidth(List> pages) { + final hasAiAssistantPage = pages.any( + (page) => + page.name == ChatSideCubit.aiAssistantPage || + page.name == ChatSideCubit.aiAssistantThreadsPage, + ); + return hasAiAssistantPage + ? kAiAssistantChatSidePageWidth + : kChatSidePageWidth; +} + class _SideRouter extends StatelessWidget { const _SideRouter({ required this.chatSideCubit, @@ -393,11 +405,12 @@ class _AnimatedChatSlide extends HookConsumerWidget { } }, [pages, controller]); + final chatSidePageWidth = _chatSidePageWidth(_pages.value); + return AnimatedBuilder( animation: controller, builder: (context, child) => SizedBox( - width: - kChatSidePageWidth * Curves.easeInOut.transform(controller.value), + width: chatSidePageWidth * Curves.easeInOut.transform(controller.value), height: constraints.maxHeight, child: controller.value != 0 ? child : null, ), @@ -406,8 +419,8 @@ class _AnimatedChatSlide extends HookConsumerWidget { alignment: AlignmentDirectional.centerStart, maxHeight: constraints.maxHeight, minHeight: constraints.maxHeight, - maxWidth: kChatSidePageWidth, - minWidth: kChatSidePageWidth, + maxWidth: chatSidePageWidth, + minWidth: chatSidePageWidth, child: Navigator( pages: _pages.value, onDidRemovePage: onDidRemovePage, diff --git a/lib/ui/home/home.dart b/lib/ui/home/home.dart index 1a4903e84d..52a2e6c1ef 100644 --- a/lib/ui/home/home.dart +++ b/lib/ui/home/home.dart @@ -43,6 +43,8 @@ const kResponsiveNavigationMinWidth = 320.0; const kConversationListWidth = 300.0; // chat side page fixed width, chat info page etc. const kChatSidePageWidth = 300.0; +// AI assistant needs more room for the prompt composer and model controls. +const kAiAssistantChatSidePageWidth = 380.0; final _conversationPageKey = GlobalKey(); From 923592ce5851ca2225c6a4f912e9442a8e106df6 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 30 Apr 2026 09:09:32 +0800 Subject: [PATCH 41/52] feat(ai): add support for explicit thread targets and new thread creation --- devtools_options.yaml | 3 + lib/ai/ai_chat_controller.dart | 13 ++- lib/ai/ai_thread_target.dart | 18 +++ lib/db/dao/ai_chat_message_dao.dart | 28 ++--- lib/ui/home/chat/input_container.dart | 109 +++++++++++++++--- .../ai_assistant/unread_summary.dart | 12 +- .../chat_slide_page/ai_assistant_page.dart | 73 +++++++----- .../ai_assistant_thread_provider.dart | 27 +++++ test/ai/ai_chat_thread_test.dart | 29 +++++ 9 files changed, 246 insertions(+), 66 deletions(-) create mode 100644 devtools_options.yaml create mode 100644 lib/ai/ai_thread_target.dart create mode 100644 lib/ui/provider/ai_assistant_thread_provider.dart diff --git a/devtools_options.yaml b/devtools_options.yaml new file mode 100644 index 0000000000..fa0b357c4f --- /dev/null +++ b/devtools_options.yaml @@ -0,0 +1,3 @@ +description: This file stores settings for Dart & Flutter DevTools. +documentation: https://docs.flutter.dev/tools/devtools/extensions#configure-extension-enablement-states +extensions: diff --git a/lib/ai/ai_chat_controller.dart b/lib/ai/ai_chat_controller.dart index 79c6344404..bc66fd55a5 100644 --- a/lib/ai/ai_chat_controller.dart +++ b/lib/ai/ai_chat_controller.dart @@ -12,6 +12,7 @@ import '../db/mixin_database.dart'; import 'ai_chat_prompt_builder.dart'; import 'ai_message_context.dart'; import 'ai_provider_requester.dart'; +import 'ai_thread_target.dart'; import 'model/ai_chat_metadata.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_provider_config.dart'; @@ -100,14 +101,15 @@ class AiChatController { required String conversationId, required String input, required String language, - String? threadId, + required AiThreadTarget target, AiProviderConfig? provider, List attachedMessages = const [], + void Function(String threadId)? onThreadReady, void Function()? onInputAccepted, }) async { - final thread = await database.aiChatMessageDao.ensureThread( + final thread = await database.aiChatMessageDao.resolveThreadTarget( conversationId: conversationId, - threadId: threadId, + target: target, ); await database.aiChatMessageDao.resolveStalePendingAssistantMessages( updatedBefore: kAiRuntimeStartedAt, @@ -177,17 +179,18 @@ class AiChatController { ); onInputAccepted?.call(); + onThreadReady?.call(thread.id); final updater = _StreamingMessageUpdater( dao: database.aiChatMessageDao, messageId: assistantMessageId, ); final requestKeys = { - conversationId, thread.id, + assistantMessageId, }; - _activeAiRequests[conversationId] = cancelToken; _activeAiRequests[thread.id] = cancelToken; + _activeAiRequests[assistantMessageId] = cancelToken; try { final messages = await _promptBuilder.buildPromptMessages( conversationId, diff --git a/lib/ai/ai_thread_target.dart b/lib/ai/ai_thread_target.dart new file mode 100644 index 0000000000..dbcd98482e --- /dev/null +++ b/lib/ai/ai_thread_target.dart @@ -0,0 +1,18 @@ +sealed class AiThreadTarget { + const AiThreadTarget(); + + const factory AiThreadTarget.existing(String threadId) = + ExistingAiThreadTarget; + + const factory AiThreadTarget.createNew() = NewAiThreadTarget; +} + +class ExistingAiThreadTarget extends AiThreadTarget { + const ExistingAiThreadTarget(this.threadId); + + final String threadId; +} + +class NewAiThreadTarget extends AiThreadTarget { + const NewAiThreadTarget(); +} diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index 8e2d2afd9f..15409c6e98 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -1,6 +1,7 @@ import 'package:drift/drift.dart'; import 'package:uuid/uuid.dart'; +import '../../ai/ai_thread_target.dart'; import '../../ai/model/ai_chat_metadata.dart'; import '../ai_database.dart'; @@ -106,23 +107,22 @@ class AiChatMessageDao extends DatabaseAccessor }); } - Future ensureThread({ + Future resolveThreadTarget({ required String conversationId, - String? threadId, + required AiThreadTarget target, }) async { - if (threadId != null) { - final thread = await threadById(threadId); - if (thread == null || - thread.conversationId != conversationId || - thread.status != activeThreadStatus) { - throw StateError('AI thread not found'); - } - return thread; + switch (target) { + case ExistingAiThreadTarget(:final threadId): + final thread = await threadById(threadId); + if (thread == null || + thread.conversationId != conversationId || + thread.status != activeThreadStatus) { + throw StateError('AI thread not found'); + } + return thread; + case NewAiThreadTarget(): + return createThread(conversationId); } - - final existing = await latestThread(conversationId); - if (existing != null) return existing; - return createThread(conversationId); } Stream> watchThreadMessages(String threadId) => diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index b99bebb2c6..e6c6ecd049 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -20,6 +20,7 @@ import 'package:simple_animations/simple_animations.dart'; import 'package:super_context_menu/super_context_menu.dart'; import '../../../ai/ai_chat_controller.dart'; +import '../../../ai/ai_thread_target.dart'; import '../../../ai/model/ai_prompt_template.dart'; import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; @@ -49,6 +50,7 @@ import '../../../widgets/sticker_page/sticker_page.dart'; import '../../../widgets/toast.dart'; import '../../../widgets/user_selector/conversation_selector.dart'; import '../../provider/abstract_responsive_navigator.dart'; +import '../../provider/ai_assistant_thread_provider.dart'; import '../../provider/ai_context_attachment_provider.dart'; import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; @@ -131,22 +133,38 @@ class _InputContainer extends HookConsumerWidget { final attachedMessagesNotifier = conversationId == null ? null : ref.read(aiContextAttachmentProvider(conversationId).notifier); - final activeAiThread = useMemoizedStream( - () => conversationId == null - ? Stream.value(null) - : context.database.aiChatMessageDao.watchLatestThread( - conversationId, - ), - keys: [conversationId], - ).data; + final aiThreadSelection = conversationId == null + ? const AiAssistantThreadSelection.latest() + : ref.watch(aiAssistantThreadSelectionProvider(conversationId)); + final aiThreads = + useMemoizedStream( + () => conversationId == null + ? Stream.value(const []) + : context.database.aiChatMessageDao.watchThreads( + conversationId, + ), + keys: [conversationId], + initialData: const [], + ).data ?? + const []; + final createNewAiThread = + aiThreadSelection.isNewThread || aiThreads.isEmpty; + final selectedAiThread = createNewAiThread + ? null + : aiThreadSelection.isLatest + ? aiThreads.firstOrNull + : aiThreads.firstWhereOrNull( + (item) => item.id == aiThreadSelection.threadId, + ); + final currentAiThread = createNewAiThread ? null : selectedAiThread; final aiMessages = useMemoizedStream( - () => activeAiThread == null + () => currentAiThread == null ? Stream.value(const []) : context.database.aiChatMessageDao.watchThreadMessages( - activeAiThread.id, + currentAiThread.id, ), - keys: [activeAiThread?.id], + keys: [currentAiThread?.id], initialData: const [], ).data ?? const []; @@ -387,6 +405,7 @@ class _InputContainer extends HookConsumerWidget { context, textEditingController, conversationId: conversationId, + createNewAiThread: createNewAiThread, ), ); }, @@ -424,7 +443,8 @@ class _InputContainer extends HookConsumerWidget { aiModeEnabled: aiModeEnabled, providerName: aiProvider?.name, modelName: aiProvider?.model, - aiThreadId: activeAiThread?.id, + aiThreadId: currentAiThread?.id, + createNewAiThread: createNewAiThread, aiRequestInFlight: aiRequestInFlight, aiDraftAssistState: aiDraftAssistState.value, ), @@ -465,7 +485,8 @@ class _InputContainer extends HookConsumerWidget { textEditingValueStream: textEditingValueStream, aiModeEnabled: aiModeEnabled, aiRequestInFlight: aiRequestInFlight, - aiThreadId: activeAiThread?.id, + aiThreadId: currentAiThread?.id, + createNewAiThread: createNewAiThread, ), ], ), @@ -485,6 +506,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { const _AnimatedSendOrVoiceButton({ required this.conversationId, required this.aiThreadId, + required this.createNewAiThread, required this.textEditingValueStream, required this.textEditingController, required this.aiModeEnabled, @@ -493,6 +515,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { final String? conversationId; final String? aiThreadId; + final bool createNewAiThread; final Stream textEditingValueStream; final TextEditingController textEditingController; final bool aiModeEnabled; @@ -541,6 +564,7 @@ class _AnimatedSendOrVoiceButton extends HookConsumerWidget { textEditingController, conversationId: conversationId, aiThreadId: aiThreadId, + createNewAiThread: createNewAiThread, ), ), ), @@ -729,6 +753,7 @@ Future _sendMessage( TextEditingController textEditingController, { required String? conversationId, String? aiThreadId, + bool createNewAiThread = false, bool silent = false, }) async { final text = textEditingController.value.text.trim(); @@ -782,14 +807,33 @@ Future _sendMessage( showToastFailed(ToastError('Please add an AI provider first')); return; } + final target = _aiThreadTarget( + aiThreadId: aiThreadId, + createNewAiThread: createNewAiThread, + ); + if (target == null) { + showToastFailed(ToastError('AI thread unavailable')); + return; + } try { await AiChatController(context.database).send( conversationId: conversationId, - threadId: aiThreadId, + target: target, input: text, language: _currentLanguageTag(context), provider: provider, attachedMessages: attachedMessages, + onThreadReady: (threadId) { + context.providerContainer + .read( + aiAssistantThreadSelectionProvider( + conversationId, + ).notifier, + ) + .state = AiAssistantThreadSelection.existing( + threadId, + ); + }, onInputAccepted: () { textEditingController.text = ''; attachedMessagesNotifier.clear(); @@ -810,14 +854,33 @@ Future _sendMessage( unawaited( context.read().replace(ChatSideCubit.aiAssistantPage), ); + final target = _aiThreadTarget( + aiThreadId: aiThreadId, + createNewAiThread: createNewAiThread, + ); + if (target == null) { + showToastFailed(ToastError('AI thread unavailable')); + return; + } try { await AiChatController(context.database).send( conversationId: conversationId, - threadId: aiThreadId, + target: target, input: inlineAiInput, language: _currentLanguageTag(context), provider: provider, attachedMessages: attachedMessages, + onThreadReady: (threadId) { + context.providerContainer + .read( + aiAssistantThreadSelectionProvider( + conversationId, + ).notifier, + ) + .state = AiAssistantThreadSelection.existing( + threadId, + ); + }, onInputAccepted: () { textEditingController.text = ''; attachedMessagesNotifier.clear(); @@ -842,6 +905,19 @@ Future _sendMessage( context.providerContainer.read(quoteMessageProvider.notifier).state = null; } +AiThreadTarget? _aiThreadTarget({ + required String? aiThreadId, + required bool createNewAiThread, +}) { + if (createNewAiThread) { + return const AiThreadTarget.createNew(); + } + if (aiThreadId == null) { + return null; + } + return AiThreadTarget.existing(aiThreadId); +} + AiProviderConfig? _resolveAiModeProvider({ required AiProviderConfig? selectedAiProvider, required List enabledAiProviders, @@ -877,6 +953,7 @@ class _SendTextField extends HookConsumerWidget { required this.providerName, required this.modelName, required this.aiThreadId, + required this.createNewAiThread, required this.aiRequestInFlight, required this.aiDraftAssistState, }); @@ -889,6 +966,7 @@ class _SendTextField extends HookConsumerWidget { final String? providerName; final String? modelName; final String? aiThreadId; + final bool createNewAiThread; final bool aiRequestInFlight; final AiDraftAssistViewState aiDraftAssistState; @@ -990,6 +1068,7 @@ class _SendTextField extends HookConsumerWidget { textEditingController, conversationId: ref.read(currentConversationIdProvider), aiThreadId: aiThreadId, + createNewAiThread: createNewAiThread, ), ), ), diff --git a/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart b/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart index b360e6214c..78ac75812b 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant/unread_summary.dart @@ -1,15 +1,16 @@ import 'package:flutter/widgets.dart'; import '../../../../ai/ai_chat_controller.dart'; +import '../../../../ai/ai_thread_target.dart'; import '../../../../ai/model/ai_prompt_template.dart'; import '../../../../db/dao/conversation_dao.dart'; import '../../../../db/database.dart'; import '../../../../db/mixin_database.dart'; import '../../../../utils/extension/extension.dart'; import '../../../../widgets/toast.dart'; +import '../../../provider/ai_assistant_thread_provider.dart'; import '../../../provider/conversation_provider.dart'; import '../../chat/chat_page.dart'; -import '../ai_assistant_page.dart'; import 'helpers.dart'; bool hasAvailableAiModel(BuildContext context) => @@ -60,13 +61,14 @@ Future summarizeUnreadMessagesWithAi({ conversationId, ); providerContainer - .read(aiAssistantThreadIdProvider(conversationId).notifier) - .state = - thread.id; + .read(aiAssistantThreadSelectionProvider(conversationId).notifier) + .state = AiAssistantThreadSelection.existing( + thread.id, + ); await AiChatController(database).send( conversationId: conversationId, - threadId: thread.id, + target: AiThreadTarget.existing(thread.id), input: _unreadSummaryPrompt( database: database, conversationId: conversationId, diff --git a/lib/ui/home/chat_slide_page/ai_assistant_page.dart b/lib/ui/home/chat_slide_page/ai_assistant_page.dart index ddb7a858e6..69fe328050 100644 --- a/lib/ui/home/chat_slide_page/ai_assistant_page.dart +++ b/lib/ui/home/chat_slide_page/ai_assistant_page.dart @@ -4,6 +4,7 @@ import 'package:hooks_riverpod/hooks_riverpod.dart'; import 'package:super_context_menu/super_context_menu.dart'; import '../../../ai/ai_chat_controller.dart'; +import '../../../ai/ai_thread_target.dart'; import '../../../ai/model/ai_provider_config.dart'; import '../../../constants/constants.dart'; import '../../../constants/icon_fonts.dart'; @@ -19,6 +20,7 @@ import '../../../widgets/dialog.dart'; import '../../../widgets/empty.dart'; import '../../../widgets/menu.dart'; import '../../../widgets/toast.dart'; +import '../../provider/ai_assistant_thread_provider.dart'; import '../../provider/ai_context_attachment_provider.dart'; import '../../provider/ai_input_mode_provider.dart'; import '../../provider/conversation_provider.dart'; @@ -30,11 +32,6 @@ import 'ai_assistant/constants.dart'; import 'ai_assistant/helpers.dart'; import 'ai_assistant/message_list.dart'; -const _newAiAssistantThreadId = ''; - -final aiAssistantThreadIdProvider = StateProvider.autoDispose - .family((ref, conversationId) => null); - class AiAssistantPage extends HookConsumerWidget { const AiAssistantPage(this.conversationState, {super.key}); @@ -72,18 +69,20 @@ class AiAssistantPage extends HookConsumerWidget { initialData: const [], ).data ?? const []; - final activeThreadId = ref.watch( - aiAssistantThreadIdProvider(conversationId), + final threadSelection = ref.watch( + aiAssistantThreadSelectionProvider(conversationId), ); - final activeThreadNotifier = ref.read( - aiAssistantThreadIdProvider(conversationId).notifier, + final threadSelectionNotifier = ref.read( + aiAssistantThreadSelectionProvider(conversationId).notifier, ); - final isNewThreadPage = - activeThreadId == _newAiAssistantThreadId || threads.isEmpty; - final activeThread = threads.firstWhereOrNull( - (item) => item.id == activeThreadId, - ); - final fallbackThread = threads.firstOrNull; + final isNewThreadPage = threadSelection.isNewThread || threads.isEmpty; + final selectedThreadId = threadSelection.threadId; + final activeThread = selectedThreadId == null + ? null + : threads.firstWhereOrNull((item) => item.id == selectedThreadId); + final fallbackThread = threadSelection.isLatest + ? threads.firstOrNull + : null; final currentThread = isNewThreadPage ? null : activeThread ?? fallbackThread; @@ -126,21 +125,34 @@ class AiAssistantPage extends HookConsumerWidget { showToastFailed(ToastError(aiAssistantUnavailable)); return; } + final target = isNewThreadPage + ? const AiThreadTarget.createNew() + : currentThread == null + ? null + : AiThreadTarget.existing(currentThread.id); + if (target == null) { + showToastFailed(ToastError('AI thread unavailable')); + return; + } try { - final threadId = await AiChatController(context.database).send( + await AiChatController(context.database).send( conversationId: conversationId, - threadId: currentThread?.id, + target: target, input: text, language: currentLanguageTag(context), provider: aiProvider, attachedMessages: attachedMessages, + onThreadReady: (threadId) { + threadSelectionNotifier.state = AiAssistantThreadSelection.existing( + threadId, + ); + }, onInputAccepted: () { textEditingController.clear(); attachedMessagesNotifier.clear(); }, ); - activeThreadNotifier.state = threadId; } catch (error, _) { showToastFailed(error); } @@ -148,7 +160,8 @@ class AiAssistantPage extends HookConsumerWidget { void openNewThreadPage() { if (isNewThreadPage) return; - activeThreadNotifier.state = _newAiAssistantThreadId; + threadSelectionNotifier.state = + const AiAssistantThreadSelection.newThread(); } return Scaffold( @@ -230,13 +243,16 @@ class AiAssistantThreadsPage extends HookConsumerWidget { initialData: const [], ).data ?? const []; - final activeThreadId = ref.watch( - aiAssistantThreadIdProvider(conversationId), + final threadSelection = ref.watch( + aiAssistantThreadSelectionProvider(conversationId), ); - final activeThreadNotifier = ref.read( - aiAssistantThreadIdProvider(conversationId).notifier, + final threadSelectionNotifier = ref.read( + aiAssistantThreadSelectionProvider(conversationId).notifier, ); - final hasSelectedThread = threads.any((item) => item.id == activeThreadId); + final activeThreadId = threadSelection.threadId; + final hasSelectedThread = + activeThreadId != null && + threads.any((item) => item.id == activeThreadId); return Scaffold( backgroundColor: context.theme.primary, @@ -257,7 +273,8 @@ class AiAssistantThreadsPage extends HookConsumerWidget { final thread = threads[index]; final selected = activeThreadId == thread.id || - ((activeThreadId == null || !hasSelectedThread) && + (activeThreadId == null && + !hasSelectedThread && index == 0); return _AiAssistantThreadTile( thread: thread, @@ -274,11 +291,13 @@ class AiAssistantThreadsPage extends HookConsumerWidget { thread.id, ); if (activeThreadId == thread.id) { - activeThreadNotifier.state = null; + threadSelectionNotifier.state = + const AiAssistantThreadSelection.latest(); } }, onTap: () { - activeThreadNotifier.state = thread.id; + threadSelectionNotifier.state = + AiAssistantThreadSelection.existing(thread.id); context.read().pop(); }, ); diff --git a/lib/ui/provider/ai_assistant_thread_provider.dart b/lib/ui/provider/ai_assistant_thread_provider.dart new file mode 100644 index 0000000000..2fa7edf4dd --- /dev/null +++ b/lib/ui/provider/ai_assistant_thread_provider.dart @@ -0,0 +1,27 @@ +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +enum AiAssistantThreadSelectionType { latest, newThread, existing } + +class AiAssistantThreadSelection { + const AiAssistantThreadSelection._(this.type, this.threadId); + + const AiAssistantThreadSelection.latest() + : this._(AiAssistantThreadSelectionType.latest, null); + + const AiAssistantThreadSelection.newThread() + : this._(AiAssistantThreadSelectionType.newThread, null); + + const AiAssistantThreadSelection.existing(String threadId) + : this._(AiAssistantThreadSelectionType.existing, threadId); + + final AiAssistantThreadSelectionType type; + final String? threadId; + + bool get isLatest => type == AiAssistantThreadSelectionType.latest; + bool get isNewThread => type == AiAssistantThreadSelectionType.newThread; +} + +final aiAssistantThreadSelectionProvider = StateProvider.autoDispose + .family( + (ref, conversationId) => const AiAssistantThreadSelection.latest(), + ); diff --git a/test/ai/ai_chat_thread_test.dart b/test/ai/ai_chat_thread_test.dart index d70d197d79..72d7ebfe4e 100644 --- a/test/ai/ai_chat_thread_test.dart +++ b/test/ai/ai_chat_thread_test.dart @@ -1,6 +1,7 @@ import 'package:drift/drift.dart'; import 'package:drift/native.dart'; import 'package:flutter_app/ai/ai_chat_prompt_builder.dart'; +import 'package:flutter_app/ai/ai_thread_target.dart'; import 'package:flutter_app/ai/model/ai_prompt_message.dart'; import 'package:flutter_app/db/ai_database.dart'; import 'package:flutter_app/db/database.dart'; @@ -141,6 +142,34 @@ void main() { ); }); + test('resolves explicit thread targets without latest fallback', () async { + const conversationId = 'conversation-id'; + final latestThread = await database.aiChatMessageDao.createThread( + conversationId, + ); + + final existingThread = await database.aiChatMessageDao + .resolveThreadTarget( + conversationId: conversationId, + target: AiThreadTarget.existing(latestThread.id), + ); + final newThread = await database.aiChatMessageDao.resolveThreadTarget( + conversationId: conversationId, + target: const AiThreadTarget.createNew(), + ); + + expect(existingThread.id, latestThread.id); + expect(newThread.id, isNot(latestThread.id)); + expect(newThread.conversationId, conversationId); + expect( + () => database.aiChatMessageDao.resolveThreadTarget( + conversationId: 'other-conversation-id', + target: AiThreadTarget.existing(latestThread.id), + ), + throwsStateError, + ); + }); + test('prompt history excludes the current user message', () async { const conversationId = 'conversation-id'; final thread = await database.aiChatMessageDao.createThread( From b2cc07322c06eaef41029937b8e82c002bd0d31d Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 30 Apr 2026 10:13:29 +0800 Subject: [PATCH 42/52] test: add unit tests for AI conversation context and message search functionality --- lib/ai/ai_chat_prompt_builder.dart | 144 +++++++++--- lib/ai/ai_message_context.dart | 100 +++++++- lib/ai/model/ai_prompt_template.dart | 54 ++++- .../tools/ai_conversation_tool_service.dart | 214 ++++++++++++++++-- lib/db/dao/message_dao.dart | 15 ++ lib/ui/home/chat/input_container.dart | 4 + lib/ui/setting/ai_settings_page.dart | 208 +++++++++++++++++ lib/utils/property/setting_property.dart | 44 +++- lib/widgets/message/message_ai_assist.dart | 7 +- test/ai/ai_conversation_context_test.dart | 213 +++++++++++++++++ test/db/property_storage_test.dart | 46 ++++ 11 files changed, 976 insertions(+), 73 deletions(-) create mode 100644 test/ai/ai_conversation_context_test.dart diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index a0cc04b5f8..cda84b563f 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -12,6 +12,10 @@ class AiChatPromptBuilder { static const _aiStatusPending = 'pending'; static const _aiContextMessageLimit = 30; static const _aiHistoryLimit = 12; + static const _attachedContextBeforeLimit = 2; + static const _attachedContextAfterLimit = 2; + static const _attachedQuotedByLimit = 3; + static const _attachedContextMaxTextLength = 1000; final Database database; @@ -60,7 +64,7 @@ class AiChatPromptBuilder { language: language, now: now, ); - _appendAttachedMessages( + await _appendAttachedMessages( promptMessages, attachedMessages: attachedMessages, language: language, @@ -210,10 +214,10 @@ class AiChatPromptBuilder { if (recentMessages.isNotEmpty) { final lines = recentMessages.reversed .map( - (message) => _conversationContextLine( - createdAt: message.createdAt, - sender: message.userFullName ?? message.userId, - content: _messagePlainText(message), + (message) => aiMessageContextLine( + message, + relation: 'recent', + maxTextLength: _attachedContextMaxTextLength, ), ) .join('\n'); @@ -234,17 +238,20 @@ class AiChatPromptBuilder { } } - void _appendAttachedMessages( + Future _appendAttachedMessages( List promptMessages, { required List attachedMessages, required String language, required DateTime now, - }) { + }) async { if (attachedMessages.isEmpty) { return; } - final lines = attachedMessages.map(aiMessageContextLine).join('\n'); + final blocks = []; + for (final message in attachedMessages) { + blocks.add(await _attachedMessageContextBlock(message)); + } promptMessages.addAll( _promptMessages( role: AiPromptRole.system, @@ -254,35 +261,114 @@ class AiChatPromptBuilder { '"this message", "these messages", or asks for a specific ' 'message to be handled. Answer in $language unless the user ' 'explicitly asks for another language. Current time: ' - '${now.toIso8601String()}.\n$lines', + '${now.toIso8601String()}.\n\n${blocks.join('\n\n')}', ), ); } - String _conversationContextLine({ - required DateTime createdAt, - required String sender, - required String content, - }) => '[${createdAt.toIso8601String()}] $sender: $content'; + Future _attachedMessageContextBlock(MessageItem message) async { + final contextMessages = await _messageContextWindow( + message, + beforeLimit: _attachedContextBeforeLimit, + afterLimit: _attachedContextAfterLimit, + ); + final lines = [ + 'Attached context block for message_id=${message.messageId}:', + for (final contextMessage in contextMessages) + aiMessageContextLine( + contextMessage, + relation: contextMessage.messageId == message.messageId + ? 'attached' + : 'nearby', + maxTextLength: _attachedContextMaxTextLength, + ), + ]; - String _messagePlainText(MessageItem message) => _messagePlainTextFromFields( - content: message.content, - mediaName: message.mediaName, - type: message.type, - ); + final missingQuoteLine = await _missingQuoteContextLine(message); + if (missingQuoteLine != null) { + lines.add(' $missingQuoteLine'); + } - String _messagePlainTextFromFields({ - required String? content, - required String? mediaName, - required String type, - }) { - if (content?.trim().isNotEmpty == true) { - return content!.trim(); + final quotedByMessages = await database.messageDao + .messagesByQuoteId( + message.conversationId, + message.messageId, + _attachedQuotedByLimit, + ) + .get(); + if (quotedByMessages.isNotEmpty) { + lines.add('Messages quoting attached message:'); + for (final quotedByMessage in quotedByMessages) { + lines.add( + aiMessageContextLine( + quotedByMessage, + relation: 'quotes_attached', + maxTextLength: _attachedContextMaxTextLength, + ), + ); + } } - if (mediaName?.isNotEmpty == true) { - return '[$type] $mediaName'; + + return lines.join('\n'); + } + + Future> _messageContextWindow( + MessageItem message, { + required int beforeLimit, + required int afterLimit, + }) async { + final orderInfo = await database.messageDao.messageOrderInfo( + message.messageId, + ); + if (orderInfo == null) { + return [message]; + } + + final beforeMessages = beforeLimit <= 0 + ? const [] + : await database.messageDao + .beforeMessagesByConversationId( + orderInfo, + message.conversationId, + beforeLimit, + ) + .get(); + final afterMessages = afterLimit <= 0 + ? const [] + : await database.messageDao + .afterMessagesByConversationId( + orderInfo, + message.conversationId, + afterLimit, + ) + .get(); + final byMessageId = {}; + for (final item in [ + ...beforeMessages.reversed, + message, + ...afterMessages, + ]) { + byMessageId[item.messageId] = item; + } + return byMessageId.values.toList(growable: false); + } + + Future _missingQuoteContextLine(MessageItem message) async { + if (aiMessageQuotedItem(message) != null) { + return null; + } + final quoteId = message.quoteId?.trim(); + if (quoteId == null || quoteId.isEmpty) { + return null; + } + final quote = await database.messageDao.findMessageItemById( + message.conversationId, + quoteId, + ); + if (quote == null) { + return 'quoted_message: message_id=$quoteId (not available)'; } - return '[$type]'; + return aiQuoteMessageContextLine(quote); } List _promptMessages({ diff --git a/lib/ai/ai_message_context.dart b/lib/ai/ai_message_context.dart index 6474a6b602..a4f4ea40d5 100644 --- a/lib/ai/ai_message_context.dart +++ b/lib/ai/ai_message_context.dart @@ -1,3 +1,7 @@ +import 'dart:convert'; + +import '../db/dao/message_dao.dart'; +import '../db/extension/message.dart'; import '../db/extension/message_category.dart'; import '../db/mixin_database.dart'; import '../utils/message_optimize.dart'; @@ -28,10 +32,88 @@ String aiMessageContextText(MessageItem message) { '[${message.type}]'; } -String aiMessageContextLine(MessageItem message) => - '[${message.createdAt.toIso8601String()}] ' - '${message.userFullName ?? message.userId} ' - '(message_id=${message.messageId}): ${aiMessageContextText(message)}'; +String aiMessageContextLine( + MessageItem message, { + String? relation, + int? maxTextLength, +}) { + final relationText = relation == null ? '' : ', relation=$relation'; + final text = _truncateAiContextText( + aiMessageContextText(message), + maxTextLength, + ); + final line = + '[${message.createdAt.toIso8601String()}] ' + '${message.userFullName ?? message.userId} ' + '(message_id=${message.messageId}$relationText): $text'; + final quote = aiMessageQuotedItem(message); + if (quote == null) { + return line; + } + return '$line\n ${aiQuoteMessageContextLine(quote)}'; +} + +QuoteMessageItem? aiMessageQuotedItem(MessageItem message) { + final raw = message.quoteContent?.trim(); + if (raw == null || raw.isEmpty) { + return null; + } + try { + final decoded = jsonDecode(raw); + if (decoded is Map) { + return mapToQuoteMessage(decoded); + } + if (decoded is Map) { + return mapToQuoteMessage(Map.from(decoded)); + } + } catch (_) { + return null; + } + return null; +} + +String aiQuoteMessageContextLine( + QuoteMessageItem message, { + String prefix = 'quoted_message', + int? maxTextLength = 1000, +}) { + final text = _truncateAiContextText( + aiQuoteMessageContextText(message), + maxTextLength, + ); + return '$prefix: [${message.createdAt.toIso8601String()}] ' + '${message.userFullName ?? message.userId} ' + '(message_id=${message.messageId}): $text'; +} + +String aiQuoteMessageContextText(QuoteMessageItem message) { + final content = message.content?.trim(); + if ((message.type.isText || message.type.isPost) && + content != null && + content.isNotEmpty) { + return content; + } + if (content != null && content.isNotEmpty) { + return content; + } + + final mediaName = message.mediaName?.trim(); + if (mediaName != null && mediaName.isNotEmpty) { + return '[${message.type}] $mediaName'; + } + + final assetName = message.assetName?.trim(); + if (assetName != null && assetName.isNotEmpty) { + return '[${message.type}] $assetName'; + } + + return messagePreviewOptimize( + message.status, + message.type, + message.content, + ) ?? + '[${message.type}]'; +} String aiMessageContextPreview(MessageItem message, {int maxLength = 96}) { final text = aiMessageContextText(message).replaceAll(RegExp(r'\s+'), ' '); @@ -50,3 +132,13 @@ Map aiMessageContextMetadata(MessageItem message) => { 'createdAt': message.createdAt.toUtc().toIso8601String(), 'preview': aiMessageContextPreview(message, maxLength: 180), }; + +String _truncateAiContextText(String text, int? maxLength) { + if (maxLength == null || text.length <= maxLength) { + return text; + } + if (maxLength <= 3) { + return text.substring(0, maxLength); + } + return '${text.substring(0, maxLength - 3)}...'; +} diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart index 17f81d5d43..3648b266a1 100644 --- a/lib/ai/model/ai_prompt_template.dart +++ b/lib/ai/model/ai_prompt_template.dart @@ -105,12 +105,24 @@ const aiPromptTemplateDefinitions = [ description: 'Primary system prompt for AI chat mode.', defaultValue: 'You are a local AI assistant inside a chat application. ' - 'Only use the provided current conversation context. ' 'The current time is {{currentIsoDateTime}}. ' 'Unless the user explicitly asks to preserve the source language, ' 'quote verbatim, translate, or use another language, respond in ' - '{{language}}. Help summarize, answer questions about the ' - 'conversation, and draft practical replies. Be concise.', + '{{language}}. Only use the provided current conversation context ' + 'and read-only conversation tools. Your strongest jobs are to ' + 'retrieve relevant past messages, summarize unread or date-scoped ' + 'activity, extract decisions, open questions, action items, links, ' + 'files, and responsibilities, explain specific messages using the ' + 'surrounding conversation, and draft practical replies. For requests ' + 'about earlier messages, previous discussions, links, files, dates, ' + 'people, decisions, or anything not clearly answered by the recent ' + 'messages, use the conversation tools before answering. When you ' + 'retrieve facts from conversation history, include useful evidence ' + 'such as sender, timestamp, and message_id when it helps the user ' + 'verify the result. Treat quoted_message, quoted_by_messages, and ' + 'nearby context as strong signals that messages belong to the same ' + 'topic. If the answer is not found after a reasonable search, say ' + 'that clearly. Be concise.', variables: [ AiPromptVariable.conversationId, AiPromptVariable.currentIsoDateTime, @@ -129,10 +141,13 @@ const aiPromptTemplateDefinitions = [ 'Unread section start:\n' '- start_at: {{unreadStartAt}}\n' '- first_unread_message_id: {{firstUnreadMessageId}}\n\n' - 'Use the conversation tools to inspect messages created at or after ' - 'start_at. Focus only on new information, decisions, questions, ' - 'requests, mentions of the user, links, files, media references, and ' - 'action items.', + 'Use the conversation tools instead of relying only on recent context. ' + 'First inspect the message count and time range from start_at to the ' + 'current time, then read the unread messages in chunks as needed. ' + 'Focus only on new information, decisions, questions, requests, ' + 'mentions of the user, links, files, media references, and action ' + 'items. Include sender and timestamp when a detail needs to be ' + 'traceable. If there are no unread messages in that range, say so.', variables: [ AiPromptVariable.conversationId, AiPromptVariable.currentIsoDateTime, @@ -314,17 +329,32 @@ const assistUserMessagePromptTemplate = const conversationToolInstructionPromptTemplate = 'Read-only conversation tools are available for the current ' - 'conversation. Use them when you need exhaustive coverage, ' - 'date-scoped summaries, statistics, older messages, or more ' - 'context than the provided messages. Tool results are returned in ' - 'TOON format, a compact tabular notation for structured data. ' + 'conversation. The provided message context is only recent and may be ' + 'incomplete. Use tools before answering when the user asks to find, ' + 'search, recall, verify, compare, or summarize messages beyond the ' + 'visible recent context. Prefer search_conversation_messages for topics, ' + 'names, links, files, quoted phrases, or previous discussions. Paginate ' + 'with anchor_id when more matches are likely. Prefer ' + 'get_conversation_stats, list_conversation_chunks, and ' + 'read_conversation_chunk for unread summaries, date-scoped summaries, ' + 'statistics, or exhaustive coverage. When a search hit needs surrounding ' + 'context, use the returned context_messages first, then read the relevant ' + 'date range around the hit if more context is still needed. Search results ' + 'may include quoted_message and quoted_by_messages; treat those as tighter ' + 'topic links than nearby messages. Tool results are returned in TOON ' + 'format, a compact tabular notation for structured data. ' + 'Ground answers in retrieved messages and include sender, timestamp, or ' + 'message_id when that evidence helps. If retrieval does not find enough ' + 'evidence, say so instead of guessing. ' 'When answering the user, default to {{language}} unless the user ' 'explicitly requires another language or preserving the source ' 'language. Do not call tools when the provided context is already ' 'sufficient.'; const recentConversationContextPromptTemplate = - 'Current conversation recent messages:\n{{messages}}'; + 'Current conversation recent messages, not a complete history. Use ' + 'conversation tools for older messages, retrieval, date-scoped questions, ' + 'or anything not clearly answered here:\n{{messages}}'; Map buildAiPromptTemplateVariables({ String? conversationId, diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index e92aa4785b..b3389b6c05 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -9,12 +9,16 @@ import 'package:toon_format/toon_format.dart'; import '../../db/dao/message_dao.dart'; import '../../db/database.dart'; import '../../db/mixin_database.dart'; +import '../ai_message_context.dart'; import '../model/ai_chat_metadata.dart'; const _kDefaultConversationChunkSize = 100; const _kMaxConversationChunkSize = 200; const _kDefaultConversationSearchLimit = 8; const _kMaxConversationSearchLimit = 20; +const _kSearchContextBeforeLimit = 2; +const _kSearchContextAfterLimit = 2; +const _kSearchQuotedByLimit = 3; const _kAiToolLogPreviewLength = 480; const _kMaxConversationMessageTextLength = 1000; const _kSearchMessageSnippetRadius = 240; @@ -29,6 +33,9 @@ class AiConversationToolMessage { required this.senderName, required this.type, required this.text, + this.quotedMessage, + this.contextMessages = const [], + this.quotedByMessages = const [], }); final String messageId; @@ -36,6 +43,9 @@ class AiConversationToolMessage { final String senderName; final String type; final String text; + final Map? quotedMessage; + final List> contextMessages; + final List> quotedByMessages; Map toJson() => { 'message_id': messageId, @@ -43,6 +53,9 @@ class AiConversationToolMessage { 'sender_name': senderName, 'type': type, 'text': text, + if (quotedMessage != null) 'quoted_message': quotedMessage, + if (contextMessages.isNotEmpty) 'context_messages': contextMessages, + if (quotedByMessages.isNotEmpty) 'quoted_by_messages': quotedByMessages, }; } @@ -287,10 +300,15 @@ class DatabaseAiConversationToolService implements AiConversationToolService { ? safeOffset + messages.length : null; + final toolMessages = []; + for (final message in messages) { + toolMessages.add(await _messageItemToToolMessage(message)); + } + return AiConversationToolChunkPage( offset: safeOffset, totalMessages: totalMessages, - messages: messages.map(_messageItemToToolMessage).toList(growable: false), + messages: toolMessages, nextOffset: nextOffset, ); } @@ -308,27 +326,75 @@ class DatabaseAiConversationToolService implements AiConversationToolService { conversationIds: [conversationId], anchorMessageId: anchorMessageId, ); + if (messages.isEmpty) { + return const AiConversationToolSearchResult( + messages: [], + nextAnchorId: null, + ); + } + final fullMessages = await database.messageDao + .messageItemByMessageIds( + messages.map((message) => message.messageId).toList(), + ) + .get(); + final fullMessageById = { + for (final message in fullMessages) message.messageId: message, + }; + final toolMessages = []; + for (final message in messages) { + final fullMessage = fullMessageById[message.messageId]; + toolMessages.add( + fullMessage == null + ? _searchMessageToToolMessage(message, query: query) + : await _messageItemToToolMessage( + fullMessage, + query: query, + maxLength: _kSearchMessageSnippetRadius * 2, + includeContext: true, + resolveMissingQuote: true, + ), + ); + } + return AiConversationToolSearchResult( - messages: messages - .map((message) => _searchMessageToToolMessage(message, query: query)) - .toList(growable: false), + messages: toolMessages, nextAnchorId: messages.length < limit ? null : messages.last.messageId, ); } - AiConversationToolMessage _messageItemToToolMessage(MessageItem message) => - AiConversationToolMessage( - messageId: message.messageId, - createdAt: message.createdAt, - senderName: message.userFullName ?? message.userId, + Future _messageItemToToolMessage( + MessageItem message, { + String? query, + int? maxLength, + bool includeContext = false, + bool resolveMissingQuote = false, + }) async { + final contextMessages = includeContext + ? await _contextMessageMapsAround(message) + : const >[]; + final quotedByMessages = includeContext + ? await _quotedByMessageMaps(message) + : const >[]; + return AiConversationToolMessage( + messageId: message.messageId, + createdAt: message.createdAt, + senderName: message.userFullName ?? message.userId, + type: message.type, + text: _messageText( + content: message.content, + mediaName: message.mediaName, type: message.type, - text: _messageText( - content: message.content, - mediaName: message.mediaName, - type: message.type, - maxLength: _kMaxConversationMessageTextLength, - ), - ); + query: query, + maxLength: maxLength ?? _kMaxConversationMessageTextLength, + ), + quotedMessage: await _quotedMessageMap( + message, + resolveMissing: resolveMissingQuote, + ), + contextMessages: contextMessages, + quotedByMessages: quotedByMessages, + ); + } AiConversationToolMessage _searchMessageToToolMessage( SearchMessageDetailItem message, { @@ -347,6 +413,101 @@ class DatabaseAiConversationToolService implements AiConversationToolService { ), ); + Future>> _contextMessageMapsAround( + MessageItem message, + ) async { + final orderInfo = await database.messageDao.messageOrderInfo( + message.messageId, + ); + if (orderInfo == null) { + return const []; + } + + final beforeMessages = await database.messageDao + .beforeMessagesByConversationId( + orderInfo, + message.conversationId, + _kSearchContextBeforeLimit, + ) + .get(); + final afterMessages = await database.messageDao + .afterMessagesByConversationId( + orderInfo, + message.conversationId, + _kSearchContextAfterLimit, + ) + .get(); + return [ + for (final item in beforeMessages.reversed) _messageItemToToolMap(item), + for (final item in afterMessages) _messageItemToToolMap(item), + ]; + } + + Future>> _quotedByMessageMaps( + MessageItem message, + ) async { + final messages = await database.messageDao + .messagesByQuoteId( + message.conversationId, + message.messageId, + _kSearchQuotedByLimit, + ) + .get(); + return messages.map(_messageItemToToolMap).toList(growable: false); + } + + Future?> _quotedMessageMap( + MessageItem message, { + required bool resolveMissing, + }) async { + final quote = aiMessageQuotedItem(message); + if (quote != null) { + return _quoteMessageItemToToolMap(quote); + } + if (!resolveMissing) { + return null; + } + final quoteId = message.quoteId?.trim(); + if (quoteId == null || quoteId.isEmpty) { + return null; + } + final resolved = await database.messageDao.findMessageItemById( + message.conversationId, + quoteId, + ); + if (resolved == null) { + return { + 'message_id': quoteId, + 'unavailable': true, + }; + } + return _quoteMessageItemToToolMap(resolved); + } + + Map _messageItemToToolMap(MessageItem message) => { + 'message_id': message.messageId, + 'created_at': _formatToolDateTime(message.createdAt), + 'sender_name': message.userFullName ?? message.userId, + 'type': message.type, + 'text': _messageText( + content: message.content, + mediaName: message.mediaName, + type: message.type, + maxLength: _kMaxConversationMessageTextLength, + ), + }; + + Map _quoteMessageItemToToolMap(QuoteMessageItem message) => { + 'message_id': message.messageId, + 'created_at': _formatToolDateTime(message.createdAt), + 'sender_name': message.userFullName ?? message.userId, + 'type': message.type, + 'text': _truncateText( + aiQuoteMessageContextText(message), + _kMaxConversationMessageTextLength, + ), + }; + String _messageText({ required String? content, required String? mediaName, @@ -378,7 +539,9 @@ class AiConversationToolKit { genkit.Tool( name: 'get_conversation_stats', description: - 'Get message count and first/last timestamps for the conversation.', + 'Get message count and first/last timestamps for the conversation, ' + 'optionally limited to a date range. Use this before date-scoped or ' + 'unread summaries to understand coverage.', inputSchema: GetConversationStatsInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -398,7 +561,10 @@ class AiConversationToolKit { ), genkit.Tool( name: 'list_conversation_chunks', - description: 'List offsets for reading conversation messages in batches.', + description: + 'List offsets for reading conversation messages in batches, ' + 'optionally limited to a date range. Use this to plan exhaustive ' + 'summaries or wide history review.', inputSchema: ListConversationChunksInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -419,7 +585,11 @@ class AiConversationToolKit { ), genkit.Tool( name: 'read_conversation_chunk', - description: 'Read conversation messages by offset and limit.', + description: + 'Read conversation messages by offset and limit, optionally limited ' + 'to a date range. Use this for unread summaries, date-scoped ' + 'summaries, or surrounding context after a search hit. Messages may ' + 'include quoted_message when they directly quote another message.', inputSchema: ReadConversationChunkInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, @@ -441,7 +611,11 @@ class AiConversationToolKit { ), genkit.Tool( name: 'search_conversation_messages', - description: 'Search messages in the current conversation.', + description: + 'Search messages in the current conversation by keyword, phrase, ' + 'person, topic, link, or file name. Use anchor_id to page through ' + 'more matches when needed. Results include nearby context messages ' + 'and quote relationships when available.', inputSchema: SearchConversationMessagesInput.schema, fn: (input, context) => _executeTool( conversationId: conversationId, diff --git a/lib/db/dao/message_dao.dart b/lib/db/dao/message_dao.dart index 4544d0229c..60c1b18373 100644 --- a/lib/db/dao/message_dao.dart +++ b/lib/db/dao/message_dao.dart @@ -788,6 +788,21 @@ class MessageDao extends DatabaseAccessor ).getSingleOrNull(); } + Selectable messagesByQuoteId( + String conversationId, + String quoteMessageId, + int limit, + ) => _baseMessageItems( + (message, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => + message.conversationId.equals(conversationId) & + message.quoteMessageId.equals(quoteMessageId), + (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, 0), + order: (message, _, _, _, _, _, _, _, _, _, _, _, _, _) => OrderBy([ + OrderingTerm.asc(message.createdAt), + OrderingTerm.asc(message.rowId), + ]), + ); + Future updateMessageQuoteContent( String messageId, String? quoteContent, diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index e6c6ecd049..4252c85727 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -695,11 +695,15 @@ Future _requestAiDraftAction( try { final controller = AiChatController(context.database); + final provider = action == AiDraftAction.translate + ? context.database.settingProperties.selectedAiTranslatorProvider + : context.database.settingProperties.selectedAiProvider; final result = await controller.assistText( instruction: instruction, language: language, input: action == AiDraftAction.replyWithContext ? null : original, conversationId: conversationId, + provider: provider, ); return result.trim(); } catch (error, stackTrace) { diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index 19e370db1d..c2e0cec8ef 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -8,6 +8,7 @@ import '../../ai/model/ai_provider_config.dart'; import '../../utils/extension/extension.dart'; import '../../widgets/app_bar.dart'; import '../../widgets/cell.dart'; +import '../../widgets/dialog.dart'; import '../../widgets/toast.dart'; import '../provider/database_provider.dart'; import 'ai_prompt_settings_page.dart'; @@ -23,6 +24,12 @@ class AiSettingsPage extends HookConsumerWidget { final providers = database.settingProperties.aiProviders; final selectedId = database.settingProperties.selectedAiProviderId; final selectedProvider = database.settingProperties.selectedAiProvider; + final selectedTranslatorProvider = + database.settingProperties.selectedAiTranslatorProvider; + final selectedTranslatorProviderId = + database.settingProperties.selectedAiTranslatorProviderId; + final selectedTranslatorModel = + database.settingProperties.selectedAiTranslatorModel; final customizedPromptCount = aiPromptTemplateDefinitions .where( (definition) => @@ -130,6 +137,33 @@ class AiSettingsPage extends HookConsumerWidget { trailing: null, ), ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: CellItem( + title: const Text('Translator Provider'), + leading: Icon( + Icons.translate_rounded, + color: context.theme.icon, + ), + description: Text( + selectedTranslatorProviderId == null + ? 'Default · ${_providerModelSummary(selectedTranslatorProvider)}' + : _providerModelSummary( + selectedTranslatorProvider, + ), + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + onTap: () => _showTranslatorProviderDialog( + context, + providers: providers, + selectedProviderId: selectedTranslatorProviderId, + selectedModel: selectedTranslatorModel, + ), + ), + ), Padding( padding: const EdgeInsets.only( left: 20, @@ -184,6 +218,180 @@ class AiSettingsPage extends HookConsumerWidget { } return '${provider.model} · $modelCount models'; } + + static String _providerModelSummary(AiProviderConfig? provider) { + if (provider == null) return 'No enabled provider'; + return '${provider.name} · ${provider.model}'; + } + + static Future _showTranslatorProviderDialog( + BuildContext context, { + required List providers, + required String? selectedProviderId, + required String? selectedModel, + }) async { + await showMixinDialog( + context: context, + child: _TranslatorProviderDialog( + providers: providers + .where((provider) => provider.enabled) + .where((provider) => provider.model.trim().isNotEmpty) + .toList(growable: false), + selectedProviderId: selectedProviderId, + selectedModel: selectedModel, + ), + ); + } +} + +class _TranslatorProviderDialog extends HookConsumerWidget { + const _TranslatorProviderDialog({ + required this.providers, + required this.selectedProviderId, + required this.selectedModel, + }); + + final List providers; + final String? selectedProviderId; + final String? selectedModel; + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + final selection = useState( + _AiProviderModelSelection( + providerId: selectedProviderId, + model: selectedModel, + ), + ); + + return AlertDialogLayout( + title: const Text('Translator Provider'), + titleMarginBottom: 20, + content: ConstrainedBox( + constraints: const BoxConstraints(maxHeight: 360), + child: SingleChildScrollView( + child: Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.stretch, + children: [ + _ProviderModelOption( + title: 'Use Default Provider', + subtitle: _providerSummary( + database.settingProperties.selectedAiProvider, + ), + selected: selection.value.providerId == null, + onTap: () => + selection.value = const _AiProviderModelSelection(), + ), + for (final provider in providers) + for (final model in provider.models) + _ProviderModelOption( + title: provider.name, + subtitle: model, + selected: + selection.value.providerId == provider.id && + selection.value.model == model, + onTap: () => selection.value = _AiProviderModelSelection( + providerId: provider.id, + model: model, + ), + ), + ], + ), + ), + ), + actions: [ + MixinButton( + backgroundTransparent: true, + onTap: () => Navigator.of(context).pop(), + child: const Text('Cancel'), + ), + MixinButton( + onTap: () { + database.settingProperties.selectedAiTranslatorProviderId = + selection.value.providerId; + database.settingProperties.selectedAiTranslatorModel = + selection.value.model; + Navigator.of(context).pop(); + }, + child: const Text('Save'), + ), + ], + ); + } + + static String _providerSummary(AiProviderConfig? provider) { + if (provider == null) return 'No enabled provider'; + return '${provider.name} · ${provider.model}'; + } +} + +class _AiProviderModelSelection { + const _AiProviderModelSelection({this.providerId, this.model}); + + final String? providerId; + final String? model; +} + +class _ProviderModelOption extends StatelessWidget { + const _ProviderModelOption({ + required this.title, + required this.subtitle, + required this.selected, + required this.onTap, + }); + + final String title; + final String subtitle; + final bool selected; + final VoidCallback onTap; + + @override + Widget build(BuildContext context) => InkWell( + onTap: onTap, + child: Padding( + padding: const EdgeInsets.symmetric(vertical: 10), + child: Row( + children: [ + Icon( + selected + ? Icons.radio_button_checked_rounded + : Icons.radio_button_unchecked_rounded, + color: selected + ? context.theme.accent + : context.theme.secondaryText, + size: 20, + ), + const SizedBox(width: 12), + Expanded( + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + title, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle(color: context.theme.text, fontSize: 15), + ), + const SizedBox(height: 2), + Text( + subtitle, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + fontWeight: FontWeight.w400, + ), + ), + ], + ), + ), + ], + ), + ), + ); } class _ProviderCell extends HookConsumerWidget { diff --git a/lib/utils/property/setting_property.dart b/lib/utils/property/setting_property.dart index d752c8b427..8a008fb110 100644 --- a/lib/utils/property/setting_property.dart +++ b/lib/utils/property/setting_property.dart @@ -15,6 +15,8 @@ const _kSelectedProxyKey = 'selected_proxy'; const _kProxyListKey = 'proxy_list'; const _kAiProviderListKey = 'ai_provider_list'; const _kSelectedAiProviderKey = 'selected_ai_provider'; +const _kSelectedAiTranslatorProviderKey = 'selected_ai_translator_provider'; +const _kSelectedAiTranslatorModelKey = 'selected_ai_translator_model'; const _kAiPromptTemplateOverridesKey = 'ai_prompt_template_overrides'; class SettingPropertyStorage extends PropertyStorage { @@ -91,17 +93,41 @@ class SettingPropertyStorage extends PropertyStorage { set selectedAiProviderId(String? value) => set(_kSelectedAiProviderKey, value); - AiProviderConfig? get selectedAiProvider { + String? get selectedAiTranslatorProviderId => + get(_kSelectedAiTranslatorProviderKey); + + set selectedAiTranslatorProviderId(String? value) => + set(_kSelectedAiTranslatorProviderKey, value); + + String? get selectedAiTranslatorModel => get(_kSelectedAiTranslatorModelKey); + + set selectedAiTranslatorModel(String? value) => + set(_kSelectedAiTranslatorModelKey, value); + + AiProviderConfig? get selectedAiProvider => + _resolveAiProvider(selectedAiProviderId, null); + + AiProviderConfig? get selectedAiTranslatorProvider => + _resolveAiProvider( + selectedAiTranslatorProviderId, + selectedAiTranslatorModel, + ) ?? + selectedAiProvider; + + AiProviderConfig? _resolveAiProvider(String? selectedId, String? model) { final providers = aiProviders.where((element) => element.enabled).toList(); if (providers.isEmpty) { return null; } - final selectedId = selectedAiProviderId; - if (selectedId == null) { - return providers.first; - } - return providers.firstWhereOrNull((element) => element.id == selectedId) ?? - providers.first; + final provider = selectedId == null + ? providers.first + : providers.firstWhereOrNull((element) => element.id == selectedId) ?? + providers.first; + final selectedModel = model?.trim(); + if (selectedModel == null || selectedModel.isEmpty) return provider; + if (!provider.models.contains(selectedModel)) return provider; + if (provider.model == selectedModel) return provider; + return provider.copyWith(model: selectedModel, defaultModel: selectedModel); } void saveAiProvider(AiProviderConfig config) { @@ -128,6 +154,10 @@ class SettingPropertyStorage extends PropertyStorage { if (selectedAiProviderId == id) { selectedAiProviderId = providers.firstOrNull?.id; } + if (selectedAiTranslatorProviderId == id) { + selectedAiTranslatorProviderId = null; + selectedAiTranslatorModel = null; + } } Map get _aiPromptTemplateOverrides { diff --git a/lib/widgets/message/message_ai_assist.dart b/lib/widgets/message/message_ai_assist.dart index 38b6a8a866..6080594fb2 100644 --- a/lib/widgets/message/message_ai_assist.dart +++ b/lib/widgets/message/message_ai_assist.dart @@ -103,7 +103,12 @@ Future runMessageAiAction( required void Function(MessageAiAction, InlineMessageAiEntry) onStateChanged, }) async { final language = _currentLanguageTag(context); - final provider = context.database.settingProperties.selectedAiProvider; + final provider = switch (action) { + MessageAiAction.translate => + context.database.settingProperties.selectedAiTranslatorProvider, + MessageAiAction.explain || MessageAiAction.suggestReplies => + context.database.settingProperties.selectedAiProvider, + }; final model = provider?.model; final templateKey = switch (action) { MessageAiAction.translate => AiPromptTemplateKey.messageTranslate, diff --git a/test/ai/ai_conversation_context_test.dart b/test/ai/ai_conversation_context_test.dart new file mode 100644 index 0000000000..bf0ed521f6 --- /dev/null +++ b/test/ai/ai_conversation_context_test.dart @@ -0,0 +1,213 @@ +import 'package:drift/drift.dart'; +import 'package:drift/native.dart'; +import 'package:flutter_app/ai/ai_message_context.dart'; +import 'package:flutter_app/ai/tools/ai_conversation_tool_service.dart'; +import 'package:flutter_app/db/ai_database.dart'; +import 'package:flutter_app/db/database.dart'; +import 'package:flutter_app/db/extension/message.dart'; +import 'package:flutter_app/db/fts_database.dart'; +import 'package:flutter_app/db/mixin_database.dart'; +import 'package:flutter_app/enum/message_category.dart'; +import 'package:flutter_test/flutter_test.dart'; +import 'package:mixin_bot_sdk_dart/mixin_bot_sdk_dart.dart'; +import 'package:toon_format/toon_format.dart'; + +void main() { + group('AI conversation context', () { + late MixinDatabase mixinDatabase; + late FtsDatabase ftsDatabase; + late AiDatabase aiDatabase; + late Database database; + + setUp(() async { + mixinDatabase = MixinDatabase(NativeDatabase.memory()); + ftsDatabase = FtsDatabase(NativeDatabase.memory()); + aiDatabase = AiDatabase(NativeDatabase.memory()); + database = Database(mixinDatabase, ftsDatabase, aiDatabase); + + await _insertUser(database, 'owner', 'Owner'); + await _insertUser(database, 'alice', 'Alice'); + await _insertUser(database, 'bob', 'Bob'); + await database.mixinDatabase + .into(database.mixinDatabase.conversations) + .insert( + ConversationsCompanion.insert( + conversationId: 'conversation', + ownerId: const Value('owner'), + createdAt: DateTime(2026, 4, 30, 9), + status: ConversationStatus.success, + ), + ); + }); + + tearDown(() async { + await database.dispose(); + }); + + test('message context line includes quoted message content', () async { + final createdAt = DateTime(2026, 4, 30, 9, 1); + await _insertMessage( + database, + id: 'quoted', + userId: 'bob', + content: 'quoted topic detail', + createdAt: createdAt, + ); + final quote = await database.messageDao.findMessageItemById( + 'conversation', + 'quoted', + ); + await _insertMessage( + database, + id: 'reply', + userId: 'alice', + content: 'replying to that', + createdAt: createdAt.add(const Duration(minutes: 1)), + quoteMessageId: 'quoted', + quoteContent: quote!.toJson(), + ); + + final reply = await database.messageDao + .messageItemByMessageId('reply') + .getSingle(); + + expect( + aiMessageContextLine(reply), + contains('quoted_message:'), + ); + expect(aiMessageContextLine(reply), contains('Bob (message_id=quoted)')); + expect(aiMessageContextLine(reply), contains('quoted topic detail')); + }); + + test('search results include nearby and quote-linked messages', () async { + final createdAt = DateTime(2026, 4, 30, 10); + await _insertMessage( + database, + id: 'before', + userId: 'alice', + content: 'setup before topic', + createdAt: createdAt, + ); + await _insertMessage( + database, + id: 'target', + userId: 'bob', + content: 'alpha decision', + createdAt: createdAt.add(const Duration(minutes: 1)), + ); + await _insertMessage( + database, + id: 'after', + userId: 'alice', + content: 'follow up detail', + createdAt: createdAt.add(const Duration(minutes: 2)), + ); + final quote = await database.messageDao.findMessageItemById( + 'conversation', + 'target', + ); + await _insertMessage( + database, + id: 'quote-reply', + userId: 'alice', + content: 'reply via quote', + createdAt: createdAt.add(const Duration(minutes: 3)), + quoteMessageId: 'target', + quoteContent: quote!.toJson(), + ); + + final service = DatabaseAiConversationToolService(database); + final targetResult = await service.searchConversationMessages( + conversationId: 'conversation', + query: 'alpha', + limit: 1, + ); + final targetJson = targetResult.toJson(); + expect(encode(targetJson), contains('context_messages')); + final targetMessage = + (targetJson['messages'] as List).single as Map; + + expect(targetMessage['message_id'], 'target'); + expect( + targetMessage['context_messages'], + contains(containsPair('message_id', 'before')), + ); + expect( + targetMessage['context_messages'], + contains(containsPair('message_id', 'after')), + ); + expect( + targetMessage['quoted_by_messages'], + contains(containsPair('message_id', 'quote-reply')), + ); + + final quoteResult = await service.searchConversationMessages( + conversationId: 'conversation', + query: 'quote', + limit: 1, + ); + final quoteJson = quoteResult.toJson(); + final quoteMessage = + (quoteJson['messages'] as List).single as Map; + + expect(quoteMessage['message_id'], 'quote-reply'); + expect( + quoteMessage['quoted_message'], + containsPair('message_id', 'target'), + ); + }); + }); +} + +Future _insertUser(Database database, String id, String name) => database + .mixinDatabase + .into(database.mixinDatabase.users) + .insert( + UsersCompanion.insert( + userId: id, + identityNumber: id, + fullName: Value(name), + ), + ); + +Future _insertMessage( + Database database, { + required String id, + required String userId, + required String content, + required DateTime createdAt, + String? quoteMessageId, + String? quoteContent, +}) async { + await database.mixinDatabase + .into(database.mixinDatabase.messages) + .insert( + MessagesCompanion.insert( + messageId: id, + conversationId: 'conversation', + userId: userId, + category: MessageCategory.plainText, + content: Value(content), + status: MessageStatus.read, + createdAt: createdAt, + quoteMessageId: Value(quoteMessageId), + quoteContent: Value(quoteContent), + ), + ); + + final rowId = await database.ftsDatabase + .into(database.ftsDatabase.messagesFts) + .insert(MessagesFt(content: content)); + await database.ftsDatabase + .into(database.ftsDatabase.messagesMetas) + .insert( + MessagesMeta( + docId: rowId, + messageId: id, + conversationId: 'conversation', + category: MessageCategory.plainText, + userId: userId, + createdAt: createdAt, + ), + ); +} diff --git a/test/db/property_storage_test.dart b/test/db/property_storage_test.dart index d6d1dc5310..43ed93f103 100644 --- a/test/db/property_storage_test.dart +++ b/test/db/property_storage_test.dart @@ -3,6 +3,8 @@ library; import 'package:drift/native.dart'; import 'package:flutter_app/ai/model/ai_prompt_template.dart'; +import 'package:flutter_app/ai/model/ai_provider_config.dart'; +import 'package:flutter_app/ai/model/ai_provider_type.dart'; import 'package:flutter_app/db/mixin_database.dart'; import 'package:flutter_app/db/util/property_storage.dart'; import 'package:flutter_app/enum/property_group.dart'; @@ -78,4 +80,48 @@ void main() { expect(storage.aiPromptTemplate(key), key.definition.defaultValue); expect(storage.hasAiPromptTemplateOverride(key), isFalse); }); + + test('AI translator provider can use an independent model', () async { + final database = MixinDatabase(NativeDatabase.memory()); + final storage = SettingPropertyStorage(database.propertyDao); + final defaultProvider = AiProviderConfig( + id: 'default', + name: 'Default', + type: AiProviderType.openaiCompatible, + baseUrl: 'https://api.example.com/v1', + apiKey: 'key', + model: 'chat-model', + models: const ['chat-model', 'translate-model'], + defaultModel: 'chat-model', + ); + final translatorProvider = AiProviderConfig( + id: 'translator', + name: 'Translator', + type: AiProviderType.openaiCompatible, + baseUrl: 'https://api.example.com/v1', + apiKey: 'key', + model: 'small', + models: const ['small', 'large'], + defaultModel: 'small', + ); + + storage + ..saveAiProvider(defaultProvider) + ..saveAiProvider(translatorProvider) + ..selectedAiProviderId = defaultProvider.id; + + expect(storage.selectedAiProvider?.id, defaultProvider.id); + expect(storage.selectedAiProvider?.model, 'chat-model'); + expect(storage.selectedAiTranslatorProvider?.id, defaultProvider.id); + expect(storage.selectedAiTranslatorProvider?.model, 'chat-model'); + + storage + ..selectedAiTranslatorProviderId = translatorProvider.id + ..selectedAiTranslatorModel = 'large'; + + expect(storage.selectedAiProvider?.id, defaultProvider.id); + expect(storage.selectedAiProvider?.model, 'chat-model'); + expect(storage.selectedAiTranslatorProvider?.id, translatorProvider.id); + expect(storage.selectedAiTranslatorProvider?.model, 'large'); + }); } From a9686091556877b6ab75a46ea84a783888a218f6 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 30 Apr 2026 10:38:52 +0800 Subject: [PATCH 43/52] feat: add support for message_id validation and markdown linking in conversation prompts --- lib/ai/model/ai_prompt_template.dart | 6 +++-- lib/utils/uri_utils.dart | 34 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart index 3648b266a1..3e53d14fbd 100644 --- a/lib/ai/model/ai_prompt_template.dart +++ b/lib/ai/model/ai_prompt_template.dart @@ -344,8 +344,10 @@ const conversationToolInstructionPromptTemplate = 'topic links than nearby messages. Tool results are returned in TOON ' 'format, a compact tabular notation for structured data. ' 'Ground answers in retrieved messages and include sender, timestamp, or ' - 'message_id when that evidence helps. If retrieval does not find enough ' - 'evidence, say so instead of guessing. ' + 'message_id when that evidence helps. When citing a retrieved message, ' + 'you may link the message_id with markdown using this exact URL pattern: ' + 'mixin://conversations/{{conversationId}}?message_id=. ' + 'If retrieval does not find enough evidence, say so instead of guessing. ' 'When answering the user, default to {{language}} unless the user ' 'explicitly requires another language or preserving the source ' 'language. Do not call tools when the provided context is already ' diff --git a/lib/utils/uri_utils.dart b/lib/utils/uri_utils.dart index 11cf5a8300..980927fa45 100644 --- a/lib/utils/uri_utils.dart +++ b/lib/utils/uri_utils.dart @@ -296,15 +296,44 @@ Future _selectConversation( } } + final initIndexMessageId = await _validatedMessageIdOfConversation( + context, + conversationId, + uri.messageIdOfConversation, + ); + if (uri.messageIdOfConversation != null && initIndexMessageId == null) { + showToastFailed(null); + return false; + } + await ConversationStateNotifier.selectConversation( context, conversationId, + initIndexMessageId: initIndexMessageId, sync: true, checkCurrentUserExist: true, ); return true; } +Future _validatedMessageIdOfConversation( + BuildContext context, + String conversationId, + String? messageId, +) async { + final trimmedMessageId = messageId?.trim(); + if (trimmedMessageId == null || trimmedMessageId.isEmpty) { + return null; + } + + final messageConversationId = await context.database.messageDao + .findConversationIdByMessageId(trimmedMessageId); + if (messageConversationId != conversationId) { + return null; + } + return trimmedMessageId; +} + extension MixinUriExt on Uri { bool get isSendToUser => !userOfSend.isNullOrBlank(); @@ -378,6 +407,11 @@ extension _MixinUriExtension on Uri { return queryParameters['start']; } + String? get messageIdOfConversation { + if (!isMixin) return null; + return queryParameters['message_id']; + } + String? get userOfSend { if (!isSend) { return null; From 659d9931883ca3def38c3837897b9f81c8bf22b2 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 30 Apr 2026 10:42:27 +0800 Subject: [PATCH 44/52] feat: add transcript context handling and improve prompt generation --- lib/ai/ai_chat_prompt_builder.dart | 67 ++++++++++++++++++--- lib/ai/ai_message_context.dart | 35 +++++++++++ test/ai/ai_conversation_context_test.dart | 73 ++++++++++++++++++++++- 3 files changed, 165 insertions(+), 10 deletions(-) diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index cda84b563f..2196cb31cd 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -1,6 +1,8 @@ import 'package:mixin_logger/mixin_logger.dart'; +import '../db/dao/transcript_message_dao.dart'; import '../db/database.dart'; +import '../db/extension/message_category.dart'; import '../db/mixin_database.dart'; import 'ai_message_context.dart'; import 'model/ai_prompt_message.dart'; @@ -16,6 +18,8 @@ class AiChatPromptBuilder { static const _attachedContextAfterLimit = 2; static const _attachedQuotedByLimit = 3; static const _attachedContextMaxTextLength = 1000; + static const _attachedTranscriptLimit = 80; + static const _attachedTranscriptMaxTextLength = 800; final Database database; @@ -274,14 +278,12 @@ class AiChatPromptBuilder { ); final lines = [ 'Attached context block for message_id=${message.messageId}:', - for (final contextMessage in contextMessages) - aiMessageContextLine( - contextMessage, - relation: contextMessage.messageId == message.messageId - ? 'attached' - : 'nearby', - maxTextLength: _attachedContextMaxTextLength, - ), + 'Primary attached message:', + aiMessageContextLine( + message, + relation: 'attached_primary', + maxTextLength: _attachedContextMaxTextLength, + ), ]; final missingQuoteLine = await _missingQuoteContextLine(message); @@ -289,6 +291,33 @@ class AiChatPromptBuilder { lines.add(' $missingQuoteLine'); } + if (message.type.isTranscript) { + final transcriptLines = await _attachedTranscriptContextLines(message); + if (transcriptLines.isNotEmpty) { + lines + ..add('Attached transcript messages:') + ..addAll(transcriptLines); + } + } + + final nearbyMessages = contextMessages + .where( + (contextMessage) => contextMessage.messageId != message.messageId, + ) + .toList(growable: false); + if (nearbyMessages.isNotEmpty) { + lines.add('Nearby context messages, for disambiguation only:'); + for (final contextMessage in nearbyMessages) { + lines.add( + aiMessageContextLine( + contextMessage, + relation: 'nearby', + maxTextLength: _attachedContextMaxTextLength, + ), + ); + } + } + final quotedByMessages = await database.messageDao .messagesByQuoteId( message.conversationId, @@ -312,6 +341,28 @@ class AiChatPromptBuilder { return lines.join('\n'); } + Future> _attachedTranscriptContextLines( + MessageItem message, + ) async { + final transcriptMessages = await database.transcriptMessageDao + .transactionMessageItem(message.messageId) + .get(); + if (transcriptMessages.isEmpty) { + return const []; + } + + return transcriptMessages + .take(_attachedTranscriptLimit) + .map( + (item) => aiMessageContextLine( + item.messageItem, + relation: 'attached_transcript_item', + maxTextLength: _attachedTranscriptMaxTextLength, + ), + ) + .toList(growable: false); + } + Future> _messageContextWindow( MessageItem message, { required int beforeLimit, diff --git a/lib/ai/ai_message_context.dart b/lib/ai/ai_message_context.dart index a4f4ea40d5..b3ea101178 100644 --- a/lib/ai/ai_message_context.dart +++ b/lib/ai/ai_message_context.dart @@ -1,5 +1,6 @@ import 'dart:convert'; +import '../blaze/vo/transcript_minimal.dart'; import '../db/dao/message_dao.dart'; import '../db/extension/message.dart'; import '../db/extension/message_category.dart'; @@ -14,6 +15,10 @@ String aiMessageContextText(MessageItem message) { return content; } + if (message.type.isTranscript) { + return _transcriptContextText(content) ?? '[transcript]'; + } + final caption = message.caption?.trim(); if (caption != null && caption.isNotEmpty) { return caption; @@ -32,6 +37,36 @@ String aiMessageContextText(MessageItem message) { '[${message.type}]'; } +String? _transcriptContextText(String? content) { + if (content == null || content.isEmpty) { + return null; + } + try { + final decoded = jsonDecode(content); + if (decoded is! List) { + return content; + } + final lines = decoded + .map((json) { + final item = TranscriptMinimal.fromJson( + Map.from(json as Map), + ); + final text = + messagePreviewOptimize(null, item.category, item.content) ?? + item.content ?? + '[${item.category}]'; + return '${item.name}: $text'; + }) + .join('\n'); + if (lines.isEmpty) { + return null; + } + return lines; + } catch (_) { + return content; + } +} + String aiMessageContextLine( MessageItem message, { String? relation, diff --git a/test/ai/ai_conversation_context_test.dart b/test/ai/ai_conversation_context_test.dart index bf0ed521f6..68b8f928d7 100644 --- a/test/ai/ai_conversation_context_test.dart +++ b/test/ai/ai_conversation_context_test.dart @@ -1,5 +1,6 @@ import 'package:drift/drift.dart'; import 'package:drift/native.dart'; +import 'package:flutter_app/ai/ai_chat_prompt_builder.dart'; import 'package:flutter_app/ai/ai_message_context.dart'; import 'package:flutter_app/ai/tools/ai_conversation_tool_service.dart'; import 'package:flutter_app/db/ai_database.dart'; @@ -156,6 +157,73 @@ void main() { containsPair('message_id', 'target'), ); }); + + test( + 'attached transcript prompt includes focused transcript items', + () async { + final createdAt = DateTime(2026, 4, 30, 11); + await _insertMessage( + database, + id: 'before-transcript', + userId: 'alice', + content: 'noise before transcript', + createdAt: createdAt, + ); + await _insertMessage( + database, + id: 'transcript', + userId: 'bob', + content: '[Transcript]', + createdAt: createdAt.add(const Duration(minutes: 1)), + category: MessageCategory.plainTranscript, + ); + await _insertMessage( + database, + id: 'after-transcript', + userId: 'alice', + content: 'noise after transcript', + createdAt: createdAt.add(const Duration(minutes: 2)), + ); + await database.mixinDatabase + .into(database.mixinDatabase.transcriptMessages) + .insert( + TranscriptMessagesCompanion.insert( + transcriptId: 'transcript', + messageId: 'transcript-item-1', + category: MessageCategory.plainText, + createdAt: createdAt.add(const Duration(minutes: 3)), + content: const Value('real transcript detail'), + userId: const Value('alice'), + userFullName: const Value('Alice'), + ), + ); + + final attached = await database.messageDao + .messageItemByMessageId('transcript') + .getSingle(); + final promptMessages = await AiChatPromptBuilder(database) + .buildPromptMessages( + 'conversation', + 'thread', + 'what is inside this transcript?', + 'English', + attachedMessages: [attached], + ); + final prompt = promptMessages + .map((message) => message.content) + .join('\n'); + + expect(prompt, contains('Primary attached message:')); + expect(prompt, contains('relation=attached_primary')); + expect(prompt, contains('Attached transcript messages:')); + expect(prompt, contains('real transcript detail')); + expect( + prompt, + contains('Nearby context messages, for disambiguation only'), + ); + expect(prompt, contains('noise before transcript')); + }, + ); }); } @@ -176,6 +244,7 @@ Future _insertMessage( required String userId, required String content, required DateTime createdAt, + String category = MessageCategory.plainText, String? quoteMessageId, String? quoteContent, }) async { @@ -186,7 +255,7 @@ Future _insertMessage( messageId: id, conversationId: 'conversation', userId: userId, - category: MessageCategory.plainText, + category: category, content: Value(content), status: MessageStatus.read, createdAt: createdAt, @@ -205,7 +274,7 @@ Future _insertMessage( docId: rowId, messageId: id, conversationId: 'conversation', - category: MessageCategory.plainText, + category: category, userId: userId, createdAt: createdAt, ), From ef73ce775b4f98e7af4ba61a98476be6c3696eef Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 30 Apr 2026 12:02:04 +0800 Subject: [PATCH 45/52] feat: add OCR support for image messages in AI chat prompts --- lib/ai/ai_chat_prompt_builder.dart | 13 + lib/ai/ai_message_context.dart | 4 + lib/ai/model/ai_prompt_template.dart | 6 +- .../tools/ai_conversation_tool_service.dart | 73 ++ lib/ai/tools/ai_image_ocr_service.dart | 337 +++++++ lib/db/ai_database.dart | 11 +- lib/db/ai_database.g.dart | 942 ++++++++++++++++++ lib/db/dao/ai_image_ocr_dao.dart | 20 + lib/db/dao/ai_image_ocr_dao.g.dart | 13 + lib/db/database.dart | 12 +- lib/db/moor/ai.drift | 15 + lib/ui/provider/database_provider.dart | 1 + macos/Podfile.lock | 6 - macos/Runner.xcodeproj/project.pbxproj | 2 - test/ai/ai_conversation_context_test.dart | 34 + 15 files changed, 1476 insertions(+), 13 deletions(-) create mode 100644 lib/ai/tools/ai_image_ocr_service.dart create mode 100644 lib/db/dao/ai_image_ocr_dao.dart create mode 100644 lib/db/dao/ai_image_ocr_dao.g.dart diff --git a/lib/ai/ai_chat_prompt_builder.dart b/lib/ai/ai_chat_prompt_builder.dart index 2196cb31cd..4ce1bf511f 100644 --- a/lib/ai/ai_chat_prompt_builder.dart +++ b/lib/ai/ai_chat_prompt_builder.dart @@ -7,6 +7,7 @@ import '../db/mixin_database.dart'; import 'ai_message_context.dart'; import 'model/ai_prompt_message.dart'; import 'model/ai_prompt_template.dart'; +import 'tools/ai_image_ocr_service.dart'; class AiChatPromptBuilder { AiChatPromptBuilder(this.database); @@ -22,6 +23,7 @@ class AiChatPromptBuilder { static const _attachedTranscriptMaxTextLength = 800; final Database database; + late final AiImageOcrService _imageOcrService = AiImageOcrService(database); Future> buildPromptMessages( String conversationId, @@ -299,6 +301,17 @@ class AiChatPromptBuilder { ..addAll(transcriptLines); } } + if (message.type.isImage) { + final ocrResult = await _imageOcrService.recognizeMessageImageText( + conversationId: message.conversationId, + messageId: message.messageId, + ); + lines.addAll( + ocrResult.toPromptLines( + 'OCR text from primary attached image:', + ), + ); + } final nearbyMessages = contextMessages .where( diff --git a/lib/ai/ai_message_context.dart b/lib/ai/ai_message_context.dart index b3ea101178..9b71f899cb 100644 --- a/lib/ai/ai_message_context.dart +++ b/lib/ai/ai_message_context.dart @@ -24,6 +24,10 @@ String aiMessageContextText(MessageItem message) { return caption; } + if (message.type.isImage) { + return '[image]'; + } + final mediaName = message.mediaName?.trim(); if (mediaName != null && mediaName.isNotEmpty) { return '[${message.type}] $mediaName'; diff --git a/lib/ai/model/ai_prompt_template.dart b/lib/ai/model/ai_prompt_template.dart index 3e53d14fbd..cc039e3bd9 100644 --- a/lib/ai/model/ai_prompt_template.dart +++ b/lib/ai/model/ai_prompt_template.dart @@ -341,7 +341,11 @@ const conversationToolInstructionPromptTemplate = 'context, use the returned context_messages first, then read the relevant ' 'date range around the hit if more context is still needed. Search results ' 'may include quoted_message and quoted_by_messages; treat those as tighter ' - 'topic links than nearby messages. Tool results are returned in TOON ' + 'topic links than nearby messages. Use read_image_text when the user asks ' + 'about text inside an image, screenshot, photo, document, receipt, or ' + 'error capture message. OCR only recognizes text and may be incomplete; ' + 'do not pretend it provides full visual understanding. Tool results are ' + 'returned in TOON ' 'format, a compact tabular notation for structured data. ' 'Ground answers in retrieved messages and include sender, timestamp, or ' 'message_id when that evidence helps. When citing a retrieved message, ' diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index b3389b6c05..b99ff889b2 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -8,9 +8,11 @@ import 'package:toon_format/toon_format.dart'; import '../../db/dao/message_dao.dart'; import '../../db/database.dart'; +import '../../db/extension/message_category.dart'; import '../../db/mixin_database.dart'; import '../ai_message_context.dart'; import '../model/ai_chat_metadata.dart'; +import 'ai_image_ocr_service.dart'; const _kDefaultConversationChunkSize = 100; const _kMaxConversationChunkSize = 200; @@ -185,12 +187,18 @@ abstract interface class AiConversationToolService { required int limit, String? anchorMessageId, }); + + Future readImageText({ + required String conversationId, + required String messageId, + }); } class DatabaseAiConversationToolService implements AiConversationToolService { DatabaseAiConversationToolService(this.database); final Database database; + late final AiImageOcrService _imageOcrService = AiImageOcrService(database); @override Future getConversationStats({ @@ -362,6 +370,15 @@ class DatabaseAiConversationToolService implements AiConversationToolService { ); } + @override + Future readImageText({ + required String conversationId, + required String messageId, + }) => _imageOcrService.recognizeMessageImageText( + conversationId: conversationId, + messageId: messageId, + ); + Future _messageItemToToolMessage( MessageItem message, { String? query, @@ -523,6 +540,10 @@ class DatabaseAiConversationToolService implements AiConversationToolService { if (mediaName?.isNotEmpty == true) { return '[$type] $mediaName'; } + if (type.isImage) { + return '[$type image; use read_image_text with message_id when the user ' + 'asks about text in this image]'; + } return '[$type]'; } } @@ -634,6 +655,30 @@ class AiConversationToolKit { }, ), ), + genkit.Tool( + name: 'read_image_text', + description: + 'Run local OCR for an image message in the current conversation. ' + 'Use this when the user asks what text appears in an image, ' + 'screenshot, photo, receipt, document, or error capture. OCR only ' + 'recognizes visible text and may be incomplete; do not treat it as ' + 'full visual understanding.', + inputSchema: ReadImageTextInput.schema, + fn: (input, context) => _executeTool( + conversationId: conversationId, + name: 'read_image_text', + arguments: input.toArguments(), + context: context, + onEvent: onEvent, + fn: () async { + final result = await service.readImageText( + conversationId: conversationId, + messageId: input.messageId, + ); + return result.toJson(); + }, + ), + ), ]; Future _executeTool({ @@ -872,6 +917,34 @@ class SearchConversationMessagesInput { }..removeWhere((_, value) => value == null); } +class ReadImageTextInput { + const ReadImageTextInput({required this.messageId}); + + final String messageId; + + static final schema = SchemanticType.from( + jsonSchema: { + 'type': 'object', + 'properties': { + 'message_id': { + 'type': 'string', + 'description': 'Image message id in the current conversation.', + }, + }, + 'required': ['message_id'], + 'additionalProperties': false, + }, + parse: (value) { + final arguments = _jsonMap(value); + return ReadImageTextInput( + messageId: _parseRequiredString(arguments, 'message_id'), + ); + }, + ); + + Map toArguments() => {'message_id': messageId}; +} + Map _rangeSchema({ Map properties = const {}, List required = const [], diff --git a/lib/ai/tools/ai_image_ocr_service.dart b/lib/ai/tools/ai_image_ocr_service.dart new file mode 100644 index 0000000000..9d9d0ab9ac --- /dev/null +++ b/lib/ai/tools/ai_image_ocr_service.dart @@ -0,0 +1,337 @@ +import 'dart:convert'; +import 'dart:io'; + +import 'package:drift/drift.dart'; +import 'package:ffi/ffi.dart' as pkg_ffi; +import 'package:mixin_logger/mixin_logger.dart'; +import 'package:objective_c/objective_c.dart' as objc; +import 'package:platform_ocr/platform_ocr.dart'; +// ignore: implementation_imports +import 'package:platform_ocr/src/darwin/bindings.g.dart' as darwin; + +import '../../db/ai_database.dart'; +import '../../db/database.dart'; +import '../../db/extension/message_category.dart'; +import '../../db/mixin_database.dart'; +import '../../utils/attachment/attachment_util.dart'; + +const aiImageOcrEngine = 'platform_ocr'; +const _kOcrStatusDone = 'done'; +const _kOcrStatusError = 'error'; + +class AiImageOcrTextResult { + const AiImageOcrTextResult({ + required this.messageId, + required this.conversationId, + required this.engine, + required this.status, + required this.text, + required this.cached, + this.errorText, + this.lines = const [], + }); + + final String messageId; + final String conversationId; + final String engine; + final String status; + final String text; + final bool cached; + final String? errorText; + final List> lines; + + bool get hasText => text.trim().isNotEmpty; + + Map toJson() => { + 'message_id': messageId, + 'conversation_id': conversationId, + 'engine': engine, + 'status': status, + 'cached': cached, + 'text': text, + if (errorText?.isNotEmpty == true) 'error_text': errorText, + if (lines.isNotEmpty) 'lines': lines, + }; + + List toPromptLines(String title) => [ + title, + 'message_id=$messageId engine=$engine status=$status cached=$cached', + if (status == _kOcrStatusDone) hasText ? text.trim() : 'no text recognized', + if (status != _kOcrStatusDone) + 'unavailable: ${errorText ?? 'unknown error'}', + ]; +} + +class AiImageOcrService { + AiImageOcrService(this.database); + + final Database database; + + Future recognizeMessageImageText({ + required String conversationId, + required String messageId, + }) async { + final message = await database.messageDao + .messageItemByMessageId(messageId) + .getSingleOrNull(); + if (message == null) { + return _unavailable( + conversationId: conversationId, + messageId: messageId, + errorText: 'message not found', + ); + } + if (message.conversationId != conversationId) { + return _unavailable( + conversationId: conversationId, + messageId: messageId, + errorText: 'message is not in the current conversation', + ); + } + if (!message.type.isImage) { + return _unavailable( + conversationId: conversationId, + messageId: messageId, + errorText: 'message is not an image', + ); + } + + final file = await _messageImageFile(message); + if (file == null) { + return _unavailable( + conversationId: conversationId, + messageId: messageId, + errorText: 'local image file is not available', + ); + } + final fingerprint = await _mediaFingerprint(message, file); + final cached = await database.aiImageOcrDao.resultByMessageId(messageId); + if (cached != null && + cached.mediaFingerprint == fingerprint && + cached.engine == aiImageOcrEngine) { + return _fromCache(cached); + } + + try { + final result = await _recognizeText(file); + final text = result.text.trim(); + final lines = result.lines.map(_ocrLineToJson).toList(growable: false); + await _saveResult( + message: message, + fingerprint: fingerprint, + status: _kOcrStatusDone, + text: text, + lines: lines, + ); + return AiImageOcrTextResult( + messageId: messageId, + conversationId: conversationId, + engine: aiImageOcrEngine, + status: _kOcrStatusDone, + text: text, + cached: false, + lines: lines, + ); + } catch (error, stacktrace) { + e('AI image OCR failed: $error, $stacktrace'); + final errorText = error.toString(); + await _saveResult( + message: message, + fingerprint: fingerprint, + status: _kOcrStatusError, + text: '', + errorText: errorText, + ); + return _unavailable( + conversationId: conversationId, + messageId: messageId, + errorText: errorText, + ); + } + } + + Future _messageImageFile(MessageItem message) async { + final identityNumber = database.identityNumber; + if (identityNumber == null || identityNumber.isEmpty) { + return null; + } + final path = AttachmentUtilBase.of(identityNumber).convertAbsolutePath( + category: message.type, + conversationId: message.conversationId, + fileName: message.mediaUrl, + ); + if (path.isEmpty) { + return null; + } + final file = File(path); + return file.existsSync() ? file : null; + } + + Future _mediaFingerprint(MessageItem message, File file) async { + final stat = file.statSync(); + return [ + message.mediaUrl ?? '', + stat.size, + stat.modified.toUtc().toIso8601String(), + ].join('|'); + } + + Future _recognizeText(File file) async { + if (Platform.isMacOS || Platform.isIOS) { + return _recognizeDarwinText(file); + } + final ocr = PlatformOcr(); + try { + return await ocr.recognizeText(OcrSource.file(file)); + } finally { + ocr.dispose(); + } + } + + Future _saveResult({ + required MessageItem message, + required String fingerprint, + required String status, + required String text, + List> lines = const [], + String? errorText, + }) { + final now = DateTime.now(); + return database.aiImageOcrDao.upsertResult( + ImageOcrResultsCompanion.insert( + messageId: message.messageId, + conversationId: message.conversationId, + mediaFingerprint: fingerprint, + engine: aiImageOcrEngine, + status: status, + recognizedText: Value(text), + linesJson: Value(lines.isEmpty ? null : jsonEncode(lines)), + errorText: Value(errorText), + createdAt: now, + updatedAt: now, + ), + ); + } +} + +Map _ocrLineToJson(OcrLine line) => { + 'text': line.text, + 'box': { + 'left': line.boundingBox.left, + 'top': line.boundingBox.top, + 'width': line.boundingBox.width, + 'height': line.boundingBox.height, + }, +}; + +Future _recognizeDarwinText(File file) async => + pkg_ffi.using((arena) async { + var result = OcrResult(text: '', lines: []); + objc.autoReleasePool(() { + final request = darwin.VNRecognizeTextRequest.alloc().init() + ..recognitionLevel = darwin + .VNRequestTextRecognitionLevel + .VNRequestTextRecognitionLevelAccurate + ..usesLanguageCorrection = true; + _enableLanguageAutoDetection(request); + + final url = objc.NSURL.fileURLWithPath(objc.NSString(file.path)); + final handler = darwin.VNImageRequestHandler.alloc().initWithURL( + url, + options: objc.NSDictionary.new$(), + ); + final success = handler.performRequests( + objc.NSArray.arrayWithObject(request), + ); + if (!success) { + throw Exception('Vision request failed'); + } + + final resultsArr = request.results; + if (resultsArr == null) { + return; + } + final lines = []; + final fullTextBuffer = StringBuffer(); + for (var i = 0; i < resultsArr.count; i++) { + final obj = resultsArr.objectAtIndex(i); + if (!darwin.VNRecognizedTextObservation.isA(obj)) { + continue; + } + final observation = darwin.VNRecognizedTextObservation.as(obj); + final topCandidates = observation.topCandidates(1); + if (topCandidates.count == 0) { + continue; + } + final recognizedText = darwin.VNRecognizedText.as( + topCandidates.objectAtIndex(0), + ); + final text = recognizedText.string.toDartString(); + final box = observation.boundingBox; + final rect = Rect.fromLTWH( + box.origin.x, + 1.0 - box.origin.y - box.size.height, + box.size.width, + box.size.height, + ); + lines.add(OcrLine(text: text, boundingBox: rect)); + fullTextBuffer.writeln(text); + } + result = OcrResult( + text: fullTextBuffer.toString().trim(), + lines: lines, + ); + }); + return result; + }); + +void _enableLanguageAutoDetection(darwin.VNRecognizeTextRequest request) { + try { + request.automaticallyDetectsLanguage = true; + } catch (_) { + // Available on newer Darwin versions only. + } +} + +AiImageOcrTextResult _fromCache(ImageOcrResult row) => AiImageOcrTextResult( + messageId: row.messageId, + conversationId: row.conversationId, + engine: row.engine, + status: row.status, + text: row.recognizedText, + cached: true, + errorText: row.errorText, + lines: _decodeLines(row.linesJson), +); + +AiImageOcrTextResult _unavailable({ + required String conversationId, + required String messageId, + required String errorText, +}) => AiImageOcrTextResult( + messageId: messageId, + conversationId: conversationId, + engine: aiImageOcrEngine, + status: _kOcrStatusError, + text: '', + cached: false, + errorText: errorText, +); + +List> _decodeLines(String? raw) { + if (raw == null || raw.isEmpty) { + return const []; + } + try { + final decoded = jsonDecode(raw); + if (decoded is! List) { + return const []; + } + return decoded + .whereType() + .map(Map.from) + .toList(growable: false); + } catch (_) { + return const []; + } +} diff --git a/lib/db/ai_database.dart b/lib/db/ai_database.dart index da6b7098b5..3864930e35 100644 --- a/lib/db/ai_database.dart +++ b/lib/db/ai_database.dart @@ -2,13 +2,14 @@ import 'package:drift/drift.dart'; import 'converter/millis_date_converter.dart'; import 'dao/ai_chat_message_dao.dart'; +import 'dao/ai_image_ocr_dao.dart'; import 'util/open_database.dart'; part 'ai_database.g.dart'; @DriftDatabase( include: {'moor/ai.drift'}, - daos: [AiChatMessageDao], + daos: [AiChatMessageDao, AiImageOcrDao], ) class AiDatabase extends _$AiDatabase { AiDatabase(super.e); @@ -27,10 +28,14 @@ class AiDatabase extends _$AiDatabase { } @override - int get schemaVersion => 1; + int get schemaVersion => 2; @override MigrationStrategy get migration => MigrationStrategy( - onUpgrade: (m, from, to) async {}, + onUpgrade: (m, from, to) async { + if (from < 2) { + await m.createTable(imageOcrResults); + } + }, ); } diff --git a/lib/db/ai_database.g.dart b/lib/db/ai_database.g.dart index 03c6a64c7b..2750637b64 100644 --- a/lib/db/ai_database.g.dart +++ b/lib/db/ai_database.g.dart @@ -1501,11 +1501,641 @@ class AiChatThreadsCompanion extends UpdateCompanion { } } +class ImageOcrResults extends Table + with TableInfo { + @override + final GeneratedDatabase attachedDatabase; + final String? _alias; + ImageOcrResults(this.attachedDatabase, [this._alias]); + static const VerificationMeta _messageIdMeta = const VerificationMeta( + 'messageId', + ); + late final GeneratedColumn messageId = GeneratedColumn( + 'message_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _conversationIdMeta = const VerificationMeta( + 'conversationId', + ); + late final GeneratedColumn conversationId = GeneratedColumn( + 'conversation_id', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _mediaFingerprintMeta = const VerificationMeta( + 'mediaFingerprint', + ); + late final GeneratedColumn mediaFingerprint = GeneratedColumn( + 'media_fingerprint', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _engineMeta = const VerificationMeta('engine'); + late final GeneratedColumn engine = GeneratedColumn( + 'engine', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _statusMeta = const VerificationMeta('status'); + late final GeneratedColumn status = GeneratedColumn( + 'status', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ); + static const VerificationMeta _recognizedTextMeta = const VerificationMeta( + 'recognizedText', + ); + late final GeneratedColumn recognizedText = GeneratedColumn( + 'recognized_text', + aliasedName, + false, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: 'NOT NULL DEFAULT \'\'', + defaultValue: const CustomExpression('\'\''), + ); + static const VerificationMeta _linesJsonMeta = const VerificationMeta( + 'linesJson', + ); + late final GeneratedColumn linesJson = GeneratedColumn( + 'lines_json', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + static const VerificationMeta _errorTextMeta = const VerificationMeta( + 'errorText', + ); + late final GeneratedColumn errorText = GeneratedColumn( + 'error_text', + aliasedName, + true, + type: DriftSqlType.string, + requiredDuringInsert: false, + $customConstraints: '', + ); + late final GeneratedColumnWithTypeConverter createdAt = + GeneratedColumn( + 'created_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(ImageOcrResults.$convertercreatedAt); + late final GeneratedColumnWithTypeConverter updatedAt = + GeneratedColumn( + 'updated_at', + aliasedName, + false, + type: DriftSqlType.int, + requiredDuringInsert: true, + $customConstraints: 'NOT NULL', + ).withConverter(ImageOcrResults.$converterupdatedAt); + @override + List get $columns => [ + messageId, + conversationId, + mediaFingerprint, + engine, + status, + recognizedText, + linesJson, + errorText, + createdAt, + updatedAt, + ]; + @override + String get aliasedName => _alias ?? actualTableName; + @override + String get actualTableName => $name; + static const String $name = 'image_ocr_results'; + @override + VerificationContext validateIntegrity( + Insertable instance, { + bool isInserting = false, + }) { + final context = VerificationContext(); + final data = instance.toColumns(true); + if (data.containsKey('message_id')) { + context.handle( + _messageIdMeta, + messageId.isAcceptableOrUnknown(data['message_id']!, _messageIdMeta), + ); + } else if (isInserting) { + context.missing(_messageIdMeta); + } + if (data.containsKey('conversation_id')) { + context.handle( + _conversationIdMeta, + conversationId.isAcceptableOrUnknown( + data['conversation_id']!, + _conversationIdMeta, + ), + ); + } else if (isInserting) { + context.missing(_conversationIdMeta); + } + if (data.containsKey('media_fingerprint')) { + context.handle( + _mediaFingerprintMeta, + mediaFingerprint.isAcceptableOrUnknown( + data['media_fingerprint']!, + _mediaFingerprintMeta, + ), + ); + } else if (isInserting) { + context.missing(_mediaFingerprintMeta); + } + if (data.containsKey('engine')) { + context.handle( + _engineMeta, + engine.isAcceptableOrUnknown(data['engine']!, _engineMeta), + ); + } else if (isInserting) { + context.missing(_engineMeta); + } + if (data.containsKey('status')) { + context.handle( + _statusMeta, + status.isAcceptableOrUnknown(data['status']!, _statusMeta), + ); + } else if (isInserting) { + context.missing(_statusMeta); + } + if (data.containsKey('recognized_text')) { + context.handle( + _recognizedTextMeta, + recognizedText.isAcceptableOrUnknown( + data['recognized_text']!, + _recognizedTextMeta, + ), + ); + } + if (data.containsKey('lines_json')) { + context.handle( + _linesJsonMeta, + linesJson.isAcceptableOrUnknown(data['lines_json']!, _linesJsonMeta), + ); + } + if (data.containsKey('error_text')) { + context.handle( + _errorTextMeta, + errorText.isAcceptableOrUnknown(data['error_text']!, _errorTextMeta), + ); + } + return context; + } + + @override + Set get $primaryKey => {messageId}; + @override + ImageOcrResult map(Map data, {String? tablePrefix}) { + final effectivePrefix = tablePrefix != null ? '$tablePrefix.' : ''; + return ImageOcrResult( + messageId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}message_id'], + )!, + conversationId: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}conversation_id'], + )!, + mediaFingerprint: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}media_fingerprint'], + )!, + engine: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}engine'], + )!, + status: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}status'], + )!, + recognizedText: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}recognized_text'], + )!, + linesJson: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}lines_json'], + ), + errorText: attachedDatabase.typeMapping.read( + DriftSqlType.string, + data['${effectivePrefix}error_text'], + ), + createdAt: ImageOcrResults.$convertercreatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}created_at'], + )!, + ), + updatedAt: ImageOcrResults.$converterupdatedAt.fromSql( + attachedDatabase.typeMapping.read( + DriftSqlType.int, + data['${effectivePrefix}updated_at'], + )!, + ), + ); + } + + @override + ImageOcrResults createAlias(String alias) { + return ImageOcrResults(attachedDatabase, alias); + } + + static TypeConverter $convertercreatedAt = + const MillisDateConverter(); + static TypeConverter $converterupdatedAt = + const MillisDateConverter(); + @override + List get customConstraints => const ['PRIMARY KEY(message_id)']; + @override + bool get dontWriteConstraints => true; +} + +class ImageOcrResult extends DataClass implements Insertable { + final String messageId; + final String conversationId; + final String mediaFingerprint; + final String engine; + final String status; + final String recognizedText; + final String? linesJson; + final String? errorText; + final DateTime createdAt; + final DateTime updatedAt; + const ImageOcrResult({ + required this.messageId, + required this.conversationId, + required this.mediaFingerprint, + required this.engine, + required this.status, + required this.recognizedText, + this.linesJson, + this.errorText, + required this.createdAt, + required this.updatedAt, + }); + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + map['message_id'] = Variable(messageId); + map['conversation_id'] = Variable(conversationId); + map['media_fingerprint'] = Variable(mediaFingerprint); + map['engine'] = Variable(engine); + map['status'] = Variable(status); + map['recognized_text'] = Variable(recognizedText); + if (!nullToAbsent || linesJson != null) { + map['lines_json'] = Variable(linesJson); + } + if (!nullToAbsent || errorText != null) { + map['error_text'] = Variable(errorText); + } + { + map['created_at'] = Variable( + ImageOcrResults.$convertercreatedAt.toSql(createdAt), + ); + } + { + map['updated_at'] = Variable( + ImageOcrResults.$converterupdatedAt.toSql(updatedAt), + ); + } + return map; + } + + ImageOcrResultsCompanion toCompanion(bool nullToAbsent) { + return ImageOcrResultsCompanion( + messageId: Value(messageId), + conversationId: Value(conversationId), + mediaFingerprint: Value(mediaFingerprint), + engine: Value(engine), + status: Value(status), + recognizedText: Value(recognizedText), + linesJson: linesJson == null && nullToAbsent + ? const Value.absent() + : Value(linesJson), + errorText: errorText == null && nullToAbsent + ? const Value.absent() + : Value(errorText), + createdAt: Value(createdAt), + updatedAt: Value(updatedAt), + ); + } + + factory ImageOcrResult.fromJson( + Map json, { + ValueSerializer? serializer, + }) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return ImageOcrResult( + messageId: serializer.fromJson(json['message_id']), + conversationId: serializer.fromJson(json['conversation_id']), + mediaFingerprint: serializer.fromJson(json['media_fingerprint']), + engine: serializer.fromJson(json['engine']), + status: serializer.fromJson(json['status']), + recognizedText: serializer.fromJson(json['recognized_text']), + linesJson: serializer.fromJson(json['lines_json']), + errorText: serializer.fromJson(json['error_text']), + createdAt: serializer.fromJson(json['created_at']), + updatedAt: serializer.fromJson(json['updated_at']), + ); + } + @override + Map toJson({ValueSerializer? serializer}) { + serializer ??= driftRuntimeOptions.defaultSerializer; + return { + 'message_id': serializer.toJson(messageId), + 'conversation_id': serializer.toJson(conversationId), + 'media_fingerprint': serializer.toJson(mediaFingerprint), + 'engine': serializer.toJson(engine), + 'status': serializer.toJson(status), + 'recognized_text': serializer.toJson(recognizedText), + 'lines_json': serializer.toJson(linesJson), + 'error_text': serializer.toJson(errorText), + 'created_at': serializer.toJson(createdAt), + 'updated_at': serializer.toJson(updatedAt), + }; + } + + ImageOcrResult copyWith({ + String? messageId, + String? conversationId, + String? mediaFingerprint, + String? engine, + String? status, + String? recognizedText, + Value linesJson = const Value.absent(), + Value errorText = const Value.absent(), + DateTime? createdAt, + DateTime? updatedAt, + }) => ImageOcrResult( + messageId: messageId ?? this.messageId, + conversationId: conversationId ?? this.conversationId, + mediaFingerprint: mediaFingerprint ?? this.mediaFingerprint, + engine: engine ?? this.engine, + status: status ?? this.status, + recognizedText: recognizedText ?? this.recognizedText, + linesJson: linesJson.present ? linesJson.value : this.linesJson, + errorText: errorText.present ? errorText.value : this.errorText, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + ); + ImageOcrResult copyWithCompanion(ImageOcrResultsCompanion data) { + return ImageOcrResult( + messageId: data.messageId.present ? data.messageId.value : this.messageId, + conversationId: data.conversationId.present + ? data.conversationId.value + : this.conversationId, + mediaFingerprint: data.mediaFingerprint.present + ? data.mediaFingerprint.value + : this.mediaFingerprint, + engine: data.engine.present ? data.engine.value : this.engine, + status: data.status.present ? data.status.value : this.status, + recognizedText: data.recognizedText.present + ? data.recognizedText.value + : this.recognizedText, + linesJson: data.linesJson.present ? data.linesJson.value : this.linesJson, + errorText: data.errorText.present ? data.errorText.value : this.errorText, + createdAt: data.createdAt.present ? data.createdAt.value : this.createdAt, + updatedAt: data.updatedAt.present ? data.updatedAt.value : this.updatedAt, + ); + } + + @override + String toString() { + return (StringBuffer('ImageOcrResult(') + ..write('messageId: $messageId, ') + ..write('conversationId: $conversationId, ') + ..write('mediaFingerprint: $mediaFingerprint, ') + ..write('engine: $engine, ') + ..write('status: $status, ') + ..write('recognizedText: $recognizedText, ') + ..write('linesJson: $linesJson, ') + ..write('errorText: $errorText, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt') + ..write(')')) + .toString(); + } + + @override + int get hashCode => Object.hash( + messageId, + conversationId, + mediaFingerprint, + engine, + status, + recognizedText, + linesJson, + errorText, + createdAt, + updatedAt, + ); + @override + bool operator ==(Object other) => + identical(this, other) || + (other is ImageOcrResult && + other.messageId == this.messageId && + other.conversationId == this.conversationId && + other.mediaFingerprint == this.mediaFingerprint && + other.engine == this.engine && + other.status == this.status && + other.recognizedText == this.recognizedText && + other.linesJson == this.linesJson && + other.errorText == this.errorText && + other.createdAt == this.createdAt && + other.updatedAt == this.updatedAt); +} + +class ImageOcrResultsCompanion extends UpdateCompanion { + final Value messageId; + final Value conversationId; + final Value mediaFingerprint; + final Value engine; + final Value status; + final Value recognizedText; + final Value linesJson; + final Value errorText; + final Value createdAt; + final Value updatedAt; + final Value rowid; + const ImageOcrResultsCompanion({ + this.messageId = const Value.absent(), + this.conversationId = const Value.absent(), + this.mediaFingerprint = const Value.absent(), + this.engine = const Value.absent(), + this.status = const Value.absent(), + this.recognizedText = const Value.absent(), + this.linesJson = const Value.absent(), + this.errorText = const Value.absent(), + this.createdAt = const Value.absent(), + this.updatedAt = const Value.absent(), + this.rowid = const Value.absent(), + }); + ImageOcrResultsCompanion.insert({ + required String messageId, + required String conversationId, + required String mediaFingerprint, + required String engine, + required String status, + this.recognizedText = const Value.absent(), + this.linesJson = const Value.absent(), + this.errorText = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + this.rowid = const Value.absent(), + }) : messageId = Value(messageId), + conversationId = Value(conversationId), + mediaFingerprint = Value(mediaFingerprint), + engine = Value(engine), + status = Value(status), + createdAt = Value(createdAt), + updatedAt = Value(updatedAt); + static Insertable custom({ + Expression? messageId, + Expression? conversationId, + Expression? mediaFingerprint, + Expression? engine, + Expression? status, + Expression? recognizedText, + Expression? linesJson, + Expression? errorText, + Expression? createdAt, + Expression? updatedAt, + Expression? rowid, + }) { + return RawValuesInsertable({ + if (messageId != null) 'message_id': messageId, + if (conversationId != null) 'conversation_id': conversationId, + if (mediaFingerprint != null) 'media_fingerprint': mediaFingerprint, + if (engine != null) 'engine': engine, + if (status != null) 'status': status, + if (recognizedText != null) 'recognized_text': recognizedText, + if (linesJson != null) 'lines_json': linesJson, + if (errorText != null) 'error_text': errorText, + if (createdAt != null) 'created_at': createdAt, + if (updatedAt != null) 'updated_at': updatedAt, + if (rowid != null) 'rowid': rowid, + }); + } + + ImageOcrResultsCompanion copyWith({ + Value? messageId, + Value? conversationId, + Value? mediaFingerprint, + Value? engine, + Value? status, + Value? recognizedText, + Value? linesJson, + Value? errorText, + Value? createdAt, + Value? updatedAt, + Value? rowid, + }) { + return ImageOcrResultsCompanion( + messageId: messageId ?? this.messageId, + conversationId: conversationId ?? this.conversationId, + mediaFingerprint: mediaFingerprint ?? this.mediaFingerprint, + engine: engine ?? this.engine, + status: status ?? this.status, + recognizedText: recognizedText ?? this.recognizedText, + linesJson: linesJson ?? this.linesJson, + errorText: errorText ?? this.errorText, + createdAt: createdAt ?? this.createdAt, + updatedAt: updatedAt ?? this.updatedAt, + rowid: rowid ?? this.rowid, + ); + } + + @override + Map toColumns(bool nullToAbsent) { + final map = {}; + if (messageId.present) { + map['message_id'] = Variable(messageId.value); + } + if (conversationId.present) { + map['conversation_id'] = Variable(conversationId.value); + } + if (mediaFingerprint.present) { + map['media_fingerprint'] = Variable(mediaFingerprint.value); + } + if (engine.present) { + map['engine'] = Variable(engine.value); + } + if (status.present) { + map['status'] = Variable(status.value); + } + if (recognizedText.present) { + map['recognized_text'] = Variable(recognizedText.value); + } + if (linesJson.present) { + map['lines_json'] = Variable(linesJson.value); + } + if (errorText.present) { + map['error_text'] = Variable(errorText.value); + } + if (createdAt.present) { + map['created_at'] = Variable( + ImageOcrResults.$convertercreatedAt.toSql(createdAt.value), + ); + } + if (updatedAt.present) { + map['updated_at'] = Variable( + ImageOcrResults.$converterupdatedAt.toSql(updatedAt.value), + ); + } + if (rowid.present) { + map['rowid'] = Variable(rowid.value); + } + return map; + } + + @override + String toString() { + return (StringBuffer('ImageOcrResultsCompanion(') + ..write('messageId: $messageId, ') + ..write('conversationId: $conversationId, ') + ..write('mediaFingerprint: $mediaFingerprint, ') + ..write('engine: $engine, ') + ..write('status: $status, ') + ..write('recognizedText: $recognizedText, ') + ..write('linesJson: $linesJson, ') + ..write('errorText: $errorText, ') + ..write('createdAt: $createdAt, ') + ..write('updatedAt: $updatedAt, ') + ..write('rowid: $rowid') + ..write(')')) + .toString(); + } +} + abstract class _$AiDatabase extends GeneratedDatabase { _$AiDatabase(QueryExecutor e) : super(e); $AiDatabaseManager get managers => $AiDatabaseManager(this); late final AiChatMessages aiChatMessages = AiChatMessages(this); late final AiChatThreads aiChatThreads = AiChatThreads(this); + late final ImageOcrResults imageOcrResults = ImageOcrResults(this); late final Index indexAiChatMessagesConversationIdCreatedAt = Index( 'index_ai_chat_messages_conversation_id_created_at', 'CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages (conversation_id, created_at DESC)', @@ -1522,9 +2152,14 @@ abstract class _$AiDatabase extends GeneratedDatabase { 'index_ai_chat_threads_conversation_id_last_message_at', 'CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_last_message_at ON ai_chat_threads (conversation_id, status, last_message_at DESC)', ); + late final Index indexImageOcrResultsConversationIdUpdatedAt = Index( + 'index_image_ocr_results_conversation_id_updated_at', + 'CREATE INDEX IF NOT EXISTS index_image_ocr_results_conversation_id_updated_at ON image_ocr_results (conversation_id, updated_at DESC)', + ); late final AiChatMessageDao aiChatMessageDao = AiChatMessageDao( this as AiDatabase, ); + late final AiImageOcrDao aiImageOcrDao = AiImageOcrDao(this as AiDatabase); @override Iterable> get allTables => allSchemaEntities.whereType>(); @@ -1532,10 +2167,12 @@ abstract class _$AiDatabase extends GeneratedDatabase { List get allSchemaEntities => [ aiChatMessages, aiChatThreads, + imageOcrResults, indexAiChatMessagesConversationIdCreatedAt, indexAiChatMessagesThreadIdCreatedAt, indexAiChatThreadsConversationIdUpdatedAt, indexAiChatThreadsConversationIdLastMessageAt, + indexImageOcrResultsConversationIdUpdatedAt, ]; } @@ -2244,6 +2881,309 @@ typedef $AiChatThreadsProcessedTableManager = AiChatThread, PrefetchHooks Function() >; +typedef $ImageOcrResultsCreateCompanionBuilder = + ImageOcrResultsCompanion Function({ + required String messageId, + required String conversationId, + required String mediaFingerprint, + required String engine, + required String status, + Value recognizedText, + Value linesJson, + Value errorText, + required DateTime createdAt, + required DateTime updatedAt, + Value rowid, + }); +typedef $ImageOcrResultsUpdateCompanionBuilder = + ImageOcrResultsCompanion Function({ + Value messageId, + Value conversationId, + Value mediaFingerprint, + Value engine, + Value status, + Value recognizedText, + Value linesJson, + Value errorText, + Value createdAt, + Value updatedAt, + Value rowid, + }); + +class $ImageOcrResultsFilterComposer + extends Composer<_$AiDatabase, ImageOcrResults> { + $ImageOcrResultsFilterComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnFilters get messageId => $composableBuilder( + column: $table.messageId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get mediaFingerprint => $composableBuilder( + column: $table.mediaFingerprint, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get engine => $composableBuilder( + column: $table.engine, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get recognizedText => $composableBuilder( + column: $table.recognizedText, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get linesJson => $composableBuilder( + column: $table.linesJson, + builder: (column) => ColumnFilters(column), + ); + + ColumnFilters get errorText => $composableBuilder( + column: $table.errorText, + builder: (column) => ColumnFilters(column), + ); + + ColumnWithTypeConverterFilters get createdAt => + $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); + + ColumnWithTypeConverterFilters get updatedAt => + $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnWithTypeConverterFilters(column), + ); +} + +class $ImageOcrResultsOrderingComposer + extends Composer<_$AiDatabase, ImageOcrResults> { + $ImageOcrResultsOrderingComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + ColumnOrderings get messageId => $composableBuilder( + column: $table.messageId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get mediaFingerprint => $composableBuilder( + column: $table.mediaFingerprint, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get engine => $composableBuilder( + column: $table.engine, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get status => $composableBuilder( + column: $table.status, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get recognizedText => $composableBuilder( + column: $table.recognizedText, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get linesJson => $composableBuilder( + column: $table.linesJson, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get errorText => $composableBuilder( + column: $table.errorText, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get createdAt => $composableBuilder( + column: $table.createdAt, + builder: (column) => ColumnOrderings(column), + ); + + ColumnOrderings get updatedAt => $composableBuilder( + column: $table.updatedAt, + builder: (column) => ColumnOrderings(column), + ); +} + +class $ImageOcrResultsAnnotationComposer + extends Composer<_$AiDatabase, ImageOcrResults> { + $ImageOcrResultsAnnotationComposer({ + required super.$db, + required super.$table, + super.joinBuilder, + super.$addJoinBuilderToRootComposer, + super.$removeJoinBuilderFromRootComposer, + }); + GeneratedColumn get messageId => + $composableBuilder(column: $table.messageId, builder: (column) => column); + + GeneratedColumn get conversationId => $composableBuilder( + column: $table.conversationId, + builder: (column) => column, + ); + + GeneratedColumn get mediaFingerprint => $composableBuilder( + column: $table.mediaFingerprint, + builder: (column) => column, + ); + + GeneratedColumn get engine => + $composableBuilder(column: $table.engine, builder: (column) => column); + + GeneratedColumn get status => + $composableBuilder(column: $table.status, builder: (column) => column); + + GeneratedColumn get recognizedText => $composableBuilder( + column: $table.recognizedText, + builder: (column) => column, + ); + + GeneratedColumn get linesJson => + $composableBuilder(column: $table.linesJson, builder: (column) => column); + + GeneratedColumn get errorText => + $composableBuilder(column: $table.errorText, builder: (column) => column); + + GeneratedColumnWithTypeConverter get createdAt => + $composableBuilder(column: $table.createdAt, builder: (column) => column); + + GeneratedColumnWithTypeConverter get updatedAt => + $composableBuilder(column: $table.updatedAt, builder: (column) => column); +} + +class $ImageOcrResultsTableManager + extends + RootTableManager< + _$AiDatabase, + ImageOcrResults, + ImageOcrResult, + $ImageOcrResultsFilterComposer, + $ImageOcrResultsOrderingComposer, + $ImageOcrResultsAnnotationComposer, + $ImageOcrResultsCreateCompanionBuilder, + $ImageOcrResultsUpdateCompanionBuilder, + ( + ImageOcrResult, + BaseReferences<_$AiDatabase, ImageOcrResults, ImageOcrResult>, + ), + ImageOcrResult, + PrefetchHooks Function() + > { + $ImageOcrResultsTableManager(_$AiDatabase db, ImageOcrResults table) + : super( + TableManagerState( + db: db, + table: table, + createFilteringComposer: () => + $ImageOcrResultsFilterComposer($db: db, $table: table), + createOrderingComposer: () => + $ImageOcrResultsOrderingComposer($db: db, $table: table), + createComputedFieldComposer: () => + $ImageOcrResultsAnnotationComposer($db: db, $table: table), + updateCompanionCallback: + ({ + Value messageId = const Value.absent(), + Value conversationId = const Value.absent(), + Value mediaFingerprint = const Value.absent(), + Value engine = const Value.absent(), + Value status = const Value.absent(), + Value recognizedText = const Value.absent(), + Value linesJson = const Value.absent(), + Value errorText = const Value.absent(), + Value createdAt = const Value.absent(), + Value updatedAt = const Value.absent(), + Value rowid = const Value.absent(), + }) => ImageOcrResultsCompanion( + messageId: messageId, + conversationId: conversationId, + mediaFingerprint: mediaFingerprint, + engine: engine, + status: status, + recognizedText: recognizedText, + linesJson: linesJson, + errorText: errorText, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + createCompanionCallback: + ({ + required String messageId, + required String conversationId, + required String mediaFingerprint, + required String engine, + required String status, + Value recognizedText = const Value.absent(), + Value linesJson = const Value.absent(), + Value errorText = const Value.absent(), + required DateTime createdAt, + required DateTime updatedAt, + Value rowid = const Value.absent(), + }) => ImageOcrResultsCompanion.insert( + messageId: messageId, + conversationId: conversationId, + mediaFingerprint: mediaFingerprint, + engine: engine, + status: status, + recognizedText: recognizedText, + linesJson: linesJson, + errorText: errorText, + createdAt: createdAt, + updatedAt: updatedAt, + rowid: rowid, + ), + withReferenceMapper: (p0) => p0 + .map((e) => (e.readTable(table), BaseReferences(db, table, e))) + .toList(), + prefetchHooksCallback: null, + ), + ); +} + +typedef $ImageOcrResultsProcessedTableManager = + ProcessedTableManager< + _$AiDatabase, + ImageOcrResults, + ImageOcrResult, + $ImageOcrResultsFilterComposer, + $ImageOcrResultsOrderingComposer, + $ImageOcrResultsAnnotationComposer, + $ImageOcrResultsCreateCompanionBuilder, + $ImageOcrResultsUpdateCompanionBuilder, + ( + ImageOcrResult, + BaseReferences<_$AiDatabase, ImageOcrResults, ImageOcrResult>, + ), + ImageOcrResult, + PrefetchHooks Function() + >; class $AiDatabaseManager { final _$AiDatabase _db; @@ -2252,4 +3192,6 @@ class $AiDatabaseManager { $AiChatMessagesTableManager(_db, _db.aiChatMessages); $AiChatThreadsTableManager get aiChatThreads => $AiChatThreadsTableManager(_db, _db.aiChatThreads); + $ImageOcrResultsTableManager get imageOcrResults => + $ImageOcrResultsTableManager(_db, _db.imageOcrResults); } diff --git a/lib/db/dao/ai_image_ocr_dao.dart b/lib/db/dao/ai_image_ocr_dao.dart new file mode 100644 index 0000000000..6be52e2e73 --- /dev/null +++ b/lib/db/dao/ai_image_ocr_dao.dart @@ -0,0 +1,20 @@ +import 'package:drift/drift.dart'; + +import '../ai_database.dart'; + +part 'ai_image_ocr_dao.g.dart'; + +@DriftAccessor() +class AiImageOcrDao extends DatabaseAccessor + with _$AiImageOcrDaoMixin { + AiImageOcrDao(super.db); + + Future resultByMessageId(String messageId) => + (select(db.imageOcrResults)..where( + (tbl) => tbl.messageId.equals(messageId), + )) + .getSingleOrNull(); + + Future upsertResult(ImageOcrResultsCompanion row) => + into(db.imageOcrResults).insertOnConflictUpdate(row); +} diff --git a/lib/db/dao/ai_image_ocr_dao.g.dart b/lib/db/dao/ai_image_ocr_dao.g.dart new file mode 100644 index 0000000000..a65934727f --- /dev/null +++ b/lib/db/dao/ai_image_ocr_dao.g.dart @@ -0,0 +1,13 @@ +// GENERATED CODE - DO NOT MODIFY BY HAND + +part of 'ai_image_ocr_dao.dart'; + +// ignore_for_file: type=lint +mixin _$AiImageOcrDaoMixin on DatabaseAccessor { + AiImageOcrDaoManager get managers => AiImageOcrDaoManager(this); +} + +class AiImageOcrDaoManager { + final _$AiImageOcrDaoMixin _db; + AiImageOcrDaoManager(this._db); +} diff --git a/lib/db/database.dart b/lib/db/database.dart index 200f46ce60..412daf3656 100644 --- a/lib/db/database.dart +++ b/lib/db/database.dart @@ -6,6 +6,7 @@ import '../utils/logger.dart'; import '../utils/property/setting_property.dart'; import 'ai_database.dart'; import 'dao/ai_chat_message_dao.dart'; +import 'dao/ai_image_ocr_dao.dart'; import 'dao/app_dao.dart'; import 'dao/asset_dao.dart'; import 'dao/chain_dao.dart'; @@ -39,7 +40,12 @@ import 'fts_database.dart'; import 'mixin_database.dart'; class Database { - Database(this.mixinDatabase, this.ftsDatabase, this.aiDatabase) { + Database( + this.mixinDatabase, + this.ftsDatabase, + this.aiDatabase, { + this.identityNumber, + }) { settingProperties = SettingPropertyStorage(mixinDatabase.propertyDao); } @@ -49,10 +55,14 @@ class Database { final AiDatabase aiDatabase; + final String? identityNumber; + AppDao get appDao => mixinDatabase.appDao; AiChatMessageDao get aiChatMessageDao => aiDatabase.aiChatMessageDao; + AiImageOcrDao get aiImageOcrDao => aiDatabase.aiImageOcrDao; + AssetDao get assetDao => mixinDatabase.assetDao; ChainDao get chainDao => mixinDatabase.chainDao; diff --git a/lib/db/moor/ai.drift b/lib/db/moor/ai.drift index 829ddaa6ef..7aab0abe55 100644 --- a/lib/db/moor/ai.drift +++ b/lib/db/moor/ai.drift @@ -33,7 +33,22 @@ CREATE TABLE ai_chat_threads ( PRIMARY KEY(id) ); +CREATE TABLE image_ocr_results ( + message_id TEXT NOT NULL, + conversation_id TEXT NOT NULL, + media_fingerprint TEXT NOT NULL, + engine TEXT NOT NULL, + status TEXT NOT NULL, + recognized_text TEXT NOT NULL DEFAULT '', + lines_json TEXT, + error_text TEXT, + created_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + updated_at INTEGER NOT NULL MAPPED BY `const MillisDateConverter()`, + PRIMARY KEY(message_id) +); + CREATE INDEX IF NOT EXISTS index_ai_chat_messages_conversation_id_created_at ON ai_chat_messages(conversation_id, created_at DESC); CREATE INDEX IF NOT EXISTS index_ai_chat_messages_thread_id_created_at ON ai_chat_messages(thread_id, created_at DESC); CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_updated_at ON ai_chat_threads(conversation_id, status, updated_at DESC); CREATE INDEX IF NOT EXISTS index_ai_chat_threads_conversation_id_last_message_at ON ai_chat_threads(conversation_id, status, last_message_at DESC); +CREATE INDEX IF NOT EXISTS index_image_ocr_results_conversation_id_updated_at ON image_ocr_results(conversation_id, updated_at DESC); diff --git a/lib/ui/provider/database_provider.dart b/lib/ui/provider/database_provider.dart index f1d0cbb767..d08074ef63 100644 --- a/lib/ui/provider/database_provider.dart +++ b/lib/ui/provider/database_provider.dart @@ -54,6 +54,7 @@ class DatabaseOpener extends DistinctStateNotifier> { mixinDatabase, await FtsDatabase.connect(identityNumber, fromMainIsolate: true), await AiDatabase.connect(identityNumber, fromMainIsolate: true), + identityNumber: identityNumber, ); // Do a database query, to ensure database has properly initialized. await mixinDatabase.doInitVerify(); diff --git a/macos/Podfile.lock b/macos/Podfile.lock index b0bd50933c..9b34a51357 100644 --- a/macos/Podfile.lock +++ b/macos/Podfile.lock @@ -27,8 +27,6 @@ PODS: - FlutterMacOS - network_info_plus (0.0.1): - FlutterMacOS - - objective_c (0.0.1): - - FlutterMacOS - ogg_opus_player (0.0.1): - FlutterMacOS - open_file_mac (1.0.3): @@ -102,7 +100,6 @@ DEPENDENCIES: - local_auth_darwin (from `Flutter/ephemeral/.symlinks/plugins/local_auth_darwin/darwin`) - mixin_logger (from `Flutter/ephemeral/.symlinks/plugins/mixin_logger/macos`) - network_info_plus (from `Flutter/ephemeral/.symlinks/plugins/network_info_plus/macos`) - - objective_c (from `Flutter/ephemeral/.symlinks/plugins/objective_c/macos`) - ogg_opus_player (from `Flutter/ephemeral/.symlinks/plugins/ogg_opus_player/macos`) - open_file_mac (from `Flutter/ephemeral/.symlinks/plugins/open_file_mac/macos`) - package_info_plus (from `Flutter/ephemeral/.symlinks/plugins/package_info_plus/macos`) @@ -154,8 +151,6 @@ EXTERNAL SOURCES: :path: Flutter/ephemeral/.symlinks/plugins/mixin_logger/macos network_info_plus: :path: Flutter/ephemeral/.symlinks/plugins/network_info_plus/macos - objective_c: - :path: Flutter/ephemeral/.symlinks/plugins/objective_c/macos ogg_opus_player: :path: Flutter/ephemeral/.symlinks/plugins/ogg_opus_player/macos open_file_mac: @@ -202,7 +197,6 @@ SPEC CHECKSUMS: local_auth_darwin: c3ee6cce0a8d56be34c8ccb66ba31f7f180aaebb mixin_logger: 6b31328b08f546a8defd32cd910370562fc48405 network_info_plus: 21d1cd6a015ccb2fdff06a1fbfa88d54b4e92f61 - objective_c: 2f927c775f7ad0d1ee8f78b4b0a5ddf03b2548d7 ogg_opus_player: 40ad7ee05152b420727fdb922afa0a90763b1817 open_file_mac: 76f06c8597551249bdb5e8fd8827a98eae0f4585 package_info_plus: f0052d280d17aa382b932f399edf32507174e870 diff --git a/macos/Runner.xcodeproj/project.pbxproj b/macos/Runner.xcodeproj/project.pbxproj index 078af80835..2e6179e08e 100644 --- a/macos/Runner.xcodeproj/project.pbxproj +++ b/macos/Runner.xcodeproj/project.pbxproj @@ -369,7 +369,6 @@ "${BUILT_PRODUCTS_DIR}/local_auth_darwin/local_auth_darwin.framework", "${BUILT_PRODUCTS_DIR}/mixin_logger/mixin_logger.framework", "${BUILT_PRODUCTS_DIR}/network_info_plus/network_info_plus.framework", - "${BUILT_PRODUCTS_DIR}/objective_c/objective_c.framework", "${BUILT_PRODUCTS_DIR}/ogg_opus_player/ogg_opus_player.framework", "${BUILT_PRODUCTS_DIR}/open_file_mac/open_file_mac.framework", "${BUILT_PRODUCTS_DIR}/package_info_plus/package_info_plus.framework", @@ -402,7 +401,6 @@ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/local_auth_darwin.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/mixin_logger.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/network_info_plus.framework", - "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/objective_c.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/ogg_opus_player.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/open_file_mac.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/package_info_plus.framework", diff --git a/test/ai/ai_conversation_context_test.dart b/test/ai/ai_conversation_context_test.dart index 68b8f928d7..96f238a20f 100644 --- a/test/ai/ai_conversation_context_test.dart +++ b/test/ai/ai_conversation_context_test.dart @@ -224,6 +224,38 @@ void main() { expect(prompt, contains('noise before transcript')); }, ); + + test('attached image prompt includes OCR context status', () async { + final createdAt = DateTime(2026, 4, 30, 12); + await _insertMessage( + database, + id: 'image', + userId: 'alice', + content: '', + createdAt: createdAt, + category: MessageCategory.plainImage, + mediaUrl: 'missing-image.png', + ); + + final attached = await database.messageDao + .messageItemByMessageId('image') + .getSingle(); + final promptMessages = await AiChatPromptBuilder(database) + .buildPromptMessages( + 'conversation', + 'thread', + 'what text is in this image?', + 'English', + attachedMessages: [attached], + ); + final prompt = promptMessages + .map((message) => message.content) + .join('\n'); + + expect(prompt, contains('OCR text from primary attached image:')); + expect(prompt, contains('status=error')); + expect(prompt, contains('local image file is not available')); + }); }); } @@ -245,6 +277,7 @@ Future _insertMessage( required String content, required DateTime createdAt, String category = MessageCategory.plainText, + String? mediaUrl, String? quoteMessageId, String? quoteContent, }) async { @@ -257,6 +290,7 @@ Future _insertMessage( userId: userId, category: category, content: Value(content), + mediaUrl: Value(mediaUrl), status: MessageStatus.read, createdAt: createdAt, quoteMessageId: Value(quoteMessageId), From afae9313f1dd4930e049dad141ea9f9174610928 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Fri, 1 May 2026 11:28:39 +0800 Subject: [PATCH 46/52] feat: add MCP server integration with UI controls and settings --- lib/db/dao/ai_chat_message_dao.dart | 4 + lib/ui/home/chat/input_container.dart | 14 + lib/ui/home/home.dart | 36 +- lib/ui/setting/ai_settings_page.dart | 132 ++++ lib/utils/mcp/mixin_mcp_bridge.dart | 143 ++++ lib/utils/mcp/mixin_mcp_server.dart | 829 +++++++++++++++++++++++ lib/utils/property/setting_property.dart | 25 + 7 files changed, 1179 insertions(+), 4 deletions(-) create mode 100644 lib/utils/mcp/mixin_mcp_bridge.dart create mode 100644 lib/utils/mcp/mixin_mcp_server.dart diff --git a/lib/db/dao/ai_chat_message_dao.dart b/lib/db/dao/ai_chat_message_dao.dart index 15409c6e98..894debb84a 100644 --- a/lib/db/dao/ai_chat_message_dao.dart +++ b/lib/db/dao/ai_chat_message_dao.dart @@ -72,6 +72,10 @@ class AiChatMessageDao extends DatabaseAccessor db.aiChatThreads, )..where((tbl) => tbl.id.equals(threadId))).getSingleOrNull(); + Future messageById(String messageId) => (select( + db.aiChatMessages, + )..where((tbl) => tbl.id.equals(messageId))).getSingleOrNull(); + Future createThread(String conversationId) async { final now = DateTime.now(); final thread = AiChatThread( diff --git a/lib/ui/home/chat/input_container.dart b/lib/ui/home/chat/input_container.dart index 4252c85727..de46e34f54 100644 --- a/lib/ui/home/chat/input_container.dart +++ b/lib/ui/home/chat/input_container.dart @@ -34,6 +34,7 @@ import '../../../utils/app_lifecycle.dart'; import '../../../utils/extension/extension.dart'; import '../../../utils/file.dart'; import '../../../utils/hook.dart'; +import '../../../utils/mcp/mixin_mcp_bridge.dart'; import '../../../utils/platform.dart'; import '../../../utils/reg_exp_utils.dart'; import '../../../utils/system/clipboard.dart'; @@ -186,6 +187,19 @@ class _InputContainer extends HookConsumerWidget { ); }, [conversationId]); + useEffect(() { + final currentConversationId = conversationId; + if (currentConversationId == null) return null; + MixinMcpBridge.instance.bindInputController( + currentConversationId, + textEditingController, + ); + return () => MixinMcpBridge.instance.unbindInputController( + currentConversationId, + textEditingController, + ); + }, [conversationId, textEditingController]); + final textEditingValueStream = useValueNotifierConvertSteam( textEditingController, ); diff --git a/lib/ui/home/home.dart b/lib/ui/home/home.dart index 52a2e6c1ef..0a18970042 100644 --- a/lib/ui/home/home.dart +++ b/lib/ui/home/home.dart @@ -1,3 +1,5 @@ +import 'dart:async'; + import 'package:flutter/material.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart' @@ -11,6 +13,8 @@ import '../../utils/audio_message_player/audio_message_service.dart'; import '../../utils/device_transfer/device_transfer_widget.dart'; import '../../utils/extension/extension.dart'; import '../../utils/hook.dart'; +import '../../utils/mcp/mixin_mcp_bridge.dart'; +import '../../utils/mcp/mixin_mcp_server.dart'; import '../../utils/platform.dart'; import '../../utils/system/package_info.dart'; import '../../utils/system/text_input.dart'; @@ -53,12 +57,36 @@ class HomePage extends HookConsumerWidget { @override Widget build(BuildContext context, WidgetRef ref) { + final database = context.database; + final accountServer = context.accountServer; + useListenable(database.settingProperties); + final enableMcpServer = database.settingProperties.enableMcpServer; + + useEffect(() { + MixinMcpBridge.instance.rootContext = context; + if (enableMcpServer) { + unawaited( + MixinMcpServer.instance.start( + database: database, + userId: accountServer.userId, + currentConversationId: () => + ref.read(currentConversationIdProvider), + ), + ); + } else { + unawaited(MixinMcpServer.instance.stop()); + } + return () { + unawaited(MixinMcpServer.instance.stop()); + }; + }, [database, accountServer.userId, enableMcpServer]); + final localTimeError = useMemoizedStream( - () => context.accountServer.connectedStateStream + () => accountServer.connectedStateStream .map((event) => event == ConnectedState.hasLocalTimeError) .distinct(), - keys: [context.accountServer], + keys: [accountServer], ).data ?? false; @@ -68,8 +96,8 @@ class HomePage extends HookConsumerWidget { final updateRequired = useMemoizedStream( - () => context.accountServer.isUpdateRequired, - keys: [context.accountServer], + () => accountServer.isUpdateRequired, + keys: [accountServer], ).data ?? false; diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index c2e0cec8ef..79004dced5 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -1,11 +1,13 @@ import 'package:flutter/cupertino.dart'; import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; import '../../ai/model/ai_prompt_template.dart'; import '../../ai/model/ai_provider_config.dart'; import '../../utils/extension/extension.dart'; +import '../../utils/mcp/mixin_mcp_server.dart'; import '../../widgets/app_bar.dart'; import '../../widgets/cell.dart'; import '../../widgets/dialog.dart'; @@ -38,6 +40,10 @@ class AiSettingsPage extends HookConsumerWidget { ), ) .length; + final mcpServer = useListenable(MixinMcpServer.instance); + final enableMcpServer = database.settingProperties.enableMcpServer; + final mcpEndpoint = mcpServer.endpoint; + final mcpToken = database.settingProperties.mcpServerToken; return Scaffold( backgroundColor: context.theme.background, @@ -91,6 +97,132 @@ class AiSettingsPage extends HookConsumerWidget { ), ), ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + CellItem( + title: const Text('Local MCP Server'), + leading: Icon( + Icons.hub_outlined, + color: context.theme.icon, + ), + description: Text( + mcpServer.isRunning ? 'Running' : 'Off', + ), + trailing: Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: enableMcpServer, + onChanged: (value) { + database.settingProperties.enableMcpServer = + value; + }, + ), + ), + ), + if (enableMcpServer) ...[ + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + CellItem( + title: const Text('Endpoint'), + description: Expanded( + child: Text( + mcpEndpoint?.toString() ?? 'Starting...', + textAlign: TextAlign.end, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + trailing: IconButton( + onPressed: mcpEndpoint == null + ? null + : () { + Clipboard.setData( + ClipboardData( + text: mcpEndpoint.toString(), + ), + ); + showToastSuccessful(); + }, + icon: Icon( + Icons.copy_rounded, + color: context.theme.icon, + ), + ), + ), + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + CellItem( + title: const Text('Access Token'), + description: Expanded( + child: Text( + mcpToken ?? 'Unavailable', + textAlign: TextAlign.end, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + trailing: Row( + mainAxisSize: MainAxisSize.min, + children: [ + IconButton( + onPressed: mcpToken == null + ? null + : () { + Clipboard.setData( + ClipboardData(text: mcpToken), + ); + showToastSuccessful(); + }, + icon: Icon( + Icons.copy_rounded, + color: context.theme.icon, + ), + ), + IconButton( + onPressed: () { + database.settingProperties + .regenerateMcpServerToken(); + showToastSuccessful(); + }, + icon: Icon( + Icons.refresh_rounded, + color: context.theme.icon, + ), + ), + ], + ), + ), + ], + ], + ), + ), + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + 'Exposes read-only conversation tools, UI navigation, draft editing, and AI thread inspection on localhost only. It never sends messages.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), CellGroup( padding: const EdgeInsets.only(right: 10, left: 10), cellBackgroundColor: diff --git a/lib/utils/mcp/mixin_mcp_bridge.dart b/lib/utils/mcp/mixin_mcp_bridge.dart new file mode 100644 index 0000000000..eeadb4cb89 --- /dev/null +++ b/lib/utils/mcp/mixin_mcp_bridge.dart @@ -0,0 +1,143 @@ +import 'dart:async'; + +import 'package:flutter/widgets.dart'; + +import '../../db/database.dart'; +import '../../db/mixin_database.dart'; +import '../../ui/home/bloc/blink_cubit.dart'; +import '../../ui/home/bloc/message_bloc.dart'; +import '../../ui/provider/ai_context_attachment_provider.dart'; +import '../../ui/provider/conversation_provider.dart'; +import '../extension/extension.dart'; + +class MixinMcpBridge { + MixinMcpBridge._(); + + static final MixinMcpBridge instance = MixinMcpBridge._(); + + BuildContext? _rootContext; + String? _inputConversationId; + TextEditingController? _inputController; + + String? get activeInputConversationId => _inputConversationId; + + set rootContext(BuildContext context) { + _rootContext = context; + } + + void bindInputController( + String conversationId, + TextEditingController controller, + ) { + _inputConversationId = conversationId; + _inputController = controller; + } + + void unbindInputController( + String conversationId, + TextEditingController controller, + ) { + if (_inputConversationId != conversationId || + _inputController != controller) { + return; + } + _inputConversationId = null; + _inputController = null; + } + + Future openConversation(String conversationId) async { + final context = _requireContext(); + await ConversationStateNotifier.selectConversation(context, conversationId); + } + + Future revealMessage({ + required String conversationId, + required String messageId, + }) async { + final context = _requireContext(); + await ConversationStateNotifier.selectConversation( + context, + conversationId, + initIndexMessageId: messageId, + ); + unawaited( + Future.delayed(const Duration(milliseconds: 120), () { + try { + context.read().scrollTo(messageId); + context.read().blinkByMessageId(messageId); + } catch (_) {} + }), + ); + } + + Future getDraft(Database database, String conversationId) async { + final controller = _controllerFor(conversationId); + if (controller != null) return controller.text; + final conversation = await database.conversationDao + .conversationItem(conversationId) + .getSingleOrNull(); + return conversation?.draft ?? ''; + } + + Future setDraft( + Database database, + String conversationId, + String text, + ) async { + final controller = _controllerFor(conversationId); + if (controller != null) { + controller.value = TextEditingValue( + text: text, + selection: TextSelection.collapsed(offset: text.length), + ); + } + await database.conversationDao.updateDraft(conversationId, text); + } + + Future insertText( + Database database, + String conversationId, + String text, + ) async { + final controller = _controllerFor(conversationId); + if (controller == null) { + final current = await getDraft(database, conversationId); + await setDraft(database, conversationId, '$current$text'); + return; + } + final value = controller.value; + final selection = value.selection; + final start = selection.isValid ? selection.start : value.text.length; + final end = selection.isValid ? selection.end : value.text.length; + final next = value.text.replaceRange(start, end, text); + final offset = start + text.length; + controller.value = TextEditingValue( + text: next, + selection: TextSelection.collapsed(offset: offset), + ); + await database.conversationDao.updateDraft(conversationId, next); + } + + Future attachMessage({ + required String conversationId, + required MessageItem message, + }) async { + final context = _requireContext(); + context.providerContainer + .read(aiContextAttachmentProvider(conversationId).notifier) + .attachMessages([message]); + } + + TextEditingController? _controllerFor(String conversationId) { + if (_inputConversationId != conversationId) return null; + return _inputController; + } + + BuildContext _requireContext() { + final context = _rootContext; + if (context == null || !context.mounted) { + throw StateError('Mixin UI is unavailable'); + } + return context; + } +} diff --git a/lib/utils/mcp/mixin_mcp_server.dart b/lib/utils/mcp/mixin_mcp_server.dart new file mode 100644 index 0000000000..42dc398096 --- /dev/null +++ b/lib/utils/mcp/mixin_mcp_server.dart @@ -0,0 +1,829 @@ +import 'dart:async'; +import 'dart:convert'; +import 'dart:io'; + +import 'package:flutter/foundation.dart'; +import 'package:genkit/genkit.dart' as genkit; +import 'package:mcp_server/mcp_server.dart' as mcp; +import 'package:schemantic/schemantic.dart'; + +import '../../ai/model/ai_chat_metadata.dart'; +import '../../ai/tools/ai_conversation_tool_service.dart'; +import '../../db/ai_database.dart'; +import '../../db/dao/conversation_dao.dart'; +import '../../db/dao/message_dao.dart'; +import '../../db/database.dart'; +import '../../db/mixin_database.dart'; +import '../extension/extension.dart'; +import '../logger.dart'; +import '../system/package_info.dart'; +import 'mixin_mcp_bridge.dart'; + +typedef CurrentConversationIdResolver = String? Function(); + +class MixinMcpServer extends ChangeNotifier { + MixinMcpServer._(); + + static final MixinMcpServer instance = MixinMcpServer._(); + + mcp.Server? _server; + mcp.ServerTransport? _transport; + Database? _database; + String? _userId; + int? _port; + CurrentConversationIdResolver? _currentConversationId; + late AiConversationToolService _conversationTools; + List, Map>> _tools = + const []; + + Uri? get endpoint { + final port = _port; + if (_server == null || port == null) return null; + return Uri( + scheme: 'http', + host: InternetAddress.loopbackIPv4.address, + port: port, + path: '/mcp', + ); + } + + bool get isRunning => _server != null && _transport != null; + + Future start({ + required Database database, + required String userId, + required CurrentConversationIdResolver currentConversationId, + }) async { + if (_server != null && + identical(_database, database) && + _userId == userId) { + return; + } + await stop(); + _database = database; + _userId = userId; + _currentConversationId = currentConversationId; + _conversationTools = DatabaseAiConversationToolService(database); + _tools = _createGenkitTools(); + final token = database.settingProperties.mcpServerToken; + if (token == null || token.isEmpty) { + throw StateError('MCP access token is unavailable'); + } + final port = await _reserveLoopbackPort(); + final transport = mcp.StreamableHttpServerTransport( + config: mcp.StreamableHttpServerConfig( + host: InternetAddress.loopbackIPv4.address, + port: port, + fallbackPorts: const [], + authToken: token, + isJsonResponseEnabled: true, + enableGetStream: false, + ), + ); + await transport.start(); + final server = mcp.Server( + name: 'mixin-local', + version: '0.1.0', + capabilities: mcp.ServerCapabilities.simple(tools: true), + ); + for (final tool in _tools) { + _registerMcpTool(server, tool); + } + server.connect(transport); + _server = server; + _transport = transport; + _port = port; + i('Mixin MCP server listening at $endpoint'); + notifyListeners(); + } + + Future stop() async { + final server = _server; + final transport = _transport; + _server = null; + _transport = null; + _database = null; + _userId = null; + _port = null; + _currentConversationId = null; + _tools = const []; + if (server != null) { + server + ..disconnect() + ..dispose(); + transport?.close(); + i('Mixin MCP server stopped'); + notifyListeners(); + } + } + + Future> _callTool( + String name, + Map arguments, + ) async { + final database = _requireDatabase(); + switch (name) { + case 'mixin_get_app_status': + final info = await getPackageInfo(); + final conversationId = _currentConversationId?.call(); + return { + 'logged_in': _userId != null, + 'user_id': _userId, + 'identity_number': database.identityNumber, + 'active_conversation_id': conversationId, + 'active_input_conversation_id': + MixinMcpBridge.instance.activeInputConversationId, + 'app': { + 'name': info.appName, + 'version': info.version, + 'build_number': info.buildNumber, + }, + 'capabilities': _tools + .map((tool) => tool.name) + .toList(growable: false), + }; + case 'mixin_list_conversations': + final query = _optionalString(arguments, 'query'); + final limit = _int( + arguments, + 'limit', + defaultValue: 30, + min: 1, + max: 100, + ); + final conversations = query == null || query.trim().isEmpty + ? await database.conversationDao.conversationItems().get() + : await _searchConversations(database, query, limit); + return { + 'conversations': conversations + .take(limit) + .map(_conversationToJson) + .toList(growable: false), + }; + case 'mixin_get_conversation': + final conversation = await _conversationById( + database, + _requiredString(arguments, 'conversation_id'), + ); + return {'conversation': _conversationToJson(conversation)}; + case 'mixin_resolve_conversation': + return { + 'conversation': _conversationToJson( + await _resolveConversation(database, arguments), + ), + }; + case 'mixin_get_conversation_stats': + final stats = await _conversationTools.getConversationStats( + conversationId: _requiredString(arguments, 'conversation_id'), + startInclusive: _date(arguments, 'start'), + endExclusive: _date(arguments, 'end'), + ); + return stats.toJson(); + case 'mixin_read_messages': + final messages = await database.messageDao + .messagesByConversationIdAndCreatedAtRange( + _requiredString(arguments, 'conversation_id'), + offset: _int(arguments, 'offset', defaultValue: 0), + limit: _int( + arguments, + 'limit', + defaultValue: 50, + min: 1, + max: 200, + ), + startInclusive: _date(arguments, 'start'), + endExclusive: _date(arguments, 'end'), + ) + .get(); + return {'messages': _messagesToJson(messages)}; + case 'mixin_search_messages': + final conversationId = _optionalString(arguments, 'conversation_id'); + final messages = await database.fuzzySearchMessage( + query: _requiredString(arguments, 'query'), + limit: _int(arguments, 'limit', defaultValue: 20, min: 1, max: 50), + conversationIds: conversationId == null ? const [] : [conversationId], + anchorMessageId: _optionalString(arguments, 'anchor_id'), + ); + return {'messages': _searchMessagesToJson(messages)}; + case 'mixin_get_message': + final message = await _messageById( + database, + _requiredString(arguments, 'message_id'), + ); + return {'message': _messageToJson(message)}; + case 'mixin_get_message_context': + final message = await _messageById( + database, + _requiredString(arguments, 'message_id'), + ); + final limit = _int( + arguments, + 'limit', + defaultValue: 10, + min: 1, + max: 50, + ); + final info = await database.messageDao.messageOrderInfo( + message.messageId, + ); + if (info == null) throw StateError('Message order info not found'); + final before = await database.messageDao + .beforeMessagesByConversationId(info, message.conversationId, limit) + .get(); + final after = await database.messageDao + .afterMessagesByConversationId(info, message.conversationId, limit) + .get(); + return { + 'before': _messagesToJson(before.reversed), + 'message': _messageToJson(message), + 'after': _messagesToJson(after), + }; + case 'mixin_read_image_text': + final result = await _conversationTools.readImageText( + conversationId: _requiredString(arguments, 'conversation_id'), + messageId: _requiredString(arguments, 'message_id'), + ); + return result.toJson(); + case 'mixin_list_attachments': + final messages = await database.messageDao + .messagesByConversationIdAndCreatedAtRange( + _requiredString(arguments, 'conversation_id'), + limit: _int( + arguments, + 'limit', + defaultValue: 50, + min: 1, + max: 200, + ), + startInclusive: _date(arguments, 'start'), + endExclusive: _date(arguments, 'end'), + ) + .get(); + return { + 'attachments': messages + .where(_hasAttachment) + .map(_attachmentToJson) + .toList(growable: false), + }; + case 'mixin_open_conversation': + final conversationId = _requiredString(arguments, 'conversation_id'); + await MixinMcpBridge.instance.openConversation(conversationId); + return {'opened': true, 'conversation_id': conversationId}; + case 'mixin_reveal_message': + final message = await _messageById( + database, + _requiredString(arguments, 'message_id'), + ); + await MixinMcpBridge.instance.revealMessage( + conversationId: message.conversationId, + messageId: message.messageId, + ); + return { + 'revealed': true, + 'conversation_id': message.conversationId, + 'message_id': message.messageId, + }; + case 'mixin_get_draft': + final conversationId = _requiredString(arguments, 'conversation_id'); + return { + 'conversation_id': conversationId, + 'draft': await MixinMcpBridge.instance.getDraft( + database, + conversationId, + ), + }; + case 'mixin_set_draft': + final conversationId = _requiredString(arguments, 'conversation_id'); + await MixinMcpBridge.instance.setDraft( + database, + conversationId, + _requiredString(arguments, 'text'), + ); + return {'updated': true, 'conversation_id': conversationId}; + case 'mixin_insert_text': + final conversationId = _requiredString(arguments, 'conversation_id'); + await MixinMcpBridge.instance.insertText( + database, + conversationId, + _requiredString(arguments, 'text'), + ); + return {'updated': true, 'conversation_id': conversationId}; + case 'mixin_clear_draft': + final conversationId = _requiredString(arguments, 'conversation_id'); + await MixinMcpBridge.instance.setDraft(database, conversationId, ''); + return {'updated': true, 'conversation_id': conversationId}; + case 'mixin_attach_message_to_ai': + final message = await _messageById( + database, + _requiredString(arguments, 'message_id'), + ); + await MixinMcpBridge.instance.attachMessage( + conversationId: message.conversationId, + message: message, + ); + return { + 'attached': true, + 'conversation_id': message.conversationId, + 'message_id': message.messageId, + }; + case 'mixin_list_ai_threads': + final threads = await database.aiChatMessageDao + .watchThreads(_requiredString(arguments, 'conversation_id')) + .first; + return { + 'threads': threads.map(_aiThreadToJson).toList(growable: false), + }; + case 'mixin_read_ai_thread': + final threadId = _requiredString(arguments, 'thread_id'); + final thread = await database.aiChatMessageDao.threadById(threadId); + if (thread == null) throw StateError('AI thread not found'); + final messages = await database.aiChatMessageDao.threadMessages( + threadId, + ); + return { + 'thread': _aiThreadToJson(thread), + 'messages': messages.map(_aiMessageToJson).toList(growable: false), + }; + case 'mixin_get_ai_tool_events': + final messageId = _requiredString(arguments, 'message_id'); + final message = await database.aiChatMessageDao.messageById(messageId); + if (message == null) throw StateError('AI message not found'); + return { + 'message_id': message.id, + 'tool_events': aiMetadataToolEvents(message.metadata), + }; + default: + throw StateError('Unknown tool: $name'); + } + } + + List, Map>> + _createGenkitTools() => _toolSpecs + .map( + (spec) => genkit.Tool, Map>( + name: spec.name, + description: spec.description, + inputSchema: SchemanticType.from>( + jsonSchema: spec.inputSchema, + parse: _jsonMap, + ), + fn: (input, _) => _callTool(spec.name, input), + ), + ) + .toList(growable: false); + + void _registerMcpTool( + mcp.Server server, + genkit.Tool, Map> tool, + ) { + server.addTool( + name: tool.name, + description: tool.description ?? '', + inputSchema: Map.from( + tool.inputSchema?.jsonSchema() ?? _emptyObjectSchema, + ), + handler: (arguments) async { + final result = await tool.runRaw(arguments); + final data = result.result; + return mcp.CallToolResult( + content: [mcp.TextContent(text: const JsonEncoder().convert(data))], + structuredContent: data, + ); + }, + ); + } + + Database _requireDatabase() { + final database = _database; + if (database == null) throw StateError('Database is unavailable'); + return database; + } +} + +Future _conversationById( + Database database, + String conversationId, +) async { + final conversation = await database.conversationDao + .conversationItem(conversationId) + .getSingleOrNull(); + if (conversation == null) throw StateError('Conversation not found'); + return conversation; +} + +Future _resolveConversation( + Database database, + Map arguments, +) async { + final conversationId = _optionalString(arguments, 'conversation_id'); + if (conversationId != null) { + return _conversationById(database, conversationId); + } + final uri = _optionalString(arguments, 'uri'); + if (uri != null) { + final parsed = Uri.tryParse(uri); + final id = parsed?.host == 'conversations' + ? parsed?.pathSegments.firstOrNull + : null; + if (id != null) return _conversationById(database, id); + } + final query = _requiredString(arguments, 'query'); + final result = await database.conversationDao + .fuzzySearchConversation(query, 1) + .getSingleOrNull(); + if (result == null) throw StateError('Conversation not found'); + return _conversationById(database, result.conversationId); +} + +Future> _searchConversations( + Database database, + String query, + int limit, +) async { + final results = await database.conversationDao + .fuzzySearchConversation(query, limit) + .get(); + final conversations = []; + for (final result in results) { + final conversation = await database.conversationDao + .conversationItem(result.conversationId) + .getSingleOrNull(); + if (conversation != null) { + conversations.add(conversation); + } + } + return conversations; +} + +Future _messageById(Database database, String messageId) async { + final message = await database.messageDao + .messageItemByMessageId(messageId) + .getSingleOrNull(); + if (message == null) throw StateError('Message not found'); + return message; +} + +Map _conversationToJson(ConversationItem conversation) => { + 'conversation_id': conversation.conversationId, + 'name': conversation.validName, + 'category': conversation.category?.name, + 'owner_id': conversation.ownerId, + 'owner_identity_number': conversation.ownerIdentityNumber, + 'unread_count': conversation.unseenMessageCount, + 'is_muted': conversation.isMute, + 'is_group': conversation.isGroupConversation, + 'last_read_message_id': conversation.lastReadMessageId, + 'created_at': _dateTime(conversation.createdAt), + 'last_message_created_at': _dateTime(conversation.lastMessageCreatedAt), +}; + +List> _messagesToJson(Iterable messages) => + messages.map(_messageToJson).toList(growable: false); + +List> _searchMessagesToJson( + Iterable messages, +) => messages + .map( + (message) => { + 'message_id': message.messageId, + 'conversation_id': message.conversationId, + 'conversation_name': message.groupName?.trim().isNotEmpty == true + ? message.groupName + : message.ownerFullName, + 'user_id': message.senderId, + 'user_full_name': message.senderFullName, + 'type': message.type, + 'content': message.content, + 'created_at': _dateTime(message.createdAt), + 'status': message.status.name, + 'media_name': message.mediaName, + }..removeWhere((_, value) => value == null), + ) + .toList(growable: false); + +Map _messageToJson(MessageItem message) => { + 'message_id': message.messageId, + 'conversation_id': message.conversationId, + 'user_id': message.userId, + 'user_full_name': message.userFullName, + 'user_identity_number': message.userIdentityNumber, + 'type': message.type, + 'content': _messageContent(message), + 'created_at': _dateTime(message.createdAt), + 'status': message.status.name, + 'media': _hasAttachment(message) ? _attachmentToJson(message) : null, +}..removeWhere((_, value) => value == null); + +String? _messageContent(MessageItem message) { + final content = message.content; + final type = message.type; + if (content?.trim().isNotEmpty == true) return content; + if (type.isImage) return '[image]'; + if (type.isVideo) return '[video] ${message.mediaName ?? ''}'.trim(); + if (type.isAudio) return '[audio]'; + if (type.isData) return '[file] ${message.mediaName ?? ''}'.trim(); + if (type.isSticker) return '[sticker]'; + return content; +} + +bool _hasAttachment(MessageItem message) { + final type = message.type; + return type.isImage || type.isVideo || type.isAudio || type.isData; +} + +Map _attachmentToJson(MessageItem message) => { + 'message_id': message.messageId, + 'conversation_id': message.conversationId, + 'type': message.type, + 'name': message.mediaName, + 'mime_type': message.mediaMimeType, + 'size': message.mediaSize, + 'width': message.mediaWidth, + 'height': message.mediaHeight, + 'duration': message.mediaDuration, + 'status': message.mediaStatus?.name, + 'created_at': _dateTime(message.createdAt), +}..removeWhere((_, value) => value == null); + +Map _aiThreadToJson(AiChatThread thread) => { + 'thread_id': thread.id, + 'conversation_id': thread.conversationId, + 'title': thread.title, + 'summary': thread.summary, + 'last_message_preview': thread.lastMessagePreview, + 'message_count': thread.messageCount, + 'created_at': _dateTime(thread.createdAt), + 'updated_at': _dateTime(thread.updatedAt), + 'last_message_at': _dateTime(thread.lastMessageAt), +}..removeWhere((_, value) => value == null); + +Map _aiMessageToJson(AiChatMessage message) => { + 'message_id': message.id, + 'thread_id': message.threadId, + 'conversation_id': message.conversationId, + 'role': message.role, + 'provider_id': message.providerId, + 'model': message.model, + 'content': message.content, + 'status': message.status, + 'error_text': message.errorText, + 'metadata': message.metadata, + 'created_at': _dateTime(message.createdAt), + 'updated_at': _dateTime(message.updatedAt), +}..removeWhere((_, value) => value == null); + +String? _dateTime(DateTime? value) => value?.toIso8601String(); + +String _requiredString(Map arguments, String key) { + final value = _optionalString(arguments, key); + if (value == null || value.isEmpty) { + throw ArgumentError('$key is required'); + } + return value; +} + +String? _optionalString(Map arguments, String key) { + final value = arguments[key]; + if (value == null) return null; + final text = value.toString().trim(); + return text.isEmpty ? null : text; +} + +int _int( + Map arguments, + String key, { + required int defaultValue, + int min = 0, + int max = 1 << 31, +}) { + final value = arguments[key]; + final parsed = value is int ? value : int.tryParse(value?.toString() ?? ''); + return (parsed ?? defaultValue).clamp(min, max); +} + +DateTime? _date(Map arguments, String key) { + final text = _optionalString(arguments, key); + return text == null ? null : DateTime.parse(text); +} + +Future _reserveLoopbackPort() async { + final socket = await ServerSocket.bind(InternetAddress.loopbackIPv4, 0); + final port = socket.port; + await socket.close(); + return port; +} + +Map _jsonMap(dynamic json) { + if (json == null) return {}; + if (json is Map) return json; + if (json is Map) return json.cast(); + throw ArgumentError('Expected JSON object'); +} + +const _emptyObjectSchema = { + 'type': 'object', + 'properties': {}, + 'additionalProperties': false, +}; + +const _toolSpecs = [ + _Tool( + 'mixin_get_app_status', + 'Get login, active conversation, app version, and MCP capability status.', + ), + _Tool( + 'mixin_list_conversations', + 'List recent conversations or search conversations by query.', + properties: { + 'query': {'type': 'string'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_get_conversation', + 'Get one conversation by conversation_id.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_resolve_conversation', + 'Resolve a conversation from conversation_id, mixin URI, or query.', + properties: { + 'conversation_id': {'type': 'string'}, + 'uri': {'type': 'string'}, + 'query': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_get_conversation_stats', + 'Get message count and first/last timestamps for a conversation.', + required: ['conversation_id'], + properties: _conversationRangeProperties, + ), + _Tool( + 'mixin_read_messages', + 'Read conversation messages by range, offset, and limit.', + required: ['conversation_id'], + properties: { + ..._conversationRangeProperties, + 'offset': {'type': 'integer'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_search_messages', + 'Search messages globally or inside a conversation.', + required: ['query'], + properties: { + 'query': {'type': 'string'}, + 'conversation_id': {'type': 'string'}, + 'limit': {'type': 'integer'}, + 'anchor_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_get_message', + 'Get a message by message_id.', + required: ['message_id'], + properties: { + 'message_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_get_message_context', + 'Read messages around a message_id.', + required: ['message_id'], + properties: { + 'message_id': {'type': 'string'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_read_image_text', + 'Run local OCR for an image message.', + required: ['conversation_id', 'message_id'], + properties: { + 'conversation_id': {'type': 'string'}, + 'message_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_list_attachments', + 'List attachment metadata for a conversation.', + required: ['conversation_id'], + properties: { + ..._conversationRangeProperties, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_open_conversation', + 'Open a conversation in the Mixin UI.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_reveal_message', + 'Open the message conversation and reveal the message in the Mixin UI.', + required: ['message_id'], + properties: { + 'message_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_get_draft', + 'Get the current draft text for a conversation.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_set_draft', + 'Replace the draft text for a conversation. Does not send.', + required: ['conversation_id', 'text'], + properties: { + 'conversation_id': {'type': 'string'}, + 'text': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_insert_text', + 'Insert text into the active input, or append to stored draft.', + required: ['conversation_id', 'text'], + properties: { + 'conversation_id': {'type': 'string'}, + 'text': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_clear_draft', + 'Clear the draft text for a conversation. Does not send.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_attach_message_to_ai', + 'Attach a message to the app AI context chip for its conversation.', + required: ['message_id'], + properties: { + 'message_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_list_ai_threads', + 'List AI threads for a conversation.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_read_ai_thread', + 'Read one AI thread and its messages.', + required: ['thread_id'], + properties: { + 'thread_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_get_ai_tool_events', + 'Read stored AI tool call/result events for an AI message.', + required: ['message_id'], + properties: { + 'message_id': {'type': 'string'}, + }, + ), +]; + +const _conversationRangeProperties = { + 'conversation_id': {'type': 'string'}, + 'start': {'type': 'string', 'description': 'Inclusive ISO-8601 timestamp.'}, + 'end': {'type': 'string', 'description': 'Exclusive ISO-8601 timestamp.'}, +}; + +class _Tool { + const _Tool( + this.name, + this.description, { + this.required = const [], + this.properties = const {}, + }); + + final String name; + final String description; + final List required; + final Map properties; + + Map get inputSchema => { + ..._emptyObjectSchema, + 'properties': properties, + 'required': required, + }; +} diff --git a/lib/utils/property/setting_property.dart b/lib/utils/property/setting_property.dart index 8a008fb110..8f4058c85d 100644 --- a/lib/utils/property/setting_property.dart +++ b/lib/utils/property/setting_property.dart @@ -1,4 +1,5 @@ import 'dart:convert'; +import 'dart:math'; import 'package:mixin_logger/mixin_logger.dart'; @@ -18,14 +19,38 @@ const _kSelectedAiProviderKey = 'selected_ai_provider'; const _kSelectedAiTranslatorProviderKey = 'selected_ai_translator_provider'; const _kSelectedAiTranslatorModelKey = 'selected_ai_translator_model'; const _kAiPromptTemplateOverridesKey = 'ai_prompt_template_overrides'; +const _kEnableMcpServerKey = 'enable_mcp_server'; +const _kMcpServerTokenKey = 'mcp_server_token'; class SettingPropertyStorage extends PropertyStorage { SettingPropertyStorage(PropertyDao dao) : super(PropertyGroup.setting, dao); + static final Random _secureRandom = Random.secure(); + bool get enableProxy => get(_kEnableProxyKey) ?? false; set enableProxy(bool value) => set(_kEnableProxyKey, value); + bool get enableMcpServer => get(_kEnableMcpServerKey) ?? false; + + set enableMcpServer(bool value) { + if (value && mcpServerToken == null) { + regenerateMcpServerToken(); + } + set(_kEnableMcpServerKey, value); + } + + String? get mcpServerToken => get(_kMcpServerTokenKey); + + String regenerateMcpServerToken() { + final token = List.generate( + 32, + (_) => _secureRandom.nextInt(256).toRadixString(16).padLeft(2, '0'), + ).join(); + set(_kMcpServerTokenKey, token); + return token; + } + String? get selectedProxyId => get(_kSelectedProxyKey); set selectedProxyId(String? value) => set(_kSelectedProxyKey, value); From b452252181d5d49a80854e9f2ecc033283ba4f3f Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Wed, 6 May 2026 21:08:53 +0800 Subject: [PATCH 47/52] feat: add circle management and enhanced MCP server functionalities --- lib/db/dao/message_dao.dart | 77 ++- lib/db/dao/pin_message_dao.dart | 11 + lib/ui/setting/ai_settings_page.dart | 82 ++- lib/utils/mcp/mixin_mcp_bridge.dart | 44 ++ lib/utils/mcp/mixin_mcp_server.dart | 726 +++++++++++++++++++++-- lib/utils/property/setting_property.dart | 12 + pubspec.lock | 35 +- pubspec.yaml | 11 +- 8 files changed, 929 insertions(+), 69 deletions(-) diff --git a/lib/db/dao/message_dao.dart b/lib/db/dao/message_dao.dart index 60c1b18373..48e1a1e34e 100644 --- a/lib/db/dao/message_dao.dart +++ b/lib/db/dao/message_dao.dart @@ -23,6 +23,21 @@ class MessageOrderInfo { final int createdAt; } +const _attachmentMessageCategories = [ + MessageCategory.signalImage, + MessageCategory.signalVideo, + MessageCategory.signalData, + MessageCategory.signalAudio, + MessageCategory.plainImage, + MessageCategory.plainVideo, + MessageCategory.plainData, + MessageCategory.plainAudio, + MessageCategory.encryptedImage, + MessageCategory.encryptedVideo, + MessageCategory.encryptedData, + MessageCategory.encryptedAudio, +]; + @DriftAccessor(include: {'../moor/dao/message.drift'}) class MessageDao extends DatabaseAccessor with _$MessageDaoMixin { @@ -690,13 +705,25 @@ class MessageDao extends DatabaseAccessor int offset = 0, DateTime? startInclusive, DateTime? endExclusive, + String? senderId, + String? senderIdentityNumber, + List categories = const [], bool ascending = true, }) { final startMillis = startInclusive?.millisecondsSinceEpoch; final endMillis = endExclusive?.millisecondsSinceEpoch; return _baseMessageItems( - (message, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => + (message, sender, _, _, _, _, _, _, _, _, _, _, _, _, em) => message.conversationId.equals(conversationId) & + (senderId == null + ? const Constant(true) + : message.userId.equals(senderId)) & + (senderIdentityNumber == null + ? const Constant(true) + : sender.identityNumber.equals(senderIdentityNumber)) & + (categories.isEmpty + ? const Constant(true) + : message.category.isIn(categories)) & (startMillis == null ? const Constant(true) : message.createdAt.isBiggerOrEqualValue(startMillis)) & @@ -713,6 +740,54 @@ class MessageDao extends DatabaseAccessor ); } + Selectable mentionMessagesByConversationId( + String conversationId, { + required int limit, + int offset = 0, + bool unreadOnly = false, + }) => _baseMessageItems( + (message, _, _, _, _, _, _, _, _, _, _, _, messageMention, _, em) => + message.conversationId.equals(conversationId) & + messageMention.messageId.isNotNull() & + (unreadOnly + ? messageMention.hasRead.equals(false) + : const Constant(true)), + (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, offset), + ); + + Selectable attachmentMessagesByConversationId( + String conversationId, { + required int limit, + int offset = 0, + DateTime? startInclusive, + DateTime? endExclusive, + String? senderId, + String? senderIdentityNumber, + List categories = const [], + }) => messagesByConversationIdAndCreatedAtRange( + conversationId, + limit: limit, + offset: offset, + startInclusive: startInclusive, + endExclusive: endExclusive, + senderId: senderId, + senderIdentityNumber: senderIdentityNumber, + categories: categories.isEmpty ? _attachmentMessageCategories : categories, + ); + + Selectable linkMessagesByConversationId( + String conversationId, { + required int limit, + int offset = 0, + }) => _baseMessageItems( + (message, _, _, _, _, _, _, _, _, hyperlink, _, _, _, _, em) => + message.conversationId.equals(conversationId) & + message.hyperlink.isNotNull() & + message.hyperlink.equals('').not() & + hyperlink.hyperlink.isNotNull(), + (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, offset), + ); + Selectable messageCountByConversationIdAndCreatedAtRange( String conversationId, { DateTime? startInclusive, diff --git a/lib/db/dao/pin_message_dao.dart b/lib/db/dao/pin_message_dao.dart index 7d0a276383..495d856ce9 100644 --- a/lib/db/dao/pin_message_dao.dart +++ b/lib/db/dao/pin_message_dao.dart @@ -74,6 +74,17 @@ class PinMessageDao extends DatabaseAccessor (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => maxLimit, ); + Future> pinMessagesByConversationId({ + required String conversationId, + required int limit, + required int offset, + }) => + (select(db.pinMessages) + ..where((tbl) => tbl.conversationId.equals(conversationId)) + ..orderBy([(tbl) => OrderingTerm.desc(tbl.createdAt)]) + ..limit(limit, offset: offset)) + .get(); + Future> getPinMessages({ required int limit, required int offset, diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index 79004dced5..e52a2d07cd 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -42,8 +42,13 @@ class AiSettingsPage extends HookConsumerWidget { .length; final mcpServer = useListenable(MixinMcpServer.instance); final enableMcpServer = database.settingProperties.enableMcpServer; + final enableMcpDraftTools = database.settingProperties.enableMcpDraftTools; + final enableMcpCircleManagement = + database.settingProperties.enableMcpCircleManagement; final mcpEndpoint = mcpServer.endpoint; final mcpToken = database.settingProperties.mcpServerToken; + const mcpPort = MixinMcpServer.defaultPort; + final mcpError = mcpServer.lastStartError; return Scaffold( backgroundColor: context.theme.background, @@ -135,7 +140,8 @@ class AiSettingsPage extends HookConsumerWidget { title: const Text('Endpoint'), description: Expanded( child: Text( - mcpEndpoint?.toString() ?? 'Starting...', + mcpEndpoint?.toString() ?? + 'http://127.0.0.1:$mcpPort/mcp', textAlign: TextAlign.end, maxLines: 1, overflow: TextOverflow.ellipsis, @@ -158,6 +164,29 @@ class AiSettingsPage extends HookConsumerWidget { ), ), ), + if (mcpError != null) ...[ + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + CellItem( + title: const Text('Status'), + description: Expanded( + child: Text( + 'Failed to bind port $mcpPort', + textAlign: TextAlign.end, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.red, + ), + ), + ), + trailing: null, + ), + ], Divider( height: 0.5, indent: 16, @@ -205,6 +234,52 @@ class AiSettingsPage extends HookConsumerWidget { ], ), ), + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + CellItem( + title: const Text('Draft Editing'), + description: const Text('Draft write tools'), + trailing: Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: enableMcpDraftTools, + onChanged: (value) { + database + .settingProperties + .enableMcpDraftTools = + value; + }, + ), + ), + ), + Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ), + CellItem( + title: const Text('Circle Management'), + description: const Text('Create and edit circles'), + trailing: Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: enableMcpCircleManagement, + onChanged: (value) { + database + .settingProperties + .enableMcpCircleManagement = + value; + }, + ), + ), + ), ], ], ), @@ -216,7 +291,10 @@ class AiSettingsPage extends HookConsumerWidget { top: 10, ), child: Text( - 'Exposes read-only conversation tools, UI navigation, draft editing, and AI thread inspection on localhost only. It never sends messages.', + 'Exposes read-only conversation tools, UI navigation, ' + 'and AI thread inspection on localhost only at port ' + '$mcpPort. Draft and circle write tools require their ' + 'own switches. It never sends messages.', style: TextStyle( color: context.theme.secondaryText, fontSize: 14, diff --git a/lib/utils/mcp/mixin_mcp_bridge.dart b/lib/utils/mcp/mixin_mcp_bridge.dart index eeadb4cb89..a127b15518 100644 --- a/lib/utils/mcp/mixin_mcp_bridge.dart +++ b/lib/utils/mcp/mixin_mcp_bridge.dart @@ -1,6 +1,8 @@ import 'dart:async'; import 'package:flutter/widgets.dart'; +import 'package:mixin_bot_sdk_dart/mixin_bot_sdk_dart.dart' + show CircleConversationRequest; import '../../db/database.dart'; import '../../db/mixin_database.dart'; @@ -128,6 +130,48 @@ class MixinMcpBridge { .attachMessages([message]); } + Future createCircle({ + required String name, + required List conversations, + }) async { + final context = _requireContext(); + await context.accountServer.createCircle(name, conversations); + } + + Future renameCircle({ + required String circleId, + required String name, + }) async { + final context = _requireContext(); + await context.accountServer.updateCircle(circleId, name); + } + + Future deleteCircle(String circleId) async { + final context = _requireContext(); + await context.accountServer.deleteCircle(circleId); + } + + Future addConversationsToCircle({ + required String circleId, + required List conversations, + }) async { + final context = _requireContext(); + await context.accountServer.editCircleConversation(circleId, conversations); + } + + Future removeConversationsFromCircle({ + required String circleId, + required List conversationIds, + }) async { + final context = _requireContext(); + for (final conversationId in conversationIds) { + await context.accountServer.circleRemoveConversation( + circleId, + conversationId, + ); + } + } + TextEditingController? _controllerFor(String conversationId) { if (_inputConversationId != conversationId) return null; return _inputController; diff --git a/lib/utils/mcp/mixin_mcp_server.dart b/lib/utils/mcp/mixin_mcp_server.dart index 42dc398096..4fb9167662 100644 --- a/lib/utils/mcp/mixin_mcp_server.dart +++ b/lib/utils/mcp/mixin_mcp_server.dart @@ -5,13 +5,17 @@ import 'dart:io'; import 'package:flutter/foundation.dart'; import 'package:genkit/genkit.dart' as genkit; import 'package:mcp_server/mcp_server.dart' as mcp; +import 'package:mixin_bot_sdk_dart/mixin_bot_sdk_dart.dart' + show CircleConversationAction, CircleConversationRequest; import 'package:schemantic/schemantic.dart'; import '../../ai/model/ai_chat_metadata.dart'; import '../../ai/tools/ai_conversation_tool_service.dart'; import '../../db/ai_database.dart'; +import '../../db/dao/circle_dao.dart'; import '../../db/dao/conversation_dao.dart'; import '../../db/dao/message_dao.dart'; +import '../../db/dao/participant_dao.dart'; import '../../db/database.dart'; import '../../db/mixin_database.dart'; import '../extension/extension.dart'; @@ -21,34 +25,52 @@ import 'mixin_mcp_bridge.dart'; typedef CurrentConversationIdResolver = String? Function(); +enum _McpPermissionScope { + read, + appControl, + draftWrite, + circleManagement, +} + +extension on _McpPermissionScope { + String get key => switch (this) { + _McpPermissionScope.read => 'read', + _McpPermissionScope.appControl => 'app_control', + _McpPermissionScope.draftWrite => 'draft_write', + _McpPermissionScope.circleManagement => 'circle_management', + }; +} + class MixinMcpServer extends ChangeNotifier { MixinMcpServer._(); static final MixinMcpServer instance = MixinMcpServer._(); + static const int defaultPort = 55001; mcp.Server? _server; mcp.ServerTransport? _transport; Database? _database; String? _userId; - int? _port; CurrentConversationIdResolver? _currentConversationId; late AiConversationToolService _conversationTools; List, Map>> _tools = const []; + Object? _lastStartError; Uri? get endpoint { - final port = _port; - if (_server == null || port == null) return null; + if (_server == null) return null; return Uri( scheme: 'http', host: InternetAddress.loopbackIPv4.address, - port: port, + port: defaultPort, path: '/mcp', ); } bool get isRunning => _server != null && _transport != null; + Object? get lastStartError => _lastStartError; + Future start({ required Database database, required String userId, @@ -69,32 +91,43 @@ class MixinMcpServer extends ChangeNotifier { if (token == null || token.isEmpty) { throw StateError('MCP access token is unavailable'); } - final port = await _reserveLoopbackPort(); - final transport = mcp.StreamableHttpServerTransport( - config: mcp.StreamableHttpServerConfig( - host: InternetAddress.loopbackIPv4.address, - port: port, - fallbackPorts: const [], - authToken: token, - isJsonResponseEnabled: true, - enableGetStream: false, - ), - ); - await transport.start(); - final server = mcp.Server( - name: 'mixin-local', - version: '0.1.0', - capabilities: mcp.ServerCapabilities.simple(tools: true), - ); - for (final tool in _tools) { - _registerMcpTool(server, tool); + _lastStartError = null; + try { + final transport = mcp.StreamableHttpServerTransport( + config: mcp.StreamableHttpServerConfig( + host: InternetAddress.loopbackIPv4.address, + port: defaultPort, + fallbackPorts: const [], + authToken: token, + isJsonResponseEnabled: true, + enableGetStream: false, + ), + ); + await transport.start(); + final server = mcp.Server( + name: 'mixin-local', + version: '0.1.0', + capabilities: mcp.ServerCapabilities.simple(tools: true), + ); + for (final tool in _tools) { + _registerMcpTool(server, tool); + } + server.connect(transport); + _server = server; + _transport = transport; + i('Mixin MCP server listening at $endpoint'); + notifyListeners(); + } catch (error, stacktrace) { + _lastStartError = error; + notifyListeners(); + e( + 'Failed to start Mixin MCP server on ' + '${InternetAddress.loopbackIPv4.address}:$defaultPort: ' + '$error', + stacktrace, + ); + rethrow; } - server.connect(transport); - _server = server; - _transport = transport; - _port = port; - i('Mixin MCP server listening at $endpoint'); - notifyListeners(); } Future stop() async { @@ -104,7 +137,6 @@ class MixinMcpServer extends ChangeNotifier { _transport = null; _database = null; _userId = null; - _port = null; _currentConversationId = null; _tools = const []; if (server != null) { @@ -122,6 +154,7 @@ class MixinMcpServer extends ChangeNotifier { Map arguments, ) async { final database = _requireDatabase(); + _ensureToolEnabled(database, name); switch (name) { case 'mixin_get_app_status': final info = await getPackageInfo(); @@ -138,12 +171,18 @@ class MixinMcpServer extends ChangeNotifier { 'version': info.version, 'build_number': info.buildNumber, }, + 'permission_scopes': _permissionScopes(database), 'capabilities': _tools .map((tool) => tool.name) .toList(growable: false), + 'enabled_capabilities': _toolSpecs + .where((spec) => _toolEnabled(database, spec)) + .map((spec) => spec.name) + .toList(growable: false), }; case 'mixin_list_conversations': final query = _optionalString(arguments, 'query'); + final circleId = _optionalString(arguments, 'circle_id'); final limit = _int( arguments, 'limit', @@ -151,11 +190,17 @@ class MixinMcpServer extends ChangeNotifier { min: 1, max: 100, ); - final conversations = query == null || query.trim().isEmpty + final offset = _int(arguments, 'offset', defaultValue: 0); + final conversations = circleId != null + ? await database.conversationDao + .conversationsByCircleId(circleId, limit, offset) + .get() + : query == null || query.trim().isEmpty ? await database.conversationDao.conversationItems().get() - : await _searchConversations(database, query, limit); + : await _searchConversations(database, query, offset + limit); return { 'conversations': conversations + .skip(circleId == null ? offset : 0) .take(limit) .map(_conversationToJson) .toList(growable: false), @@ -193,15 +238,39 @@ class MixinMcpServer extends ChangeNotifier { ), startInclusive: _date(arguments, 'start'), endExclusive: _date(arguments, 'end'), + senderId: _optionalString(arguments, 'sender_id'), + senderIdentityNumber: _optionalString( + arguments, + 'sender_identity_number', + ), + categories: _optionalStringList(arguments, 'message_types'), ) .get(); - return {'messages': _messagesToJson(messages)}; + return { + 'messages': _messagesToJson( + messages, + includePinState: _bool(arguments, 'include_pin_state'), + ), + }; case 'mixin_search_messages': final conversationId = _optionalString(arguments, 'conversation_id'); + final circleId = _optionalString(arguments, 'circle_id'); + final conversationIds = conversationId == null + ? circleId == null + ? const [] + : await database.conversationDao.conversationIdsByCircleId( + circleId, + ) + : [conversationId]; + if (circleId != null && conversationIds.isEmpty) { + return {'messages': const >[]}; + } final messages = await database.fuzzySearchMessage( query: _requiredString(arguments, 'query'), limit: _int(arguments, 'limit', defaultValue: 20, min: 1, max: 50), - conversationIds: conversationId == null ? const [] : [conversationId], + conversationIds: conversationIds, + userId: _optionalString(arguments, 'sender_id'), + categories: _optionalStringList(arguments, 'message_types'), anchorMessageId: _optionalString(arguments, 'anchor_id'), ); return {'messages': _searchMessagesToJson(messages)}; @@ -210,7 +279,7 @@ class MixinMcpServer extends ChangeNotifier { database, _requiredString(arguments, 'message_id'), ); - return {'message': _messageToJson(message)}; + return {'message': _messageToJson(message, includePinState: true)}; case 'mixin_get_message_context': final message = await _messageById( database, @@ -234,9 +303,9 @@ class MixinMcpServer extends ChangeNotifier { .afterMessagesByConversationId(info, message.conversationId, limit) .get(); return { - 'before': _messagesToJson(before.reversed), - 'message': _messageToJson(message), - 'after': _messagesToJson(after), + 'before': _messagesToJson(before.reversed, includePinState: true), + 'message': _messageToJson(message, includePinState: true), + 'after': _messagesToJson(after, includePinState: true), }; case 'mixin_read_image_text': final result = await _conversationTools.readImageText( @@ -246,8 +315,9 @@ class MixinMcpServer extends ChangeNotifier { return result.toJson(); case 'mixin_list_attachments': final messages = await database.messageDao - .messagesByConversationIdAndCreatedAtRange( + .attachmentMessagesByConversationId( _requiredString(arguments, 'conversation_id'), + offset: _int(arguments, 'offset', defaultValue: 0), limit: _int( arguments, 'limit', @@ -257,6 +327,12 @@ class MixinMcpServer extends ChangeNotifier { ), startInclusive: _date(arguments, 'start'), endExclusive: _date(arguments, 'end'), + senderId: _optionalString(arguments, 'sender_id'), + senderIdentityNumber: _optionalString( + arguments, + 'sender_identity_number', + ), + categories: _optionalStringList(arguments, 'message_types'), ) .get(); return { @@ -265,6 +341,135 @@ class MixinMcpServer extends ChangeNotifier { .map(_attachmentToJson) .toList(growable: false), }; + case 'mixin_list_pinned_messages': + final conversationId = _requiredString(arguments, 'conversation_id'); + final pins = await database.pinMessageDao.pinMessagesByConversationId( + conversationId: conversationId, + limit: _int(arguments, 'limit', defaultValue: 50, min: 1, max: 200), + offset: _int(arguments, 'offset', defaultValue: 0), + ); + final messageIds = pins.map((pin) => pin.messageId).toList(); + final messages = await database.messageDao + .messageItemByMessageIds(messageIds) + .get(); + final messagesById = { + for (final message in messages) message.messageId: message, + }; + return { + 'messages': pins + .map((pin) { + final message = messagesById[pin.messageId]; + if (message == null) return null; + return { + ..._messageToJson(message, includePinState: true), + 'pinned_at': _dateTime(pin.createdAt), + }; + }) + .nonNulls + .toList(growable: false), + }; + case 'mixin_list_participants': + final query = _optionalString(arguments, 'query'); + final offset = _int(arguments, 'offset', defaultValue: 0); + final limit = _int( + arguments, + 'limit', + defaultValue: 50, + min: 1, + max: 200, + ); + final participants = await database.participantDao + .groupParticipantsByConversationId( + _requiredString(arguments, 'conversation_id'), + ) + .get(); + final filtered = query == null + ? participants + : participants + .where( + (participant) => _participantMatches(participant, query), + ) + .toList(growable: false); + return { + 'participants': filtered + .skip(offset) + .take(limit) + .map(_participantToJson) + .toList(growable: false), + }; + case 'mixin_resolve_user_in_conversation': + final query = _requiredString(arguments, 'query'); + final participants = await database.participantDao + .groupParticipantsByConversationId( + _requiredString(arguments, 'conversation_id'), + ) + .get(); + return { + 'participants': participants + .where((participant) => _participantMatches(participant, query)) + .take(_int(arguments, 'limit', defaultValue: 5, min: 1, max: 20)) + .map(_participantToJson) + .toList(growable: false), + }; + case 'mixin_list_circles': + final circles = await database.circleDao.allCircles().get(); + return { + 'circles': circles.map(_circleToJson).toList(growable: false), + }; + case 'mixin_list_circle_conversations': + final circleId = _requiredString(arguments, 'circle_id'); + final limit = _int( + arguments, + 'limit', + defaultValue: 50, + min: 1, + max: 200, + ); + final offset = _int(arguments, 'offset', defaultValue: 0); + final conversations = await database.conversationDao + .conversationsByCircleId(circleId, limit, offset) + .get(); + return { + 'circle_id': circleId, + 'conversations': conversations + .map(_conversationToJson) + .toList(growable: false), + }; + case 'mixin_read_mentions': + final messages = await database.messageDao + .mentionMessagesByConversationId( + _requiredString(arguments, 'conversation_id'), + limit: _int( + arguments, + 'limit', + defaultValue: 50, + min: 1, + max: 200, + ), + offset: _int(arguments, 'offset', defaultValue: 0), + unreadOnly: _bool(arguments, 'unread_only'), + ) + .get(); + return { + 'messages': _messagesToJson(messages, includePinState: true), + }; + case 'mixin_list_links': + final messages = await database.messageDao + .linkMessagesByConversationId( + _requiredString(arguments, 'conversation_id'), + limit: _int( + arguments, + 'limit', + defaultValue: 50, + min: 1, + max: 200, + ), + offset: _int(arguments, 'offset', defaultValue: 0), + ) + .get(); + return { + 'links': messages.map(_linkToJson).toList(growable: false), + }; case 'mixin_open_conversation': final conversationId = _requiredString(arguments, 'conversation_id'); await MixinMcpBridge.instance.openConversation(conversationId); @@ -312,6 +517,71 @@ class MixinMcpServer extends ChangeNotifier { final conversationId = _requiredString(arguments, 'conversation_id'); await MixinMcpBridge.instance.setDraft(database, conversationId, ''); return {'updated': true, 'conversation_id': conversationId}; + case 'mixin_create_circle': + final name = _requiredString(arguments, 'name'); + final conversationIds = _optionalStringList( + arguments, + 'conversation_ids', + ); + await MixinMcpBridge.instance.createCircle( + name: name, + conversations: await _circleConversationRequests( + database, + conversationIds, + CircleConversationAction.add, + ), + ); + return { + 'created': true, + 'name': name, + 'conversation_ids': conversationIds, + }; + case 'mixin_rename_circle': + final circleId = _requiredString(arguments, 'circle_id'); + final name = _requiredString(arguments, 'name'); + await MixinMcpBridge.instance.renameCircle( + circleId: circleId, + name: name, + ); + return {'updated': true, 'circle_id': circleId, 'name': name}; + case 'mixin_delete_circle': + final circleId = _requiredString(arguments, 'circle_id'); + await MixinMcpBridge.instance.deleteCircle(circleId); + return {'deleted': true, 'circle_id': circleId}; + case 'mixin_add_conversations_to_circle': + final circleId = _requiredString(arguments, 'circle_id'); + final conversationIds = _requiredStringList( + arguments, + 'conversation_ids', + ); + await MixinMcpBridge.instance.addConversationsToCircle( + circleId: circleId, + conversations: await _circleConversationRequests( + database, + conversationIds, + CircleConversationAction.add, + ), + ); + return { + 'updated': true, + 'circle_id': circleId, + 'conversation_ids': conversationIds, + }; + case 'mixin_remove_conversations_from_circle': + final circleId = _requiredString(arguments, 'circle_id'); + final conversationIds = _requiredStringList( + arguments, + 'conversation_ids', + ); + await MixinMcpBridge.instance.removeConversationsFromCircle( + circleId: circleId, + conversationIds: conversationIds, + ); + return { + 'updated': true, + 'circle_id': circleId, + 'conversation_ids': conversationIds, + }; case 'mixin_attach_message_to_ai': final message = await _messageById( database, @@ -383,8 +653,20 @@ class MixinMcpServer extends ChangeNotifier { tool.inputSchema?.jsonSchema() ?? _emptyObjectSchema, ), handler: (arguments) async { - final result = await tool.runRaw(arguments); - final data = result.result; + final startedAt = DateTime.now(); + i('MCP tool call ${tool.name}: ${_auditArguments(arguments)}'); + Map data; + try { + final result = await tool.runRaw(arguments); + data = _toolResult(tool.name, result.result, startedAt); + i( + 'MCP tool result ${tool.name}: ok ' + '${data['elapsed_ms']}ms', + ); + } catch (error, stacktrace) { + data = _toolErrorResult(tool.name, error, startedAt); + e('MCP tool error ${tool.name}: $error', stacktrace); + } return mcp.CallToolResult( content: [mcp.TextContent(text: const JsonEncoder().convert(data))], structuredContent: data, @@ -400,6 +682,91 @@ class MixinMcpServer extends ChangeNotifier { } } +void _ensureToolEnabled(Database database, String name) { + for (final spec in _toolSpecs) { + if (spec.name != name) continue; + if (_toolEnabled(database, spec)) return; + throw StateError('MCP permission scope "${spec.scope.key}" is disabled'); + } +} + +Map _permissionScopes(Database database) => { + _McpPermissionScope.read.key: true, + _McpPermissionScope.appControl.key: true, + _McpPermissionScope.draftWrite.key: + database.settingProperties.enableMcpDraftTools, + _McpPermissionScope.circleManagement.key: + database.settingProperties.enableMcpCircleManagement, + 'account_write': false, + 'message_send': false, +}; + +bool _toolEnabled(Database database, _Tool spec) { + switch (spec.scope) { + case _McpPermissionScope.read: + case _McpPermissionScope.appControl: + return true; + case _McpPermissionScope.draftWrite: + return database.settingProperties.enableMcpDraftTools; + case _McpPermissionScope.circleManagement: + return database.settingProperties.enableMcpCircleManagement; + } +} + +Map _toolResult( + String name, + Map result, + DateTime startedAt, +) => { + 'ok': true, + 'tool': name, + ...result, + 'elapsed_ms': DateTime.now().difference(startedAt).inMilliseconds, +}; + +Map _toolErrorResult( + String name, + Object error, + DateTime startedAt, +) => { + 'ok': false, + 'tool': name, + 'error': { + 'type': error.runtimeType.toString(), + 'message': error.toString(), + }, + 'elapsed_ms': DateTime.now().difference(startedAt).inMilliseconds, +}; + +String _auditArguments(Map arguments) => + const JsonEncoder().convert(_redactForAudit(arguments)); + +Object? _redactForAudit(Object? value, [String? key]) { + if (value is Map) { + return { + for (final entry in value.entries) + entry.key.toString(): _redactForAudit( + entry.value, + entry.key.toString(), + ), + }; + } + if (value is Iterable) { + return value.map(_redactForAudit).toList(growable: false); + } + final normalizedKey = key?.toLowerCase(); + if (value is String && + normalizedKey != null && + (normalizedKey.contains('token') || + normalizedKey.contains('secret') || + normalizedKey == 'text' || + normalizedKey == 'content' || + normalizedKey == 'draft')) { + return '<${value.length} chars>'; + } + return value; +} + Future _conversationById( Database database, String conversationId, @@ -455,6 +822,25 @@ Future> _searchConversations( return conversations; } +Future> _circleConversationRequests( + Database database, + List conversationIds, + CircleConversationAction action, +) async { + final requests = []; + for (final conversationId in conversationIds) { + final conversation = await _conversationById(database, conversationId); + requests.add( + CircleConversationRequest( + conversationId: conversation.conversationId, + action: action, + userId: conversation.ownerId, + ), + ); + } + return requests; +} + Future _messageById(Database database, String messageId) async { final message = await database.messageDao .messageItemByMessageId(messageId) @@ -477,8 +863,12 @@ Map _conversationToJson(ConversationItem conversation) => { 'last_message_created_at': _dateTime(conversation.lastMessageCreatedAt), }; -List> _messagesToJson(Iterable messages) => - messages.map(_messageToJson).toList(growable: false); +List> _messagesToJson( + Iterable messages, { + bool includePinState = false, +}) => messages + .map((message) => _messageToJson(message, includePinState: includePinState)) + .toList(growable: false); List> _searchMessagesToJson( Iterable messages, @@ -501,7 +891,10 @@ List> _searchMessagesToJson( ) .toList(growable: false); -Map _messageToJson(MessageItem message) => { +Map _messageToJson( + MessageItem message, { + bool includePinState = false, +}) => { 'message_id': message.messageId, 'conversation_id': message.conversationId, 'user_id': message.userId, @@ -509,8 +902,13 @@ Map _messageToJson(MessageItem message) => { 'user_identity_number': message.userIdentityNumber, 'type': message.type, 'content': _messageContent(message), + 'quote_message_id': message.quoteId, + 'quote_content': message.quoteContent, + 'caption': message.caption, 'created_at': _dateTime(message.createdAt), 'status': message.status.name, + if (includePinState || message.pinned) 'is_pinned': message.pinned, + 'link': _linkPreviewToJson(message), 'media': _hasAttachment(message) ? _attachmentToJson(message) : null, }..removeWhere((_, value) => value == null); @@ -545,6 +943,54 @@ Map _attachmentToJson(MessageItem message) => { 'created_at': _dateTime(message.createdAt), }..removeWhere((_, value) => value == null); +Map? _linkPreviewToJson(MessageItem message) { + final link = { + 'site_name': message.siteName, + 'title': message.siteTitle, + 'description': message.siteDescription, + 'image': message.siteImage, + }..removeWhere((_, value) => value == null); + return link.isEmpty ? null : link; +} + +Map _linkToJson(MessageItem message) => { + ..._messageToJson(message, includePinState: true), + 'link': _linkPreviewToJson(message), +}..removeWhere((_, value) => value == null); + +Map _participantToJson(ParticipantUser participant) => { + 'conversation_id': participant.conversationId, + 'user_id': participant.userId, + 'identity_number': participant.identityNumber, + 'full_name': participant.fullName, + 'role': participant.role?.name, + 'relationship': participant.relationship?.name, + 'biography': participant.biography, + 'avatar_url': participant.avatarUrl, + 'is_verified': participant.isVerified, + 'is_bot': participant.appId != null, + 'app_id': participant.appId, + 'membership': participant.membership?.toJson(), + 'created_at': _dateTime(participant.createdAt), +}..removeWhere((_, value) => value == null); + +bool _participantMatches(ParticipantUser participant, String query) { + final needle = query.toLowerCase(); + return participant.userId.toLowerCase().contains(needle) || + participant.identityNumber.toLowerCase().contains(needle) || + (participant.fullName?.toLowerCase().contains(needle) ?? false); +} + +Map _circleToJson(ConversationCircleItem circle) => { + 'circle_id': circle.circleId, + 'name': circle.name, + 'conversation_count': circle.count, + 'unseen_conversation_count': circle.unseenConversationCount, + 'unseen_muted_conversation_count': circle.unseenMutedConversationCount, + 'created_at': _dateTime(circle.createdAt), + 'ordered_at': _dateTime(circle.orderedAt), +}; + Map _aiThreadToJson(AiChatThread thread) => { 'thread_id': thread.id, 'conversation_id': thread.conversationId, @@ -589,6 +1035,43 @@ String? _optionalString(Map arguments, String key) { return text.isEmpty ? null : text; } +List _requiredStringList(Map arguments, String key) { + final list = _optionalStringList(arguments, key); + if (list.isEmpty) throw ArgumentError('$key is required'); + return list; +} + +List _optionalStringList(Map arguments, String key) { + final value = arguments[key]; + if (value == null) return const []; + if (value is Iterable) { + return value + .map((item) => item.toString().trim()) + .where((item) => item.isNotEmpty) + .toList(growable: false); + } + return value + .toString() + .split(',') + .map((item) => item.trim()) + .where((item) => item.isNotEmpty) + .toList(growable: false); +} + +bool _bool( + Map arguments, + String key, { + bool defaultValue = false, +}) { + final value = arguments[key]; + if (value == null) return defaultValue; + if (value is bool) return value; + final text = value.toString().trim().toLowerCase(); + if (text == 'true' || text == '1' || text == 'yes') return true; + if (text == 'false' || text == '0' || text == 'no') return false; + return defaultValue; +} + int _int( Map arguments, String key, { @@ -606,13 +1089,6 @@ DateTime? _date(Map arguments, String key) { return text == null ? null : DateTime.parse(text); } -Future _reserveLoopbackPort() async { - final socket = await ServerSocket.bind(InternetAddress.loopbackIPv4, 0); - final port = socket.port; - await socket.close(); - return port; -} - Map _jsonMap(dynamic json) { if (json == null) return {}; if (json is Map) return json; @@ -626,6 +1102,11 @@ const _emptyObjectSchema = { 'additionalProperties': false, }; +const _stringArraySchema = { + 'type': 'array', + 'items': {'type': 'string'}, +}; + const _toolSpecs = [ _Tool( 'mixin_get_app_status', @@ -633,9 +1114,11 @@ const _toolSpecs = [ ), _Tool( 'mixin_list_conversations', - 'List recent conversations or search conversations by query.', + 'List recent conversations, search conversations, or list a circle.', properties: { 'query': {'type': 'string'}, + 'circle_id': {'type': 'string'}, + 'offset': {'type': 'integer'}, 'limit': {'type': 'integer'}, }, ), @@ -664,21 +1147,28 @@ const _toolSpecs = [ ), _Tool( 'mixin_read_messages', - 'Read conversation messages by range, offset, and limit.', + 'Read conversation messages by range, sender, type, offset, and limit.', required: ['conversation_id'], properties: { ..._conversationRangeProperties, + 'sender_id': {'type': 'string'}, + 'sender_identity_number': {'type': 'string'}, + 'message_types': _stringArraySchema, + 'include_pin_state': {'type': 'boolean'}, 'offset': {'type': 'integer'}, 'limit': {'type': 'integer'}, }, ), _Tool( 'mixin_search_messages', - 'Search messages globally or inside a conversation.', + 'Search messages globally, inside a conversation, or inside a circle.', required: ['query'], properties: { 'query': {'type': 'string'}, 'conversation_id': {'type': 'string'}, + 'circle_id': {'type': 'string'}, + 'sender_id': {'type': 'string'}, + 'message_types': _stringArraySchema, 'limit': {'type': 'integer'}, 'anchor_id': {'type': 'string'}, }, @@ -715,12 +1205,83 @@ const _toolSpecs = [ required: ['conversation_id'], properties: { ..._conversationRangeProperties, + 'sender_id': {'type': 'string'}, + 'sender_identity_number': {'type': 'string'}, + 'message_types': _stringArraySchema, + 'offset': {'type': 'integer'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_list_pinned_messages', + 'List pinned messages for a conversation.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + 'offset': {'type': 'integer'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_list_participants', + 'List or search participants in a conversation.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + 'query': {'type': 'string'}, + 'offset': {'type': 'integer'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_resolve_user_in_conversation', + 'Resolve participants by user_id, identity number, or name.', + required: ['conversation_id', 'query'], + properties: { + 'conversation_id': {'type': 'string'}, + 'query': {'type': 'string'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_list_circles', + 'List local circles and their conversation counts.', + ), + _Tool( + 'mixin_list_circle_conversations', + 'List conversations in a circle.', + required: ['circle_id'], + properties: { + 'circle_id': {'type': 'string'}, + 'offset': {'type': 'integer'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_read_mentions', + 'Read mention messages in a conversation without marking them read.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + 'unread_only': {'type': 'boolean'}, + 'offset': {'type': 'integer'}, + 'limit': {'type': 'integer'}, + }, + ), + _Tool( + 'mixin_list_links', + 'List messages with link previews in a conversation.', + required: ['conversation_id'], + properties: { + 'conversation_id': {'type': 'string'}, + 'offset': {'type': 'integer'}, 'limit': {'type': 'integer'}, }, ), _Tool( 'mixin_open_conversation', 'Open a conversation in the Mixin UI.', + scope: _McpPermissionScope.appControl, required: ['conversation_id'], properties: { 'conversation_id': {'type': 'string'}, @@ -729,6 +1290,7 @@ const _toolSpecs = [ _Tool( 'mixin_reveal_message', 'Open the message conversation and reveal the message in the Mixin UI.', + scope: _McpPermissionScope.appControl, required: ['message_id'], properties: { 'message_id': {'type': 'string'}, @@ -737,6 +1299,7 @@ const _toolSpecs = [ _Tool( 'mixin_get_draft', 'Get the current draft text for a conversation.', + scope: _McpPermissionScope.draftWrite, required: ['conversation_id'], properties: { 'conversation_id': {'type': 'string'}, @@ -745,6 +1308,7 @@ const _toolSpecs = [ _Tool( 'mixin_set_draft', 'Replace the draft text for a conversation. Does not send.', + scope: _McpPermissionScope.draftWrite, required: ['conversation_id', 'text'], properties: { 'conversation_id': {'type': 'string'}, @@ -754,6 +1318,7 @@ const _toolSpecs = [ _Tool( 'mixin_insert_text', 'Insert text into the active input, or append to stored draft.', + scope: _McpPermissionScope.draftWrite, required: ['conversation_id', 'text'], properties: { 'conversation_id': {'type': 'string'}, @@ -763,14 +1328,65 @@ const _toolSpecs = [ _Tool( 'mixin_clear_draft', 'Clear the draft text for a conversation. Does not send.', + scope: _McpPermissionScope.draftWrite, required: ['conversation_id'], properties: { 'conversation_id': {'type': 'string'}, }, ), + _Tool( + 'mixin_create_circle', + 'Create a circle, optionally with initial conversations.', + scope: _McpPermissionScope.circleManagement, + required: ['name'], + properties: { + 'name': {'type': 'string'}, + 'conversation_ids': _stringArraySchema, + }, + ), + _Tool( + 'mixin_rename_circle', + 'Rename a circle.', + scope: _McpPermissionScope.circleManagement, + required: ['circle_id', 'name'], + properties: { + 'circle_id': {'type': 'string'}, + 'name': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_delete_circle', + 'Delete a circle.', + scope: _McpPermissionScope.circleManagement, + required: ['circle_id'], + properties: { + 'circle_id': {'type': 'string'}, + }, + ), + _Tool( + 'mixin_add_conversations_to_circle', + 'Add conversations to a circle.', + scope: _McpPermissionScope.circleManagement, + required: ['circle_id', 'conversation_ids'], + properties: { + 'circle_id': {'type': 'string'}, + 'conversation_ids': _stringArraySchema, + }, + ), + _Tool( + 'mixin_remove_conversations_from_circle', + 'Remove conversations from a circle.', + scope: _McpPermissionScope.circleManagement, + required: ['circle_id', 'conversation_ids'], + properties: { + 'circle_id': {'type': 'string'}, + 'conversation_ids': _stringArraySchema, + }, + ), _Tool( 'mixin_attach_message_to_ai', 'Attach a message to the app AI context chip for its conversation.', + scope: _McpPermissionScope.appControl, required: ['message_id'], properties: { 'message_id': {'type': 'string'}, @@ -812,12 +1428,14 @@ class _Tool { const _Tool( this.name, this.description, { + this.scope = _McpPermissionScope.read, this.required = const [], this.properties = const {}, }); final String name; final String description; + final _McpPermissionScope scope; final List required; final Map properties; diff --git a/lib/utils/property/setting_property.dart b/lib/utils/property/setting_property.dart index 8f4058c85d..055b4e53ad 100644 --- a/lib/utils/property/setting_property.dart +++ b/lib/utils/property/setting_property.dart @@ -21,6 +21,8 @@ const _kSelectedAiTranslatorModelKey = 'selected_ai_translator_model'; const _kAiPromptTemplateOverridesKey = 'ai_prompt_template_overrides'; const _kEnableMcpServerKey = 'enable_mcp_server'; const _kMcpServerTokenKey = 'mcp_server_token'; +const _kEnableMcpDraftToolsKey = 'enable_mcp_draft_tools'; +const _kEnableMcpCircleManagementKey = 'enable_mcp_circle_management'; class SettingPropertyStorage extends PropertyStorage { SettingPropertyStorage(PropertyDao dao) : super(PropertyGroup.setting, dao); @@ -42,6 +44,16 @@ class SettingPropertyStorage extends PropertyStorage { String? get mcpServerToken => get(_kMcpServerTokenKey); + bool get enableMcpDraftTools => get(_kEnableMcpDraftToolsKey) ?? false; + + set enableMcpDraftTools(bool value) => set(_kEnableMcpDraftToolsKey, value); + + bool get enableMcpCircleManagement => + get(_kEnableMcpCircleManagementKey) ?? false; + + set enableMcpCircleManagement(bool value) => + set(_kEnableMcpCircleManagementKey, value); + String regenerateMcpServerToken() { final token = List.generate( 32, diff --git a/pubspec.lock b/pubspec.lock index 854d67413c..c4feb1b20a 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -243,7 +243,7 @@ packages: source: hosted version: "1.1.2" code_assets: - dependency: transitive + dependency: "direct overridden" description: name: code_assets sha256: "83ccdaa064c980b5596c35dd64a8d3ecc68620174ab9b90b6343b753aa721687" @@ -343,8 +343,8 @@ packages: dependency: "direct main" description: path: "packages/data_detector" - ref: "08c1ce40eb6abfad6049fb6aad8bd30312ec5319" - resolved-ref: "08c1ce40eb6abfad6049fb6aad8bd30312ec5319" + ref: "821be771429135163704ede47acf95be5ba82095" + resolved-ref: "821be771429135163704ede47acf95be5ba82095" url: "https://github.com/MixinNetwork/flutter-plugins.git" source: git version: "0.0.1" @@ -916,7 +916,7 @@ packages: source: hosted version: "1.1.0" hooks: - dependency: transitive + dependency: "direct overridden" description: name: hooks sha256: "7a08a0d684cb3b8fb604b78455d5d352f502b68079f7b80b831c62220ab0a4f6" @@ -1292,6 +1292,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.13.0" + mcp_server: + dependency: "direct main" + description: + name: mcp_server + sha256: d5b2603a51b524c60753ea9031c52759f80b53dd6606f38ae70305b5eaa28d4d + url: "https://pub.dev" + source: hosted + version: "2.0.0" meta: dependency: transitive description: @@ -1327,10 +1335,9 @@ packages: mixin_markdown_widget: dependency: "direct main" description: - name: mixin_markdown_widget - sha256: "58b366d61d55fe852a91a7b3fb102481a76a7a8f208e0bd948ef823501ba29d9" - url: "https://pub.dev" - source: hosted + path: "../flutter-plugins/packages/mixin_markdown_widget" + relative: true + source: path version: "0.2.1" msix: dependency: "direct dev" @@ -1385,10 +1392,10 @@ packages: dependency: transitive description: name: objective_c - sha256: "77c341fce45bb3865a7bc3ddee4201605799e3de2f7af200e8dae26369d210ea" + sha256: "100a1c87616ab6ed41ec263b083c0ef3261ee6cd1dc3b0f35f8ddfa4f996fe52" url: "https://pub.dev" source: hosted - version: "1.1.0" + version: "9.3.0" octo_image: dependency: "direct main" description: @@ -1629,6 +1636,14 @@ packages: url: "https://pub.dev" source: hosted version: "3.1.6" + platform_ocr: + dependency: "direct main" + description: + name: platform_ocr + sha256: a50c7ac0d8667b3d2a1a3900f1b221b966d9602539ddc2351565d3de43e881e3 + url: "https://pub.dev" + source: hosted + version: "1.0.0" plugin_platform_interface: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 6621dfa4f2..2f6abf49d1 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -99,7 +99,8 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - mixin_markdown_widget: ^0.2.1 + mixin_markdown_widget: + path: ../flutter-plugins/packages/mixin_markdown_widget mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 @@ -112,6 +113,7 @@ dependencies: path: ^1.8.0 path_provider: ^2.1.2 photo_view: ^0.15.0 + platform_ocr: ^1.0.0 pin_code_fields: ^8.0.1 pretty_qr_code: ^3.6.0 protocol_handler: ^0.2.0 @@ -166,7 +168,7 @@ dependencies: data_detector: git: url: https://github.com/MixinNetwork/flutter-plugins.git - ref: 08c1ce40eb6abfad6049fb6aad8bd30312ec5319 + ref: 821be771429135163704ede47acf95be5ba82095 path: packages/data_detector envied: ^1.3.4 genkit: ^0.12.1 @@ -178,6 +180,7 @@ dependencies: git: url: https://github.com/toon-format/toon-dart.git ref: 51fa0e9311837b84c24e30827b53891041378448 + mcp_server: ^2.0.0 dev_dependencies: build_runner: ^2.13.1 @@ -239,6 +242,10 @@ flutter_intl: class_name: Localization use_deferred_loading: false +dependency_overrides: + code_assets: ^1.0.0 + hooks: ^1.0.0 + analyzer: plugins: - moor From 91a9faa68b05f46f728499ddc7ebd31c167e0452 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 7 May 2026 08:54:47 +0800 Subject: [PATCH 48/52] Refine local MCP settings --- lib/ui/setting/ai_mcp_settings_page.dart | 381 +++++++++++++++++++++++ lib/ui/setting/ai_settings_page.dart | 211 ++----------- lib/utils/mcp/mixin_mcp_server.dart | 38 +++ 3 files changed, 442 insertions(+), 188 deletions(-) create mode 100644 lib/ui/setting/ai_mcp_settings_page.dart diff --git a/lib/ui/setting/ai_mcp_settings_page.dart b/lib/ui/setting/ai_mcp_settings_page.dart new file mode 100644 index 0000000000..3de1f98645 --- /dev/null +++ b/lib/ui/setting/ai_mcp_settings_page.dart @@ -0,0 +1,381 @@ +import 'package:flutter/cupertino.dart'; +import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; +import 'package:flutter_hooks/flutter_hooks.dart'; +import 'package:hooks_riverpod/hooks_riverpod.dart'; + +import '../../utils/extension/extension.dart'; +import '../../utils/mcp/mixin_mcp_server.dart'; +import '../../widgets/app_bar.dart'; +import '../../widgets/cell.dart'; +import '../../widgets/toast.dart'; +import '../provider/database_provider.dart'; + +class AiMcpSettingsPage extends HookConsumerWidget { + const AiMcpSettingsPage({super.key}); + + @override + Widget build(BuildContext context, WidgetRef ref) { + final database = ref.watch(databaseProvider).requireValue; + useListenable(database.settingProperties); + final mcpServer = useListenable(MixinMcpServer.instance); + final settings = database.settingProperties; + final enableMcpServer = settings.enableMcpServer; + final mcpEndpoint = + mcpServer.endpoint?.toString() ?? _defaultMcpEndpointText; + final mcpToken = settings.mcpServerToken; + final mcpError = mcpServer.lastStartError; + final tools = MixinMcpServer.toolInfos(database); + final enabledToolCount = tools.where((tool) => tool.enabled).length; + final statusText = _serverStatusText( + enabled: enableMcpServer, + running: mcpServer.isRunning, + error: mcpError, + ); + + return Scaffold( + backgroundColor: context.theme.background, + appBar: const MixinAppBar(title: Text('Local MCP Server')), + body: Align( + alignment: Alignment.topCenter, + child: ConstrainedBox( + constraints: const BoxConstraints(maxWidth: 600), + child: SingleChildScrollView( + child: Padding( + padding: const EdgeInsets.only(top: 20, bottom: 20), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Padding( + padding: const EdgeInsets.all(16), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + statusText == 'Running' + ? 'Running on localhost' + : statusText, + style: TextStyle( + color: context.theme.text, + fontSize: 15, + fontWeight: FontWeight.w600, + ), + ), + const SizedBox(height: 8), + Text( + 'Exposes Mixin desktop tools to local MCP clients at $_defaultMcpEndpointText. It never sends messages.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + height: 1.4, + ), + ), + if (enableMcpServer && mcpError != null) ...[ + const SizedBox(height: 8), + Text( + 'Failed to bind port ${MixinMcpServer.defaultPort}.', + style: TextStyle( + color: context.theme.red, + fontSize: 13, + height: 1.4, + ), + ), + ], + ], + ), + ), + ), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + CellItem( + title: const Text('Server'), + description: Text(statusText), + trailing: Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: enableMcpServer, + onChanged: (value) { + settings.enableMcpServer = value; + }, + ), + ), + ), + if (enableMcpServer) ...[ + _Divider(), + CellItem( + title: const Text('Endpoint'), + description: Expanded( + child: Text( + mcpEndpoint, + textAlign: TextAlign.end, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + trailing: IconButton( + onPressed: () { + Clipboard.setData( + ClipboardData(text: mcpEndpoint), + ); + showToastSuccessful(); + }, + icon: Icon( + Icons.copy_rounded, + color: context.theme.icon, + ), + ), + ), + _Divider(), + CellItem( + title: const Text('Access Token'), + description: Expanded( + child: Text( + _maskedToken(mcpToken), + textAlign: TextAlign.end, + maxLines: 1, + overflow: TextOverflow.ellipsis, + ), + ), + trailing: Row( + mainAxisSize: MainAxisSize.min, + children: [ + IconButton( + onPressed: mcpToken == null + ? null + : () { + Clipboard.setData( + ClipboardData(text: mcpToken), + ); + showToastSuccessful(); + }, + icon: Icon( + Icons.copy_rounded, + color: context.theme.icon, + ), + ), + IconButton( + onPressed: () { + settings.regenerateMcpServerToken(); + showToastSuccessful(); + }, + icon: Icon( + Icons.refresh_rounded, + color: context.theme.icon, + ), + ), + ], + ), + ), + _Divider(), + CellItem( + title: const Text('Draft Editing'), + description: const Text('Draft write tools'), + trailing: Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: settings.enableMcpDraftTools, + onChanged: (value) { + settings.enableMcpDraftTools = value; + }, + ), + ), + ), + _Divider(), + CellItem( + title: const Text('Circle Management'), + description: const Text('Create and edit circles'), + trailing: Transform.scale( + scale: 0.7, + child: CupertinoSwitch( + activeTrackColor: context.theme.accent, + value: settings.enableMcpCircleManagement, + onChanged: (value) { + settings.enableMcpCircleManagement = value; + }, + ), + ), + ), + ], + ], + ), + ), + Padding( + padding: const EdgeInsets.only( + left: 20, + bottom: 14, + top: 10, + ), + child: Text( + '$enabledToolCount/${tools.length} tools enabled. Draft and circle tools require their own switches.', + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 14, + ), + ), + ), + for (final group in _toolGroups(tools)) ...[ + _SectionLabel(title: group.key), + CellGroup( + padding: const EdgeInsets.only(right: 10, left: 10), + cellBackgroundColor: + context.theme.settingCellBackgroundColor, + child: Column( + children: [ + for (var i = 0; i < group.value.length; i++) ...[ + _ToolCell(tool: group.value[i]), + if (i != group.value.length - 1) _Divider(), + ], + ], + ), + ), + ], + ], + ), + ), + ), + ), + ), + ); + } +} + +class _ToolCell extends StatelessWidget { + const _ToolCell({required this.tool}); + + final MixinMcpToolInfo tool; + + @override + Widget build(BuildContext context) { + final requiredText = tool.requiredArguments.isEmpty + ? null + : 'Required: ${tool.requiredArguments.join(', ')}'; + + return CellItem( + title: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + tool.name, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: const TextStyle(fontFamily: 'Menlo', fontSize: 14), + ), + const SizedBox(height: 4), + Text( + tool.description, + maxLines: 2, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + height: 1.3, + ), + ), + if (requiredText != null) ...[ + const SizedBox(height: 4), + Text( + requiredText, + maxLines: 1, + overflow: TextOverflow.ellipsis, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 12, + ), + ), + ], + ], + ), + description: SizedBox( + width: 44, + child: Text( + tool.enabled ? 'On' : 'Off', + textAlign: TextAlign.end, + style: TextStyle( + color: tool.enabled ? context.theme.accent : context.theme.red, + fontSize: 13, + fontWeight: FontWeight.w600, + ), + ), + ), + trailing: null, + ); + } +} + +class _SectionLabel extends StatelessWidget { + const _SectionLabel({required this.title}); + + final String title; + + @override + Widget build(BuildContext context) => Padding( + padding: const EdgeInsets.only(left: 20, bottom: 8, top: 12), + child: Text( + title, + style: TextStyle( + color: context.theme.secondaryText, + fontSize: 13, + fontWeight: FontWeight.w600, + ), + ), + ); +} + +class _Divider extends StatelessWidget { + @override + Widget build(BuildContext context) => Divider( + height: 0.5, + indent: 16, + endIndent: 16, + color: context.theme.divider, + ); +} + +List>> _toolGroups( + List tools, +) { + const scopeOrder = [ + 'read', + 'app_control', + 'draft_write', + 'circle_management', + ]; + return [ + for (final scope in scopeOrder) + if (tools.any((tool) => tool.scopeKey == scope)) + MapEntry( + tools.firstWhere((tool) => tool.scopeKey == scope).scopeTitle, + tools.where((tool) => tool.scopeKey == scope).toList(growable: false), + ), + ]; +} + +String _maskedToken(String? token) { + if (token == null || token.isEmpty) return 'Unavailable'; + if (token.length <= 8) return '********'; + return '********${token.substring(token.length - 6)}'; +} + +String _serverStatusText({ + required bool enabled, + required bool running, + required Object? error, +}) { + if (running) return 'Running'; + if (!enabled) return 'Off'; + if (error != null) return 'Error'; + return 'On'; +} + +const _defaultMcpEndpointText = + 'http://127.0.0.1:${MixinMcpServer.defaultPort}/mcp'; diff --git a/lib/ui/setting/ai_settings_page.dart b/lib/ui/setting/ai_settings_page.dart index e52a2d07cd..942d76db9a 100644 --- a/lib/ui/setting/ai_settings_page.dart +++ b/lib/ui/setting/ai_settings_page.dart @@ -1,6 +1,5 @@ import 'package:flutter/cupertino.dart'; import 'package:flutter/material.dart'; -import 'package:flutter/services.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; @@ -13,6 +12,7 @@ import '../../widgets/cell.dart'; import '../../widgets/dialog.dart'; import '../../widgets/toast.dart'; import '../provider/database_provider.dart'; +import 'ai_mcp_settings_page.dart'; import 'ai_prompt_settings_page.dart'; import 'ai_provider_edit_page.dart'; @@ -41,14 +41,13 @@ class AiSettingsPage extends HookConsumerWidget { ) .length; final mcpServer = useListenable(MixinMcpServer.instance); - final enableMcpServer = database.settingProperties.enableMcpServer; - final enableMcpDraftTools = database.settingProperties.enableMcpDraftTools; - final enableMcpCircleManagement = - database.settingProperties.enableMcpCircleManagement; - final mcpEndpoint = mcpServer.endpoint; - final mcpToken = database.settingProperties.mcpServerToken; - const mcpPort = MixinMcpServer.defaultPort; - final mcpError = mcpServer.lastStartError; + final mcpTools = MixinMcpServer.toolInfos(database); + final enabledMcpToolCount = mcpTools.where((tool) => tool.enabled).length; + final mcpStatus = mcpServer.isRunning + ? 'Running' + : database.settingProperties.enableMcpServer + ? 'On' + : 'Off'; return Scaffold( backgroundColor: context.theme.background, @@ -106,182 +105,20 @@ class AiSettingsPage extends HookConsumerWidget { padding: const EdgeInsets.only(right: 10, left: 10), cellBackgroundColor: context.theme.settingCellBackgroundColor, - child: Column( - children: [ - CellItem( - title: const Text('Local MCP Server'), - leading: Icon( - Icons.hub_outlined, - color: context.theme.icon, - ), - description: Text( - mcpServer.isRunning ? 'Running' : 'Off', - ), - trailing: Transform.scale( - scale: 0.7, - child: CupertinoSwitch( - activeTrackColor: context.theme.accent, - value: enableMcpServer, - onChanged: (value) { - database.settingProperties.enableMcpServer = - value; - }, - ), - ), + child: CellItem( + title: const Text('Local MCP Server'), + leading: Icon( + Icons.hub_outlined, + color: context.theme.icon, + ), + description: Text( + '$mcpStatus · $enabledMcpToolCount/${mcpTools.length} tools', + ), + onTap: () => Navigator.of(context).push( + MaterialPageRoute( + builder: (_) => const AiMcpSettingsPage(), ), - if (enableMcpServer) ...[ - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, - ), - CellItem( - title: const Text('Endpoint'), - description: Expanded( - child: Text( - mcpEndpoint?.toString() ?? - 'http://127.0.0.1:$mcpPort/mcp', - textAlign: TextAlign.end, - maxLines: 1, - overflow: TextOverflow.ellipsis, - ), - ), - trailing: IconButton( - onPressed: mcpEndpoint == null - ? null - : () { - Clipboard.setData( - ClipboardData( - text: mcpEndpoint.toString(), - ), - ); - showToastSuccessful(); - }, - icon: Icon( - Icons.copy_rounded, - color: context.theme.icon, - ), - ), - ), - if (mcpError != null) ...[ - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, - ), - CellItem( - title: const Text('Status'), - description: Expanded( - child: Text( - 'Failed to bind port $mcpPort', - textAlign: TextAlign.end, - maxLines: 1, - overflow: TextOverflow.ellipsis, - style: TextStyle( - color: context.theme.red, - ), - ), - ), - trailing: null, - ), - ], - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, - ), - CellItem( - title: const Text('Access Token'), - description: Expanded( - child: Text( - mcpToken ?? 'Unavailable', - textAlign: TextAlign.end, - maxLines: 1, - overflow: TextOverflow.ellipsis, - ), - ), - trailing: Row( - mainAxisSize: MainAxisSize.min, - children: [ - IconButton( - onPressed: mcpToken == null - ? null - : () { - Clipboard.setData( - ClipboardData(text: mcpToken), - ); - showToastSuccessful(); - }, - icon: Icon( - Icons.copy_rounded, - color: context.theme.icon, - ), - ), - IconButton( - onPressed: () { - database.settingProperties - .regenerateMcpServerToken(); - showToastSuccessful(); - }, - icon: Icon( - Icons.refresh_rounded, - color: context.theme.icon, - ), - ), - ], - ), - ), - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, - ), - CellItem( - title: const Text('Draft Editing'), - description: const Text('Draft write tools'), - trailing: Transform.scale( - scale: 0.7, - child: CupertinoSwitch( - activeTrackColor: context.theme.accent, - value: enableMcpDraftTools, - onChanged: (value) { - database - .settingProperties - .enableMcpDraftTools = - value; - }, - ), - ), - ), - Divider( - height: 0.5, - indent: 16, - endIndent: 16, - color: context.theme.divider, - ), - CellItem( - title: const Text('Circle Management'), - description: const Text('Create and edit circles'), - trailing: Transform.scale( - scale: 0.7, - child: CupertinoSwitch( - activeTrackColor: context.theme.accent, - value: enableMcpCircleManagement, - onChanged: (value) { - database - .settingProperties - .enableMcpCircleManagement = - value; - }, - ), - ), - ), - ], - ], + ), ), ), Padding( @@ -291,10 +128,8 @@ class AiSettingsPage extends HookConsumerWidget { top: 10, ), child: Text( - 'Exposes read-only conversation tools, UI navigation, ' - 'and AI thread inspection on localhost only at port ' - '$mcpPort. Draft and circle write tools require their ' - 'own switches. It never sends messages.', + 'Manage the localhost MCP endpoint, access token, write ' + 'permissions, and supported tool list.', style: TextStyle( color: context.theme.secondaryText, fontSize: 14, diff --git a/lib/utils/mcp/mixin_mcp_server.dart b/lib/utils/mcp/mixin_mcp_server.dart index 4fb9167662..bb71956045 100644 --- a/lib/utils/mcp/mixin_mcp_server.dart +++ b/lib/utils/mcp/mixin_mcp_server.dart @@ -25,6 +25,24 @@ import 'mixin_mcp_bridge.dart'; typedef CurrentConversationIdResolver = String? Function(); +class MixinMcpToolInfo { + const MixinMcpToolInfo({ + required this.name, + required this.description, + required this.scopeKey, + required this.scopeTitle, + required this.enabled, + required this.requiredArguments, + }); + + final String name; + final String description; + final String scopeKey; + final String scopeTitle; + final bool enabled; + final List requiredArguments; +} + enum _McpPermissionScope { read, appControl, @@ -39,6 +57,13 @@ extension on _McpPermissionScope { _McpPermissionScope.draftWrite => 'draft_write', _McpPermissionScope.circleManagement => 'circle_management', }; + + String get title => switch (this) { + _McpPermissionScope.read => 'Read', + _McpPermissionScope.appControl => 'App Control', + _McpPermissionScope.draftWrite => 'Draft Editing', + _McpPermissionScope.circleManagement => 'Circle Management', + }; } class MixinMcpServer extends ChangeNotifier { @@ -71,6 +96,19 @@ class MixinMcpServer extends ChangeNotifier { Object? get lastStartError => _lastStartError; + static List toolInfos(Database database) => _toolSpecs + .map( + (spec) => MixinMcpToolInfo( + name: spec.name, + description: spec.description, + scopeKey: spec.scope.key, + scopeTitle: spec.scope.title, + enabled: _toolEnabled(database, spec), + requiredArguments: spec.required, + ), + ) + .toList(growable: false); + Future start({ required Database database, required String userId, From 0c8803e4890135723f208b05b28edd5feedc309e Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 7 May 2026 09:47:32 +0800 Subject: [PATCH 49/52] chore(deps): update dependencies and Podfile.lock for ogg_opus_player and mixin_markdown_widget --- macos/Podfile.lock | 7 ++++--- pubspec.lock | 13 +++++++------ pubspec.yaml | 5 ++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/macos/Podfile.lock b/macos/Podfile.lock index 9b34a51357..0bce530e72 100644 --- a/macos/Podfile.lock +++ b/macos/Podfile.lock @@ -28,6 +28,7 @@ PODS: - network_info_plus (0.0.1): - FlutterMacOS - ogg_opus_player (0.0.1): + - Flutter - FlutterMacOS - open_file_mac (1.0.3): - FlutterMacOS @@ -100,7 +101,7 @@ DEPENDENCIES: - local_auth_darwin (from `Flutter/ephemeral/.symlinks/plugins/local_auth_darwin/darwin`) - mixin_logger (from `Flutter/ephemeral/.symlinks/plugins/mixin_logger/macos`) - network_info_plus (from `Flutter/ephemeral/.symlinks/plugins/network_info_plus/macos`) - - ogg_opus_player (from `Flutter/ephemeral/.symlinks/plugins/ogg_opus_player/macos`) + - ogg_opus_player (from `Flutter/ephemeral/.symlinks/plugins/ogg_opus_player/darwin`) - open_file_mac (from `Flutter/ephemeral/.symlinks/plugins/open_file_mac/macos`) - package_info_plus (from `Flutter/ephemeral/.symlinks/plugins/package_info_plus/macos`) - path_provider_foundation (from `Flutter/ephemeral/.symlinks/plugins/path_provider_foundation/darwin`) @@ -152,7 +153,7 @@ EXTERNAL SOURCES: network_info_plus: :path: Flutter/ephemeral/.symlinks/plugins/network_info_plus/macos ogg_opus_player: - :path: Flutter/ephemeral/.symlinks/plugins/ogg_opus_player/macos + :path: Flutter/ephemeral/.symlinks/plugins/ogg_opus_player/darwin open_file_mac: :path: Flutter/ephemeral/.symlinks/plugins/open_file_mac/macos package_info_plus: @@ -197,7 +198,7 @@ SPEC CHECKSUMS: local_auth_darwin: c3ee6cce0a8d56be34c8ccb66ba31f7f180aaebb mixin_logger: 6b31328b08f546a8defd32cd910370562fc48405 network_info_plus: 21d1cd6a015ccb2fdff06a1fbfa88d54b4e92f61 - ogg_opus_player: 40ad7ee05152b420727fdb922afa0a90763b1817 + ogg_opus_player: 954784304f4e2722780018c4abf284e4f93cddf5 open_file_mac: 76f06c8597551249bdb5e8fd8827a98eae0f4585 package_info_plus: f0052d280d17aa382b932f399edf32507174e870 path_provider_foundation: 080d55be775b7414fd5a5ef3ac137b97b097e564 diff --git a/pubspec.lock b/pubspec.lock index c4feb1b20a..14f93ae117 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -1335,10 +1335,11 @@ packages: mixin_markdown_widget: dependency: "direct main" description: - path: "../flutter-plugins/packages/mixin_markdown_widget" - relative: true - source: path - version: "0.2.1" + name: mixin_markdown_widget + sha256: "4b4f6430c3be9be5766ae06b0d6c78ab6bd3059c1f2df076ef325028b6f030d5" + url: "https://pub.dev" + source: hosted + version: "0.3.1" msix: dependency: "direct dev" description: @@ -1408,10 +1409,10 @@ packages: dependency: "direct main" description: name: ogg_opus_player - sha256: cc839bf53bae215e3b4f8a796040038042173337d13d11604eae720b54f41e9d + sha256: d9bba5c2e276ff13ceae1c216a2650560c805d6b06de4bf4eb608bc38f1b75f4 url: "https://pub.dev" source: hosted - version: "0.7.0" + version: "0.8.0" open_file: dependency: "direct main" description: diff --git a/pubspec.yaml b/pubspec.yaml index 2f6abf49d1..53821db514 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -99,14 +99,13 @@ dependencies: local_auth: ^3.0.1 lottie: ^3.3.3 map: ^2.0.2 - mixin_markdown_widget: - path: ../flutter-plugins/packages/mixin_markdown_widget + mixin_markdown_widget: ^0.3.1 mime: ^2.0.0 mixin_bot_sdk_dart: ^1.5.0 mixin_logger: ^0.1.3 network_info_plus: ^7.0.0 octo_image: ^2.0.0 - ogg_opus_player: ^0.7.0 + ogg_opus_player: ^0.8.0 open_file: ^3.5.11 overlay_support: ^2.1.0 package_info_plus: ^9.0.1 From 640e0450410d44a4adf822da9a2dccc93e2e2f1e Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 7 May 2026 10:36:15 +0800 Subject: [PATCH 50/52] Refine local MCP tool surface --- lib/db/dao/message_dao.dart | 62 +- lib/db/dao/pin_message_dao.dart | 76 +- lib/utils/mcp/mixin_mcp_server.dart | 1290 ++++++++++++++++++++------- 3 files changed, 1085 insertions(+), 343 deletions(-) diff --git a/lib/db/dao/message_dao.dart b/lib/db/dao/message_dao.dart index 48e1a1e34e..4f2340cde2 100644 --- a/lib/db/dao/message_dao.dart +++ b/lib/db/dao/message_dao.dart @@ -705,6 +705,8 @@ class MessageDao extends DatabaseAccessor int offset = 0, DateTime? startInclusive, DateTime? endExclusive, + MessageOrderInfo? before, + MessageOrderInfo? after, String? senderId, String? senderIdentityNumber, List categories = const [], @@ -729,7 +731,17 @@ class MessageDao extends DatabaseAccessor : message.createdAt.isBiggerOrEqualValue(startMillis)) & (endMillis == null ? const Constant(true) - : message.createdAt.isSmallerThanValue(endMillis)), + : message.createdAt.isSmallerThanValue(endMillis)) & + (before == null + ? const Constant(true) + : message.createdAt.isSmallerThanValue(before.createdAt) | + (message.createdAt.equals(before.createdAt) & + message.rowId.isSmallerThanValue(before.rowId))) & + (after == null + ? const Constant(true) + : message.createdAt.isBiggerThanValue(after.createdAt) | + (message.createdAt.equals(after.createdAt) & + message.rowId.isBiggerThanValue(after.rowId))), (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, offset), order: (message, _, _, _, _, _, _, _, _, _, _, _, _, em) => OrderBy([ if (ascending) OrderingTerm.asc(message.createdAt), @@ -745,14 +757,33 @@ class MessageDao extends DatabaseAccessor required int limit, int offset = 0, bool unreadOnly = false, + MessageOrderInfo? before, + MessageOrderInfo? after, + bool ascending = true, }) => _baseMessageItems( (message, _, _, _, _, _, _, _, _, _, _, _, messageMention, _, em) => message.conversationId.equals(conversationId) & messageMention.messageId.isNotNull() & (unreadOnly ? messageMention.hasRead.equals(false) - : const Constant(true)), + : const Constant(true)) & + (before == null + ? const Constant(true) + : message.createdAt.isSmallerThanValue(before.createdAt) | + (message.createdAt.equals(before.createdAt) & + message.rowId.isSmallerThanValue(before.rowId))) & + (after == null + ? const Constant(true) + : message.createdAt.isBiggerThanValue(after.createdAt) | + (message.createdAt.equals(after.createdAt) & + message.rowId.isBiggerThanValue(after.rowId))), (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, offset), + order: (message, _, _, _, _, _, _, _, _, _, _, _, _, em) => OrderBy([ + if (ascending) OrderingTerm.asc(message.createdAt), + if (ascending) OrderingTerm.asc(message.rowId), + if (!ascending) OrderingTerm.desc(message.createdAt), + if (!ascending) OrderingTerm.desc(message.rowId), + ]), ); Selectable attachmentMessagesByConversationId( @@ -761,31 +792,56 @@ class MessageDao extends DatabaseAccessor int offset = 0, DateTime? startInclusive, DateTime? endExclusive, + MessageOrderInfo? before, + MessageOrderInfo? after, String? senderId, String? senderIdentityNumber, List categories = const [], + bool ascending = true, }) => messagesByConversationIdAndCreatedAtRange( conversationId, limit: limit, offset: offset, startInclusive: startInclusive, endExclusive: endExclusive, + before: before, + after: after, senderId: senderId, senderIdentityNumber: senderIdentityNumber, categories: categories.isEmpty ? _attachmentMessageCategories : categories, + ascending: ascending, ); Selectable linkMessagesByConversationId( String conversationId, { required int limit, int offset = 0, + MessageOrderInfo? before, + MessageOrderInfo? after, + bool ascending = true, }) => _baseMessageItems( (message, _, _, _, _, _, _, _, _, hyperlink, _, _, _, _, em) => message.conversationId.equals(conversationId) & message.hyperlink.isNotNull() & message.hyperlink.equals('').not() & - hyperlink.hyperlink.isNotNull(), + hyperlink.hyperlink.isNotNull() & + (before == null + ? const Constant(true) + : message.createdAt.isSmallerThanValue(before.createdAt) | + (message.createdAt.equals(before.createdAt) & + message.rowId.isSmallerThanValue(before.rowId))) & + (after == null + ? const Constant(true) + : message.createdAt.isBiggerThanValue(after.createdAt) | + (message.createdAt.equals(after.createdAt) & + message.rowId.isBiggerThanValue(after.rowId))), (_, _, _, _, _, _, _, _, _, _, _, _, _, _, em) => Limit(limit, offset), + order: (message, _, _, _, _, _, _, _, _, _, _, _, _, em) => OrderBy([ + if (ascending) OrderingTerm.asc(message.createdAt), + if (ascending) OrderingTerm.asc(message.rowId), + if (!ascending) OrderingTerm.desc(message.createdAt), + if (!ascending) OrderingTerm.desc(message.rowId), + ]), ); Selectable messageCountByConversationIdAndCreatedAtRange( diff --git a/lib/db/dao/pin_message_dao.dart b/lib/db/dao/pin_message_dao.dart index 495d856ce9..2dfeb47334 100644 --- a/lib/db/dao/pin_message_dao.dart +++ b/lib/db/dao/pin_message_dao.dart @@ -77,13 +77,77 @@ class PinMessageDao extends DatabaseAccessor Future> pinMessagesByConversationId({ required String conversationId, required int limit, - required int offset, + String? beforeMessageId, + String? afterMessageId, + bool ascending = false, + }) async { + final before = beforeMessageId == null + ? null + : await pinMessageByMessageId( + conversationId: conversationId, + messageId: beforeMessageId, + ); + if (beforeMessageId != null && before == null) { + throw StateError('Pinned cursor message not found'); + } + final after = afterMessageId == null + ? null + : await pinMessageByMessageId( + conversationId: conversationId, + messageId: afterMessageId, + ); + if (afterMessageId != null && after == null) { + throw StateError('Pinned cursor message not found'); + } + return (select(db.pinMessages) + ..where( + (tbl) => + tbl.conversationId.equals(conversationId) & + (before == null + ? const Constant(true) + : tbl.createdAt.isSmallerThanValue( + before.createdAt.millisecondsSinceEpoch, + ) | + (tbl.createdAt.equals( + before.createdAt.millisecondsSinceEpoch, + ) & + tbl.messageId.isSmallerThanValue( + before.messageId, + ))) & + (after == null + ? const Constant(true) + : tbl.createdAt.isBiggerThanValue( + after.createdAt.millisecondsSinceEpoch, + ) | + (tbl.createdAt.equals( + after.createdAt.millisecondsSinceEpoch, + ) & + tbl.messageId.isBiggerThanValue( + after.messageId, + ))), + ) + ..orderBy([ + (tbl) => ascending + ? OrderingTerm.asc(tbl.createdAt) + : OrderingTerm.desc(tbl.createdAt), + (tbl) => ascending + ? OrderingTerm.asc(tbl.messageId) + : OrderingTerm.desc(tbl.messageId), + ]) + ..limit(limit)) + .get(); + } + + Future pinMessageByMessageId({ + required String conversationId, + required String messageId, }) => - (select(db.pinMessages) - ..where((tbl) => tbl.conversationId.equals(conversationId)) - ..orderBy([(tbl) => OrderingTerm.desc(tbl.createdAt)]) - ..limit(limit, offset: offset)) - .get(); + (select(db.pinMessages)..where( + (tbl) => + tbl.conversationId.equals(conversationId) & + tbl.messageId.equals(messageId), + )) + .getSingleOrNull(); Future> getPinMessages({ required int limit, diff --git a/lib/utils/mcp/mixin_mcp_server.dart b/lib/utils/mcp/mixin_mcp_server.dart index bb71956045..30d0babe8a 100644 --- a/lib/utils/mcp/mixin_mcp_server.dart +++ b/lib/utils/mcp/mixin_mcp_server.dart @@ -18,6 +18,7 @@ import '../../db/dao/message_dao.dart'; import '../../db/dao/participant_dao.dart'; import '../../db/database.dart'; import '../../db/mixin_database.dart'; +import '../../enum/message_category.dart'; import '../extension/extension.dart'; import '../logger.dart'; import '../system/package_info.dart'; @@ -219,30 +220,15 @@ class MixinMcpServer extends ChangeNotifier { .toList(growable: false), }; case 'mixin_list_conversations': - final query = _optionalString(arguments, 'query'); final circleId = _optionalString(arguments, 'circle_id'); - final limit = _int( - arguments, - 'limit', - defaultValue: 30, - min: 1, - max: 100, - ); - final offset = _int(arguments, 'offset', defaultValue: 0); - final conversations = circleId != null - ? await database.conversationDao - .conversationsByCircleId(circleId, limit, offset) - .get() - : query == null || query.trim().isEmpty - ? await database.conversationDao.conversationItems().get() - : await _searchConversations(database, query, offset + limit); + final page = await _readConversationPage(database, arguments); return { - 'conversations': conversations - .skip(circleId == null ? offset : 0) - .take(limit) + 'circle_id': circleId, + 'conversations': page.conversations .map(_conversationToJson) .toList(growable: false), - }; + 'pagination': page.toJson(), + }..removeWhere((_, value) => value == null); case 'mixin_get_conversation': final conversation = await _conversationById( database, @@ -262,56 +248,8 @@ class MixinMcpServer extends ChangeNotifier { endExclusive: _date(arguments, 'end'), ); return stats.toJson(); - case 'mixin_read_messages': - final messages = await database.messageDao - .messagesByConversationIdAndCreatedAtRange( - _requiredString(arguments, 'conversation_id'), - offset: _int(arguments, 'offset', defaultValue: 0), - limit: _int( - arguments, - 'limit', - defaultValue: 50, - min: 1, - max: 200, - ), - startInclusive: _date(arguments, 'start'), - endExclusive: _date(arguments, 'end'), - senderId: _optionalString(arguments, 'sender_id'), - senderIdentityNumber: _optionalString( - arguments, - 'sender_identity_number', - ), - categories: _optionalStringList(arguments, 'message_types'), - ) - .get(); - return { - 'messages': _messagesToJson( - messages, - includePinState: _bool(arguments, 'include_pin_state'), - ), - }; - case 'mixin_search_messages': - final conversationId = _optionalString(arguments, 'conversation_id'); - final circleId = _optionalString(arguments, 'circle_id'); - final conversationIds = conversationId == null - ? circleId == null - ? const [] - : await database.conversationDao.conversationIdsByCircleId( - circleId, - ) - : [conversationId]; - if (circleId != null && conversationIds.isEmpty) { - return {'messages': const >[]}; - } - final messages = await database.fuzzySearchMessage( - query: _requiredString(arguments, 'query'), - limit: _int(arguments, 'limit', defaultValue: 20, min: 1, max: 50), - conversationIds: conversationIds, - userId: _optionalString(arguments, 'sender_id'), - categories: _optionalStringList(arguments, 'message_types'), - anchorMessageId: _optionalString(arguments, 'anchor_id'), - ); - return {'messages': _searchMessagesToJson(messages)}; + case 'mixin_list_messages': + return _listMessages(database, arguments); case 'mixin_get_message': final message = await _messageById( database, @@ -345,70 +283,14 @@ class MixinMcpServer extends ChangeNotifier { 'message': _messageToJson(message, includePinState: true), 'after': _messagesToJson(after, includePinState: true), }; - case 'mixin_read_image_text': + case 'mixin_read_image_message_text': final result = await _conversationTools.readImageText( conversationId: _requiredString(arguments, 'conversation_id'), messageId: _requiredString(arguments, 'message_id'), ); return result.toJson(); - case 'mixin_list_attachments': - final messages = await database.messageDao - .attachmentMessagesByConversationId( - _requiredString(arguments, 'conversation_id'), - offset: _int(arguments, 'offset', defaultValue: 0), - limit: _int( - arguments, - 'limit', - defaultValue: 50, - min: 1, - max: 200, - ), - startInclusive: _date(arguments, 'start'), - endExclusive: _date(arguments, 'end'), - senderId: _optionalString(arguments, 'sender_id'), - senderIdentityNumber: _optionalString( - arguments, - 'sender_identity_number', - ), - categories: _optionalStringList(arguments, 'message_types'), - ) - .get(); - return { - 'attachments': messages - .where(_hasAttachment) - .map(_attachmentToJson) - .toList(growable: false), - }; - case 'mixin_list_pinned_messages': - final conversationId = _requiredString(arguments, 'conversation_id'); - final pins = await database.pinMessageDao.pinMessagesByConversationId( - conversationId: conversationId, - limit: _int(arguments, 'limit', defaultValue: 50, min: 1, max: 200), - offset: _int(arguments, 'offset', defaultValue: 0), - ); - final messageIds = pins.map((pin) => pin.messageId).toList(); - final messages = await database.messageDao - .messageItemByMessageIds(messageIds) - .get(); - final messagesById = { - for (final message in messages) message.messageId: message, - }; - return { - 'messages': pins - .map((pin) { - final message = messagesById[pin.messageId]; - if (message == null) return null; - return { - ..._messageToJson(message, includePinState: true), - 'pinned_at': _dateTime(pin.createdAt), - }; - }) - .nonNulls - .toList(growable: false), - }; - case 'mixin_list_participants': + case 'mixin_list_conversation_participants': final query = _optionalString(arguments, 'query'); - final offset = _int(arguments, 'offset', defaultValue: 0); final limit = _int( arguments, 'limit', @@ -428,14 +310,18 @@ class MixinMcpServer extends ChangeNotifier { (participant) => _participantMatches(participant, query), ) .toList(growable: false); + final page = _participantPageFromRows( + filtered, + limit: limit, + cursorUserId: _optionalString(arguments, 'cursor_user_id'), + ); return { - 'participants': filtered - .skip(offset) - .take(limit) + 'participants': page.participants .map(_participantToJson) .toList(growable: false), + 'pagination': page.toJson(), }; - case 'mixin_resolve_user_in_conversation': + case 'mixin_resolve_conversation_participant': final query = _requiredString(arguments, 'query'); final participants = await database.participantDao .groupParticipantsByConversationId( @@ -454,60 +340,6 @@ class MixinMcpServer extends ChangeNotifier { return { 'circles': circles.map(_circleToJson).toList(growable: false), }; - case 'mixin_list_circle_conversations': - final circleId = _requiredString(arguments, 'circle_id'); - final limit = _int( - arguments, - 'limit', - defaultValue: 50, - min: 1, - max: 200, - ); - final offset = _int(arguments, 'offset', defaultValue: 0); - final conversations = await database.conversationDao - .conversationsByCircleId(circleId, limit, offset) - .get(); - return { - 'circle_id': circleId, - 'conversations': conversations - .map(_conversationToJson) - .toList(growable: false), - }; - case 'mixin_read_mentions': - final messages = await database.messageDao - .mentionMessagesByConversationId( - _requiredString(arguments, 'conversation_id'), - limit: _int( - arguments, - 'limit', - defaultValue: 50, - min: 1, - max: 200, - ), - offset: _int(arguments, 'offset', defaultValue: 0), - unreadOnly: _bool(arguments, 'unread_only'), - ) - .get(); - return { - 'messages': _messagesToJson(messages, includePinState: true), - }; - case 'mixin_list_links': - final messages = await database.messageDao - .linkMessagesByConversationId( - _requiredString(arguments, 'conversation_id'), - limit: _int( - arguments, - 'limit', - defaultValue: 50, - min: 1, - max: 200, - ), - offset: _int(arguments, 'offset', defaultValue: 0), - ) - .get(); - return { - 'links': messages.map(_linkToJson).toList(growable: false), - }; case 'mixin_open_conversation': final conversationId = _requiredString(arguments, 'conversation_id'); await MixinMcpBridge.instance.openConversation(conversationId); @@ -526,7 +358,7 @@ class MixinMcpServer extends ChangeNotifier { 'conversation_id': message.conversationId, 'message_id': message.messageId, }; - case 'mixin_get_draft': + case 'mixin_get_conversation_draft': final conversationId = _requiredString(arguments, 'conversation_id'); return { 'conversation_id': conversationId, @@ -535,7 +367,7 @@ class MixinMcpServer extends ChangeNotifier { conversationId, ), }; - case 'mixin_set_draft': + case 'mixin_set_conversation_draft': final conversationId = _requiredString(arguments, 'conversation_id'); await MixinMcpBridge.instance.setDraft( database, @@ -543,7 +375,7 @@ class MixinMcpServer extends ChangeNotifier { _requiredString(arguments, 'text'), ); return {'updated': true, 'conversation_id': conversationId}; - case 'mixin_insert_text': + case 'mixin_insert_conversation_text': final conversationId = _requiredString(arguments, 'conversation_id'); await MixinMcpBridge.instance.insertText( database, @@ -551,7 +383,7 @@ class MixinMcpServer extends ChangeNotifier { _requiredString(arguments, 'text'), ); return {'updated': true, 'conversation_id': conversationId}; - case 'mixin_clear_draft': + case 'mixin_clear_conversation_draft': final conversationId = _requiredString(arguments, 'conversation_id'); await MixinMcpBridge.instance.setDraft(database, conversationId, ''); return {'updated': true, 'conversation_id': conversationId}; @@ -620,7 +452,7 @@ class MixinMcpServer extends ChangeNotifier { 'circle_id': circleId, 'conversation_ids': conversationIds, }; - case 'mixin_attach_message_to_ai': + case 'mixin_attach_message_to_ai_context': final message = await _messageById( database, _requiredString(arguments, 'message_id'), @@ -634,7 +466,7 @@ class MixinMcpServer extends ChangeNotifier { 'conversation_id': message.conversationId, 'message_id': message.messageId, }; - case 'mixin_list_ai_threads': + case 'mixin_list_conversation_ai_threads': final threads = await database.aiChatMessageDao .watchThreads(_requiredString(arguments, 'conversation_id')) .first; @@ -652,7 +484,7 @@ class MixinMcpServer extends ChangeNotifier { 'thread': _aiThreadToJson(thread), 'messages': messages.map(_aiMessageToJson).toList(growable: false), }; - case 'mixin_get_ai_tool_events': + case 'mixin_get_ai_message_tool_events': final messageId = _requiredString(arguments, 'message_id'); final message = await database.aiChatMessageDao.messageById(messageId); if (message == null) throw StateError('AI message not found'); @@ -805,6 +637,690 @@ Object? _redactForAudit(Object? value, [String? key]) { return value; } +const _messagePageLatest = 'latest'; +const _messagePageBefore = 'before'; +const _messagePageAfter = 'after'; +const _messageKindAll = 'all'; +const _messageKindAttachments = 'attachments'; +const _messageKindPinned = 'pinned'; +const _messageKindMentions = 'mentions'; +const _messageKindLinks = 'links'; +const _conversationScanLimit = 5000; +const _attachmentMessageCategories = [ + MessageCategory.signalImage, + MessageCategory.signalVideo, + MessageCategory.signalData, + MessageCategory.signalAudio, + MessageCategory.plainImage, + MessageCategory.plainVideo, + MessageCategory.plainData, + MessageCategory.plainAudio, + MessageCategory.encryptedImage, + MessageCategory.encryptedVideo, + MessageCategory.encryptedData, + MessageCategory.encryptedAudio, +]; + +class _ConversationPage { + const _ConversationPage({ + required this.conversations, + required this.limit, + required this.hasMore, + required this.order, + this.cursorConversationId, + }); + + final List conversations; + final int limit; + final bool hasMore; + final String order; + final String? cursorConversationId; + + Map toJson() => { + 'order': order, + 'limit': limit, + 'cursor_conversation_id': cursorConversationId, + 'next_cursor_conversation_id': hasMore + ? conversations.lastOrNull?.conversationId + : null, + 'has_more': hasMore, + }..removeWhere((_, value) => value == null); +} + +class _ParticipantPage { + const _ParticipantPage({ + required this.participants, + required this.limit, + required this.hasMore, + this.cursorUserId, + }); + + final List participants; + final int limit; + final bool hasMore; + final String? cursorUserId; + + Map toJson() => { + 'order': 'full_name_identity_number_user_id', + 'limit': limit, + 'cursor_user_id': cursorUserId, + 'next_cursor_user_id': hasMore ? participants.lastOrNull?.userId : null, + 'has_more': hasMore, + }..removeWhere((_, value) => value == null); +} + +class _MessagePage { + const _MessagePage({ + required this.messages, + required this.page, + required this.limit, + required this.hasMore, + this.cursorMessageId, + }); + + final List messages; + final String page; + final int limit; + final bool hasMore; + final String? cursorMessageId; + + Map toJson() => _cursorPaginationToJson( + page: page, + limit: limit, + hasMore: hasMore, + cursorMessageId: cursorMessageId, + oldestMessageId: messages.firstOrNull?.messageId, + newestMessageId: messages.lastOrNull?.messageId, + ); +} + +class _PinnedMessagePage { + const _PinnedMessagePage({ + required this.pins, + required this.page, + required this.limit, + required this.hasMore, + this.cursorMessageId, + }); + + final List pins; + final String page; + final int limit; + final bool hasMore; + final String? cursorMessageId; + + Map toJson() => { + ..._cursorPaginationToJson( + page: page, + limit: limit, + hasMore: hasMore, + cursorMessageId: cursorMessageId, + oldestMessageId: pins.firstOrNull?.messageId, + newestMessageId: pins.lastOrNull?.messageId, + ), + 'order': 'oldest_to_newest_by_pinned_at', + }; +} + +Future> _listMessages( + Database database, + Map arguments, +) async { + final query = _optionalString(arguments, 'query'); + return query == null + ? _listConversationMessages(database, arguments) + : _searchMessages(database, arguments, query); +} + +Future> _listConversationMessages( + Database database, + Map arguments, +) async { + final conversationId = _requiredString(arguments, 'conversation_id'); + final kind = _messageKind(arguments); + if (_optionalString(arguments, 'circle_id') != null) { + throw ArgumentError('circle_id only applies when query is set'); + } + return switch (kind) { + _messageKindAll => () async { + final page = await _readMessagePage(database, arguments); + return { + 'messages': _messagesToJson( + page.messages, + includePinState: _bool(arguments, 'include_pin_state'), + ), + 'pagination': page.toJson(), + }; + }(), + _messageKindAttachments => () async { + final page = await _readMessagePage( + database, + arguments, + attachmentMessagesOnly: true, + ); + return { + 'messages': _messagesToJson(page.messages, includePinState: true), + 'pagination': page.toJson(), + }; + }(), + _messageKindPinned => () async { + final page = await _readPinnedMessagePage( + database, + conversationId: conversationId, + arguments: arguments, + ); + return { + 'messages': await _pinnedMessagesToJson(database, page.pins), + 'pagination': page.toJson(), + }; + }(), + _messageKindMentions => () async { + final page = await _readMentionMessagePage(database, arguments); + return { + 'messages': _messagesToJson(page.messages, includePinState: true), + 'pagination': page.toJson(), + }; + }(), + _messageKindLinks => () async { + final page = await _readLinkMessagePage(database, arguments); + return { + 'messages': _messagesToJson(page.messages, includePinState: true), + 'pagination': page.toJson(), + }; + }(), + _ => throw ArgumentError('Unsupported message kind: $kind'), + }; +} + +Future> _searchMessages( + Database database, + Map arguments, + String query, +) async { + final kind = _messageKind(arguments); + if (kind != _messageKindAll && kind != _messageKindAttachments) { + throw ArgumentError( + 'kind is only supported as all or attachments when query is set', + ); + } + final conversationId = _optionalString(arguments, 'conversation_id'); + final circleId = _optionalString(arguments, 'circle_id'); + final conversationIds = conversationId == null + ? circleId == null + ? const [] + : await database.conversationDao.conversationIdsByCircleId(circleId) + : [conversationId]; + if (circleId != null && conversationIds.isEmpty) { + return { + 'messages': const >[], + 'pagination': { + 'limit': _searchMessageLimit(arguments), + 'has_more': false, + }, + }; + } + final limit = _searchMessageLimit(arguments); + final messages = await database.fuzzySearchMessage( + query: query, + limit: limit + 1, + conversationIds: conversationIds, + userId: _optionalString(arguments, 'sender_id'), + categories: _searchMessageCategories(arguments, kind), + anchorMessageId: _optionalString(arguments, 'cursor_message_id'), + ); + final hasMore = messages.length > limit; + final selected = messages.take(limit).toList(growable: false); + return { + 'messages': _searchMessagesToJson(selected), + 'pagination': { + 'limit': limit, + 'cursor_message_id': _optionalString(arguments, 'cursor_message_id'), + 'next_cursor_message_id': hasMore ? selected.lastOrNull?.messageId : null, + 'has_more': hasMore, + }..removeWhere((_, value) => value == null), + }; +} + +Future>> _pinnedMessagesToJson( + Database database, + List pins, +) async { + final messageIds = pins.map((pin) => pin.messageId).toList(); + final messages = await database.messageDao + .messageItemByMessageIds(messageIds) + .get(); + final messagesById = { + for (final message in messages) message.messageId: message, + }; + return pins + .map((pin) { + final message = messagesById[pin.messageId]; + if (message == null) return null; + return { + ..._messageToJson(message, includePinState: true), + 'pinned_at': _dateTime(pin.createdAt), + }; + }) + .nonNulls + .toList(growable: false); +} + +int _searchMessageLimit(Map arguments) => + _int(arguments, 'limit', defaultValue: 100, min: 1, max: 200); + +List _searchMessageCategories( + Map arguments, + String kind, +) { + final explicit = _optionalStringList(arguments, 'message_types'); + if (explicit.isNotEmpty) return explicit; + return kind == _messageKindAttachments + ? _attachmentMessageCategories + : const []; +} + +String _messageKind(Map arguments) { + final kind = _optionalString(arguments, 'kind') ?? _messageKindAll; + return switch (kind) { + _messageKindAll || + _messageKindAttachments || + _messageKindPinned || + _messageKindMentions || + _messageKindLinks => kind, + _ => throw ArgumentError( + 'kind must be one of all, attachments, pinned, mentions, or links', + ), + }; +} + +Future<_MessagePage> _readMessagePage( + Database database, + Map arguments, { + bool attachmentMessagesOnly = false, +}) async { + final conversationId = _requiredString(arguments, 'conversation_id'); + final limit = _int(arguments, 'limit', defaultValue: 100, min: 1, max: 200); + final page = _messagePage(arguments); + final cursorMessageId = _cursorMessageId(arguments, page); + final before = page == _messagePageBefore + ? await _messageOrderInfoForCursor( + database, + conversationId, + cursorMessageId, + ) + : null; + final after = page == _messagePageAfter + ? await _messageOrderInfoForCursor( + database, + conversationId, + cursorMessageId, + ) + : null; + final ascending = page == _messagePageAfter; + final rows = attachmentMessagesOnly + ? await database.messageDao + .attachmentMessagesByConversationId( + conversationId, + limit: limit + 1, + startInclusive: _date(arguments, 'start'), + endExclusive: _date(arguments, 'end'), + before: before, + after: after, + senderId: _optionalString(arguments, 'sender_id'), + senderIdentityNumber: _optionalString( + arguments, + 'sender_identity_number', + ), + categories: _optionalStringList(arguments, 'message_types'), + ascending: ascending, + ) + .get() + : await database.messageDao + .messagesByConversationIdAndCreatedAtRange( + conversationId, + limit: limit + 1, + startInclusive: _date(arguments, 'start'), + endExclusive: _date(arguments, 'end'), + before: before, + after: after, + senderId: _optionalString(arguments, 'sender_id'), + senderIdentityNumber: _optionalString( + arguments, + 'sender_identity_number', + ), + categories: _optionalStringList(arguments, 'message_types'), + ascending: ascending, + ) + .get(); + return _messagePageFromRows( + rows, + page: page, + limit: limit, + cursorMessageId: cursorMessageId, + ascending: ascending, + ); +} + +Future<_MessagePage> _readMentionMessagePage( + Database database, + Map arguments, +) async { + final conversationId = _requiredString(arguments, 'conversation_id'); + final limit = _int(arguments, 'limit', defaultValue: 100, min: 1, max: 200); + final page = _messagePage(arguments); + final cursorMessageId = _cursorMessageId(arguments, page); + final before = page == _messagePageBefore + ? await _messageOrderInfoForCursor( + database, + conversationId, + cursorMessageId, + ) + : null; + final after = page == _messagePageAfter + ? await _messageOrderInfoForCursor( + database, + conversationId, + cursorMessageId, + ) + : null; + final ascending = page == _messagePageAfter; + final rows = await database.messageDao + .mentionMessagesByConversationId( + conversationId, + limit: limit + 1, + unreadOnly: _bool(arguments, 'unread_only'), + before: before, + after: after, + ascending: ascending, + ) + .get(); + return _messagePageFromRows( + rows, + page: page, + limit: limit, + cursorMessageId: cursorMessageId, + ascending: ascending, + ); +} + +Future<_MessagePage> _readLinkMessagePage( + Database database, + Map arguments, +) async { + final conversationId = _requiredString(arguments, 'conversation_id'); + final limit = _int(arguments, 'limit', defaultValue: 100, min: 1, max: 200); + final page = _messagePage(arguments); + final cursorMessageId = _cursorMessageId(arguments, page); + final before = page == _messagePageBefore + ? await _messageOrderInfoForCursor( + database, + conversationId, + cursorMessageId, + ) + : null; + final after = page == _messagePageAfter + ? await _messageOrderInfoForCursor( + database, + conversationId, + cursorMessageId, + ) + : null; + final ascending = page == _messagePageAfter; + final rows = await database.messageDao + .linkMessagesByConversationId( + conversationId, + limit: limit + 1, + before: before, + after: after, + ascending: ascending, + ) + .get(); + return _messagePageFromRows( + rows, + page: page, + limit: limit, + cursorMessageId: cursorMessageId, + ascending: ascending, + ); +} + +Future<_PinnedMessagePage> _readPinnedMessagePage( + Database database, { + required String conversationId, + required Map arguments, +}) async { + final limit = _int(arguments, 'limit', defaultValue: 100, min: 1, max: 200); + final page = _messagePage(arguments); + final cursorMessageId = _cursorMessageId(arguments, page); + final pins = await database.pinMessageDao.pinMessagesByConversationId( + conversationId: conversationId, + limit: limit + 1, + beforeMessageId: page == _messagePageBefore ? cursorMessageId : null, + afterMessageId: page == _messagePageAfter ? cursorMessageId : null, + ascending: page == _messagePageAfter, + ); + final hasMore = pins.length > limit; + final selected = pins.take(limit).toList(growable: false); + return _PinnedMessagePage( + pins: page == _messagePageAfter + ? selected + : selected.reversed.toList(growable: false), + page: page, + limit: limit, + hasMore: hasMore, + cursorMessageId: cursorMessageId, + ); +} + +_MessagePage _messagePageFromRows( + List rows, { + required String page, + required int limit, + required String? cursorMessageId, + required bool ascending, +}) { + final hasMore = rows.length > limit; + final selected = rows.take(limit).toList(growable: false); + return _MessagePage( + messages: ascending ? selected : selected.reversed.toList(growable: false), + page: page, + limit: limit, + hasMore: hasMore, + cursorMessageId: cursorMessageId, + ); +} + +Map _cursorPaginationToJson({ + required String page, + required int limit, + required bool hasMore, + required String? cursorMessageId, + required String? oldestMessageId, + required String? newestMessageId, +}) => { + 'order': 'oldest_to_newest', + 'page': page, + 'limit': limit, + 'cursor_message_id': cursorMessageId, + 'has_more': hasMore, + 'has_more_direction': page == _messagePageAfter ? 'newer' : 'older', + 'oldest_message_id': oldestMessageId, + 'newest_message_id': newestMessageId, + 'older_page': oldestMessageId == null + ? null + : { + 'page': _messagePageBefore, + 'cursor_message_id': oldestMessageId, + }, + 'newer_page': newestMessageId == null + ? null + : { + 'page': _messagePageAfter, + 'cursor_message_id': newestMessageId, + }, +}..removeWhere((_, value) => value == null); + +String _messagePage(Map arguments) { + final page = _optionalString(arguments, 'page') ?? _messagePageLatest; + return switch (page) { + _messagePageLatest || _messagePageBefore || _messagePageAfter => page, + _ => throw ArgumentError( + 'page must be one of latest, before, or after', + ), + }; +} + +String? _cursorMessageId(Map arguments, String page) { + final cursorMessageId = _optionalString(arguments, 'cursor_message_id'); + if (page == _messagePageLatest) { + if (cursorMessageId != null) { + throw ArgumentError( + 'cursor_message_id is only valid when page is before or after', + ); + } + return null; + } + if (cursorMessageId == null) { + throw ArgumentError('cursor_message_id is required when page is $page'); + } + return cursorMessageId; +} + +Future<_ConversationPage> _readConversationPage( + Database database, + Map arguments, +) async { + final limit = _int(arguments, 'limit', defaultValue: 30, min: 1, max: 100); + final query = _optionalString(arguments, 'query'); + final circleId = _optionalString(arguments, 'circle_id'); + final rows = query == null + ? circleId == null + ? await database.conversationDao.conversationItems().get() + : await database.conversationDao + .conversationsByCircleId(circleId, _conversationScanLimit, 0) + .get() + : await _searchConversations(database, query, _conversationScanLimit); + final scopedRows = query == null || circleId == null + ? rows + : await _filterConversationsByCircle(database, rows, circleId); + return _conversationPageFromRows( + scopedRows, + limit: limit, + cursorConversationId: _optionalString( + arguments, + 'cursor_conversation_id', + ), + order: query == null ? 'app_chat_list' : 'search_relevance', + ); +} + +_ConversationPage _conversationPageFromRows( + List rows, { + required int limit, + required String? cursorConversationId, + required String order, +}) { + final start = _cursorStartIndex( + rows, + cursorConversationId, + (conversation) => conversation.conversationId, + 'cursor_conversation_id', + ); + final selected = rows.skip(start).take(limit + 1).toList(growable: false); + final hasMore = selected.length > limit; + return _ConversationPage( + conversations: selected.take(limit).toList(growable: false), + limit: limit, + hasMore: hasMore, + order: order, + cursorConversationId: cursorConversationId, + ); +} + +Future> _filterConversationsByCircle( + Database database, + List rows, + String circleId, +) async { + final conversationIds = await database.conversationDao + .conversationIdsByCircleId(circleId, limit: _conversationScanLimit); + final conversationIdSet = conversationIds.toSet(); + return rows + .where( + (conversation) => + conversationIdSet.contains(conversation.conversationId), + ) + .toList(growable: false); +} + +_ParticipantPage _participantPageFromRows( + List rows, { + required int limit, + required String? cursorUserId, +}) { + final sorted = [...rows] + ..sort((a, b) { + final name = _compareNullableText(a.fullName, b.fullName); + if (name != 0) return name; + final identityNumber = a.identityNumber.compareTo(b.identityNumber); + if (identityNumber != 0) return identityNumber; + return a.userId.compareTo(b.userId); + }); + final start = _cursorStartIndex( + sorted, + cursorUserId, + (participant) => participant.userId, + 'cursor_user_id', + ); + final selected = sorted.skip(start).take(limit + 1).toList(growable: false); + final hasMore = selected.length > limit; + return _ParticipantPage( + participants: selected.take(limit).toList(growable: false), + limit: limit, + hasMore: hasMore, + cursorUserId: cursorUserId, + ); +} + +int _cursorStartIndex( + List rows, + String? cursor, + String Function(T row) idOf, + String cursorName, +) { + if (cursor == null) return 0; + final index = rows.indexWhere((row) => idOf(row) == cursor); + if (index < 0) throw ArgumentError('$cursorName not found'); + return index + 1; +} + +int _compareNullableText(String? a, String? b) { + final left = a?.trim().toLowerCase(); + final right = b?.trim().toLowerCase(); + if (left == null || left.isEmpty) { + return right == null || right.isEmpty ? 0 : 1; + } + if (right == null || right.isEmpty) return -1; + return left.compareTo(right); +} + +Future _messageOrderInfoForCursor( + Database database, + String conversationId, + String? cursorMessageId, +) async { + if (cursorMessageId == null) { + throw ArgumentError('cursor_message_id is required'); + } + final message = await _messageById(database, cursorMessageId); + if (message.conversationId != conversationId) { + throw ArgumentError('cursor_message_id is not in conversation_id'); + } + final info = await database.messageDao.messageOrderInfo(cursorMessageId); + if (info == null) throw StateError('Message order info not found'); + return info; +} + Future _conversationById( Database database, String conversationId, @@ -991,11 +1507,6 @@ Map? _linkPreviewToJson(MessageItem message) { return link.isEmpty ? null : link; } -Map _linkToJson(MessageItem message) => { - ..._messageToJson(message, includePinState: true), - 'link': _linkPreviewToJson(message), -}..removeWhere((_, value) => value == null); - Map _participantToJson(ParticipantUser participant) => { 'conversation_id': participant.conversationId, 'user_id': participant.userId, @@ -1145,19 +1656,104 @@ const _stringArraySchema = { 'items': {'type': 'string'}, }; +const _conversationIdProperty = { + 'type': 'string', + 'description': + 'Mixin conversation_id. Use mixin_resolve_conversation first when only a name or mixin:// URL is known.', +}; + +const _messageIdProperty = { + 'type': 'string', + 'description': 'Mixin message_id.', +}; + +const _limit100Property = { + 'type': 'integer', + 'description': 'Maximum number of items to return.', + 'default': 100, + 'minimum': 1, + 'maximum': 200, +}; + +const _limit50Property = { + 'type': 'integer', + 'description': 'Maximum number of items to return.', + 'default': 50, + 'minimum': 1, + 'maximum': 200, +}; + +const _limit30Property = { + 'type': 'integer', + 'description': 'Maximum number of items to return.', + 'default': 30, + 'minimum': 1, + 'maximum': 100, +}; + +const _conversationCursorProperty = { + 'type': 'string', + 'description': + 'Optional cursor for the next page. Use pagination.next_cursor_conversation_id from the previous result.', +}; + +const _participantCursorProperty = { + 'type': 'string', + 'description': + 'Optional cursor for the next page. Use pagination.next_cursor_user_id from the previous result.', +}; + +const _messageCursorProperties = { + 'page': { + 'type': 'string', + 'enum': [_messagePageLatest, _messagePageBefore, _messagePageAfter], + 'default': _messagePageLatest, + 'description': + 'latest returns the latest matching messages. before returns messages older than cursor_message_id. after returns messages newer than cursor_message_id. Results are always oldest_to_newest.', + }, + 'cursor_message_id': { + 'type': 'string', + 'description': + 'Required when page is before or after. Use pagination.older_page.cursor_message_id or pagination.newer_page.cursor_message_id from the previous result.', + }, + 'limit': _limit100Property, +}; + +const _conversationRangeProperties = { + 'conversation_id': _conversationIdProperty, + 'start': { + 'type': 'string', + 'format': 'date-time', + 'description': 'Inclusive ISO-8601 lower bound for message created_at.', + }, + 'end': { + 'type': 'string', + 'format': 'date-time', + 'description': 'Exclusive ISO-8601 upper bound for message created_at.', + }, +}; + const _toolSpecs = [ _Tool( 'mixin_get_app_status', - 'Get login, active conversation, app version, and MCP capability status.', + 'Get login state, app version, active conversation ids, permission scopes, and enabled MCP tools.', ), _Tool( 'mixin_list_conversations', - 'List recent conversations, search conversations, or list a circle.', + 'List conversations. Without query, returns app chat-list order. With query, searches conversations. With circle_id, restricts results to that circle. Use cursor_conversation_id to continue.', properties: { - 'query': {'type': 'string'}, - 'circle_id': {'type': 'string'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, + 'query': { + 'type': 'string', + 'description': + 'Optional fuzzy conversation name search. Omit to list conversations.', + }, + 'circle_id': { + 'type': 'string', + 'description': + 'Optional circle id from mixin_list_circles. When set, only conversations in that circle are returned.', + }, + 'limit': _limit30Property, + 'cursor_conversation_id': _conversationCursorProperty, }, ), _Tool( @@ -1165,50 +1761,96 @@ const _toolSpecs = [ 'Get one conversation by conversation_id.', required: ['conversation_id'], properties: { - 'conversation_id': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, }, ), _Tool( 'mixin_resolve_conversation', - 'Resolve a conversation from conversation_id, mixin URI, or query.', + 'Resolve exactly one conversation from conversation_id, mixin://conversations/, or a fuzzy query. Provide one of conversation_id, uri, or query.', properties: { - 'conversation_id': {'type': 'string'}, - 'uri': {'type': 'string'}, - 'query': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, + 'uri': { + 'type': 'string', + 'description': + 'Mixin URI such as mixin://conversations/.', + }, + 'query': { + 'type': 'string', + 'description': + 'Fuzzy conversation name search. Returns the best match.', + }, + }, + schema: { + 'oneOf': [ + { + 'required': ['conversation_id'], + }, + { + 'required': ['uri'], + }, + { + 'required': ['query'], + }, + ], }, ), _Tool( 'mixin_get_conversation_stats', - 'Get message count and first/last timestamps for a conversation.', + 'Get message_count, first_message_at, and last_message_at for a conversation and optional time range.', required: ['conversation_id'], properties: _conversationRangeProperties, ), _Tool( - 'mixin_read_messages', - 'Read conversation messages by range, sender, type, offset, and limit.', - required: ['conversation_id'], + 'mixin_list_messages', + 'List or search messages. With query, searches globally or inside conversation_id/circle_id. Without query, conversation_id is required and messages are listed by cursor. Use kind to list all messages, attachments, pinned messages, mentions, or links.', properties: { ..._conversationRangeProperties, - 'sender_id': {'type': 'string'}, - 'sender_identity_number': {'type': 'string'}, - 'message_types': _stringArraySchema, - 'include_pin_state': {'type': 'boolean'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, - }, - ), - _Tool( - 'mixin_search_messages', - 'Search messages globally, inside a conversation, or inside a circle.', - required: ['query'], - properties: { - 'query': {'type': 'string'}, - 'conversation_id': {'type': 'string'}, - 'circle_id': {'type': 'string'}, - 'sender_id': {'type': 'string'}, - 'message_types': _stringArraySchema, - 'limit': {'type': 'integer'}, - 'anchor_id': {'type': 'string'}, + ..._messageCursorProperties, + 'query': { + 'type': 'string', + 'description': + 'Optional search text. When omitted, conversation_id is required and the tool lists messages from that conversation.', + }, + 'circle_id': { + 'type': 'string', + 'description': 'Optional search scope. Only applies when query is set.', + }, + 'kind': { + 'type': 'string', + 'enum': [ + _messageKindAll, + _messageKindAttachments, + _messageKindPinned, + _messageKindMentions, + _messageKindLinks, + ], + 'default': _messageKindAll, + 'description': + 'Message filter. For search, only all and attachments are supported. For conversation listing, all values are supported.', + }, + 'sender_id': { + 'type': 'string', + 'description': 'Optional sender user_id filter.', + }, + 'sender_identity_number': { + 'type': 'string', + 'description': + 'Optional sender identity number filter. Only applies when query is omitted.', + }, + 'message_types': { + ..._stringArraySchema, + 'description': 'Optional Mixin message category filters.', + }, + 'include_pin_state': { + 'type': 'boolean', + 'default': false, + 'description': 'Whether every returned message includes is_pinned.', + }, + 'unread_only': { + 'type': 'boolean', + 'default': false, + 'description': 'Only applies when kind is mentions.', + }, }, ), _Tool( @@ -1216,113 +1858,78 @@ const _toolSpecs = [ 'Get a message by message_id.', required: ['message_id'], properties: { - 'message_id': {'type': 'string'}, + 'message_id': _messageIdProperty, }, ), _Tool( 'mixin_get_message_context', - 'Read messages around a message_id.', + 'Read messages immediately before and after one message_id.', required: ['message_id'], properties: { - 'message_id': {'type': 'string'}, - 'limit': {'type': 'integer'}, + 'message_id': _messageIdProperty, + 'limit': { + 'type': 'integer', + 'description': + 'Number of messages to read before and after the target.', + 'default': 10, + 'minimum': 1, + 'maximum': 50, + }, }, ), _Tool( - 'mixin_read_image_text', - 'Run local OCR for an image message.', + 'mixin_read_image_message_text', + 'Run local OCR for one image message.', required: ['conversation_id', 'message_id'], properties: { - 'conversation_id': {'type': 'string'}, - 'message_id': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, + 'message_id': _messageIdProperty, }, ), _Tool( - 'mixin_list_attachments', - 'List attachment metadata for a conversation.', + 'mixin_list_conversation_participants', + 'List or search participants in a conversation. Results are ordered by full_name, identity_number, then user_id.', required: ['conversation_id'], properties: { - ..._conversationRangeProperties, - 'sender_id': {'type': 'string'}, - 'sender_identity_number': {'type': 'string'}, - 'message_types': _stringArraySchema, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, - }, - ), - _Tool( - 'mixin_list_pinned_messages', - 'List pinned messages for a conversation.', - required: ['conversation_id'], - properties: { - 'conversation_id': {'type': 'string'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, - }, - ), - _Tool( - 'mixin_list_participants', - 'List or search participants in a conversation.', - required: ['conversation_id'], - properties: { - 'conversation_id': {'type': 'string'}, - 'query': {'type': 'string'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, + 'conversation_id': _conversationIdProperty, + 'query': { + 'type': 'string', + 'description': 'Optional user_id, identity number, or name search.', + }, + 'limit': _limit50Property, + 'cursor_user_id': _participantCursorProperty, }, ), _Tool( - 'mixin_resolve_user_in_conversation', - 'Resolve participants by user_id, identity number, or name.', + 'mixin_resolve_conversation_participant', + 'Resolve participants in one conversation by user_id, identity number, or name.', required: ['conversation_id', 'query'], properties: { - 'conversation_id': {'type': 'string'}, - 'query': {'type': 'string'}, - 'limit': {'type': 'integer'}, + 'conversation_id': _conversationIdProperty, + 'query': { + 'type': 'string', + 'description': 'User id, identity number, or name.', + }, + 'limit': { + 'type': 'integer', + 'description': 'Maximum number of matching participants to return.', + 'default': 5, + 'minimum': 1, + 'maximum': 20, + }, }, ), _Tool( 'mixin_list_circles', 'List local circles and their conversation counts.', ), - _Tool( - 'mixin_list_circle_conversations', - 'List conversations in a circle.', - required: ['circle_id'], - properties: { - 'circle_id': {'type': 'string'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, - }, - ), - _Tool( - 'mixin_read_mentions', - 'Read mention messages in a conversation without marking them read.', - required: ['conversation_id'], - properties: { - 'conversation_id': {'type': 'string'}, - 'unread_only': {'type': 'boolean'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, - }, - ), - _Tool( - 'mixin_list_links', - 'List messages with link previews in a conversation.', - required: ['conversation_id'], - properties: { - 'conversation_id': {'type': 'string'}, - 'offset': {'type': 'integer'}, - 'limit': {'type': 'integer'}, - }, - ), _Tool( 'mixin_open_conversation', 'Open a conversation in the Mixin UI.', scope: _McpPermissionScope.appControl, required: ['conversation_id'], properties: { - 'conversation_id': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, }, ), _Tool( @@ -1331,45 +1938,51 @@ const _toolSpecs = [ scope: _McpPermissionScope.appControl, required: ['message_id'], properties: { - 'message_id': {'type': 'string'}, + 'message_id': _messageIdProperty, }, ), _Tool( - 'mixin_get_draft', + 'mixin_get_conversation_draft', 'Get the current draft text for a conversation.', scope: _McpPermissionScope.draftWrite, required: ['conversation_id'], properties: { - 'conversation_id': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, }, ), _Tool( - 'mixin_set_draft', + 'mixin_set_conversation_draft', 'Replace the draft text for a conversation. Does not send.', scope: _McpPermissionScope.draftWrite, required: ['conversation_id', 'text'], properties: { - 'conversation_id': {'type': 'string'}, - 'text': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, + 'text': { + 'type': 'string', + 'description': 'Draft text. This never sends a message.', + }, }, ), _Tool( - 'mixin_insert_text', + 'mixin_insert_conversation_text', 'Insert text into the active input, or append to stored draft.', scope: _McpPermissionScope.draftWrite, required: ['conversation_id', 'text'], properties: { - 'conversation_id': {'type': 'string'}, - 'text': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, + 'text': { + 'type': 'string', + 'description': 'Text to insert. This never sends a message.', + }, }, ), _Tool( - 'mixin_clear_draft', + 'mixin_clear_conversation_draft', 'Clear the draft text for a conversation. Does not send.', scope: _McpPermissionScope.draftWrite, required: ['conversation_id'], properties: { - 'conversation_id': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, }, ), _Tool( @@ -1378,8 +1991,14 @@ const _toolSpecs = [ scope: _McpPermissionScope.circleManagement, required: ['name'], properties: { - 'name': {'type': 'string'}, - 'conversation_ids': _stringArraySchema, + 'name': { + 'type': 'string', + 'description': 'Circle name.', + }, + 'conversation_ids': { + ..._stringArraySchema, + 'description': 'Optional initial conversation_ids.', + }, }, ), _Tool( @@ -1408,7 +2027,10 @@ const _toolSpecs = [ required: ['circle_id', 'conversation_ids'], properties: { 'circle_id': {'type': 'string'}, - 'conversation_ids': _stringArraySchema, + 'conversation_ids': { + ..._stringArraySchema, + 'description': 'Conversation ids to add.', + }, }, ), _Tool( @@ -1418,24 +2040,27 @@ const _toolSpecs = [ required: ['circle_id', 'conversation_ids'], properties: { 'circle_id': {'type': 'string'}, - 'conversation_ids': _stringArraySchema, + 'conversation_ids': { + ..._stringArraySchema, + 'description': 'Conversation ids to remove.', + }, }, ), _Tool( - 'mixin_attach_message_to_ai', + 'mixin_attach_message_to_ai_context', 'Attach a message to the app AI context chip for its conversation.', scope: _McpPermissionScope.appControl, required: ['message_id'], properties: { - 'message_id': {'type': 'string'}, + 'message_id': _messageIdProperty, }, ), _Tool( - 'mixin_list_ai_threads', + 'mixin_list_conversation_ai_threads', 'List AI threads for a conversation.', required: ['conversation_id'], properties: { - 'conversation_id': {'type': 'string'}, + 'conversation_id': _conversationIdProperty, }, ), _Tool( @@ -1447,21 +2072,15 @@ const _toolSpecs = [ }, ), _Tool( - 'mixin_get_ai_tool_events', + 'mixin_get_ai_message_tool_events', 'Read stored AI tool call/result events for an AI message.', required: ['message_id'], properties: { - 'message_id': {'type': 'string'}, + 'message_id': _messageIdProperty, }, ), ]; -const _conversationRangeProperties = { - 'conversation_id': {'type': 'string'}, - 'start': {'type': 'string', 'description': 'Inclusive ISO-8601 timestamp.'}, - 'end': {'type': 'string', 'description': 'Exclusive ISO-8601 timestamp.'}, -}; - class _Tool { const _Tool( this.name, @@ -1469,6 +2088,7 @@ class _Tool { this.scope = _McpPermissionScope.read, this.required = const [], this.properties = const {}, + this.schema = const {}, }); final String name; @@ -1476,10 +2096,12 @@ class _Tool { final _McpPermissionScope scope; final List required; final Map properties; + final Map schema; Map get inputSchema => { ..._emptyObjectSchema, 'properties': properties, 'required': required, + ...schema, }; } From 5195c783755715bfd0869a6b7fe345f1f23a8130 Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 7 May 2026 11:01:51 +0800 Subject: [PATCH 51/52] Use TOON for MCP tool text results --- lib/ai/tools/ai_conversation_tool_service.dart | 6 +++--- lib/utils/mcp/mixin_mcp_server.dart | 2 +- test/ai/ai_conversation_context_test.dart | 15 +++++++++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/lib/ai/tools/ai_conversation_tool_service.dart b/lib/ai/tools/ai_conversation_tool_service.dart index b99ff889b2..a976b4d78a 100644 --- a/lib/ai/tools/ai_conversation_tool_service.dart +++ b/lib/ai/tools/ai_conversation_tool_service.dart @@ -701,7 +701,7 @@ class AiConversationToolKit { ); try { final result = await fn(); - final encodedResult = _encodeToolResult(result); + final encodedResult = encodeAiToolResult(result); d( 'AI tool execute done: conversationId=$conversationId ' 'tool=$name id=$id elapsedMs=${stopwatch.elapsedMilliseconds} ' @@ -728,7 +728,7 @@ class AiConversationToolKit { errorText: error.toString(), ), ); - return _encodeToolResult({'error': '$error'}); + return encodeAiToolResult({'error': '$error'}); } } } @@ -1094,7 +1094,7 @@ String _truncateText(String text, int? maxLength) { return '${text.substring(0, end)}$suffix'; } -String _encodeToolResult(Map result) => +String encodeAiToolResult(Map result) => encode(_stripNullValues(result)); Object? _stripNullValues(Object? value) { diff --git a/lib/utils/mcp/mixin_mcp_server.dart b/lib/utils/mcp/mixin_mcp_server.dart index 30d0babe8a..f12d27367b 100644 --- a/lib/utils/mcp/mixin_mcp_server.dart +++ b/lib/utils/mcp/mixin_mcp_server.dart @@ -538,7 +538,7 @@ class MixinMcpServer extends ChangeNotifier { e('MCP tool error ${tool.name}: $error', stacktrace); } return mcp.CallToolResult( - content: [mcp.TextContent(text: const JsonEncoder().convert(data))], + content: [mcp.TextContent(text: encodeAiToolResult(data))], structuredContent: data, ); }, diff --git a/test/ai/ai_conversation_context_test.dart b/test/ai/ai_conversation_context_test.dart index 96f238a20f..445361fed2 100644 --- a/test/ai/ai_conversation_context_test.dart +++ b/test/ai/ai_conversation_context_test.dart @@ -11,9 +11,20 @@ import 'package:flutter_app/db/mixin_database.dart'; import 'package:flutter_app/enum/message_category.dart'; import 'package:flutter_test/flutter_test.dart'; import 'package:mixin_bot_sdk_dart/mixin_bot_sdk_dart.dart'; -import 'package:toon_format/toon_format.dart'; void main() { + test('tool result encoder uses TOON text and strips null values', () { + final encoded = encodeAiToolResult({ + 'messages': [ + {'message_id': '1', 'content': 'hello', 'unused': null}, + ], + }); + + expect(encoded, contains('messages')); + expect(encoded, isNot(contains('"messages"'))); + expect(encoded, isNot(contains('unused'))); + }); + group('AI conversation context', () { late MixinDatabase mixinDatabase; late FtsDatabase ftsDatabase; @@ -124,7 +135,7 @@ void main() { limit: 1, ); final targetJson = targetResult.toJson(); - expect(encode(targetJson), contains('context_messages')); + expect(encodeAiToolResult(targetJson), contains('context_messages')); final targetMessage = (targetJson['messages'] as List).single as Map; From de3a61f824b5b8982a868797b24e9d5dbe88729e Mon Sep 17 00:00:00 2001 From: bin <17426470+boyan01@users.noreply.github.com> Date: Thu, 7 May 2026 20:10:16 +0800 Subject: [PATCH 52/52] fix: normalize Anthropic base URL and update related hints and tests --- lib/ai/ai_provider_requester.dart | 9 ++++++++- lib/ui/setting/ai_provider_edit_page.dart | 6 +++--- test/ai/ai_provider_requester_test.dart | 24 +++++++++++++++++++++++ 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/lib/ai/ai_provider_requester.dart b/lib/ai/ai_provider_requester.dart index 59d847bb47..79f73e578e 100644 --- a/lib/ai/ai_provider_requester.dart +++ b/lib/ai/ai_provider_requester.dart @@ -169,7 +169,7 @@ class AiProviderRequester { ), AiProviderType.anthropic => anthropic( apiKey: config.apiKey, - baseUrl: _emptyToNull(config.baseUrl), + baseUrl: normalizeAiProviderBaseUrl(config.type, config.baseUrl), ), AiProviderType.gemini => googleAI(apiKey: config.apiKey), }; @@ -187,6 +187,13 @@ class AiProviderRequester { } } +String? normalizeAiProviderBaseUrl(AiProviderType type, String value) { + final trimmed = value.trim(); + if (trimmed.isEmpty) return null; + if (type != AiProviderType.anthropic) return trimmed; + return trimmed.replaceFirst(RegExp(r'/v1/?$'), ''); +} + Map _genkitResponseMetadata( genkit.GenerateResponseHelper response, { required int elapsedMs, diff --git a/lib/ui/setting/ai_provider_edit_page.dart b/lib/ui/setting/ai_provider_edit_page.dart index 66429690aa..6d4f0da0fc 100644 --- a/lib/ui/setting/ai_provider_edit_page.dart +++ b/lib/ui/setting/ai_provider_edit_page.dart @@ -508,13 +508,13 @@ class AiProviderEditPage extends HookConsumerWidget { static String _defaultBaseUrlFor(AiProviderType type) => switch (type) { AiProviderType.openaiCompatible => '', - AiProviderType.anthropic => 'https://api.anthropic.com/v1', + AiProviderType.anthropic => 'https://api.anthropic.com', AiProviderType.gemini => '', }; static String _baseUrlHintFor(AiProviderType type) => switch (type) { AiProviderType.openaiCompatible => 'https://api.example.com/v1', - AiProviderType.anthropic => 'https://api.anthropic.com/v1', + AiProviderType.anthropic => 'https://api.anthropic.com', AiProviderType.gemini => 'https://generativelanguage.googleapis.com/v1beta', }; @@ -522,7 +522,7 @@ class AiProviderEditPage extends HookConsumerWidget { AiProviderType.openaiCompatible => 'For OpenAI-compatible APIs, use the server root that exposes /chat/completions.', AiProviderType.anthropic => - 'Anthropic uses the Messages API under /v1/messages.', + 'Use the API host root. The app appends /v1/messages automatically.', AiProviderType.gemini => 'Gemini uses the Google Generative Language API and appends /models/{model}:streamGenerateContent automatically.', }; diff --git a/test/ai/ai_provider_requester_test.dart b/test/ai/ai_provider_requester_test.dart index bb55fd83fe..cdf05111e9 100644 --- a/test/ai/ai_provider_requester_test.dart +++ b/test/ai/ai_provider_requester_test.dart @@ -66,5 +66,29 @@ void main() { ), ); }); + + test('normalizes Anthropic base URL to the API host root', () { + expect( + normalizeAiProviderBaseUrl( + AiProviderType.anthropic, + 'https://api.anthropic.com/v1', + ), + 'https://api.anthropic.com', + ); + expect( + normalizeAiProviderBaseUrl( + AiProviderType.anthropic, + 'https://api.anthropic.com/v1/', + ), + 'https://api.anthropic.com', + ); + expect( + normalizeAiProviderBaseUrl( + AiProviderType.openaiCompatible, + 'https://api.example.com/v1', + ), + 'https://api.example.com/v1', + ); + }); }); }