diff --git a/apps/mobile/src/app/_layout.tsx b/apps/mobile/src/app/_layout.tsx index 52ff29a585..5d7a66ab35 100644 --- a/apps/mobile/src/app/_layout.tsx +++ b/apps/mobile/src/app/_layout.tsx @@ -11,7 +11,7 @@ import * as Sentry from '@sentry/react-native'; import { QueryClientProvider } from '@tanstack/react-query'; import { isRunningInExpoGo } from 'expo'; import { useFonts } from 'expo-font'; -import { type Href, Slot, useNavigationContainerRef, useRouter, useSegments } from 'expo-router'; +import { Slot, useNavigationContainerRef, useRouter, useSegments } from 'expo-router'; import * as SplashScreen from 'expo-splash-screen'; import { StatusBar } from 'expo-status-bar'; import { requestTrackingPermissionsAsync } from 'expo-tracking-transparency'; @@ -130,7 +130,7 @@ function RootLayoutNav() { // Navigate to pending notification deep link (cold start / background tap) const pendingLink = getPendingNotificationLink(); if (pendingLink) { - router.push(pendingLink as Href); + router.push(pendingLink); } } }, [token, isLoading, updateRequired, inAuthGroup, inForceUpdate, router]); diff --git a/apps/mobile/src/components/agents/cloud-agent-notification-prompt.tsx b/apps/mobile/src/components/agents/cloud-agent-notification-prompt.tsx new file mode 100644 index 0000000000..0a60937dcb --- /dev/null +++ b/apps/mobile/src/components/agents/cloud-agent-notification-prompt.tsx @@ -0,0 +1,167 @@ +import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query'; +import * as SecureStore from 'expo-secure-store'; +import { Bell } from 'lucide-react-native'; +import { useCallback, useEffect, useMemo, useState } from 'react'; +import { Alert, Linking, View } from 'react-native'; +import Animated, { FadeIn, FadeOut } from 'react-native-reanimated'; +import { toast } from 'sonner-native'; + +import { Button } from '@/components/ui/button'; +import { Text } from '@/components/ui/text'; +import { useThemeColors } from '@/lib/hooks/use-theme-colors'; +import { + getDevicePushToken, + getNotificationPermissionStatus, + getPlatform, + registerForPushNotifications, +} from '@/lib/notifications'; +import { CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY } from '@/lib/storage-keys'; +import { useTRPC } from '@/lib/trpc'; + +const promptDelayMs = 10_000; + +export function CloudAgentNotificationPrompt({ enabled }: { enabled: boolean }) { + const [visible, setVisible] = useState(false); + const colors = useThemeColors(); + const trpc = useTRPC(); + const queryClient = useQueryClient(); + + const pushTokensOptions = trpc.user.getMyPushTokens.queryOptions(); + const pushTokensQueryKey = useMemo(() => pushTokensOptions.queryKey, [pushTokensOptions]); + + const pushTokensQuery = useQuery({ + ...pushTokensOptions, + enabled, + retry: false, + }); + + const registerToken = useMutation( + trpc.user.registerPushToken.mutationOptions({ + onError: error => { + toast.error(error.message); + }, + onSuccess: async () => { + await queryClient.invalidateQueries({ queryKey: pushTokensQueryKey }); + }, + }) + ); + + useEffect(() => { + if (!enabled || pushTokensQuery.isPending || pushTokensQuery.isError) { + return undefined; + } + + const abortController = new AbortController(); + const { signal } = abortController; + let timeout: ReturnType | null = null; + + // oxlint's flow analysis can't tell that `signal.aborted` flips + // asynchronously from the cleanup callback, so it flags each read as + // "always falsy". The check is load-bearing — bail if the effect was + // cleaned up while an `await` was pending. + /* eslint-disable @typescript-eslint/no-unnecessary-condition */ + async function check() { + const seen = await SecureStore.getItemAsync(CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY); + if (seen || signal.aborted) { + return; + } + + const status = await getNotificationPermissionStatus(); + if (signal.aborted) { + return; + } + if (status === 'granted') { + const deviceToken = await getDevicePushToken(); + if (signal.aborted) { + return; + } + const alreadyRegistered = Boolean( + deviceToken && (pushTokensQuery.data ?? []).some(t => t.token === deviceToken) + ); + if (alreadyRegistered) { + return; + } + } + + timeout = setTimeout(() => { + if (!signal.aborted) { + setVisible(true); + } + }, promptDelayMs); + } + /* eslint-enable @typescript-eslint/no-unnecessary-condition */ + + void check(); + + return () => { + abortController.abort(); + if (timeout) { + clearTimeout(timeout); + } + }; + }, [enabled, pushTokensQuery.data, pushTokensQuery.isError, pushTokensQuery.isPending]); + + const handleEnable = useCallback(async () => { + const currentStatus = await getNotificationPermissionStatus(); + + if (currentStatus === 'denied') { + Alert.alert( + 'Notifications Disabled', + 'To enable notifications, turn them on in your device settings.', + [ + { text: 'Cancel', style: 'cancel' }, + { text: 'Open Settings', onPress: () => void Linking.openSettings() }, + ] + ); + return; + } + + const token = await registerForPushNotifications(); + if (!token) { + toast.error('Notification permission was not granted'); + return; + } + + registerToken.mutate( + { token, platform: getPlatform() }, + { + onSuccess: () => { + void SecureStore.setItemAsync(CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY, 'true'); + setVisible(false); + toast.success('Notifications enabled'); + }, + } + ); + }, [registerToken]); + + const handleDismiss = useCallback(async () => { + await SecureStore.setItemAsync(CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY, 'true'); + setVisible(false); + }, []); + + if (!visible) { + return null; + } + + return ( + + + + + Get notified when your agent finishes + + We'll ping your phone when a task completes, so you can close the app. + + + + + + + + + ); +} diff --git a/apps/mobile/src/components/agents/session-detail-content.tsx b/apps/mobile/src/components/agents/session-detail-content.tsx index 1435dc6d26..a07de34a48 100644 --- a/apps/mobile/src/components/agents/session-detail-content.tsx +++ b/apps/mobile/src/components/agents/session-detail-content.tsx @@ -1,21 +1,22 @@ -import { useCallback, useEffect, useMemo, useState } from 'react'; -import { ActivityIndicator, FlatList, KeyboardAvoidingView, Platform, View } from 'react-native'; -import { useAtomValue } from 'jotai'; import { type CloudStatus, type KiloSessionId, type StoredMessage } from 'cloud-agent-sdk'; -import { toast } from 'sonner-native'; +import { useAtomValue } from 'jotai'; +import { useCallback, useEffect, useMemo } from 'react'; +import { ActivityIndicator, FlatList, KeyboardAvoidingView, Platform, View } from 'react-native'; import { useSafeAreaInsets } from 'react-native-safe-area-context'; +import { toast } from 'sonner-native'; import { ChatComposer } from '@/components/agents/chat-composer'; +import { CloudAgentNotificationPrompt } from '@/components/agents/cloud-agent-notification-prompt'; import { ConnectivityBanner } from '@/components/agents/connectivity-banner'; import { MessageBubble } from '@/components/agents/message-bubble'; -import { normalizeAgentMode } from '@/components/agents/mode-options'; -import { type AgentMode } from '@/components/agents/mode-selector'; import { PermissionCard } from '@/components/agents/permission-card'; import { QuestionCard } from '@/components/agents/question-card'; import { useSessionManager } from '@/components/agents/session-provider'; import { SessionStatusIndicator } from '@/components/agents/session-status-indicator'; import { useInteractionHandlers } from '@/components/agents/use-interaction-handlers'; +import { useMarkSessionRead } from '@/components/agents/use-mark-session-read'; import { useSessionAutoScroll } from '@/components/agents/use-session-auto-scroll'; +import { useSessionConfigSync } from '@/components/agents/use-session-config-sync'; import { WorkingIndicator } from '@/components/agents/working-indicator'; import { ScreenHeader } from '@/components/screen-header'; import { Text } from '@/components/ui/text'; @@ -59,6 +60,8 @@ export function SessionDetailContent({ sessionId }: Readonly(() => - normalizeAgentMode(fetchedData?.mode) - ); - - const [currentModel, setCurrentModel] = useState(fetchedData?.model ?? ''); - const [currentVariant, setCurrentVariant] = useState(fetchedData?.variant ?? ''); - - // Sync mode/model/variant from session data and SDK session config. - // The SDK's sessionConfig is updated from assistant messages during snapshot - // replay, so it captures the model actually used in the conversation. - useEffect(() => { - const mode = sessionConfig?.mode ?? fetchedData?.mode; - if (mode) { - setCurrentMode(normalizeAgentMode(mode)); - } - - const model = sessionConfig?.model ?? fetchedData?.model; - if (model) { - setCurrentModel(model); - } - - const variant = sessionConfig?.variant ?? fetchedData?.variant; - if (variant) { - setCurrentVariant(variant); - } - }, [ - sessionConfig?.mode, - sessionConfig?.model, - sessionConfig?.variant, - fetchedData?.mode, - fetchedData?.model, - fetchedData?.variant, - ]); - - // Auto-select first available model when session has no model (e.g. remote CLI sessions) - useEffect(() => { - if (currentModel || modelOptions.length === 0 || fetchedData === null) { - return; - } - const firstModel = modelOptions[0]; - if (firstModel) { - setCurrentModel(firstModel.id); - setCurrentVariant(firstModel.variants[0] ?? ''); - } - }, [currentModel, modelOptions, fetchedData]); + const { + currentMode, + currentModel, + currentVariant, + setCurrentMode, + setCurrentModel, + setCurrentVariant, + } = useSessionConfigSync({ fetchedData, sessionConfig, modelOptions }); const { flatListRef, @@ -181,6 +147,13 @@ export function SessionDetailContent({ sessionId }: Readonly { @@ -221,6 +194,8 @@ export function SessionDetailContent({ sessionId }: Readonly {renderContent()} + + {activeQuestion ? ( { + void Notifications.setBadgeCountAsync(badgeCount); + }, + onError: err => { + toast.error(err.message || 'Failed to update badge count'); + }, + }) + ); + + useFocusEffect( + useCallback(() => { + isFocusedRef.current = true; + setActiveCliSession(sessionId); + markChatRead({ channelId: sessionId }); + + // If a notification for this session arrives while the screen is already open it is + // visually suppressed, but the worker still incremented the server-side count. + const subscription = Notifications.addNotificationReceivedListener(notification => { + const data = parseNotificationData(notification.request.content.data); + if (data?.type === 'cloud_agent_session' && data.cliSessionId === sessionId) { + markChatRead({ channelId: sessionId }); + } + }); + + return () => { + isFocusedRef.current = false; + setActiveCliSession(null); + subscription.remove(); + }; + }, [sessionId, markChatRead]) + ); + + // Clear badge when the app returns to the foreground while this session is focused. + // `useFocusEffect` already handles the focus/sessionId change case; this effect + // only fires on the inactive -> active transition to avoid a duplicate call. + const wasActiveRef = useRef(isActive); + useEffect(() => { + const becameActive = isActive && !wasActiveRef.current; + wasActiveRef.current = isActive; + if (becameActive && isFocusedRef.current) { + markChatRead({ channelId: sessionId }); + } + }, [isActive, sessionId, markChatRead]); +} diff --git a/apps/mobile/src/components/agents/use-session-config-sync.ts b/apps/mobile/src/components/agents/use-session-config-sync.ts new file mode 100644 index 0000000000..cc8618babd --- /dev/null +++ b/apps/mobile/src/components/agents/use-session-config-sync.ts @@ -0,0 +1,86 @@ +import { useEffect, useState } from 'react'; + +import { normalizeAgentMode } from '@/components/agents/mode-options'; +import { type AgentMode } from '@/components/agents/mode-selector'; +import { type ModelOption } from '@/lib/hooks/use-available-models'; + +type SessionConfigSnapshot = { + mode?: string | null; + model?: string | null; + variant?: string | null; +}; + +type UseSessionConfigSyncOptions = { + fetchedData: SessionConfigSnapshot | null; + sessionConfig: SessionConfigSnapshot | null | undefined; + modelOptions: ModelOption[]; +}; + +type UseSessionConfigSyncResult = { + currentMode: AgentMode; + currentModel: string; + currentVariant: string; + setCurrentMode: (mode: AgentMode) => void; + setCurrentModel: (model: string) => void; + setCurrentVariant: (variant: string) => void; +}; + +// Keeps the composer's mode/model/variant in sync with the session's +// fetched data and the SDK session config (which is updated from assistant +// messages during snapshot replay). For sessions without a configured model +// (e.g. remote CLI sessions), auto-selects the first available model. +export function useSessionConfigSync({ + fetchedData, + sessionConfig, + modelOptions, +}: UseSessionConfigSyncOptions): UseSessionConfigSyncResult { + const [currentMode, setCurrentMode] = useState(() => + normalizeAgentMode(fetchedData?.mode) + ); + const [currentModel, setCurrentModel] = useState(fetchedData?.model ?? ''); + const [currentVariant, setCurrentVariant] = useState(fetchedData?.variant ?? ''); + + useEffect(() => { + const mode = sessionConfig?.mode ?? fetchedData?.mode; + if (mode) { + setCurrentMode(normalizeAgentMode(mode)); + } + + const model = sessionConfig?.model ?? fetchedData?.model; + if (model) { + setCurrentModel(model); + } + + const variant = sessionConfig?.variant ?? fetchedData?.variant; + if (variant) { + setCurrentVariant(variant); + } + }, [ + sessionConfig?.mode, + sessionConfig?.model, + sessionConfig?.variant, + fetchedData?.mode, + fetchedData?.model, + fetchedData?.variant, + ]); + + useEffect(() => { + if (currentModel || modelOptions.length === 0 || fetchedData === null) { + return; + } + const firstModel = modelOptions[0]; + if (firstModel) { + setCurrentModel(firstModel.id); + setCurrentVariant(firstModel.variants[0] ?? ''); + } + }, [currentModel, modelOptions, fetchedData]); + + return { + currentMode, + currentModel, + currentVariant, + setCurrentMode, + setCurrentModel, + setCurrentVariant, + }; +} diff --git a/apps/mobile/src/lib/auth/auth-context.tsx b/apps/mobile/src/lib/auth/auth-context.tsx index 5c19c1a13e..b437955ada 100644 --- a/apps/mobile/src/lib/auth/auth-context.tsx +++ b/apps/mobile/src/lib/auth/auth-context.tsx @@ -13,6 +13,7 @@ import { trackEvent } from '@/lib/appsflyer'; import { queryClient } from '@/lib/query-client'; import { AUTH_TOKEN_KEY, + CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY, NOTIFICATION_PROMPT_SEEN_KEY, ORGANIZATION_STORAGE_KEY, SESSION_FILTERS_KEY, @@ -58,6 +59,7 @@ export function AuthProvider({ children }: { readonly children: ReactNode }) { await SecureStore.deleteItemAsync(ORGANIZATION_STORAGE_KEY); await SecureStore.deleteItemAsync(SESSION_FILTERS_KEY); await SecureStore.deleteItemAsync(NOTIFICATION_PROMPT_SEEN_KEY); + await SecureStore.deleteItemAsync(CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY); queryClient.clear(); setToken(undefined); }, []); diff --git a/apps/mobile/src/lib/notifications.ts b/apps/mobile/src/lib/notifications.ts index e524feedae..8b532ce5a3 100644 --- a/apps/mobile/src/lib/notifications.ts +++ b/apps/mobile/src/lib/notifications.ts @@ -4,13 +4,14 @@ import { type Href, router } from 'expo-router'; import { Platform } from 'react-native'; import { z } from 'zod'; +const easConfigSchema = z.object({ projectId: z.string().min(1) }); + function getProjectId(): string { - const eas = expoConstants.expoConfig?.extra?.eas as { projectId?: string } | undefined; - const projectId = eas?.projectId; - if (!projectId) { + const parsed = easConfigSchema.safeParse(expoConstants.expoConfig?.extra?.eas); + if (!parsed.success) { throw new Error('Missing extra.eas.projectId in app config'); } - return projectId; + return parsed.data.projectId; } // Tracks which chat instance screen is currently focused. @@ -19,14 +20,20 @@ function getProjectId(): string { // A module-level variable (not React state) because the notification handler // is registered once and must always read the latest value without stale closures. let activeChatInstanceId: string | null = null; +let activeCliSessionId: string | null = null; export function setActiveChatInstance(instanceId: string | null) { activeChatInstanceId = instanceId; } +export function setActiveCliSession(cliSessionId: string | null) { + activeCliSessionId = cliSessionId; +} + // Keep in sync with the `data` payloads emitted by: // - services/notifications/src/dos/NotificationChannelDO.ts (chat) // - services/notifications/src/lib/notifications-service.ts (instance-lifecycle) +// - services/cloud-agent-next notifications producer (cloud_agent_session) const notificationDataSchema = z.discriminatedUnion('type', [ z.object({ type: z.literal('chat'), @@ -37,6 +44,10 @@ const notificationDataSchema = z.discriminatedUnion('type', [ event: z.enum(['ready', 'start_failed']), instanceId: z.string().min(1), }), + z.object({ + type: z.literal('cloud_agent_session'), + cliSessionId: z.string().min(1), + }), ]); type NotificationData = z.infer; @@ -55,7 +66,7 @@ const shown = { shouldSetBadge: true, shouldShowBanner: true, shouldShowList: true, -} as const; +} satisfies Notifications.NotificationBehavior; const suppressed = { shouldShowAlert: false, @@ -63,7 +74,22 @@ const suppressed = { shouldSetBadge: false, shouldShowBanner: false, shouldShowList: false, -} as const; +} satisfies Notifications.NotificationBehavior; + +function getNotificationPath(data: NotificationData): Href { + if (data.type === 'cloud_agent_session') { + return { + pathname: '/(app)/agent-chat/[session-id]', + params: { 'session-id': data.cliSessionId }, + }; + } + + // chat + instance-lifecycle both deep-link to the same chat route by instanceId. + return { + pathname: '/(app)/chat/[instance-id]', + params: { 'instance-id': data.instanceId }, + }; +} export function setupNotificationHandler() { Notifications.setNotificationHandler({ @@ -71,8 +97,11 @@ export function setupNotificationHandler() { handleNotification: async notification => { const data = parseNotificationData(notification.request.content.data); - // Suppress only if the user is already viewing this exact chat - if (data && data.instanceId === activeChatInstanceId) { + // Suppress only if the user is already viewing this exact chat/session. + if (data?.type === 'chat' && data.instanceId === activeChatInstanceId) { + return suppressed; + } + if (data?.type === 'cloud_agent_session' && data.cliSessionId === activeCliSessionId) { return suppressed; } @@ -83,35 +112,26 @@ export function setupNotificationHandler() { // Pending deep link from a notification tap (cold start or background). // Consumed by the root nav after auth/navigation is ready. -let pendingNotificationLink: string | null = null; +let pendingNotificationLink: Href | null = null; -export function getPendingNotificationLink(): string | null { +export function getPendingNotificationLink(): Href | null { const link = pendingNotificationLink; pendingNotificationLink = null; return link; } -function instanceChatPath(data: NotificationData | null): string | null { - if (!data) { - return null; - } - // Both chat and instance-lifecycle payloads carry `instanceId` and deep-link - // to the same chat route. - return `/(app)/chat/${data.instanceId}`; -} - export function setupNotificationResponseHandler() { const subscription = Notifications.addNotificationResponseReceivedListener(response => { const data = parseNotificationData(response.notification.request.content.data); - const path = instanceChatPath(data); - if (!path) { + if (!data) { return; } - // If the router is ready (has segments), navigate immediately. - // Otherwise store as pending for consumption after auth completes. + const path = getNotificationPath(data); + Notifications.clearLastNotificationResponse(); + // If the router is ready, navigate immediately; otherwise store as pending. try { - router.replace(path as Href); + router.replace(path); } catch { pendingNotificationLink = path; } @@ -127,9 +147,9 @@ export function checkInitialNotification(): void { return; } const data = parseNotificationData(response.notification.request.content.data); - const path = instanceChatPath(data); - if (path) { - pendingNotificationLink = path; + if (data) { + pendingNotificationLink = getNotificationPath(data); + Notifications.clearLastNotificationResponse(); } } @@ -173,5 +193,12 @@ export async function getNotificationPermissionStatus(): Promise< } export function getPlatform(): 'ios' | 'android' { - return Platform.OS as 'ios' | 'android'; + if (Platform.OS === 'ios') { + return 'ios'; + } + if (Platform.OS === 'android') { + return 'android'; + } + + throw new Error('Unsupported platform for push notifications'); } diff --git a/apps/mobile/src/lib/storage-keys.ts b/apps/mobile/src/lib/storage-keys.ts index c24cb82abc..c92c390e37 100644 --- a/apps/mobile/src/lib/storage-keys.ts +++ b/apps/mobile/src/lib/storage-keys.ts @@ -10,3 +10,4 @@ export const ORGANIZATION_STORAGE_KEY = 'selected-organization'; export const SESSION_FILTERS_KEY = 'agent-session-filters'; export const NOTIFICATION_PROMPT_SEEN_KEY = 'notification-prompt-seen'; export const LAST_ACTIVE_INSTANCE_KEY = 'last-active-chat-instance'; +export const CLOUD_AGENT_NOTIFICATION_PROMPT_SEEN_KEY = 'cloud-agent-notification-prompt-seen'; diff --git a/dev/local/services.ts b/dev/local/services.ts index 8150c6fe25..a1a110f515 100644 --- a/dev/local/services.ts +++ b/dev/local/services.ts @@ -21,12 +21,13 @@ const groups: ServiceGroup[] = [ alwaysOn: false, sectionBreakBefore: true, }, - { id: 'kiloclaw', label: 'KiloClaw', alwaysOn: false }, + { id: 'notifications', label: 'Notifications', alwaysOn: false }, + { id: 'kiloclaw', label: 'KiloClaw', alwaysOn: false, groupDependsOn: ['notifications'] }, { id: 'cloud-agent', label: 'Cloud Agent', alwaysOn: false, - groupDependsOn: ['git-token-service'], + groupDependsOn: ['git-token-service', 'notifications'], }, { id: 'code-review', label: 'Code Review', alwaysOn: false, groupDependsOn: ['cloud-agent'] }, { id: 'app-builder', label: 'App Builder', alwaysOn: false, groupDependsOn: ['cloud-agent'] }, @@ -70,7 +71,13 @@ const serviceMeta: Record = { // cloud-agent 'cloud-agent-next': { group: 'cloud-agent', - dependsOn: ['postgres', 'nextjs', 'cloudflare-session-ingest', 'cloudflare-git-token-service'], + dependsOn: [ + 'postgres', + 'nextjs', + 'cloudflare-session-ingest', + 'cloudflare-git-token-service', + 'notifications', + ], dir: 'services/cloud-agent-next', useLanIp: true, }, @@ -137,7 +144,7 @@ const serviceMeta: Record = { 'kiloclaw-stripe': { group: 'kiloclaw', dependsOn: [] }, 'kiloclaw-docker-tcp': { group: 'kiloclaw', dependsOn: [] }, notifications: { - group: 'kiloclaw', + group: 'notifications', dependsOn: ['postgres'], dir: 'services/notifications', }, diff --git a/services/cloud-agent-next/src/notifications-binding.ts b/services/cloud-agent-next/src/notifications-binding.ts new file mode 100644 index 0000000000..b4c986cdc9 --- /dev/null +++ b/services/cloud-agent-next/src/notifications-binding.ts @@ -0,0 +1,29 @@ +/** + * RPC method types for the NOTIFICATIONS service binding. + * + * `wrangler types` only sees `Fetcher` for service bindings; the actual RPC + * shape comes from the notifications worker's WorkerEntrypoint and is declared + * here so the generated file can be freely regenerated. + * + * Keep in sync with: services/notifications/src/lib/notifications-service.ts (NotificationsService). + */ + +import type { CloudAgentPushStatus } from './notifications/types.js'; + +export type SendCloudAgentSessionNotificationParams = { + userId: string; + cliSessionId: string; + status: CloudAgentPushStatus; + body: string; +}; + +export type SendCloudAgentSessionNotificationResult = { + dispatched: boolean; + reason?: 'missing_user' | 'missing_session'; +}; + +export type NotificationsBinding = Fetcher & { + sendCloudAgentSessionNotification( + params: SendCloudAgentSessionNotificationParams + ): Promise; +}; diff --git a/services/cloud-agent-next/src/notifications/producer.test.ts b/services/cloud-agent-next/src/notifications/producer.test.ts new file mode 100644 index 0000000000..d720ccb9a5 --- /dev/null +++ b/services/cloud-agent-next/src/notifications/producer.test.ts @@ -0,0 +1,23 @@ +import { describe, expect, it } from 'vitest'; + +import { buildCloudAgentPushBody, truncatePushSnippet } from './producer.js'; + +describe('push notification body helpers', () => { + it('uses completed snippet or fallback', () => { + expect(buildCloudAgentPushBody('completed', ' Done ')).toBe('Done'); + expect(buildCloudAgentPushBody('completed', undefined)).toBe('Task completed'); + }); + + it('prefixes failed and interrupted bodies with fallbacks', () => { + expect(buildCloudAgentPushBody('failed', 'bad')).toBe('Failed: bad'); + expect(buildCloudAgentPushBody('failed', undefined, 'boom')).toBe('Failed: boom'); + expect(buildCloudAgentPushBody('failed', undefined)).toBe('Failed: Task failed'); + expect(buildCloudAgentPushBody('interrupted', 'stopped')).toBe('Interrupted: stopped'); + expect(buildCloudAgentPushBody('interrupted', undefined)).toBe('Interrupted: Task interrupted'); + }); + + it('truncates and normalizes snippets', () => { + expect(truncatePushSnippet('abcdefghij', 6)).toBe('abc...'); + expect(truncatePushSnippet('hello\n\nworld')).toBe('hello world'); + }); +}); diff --git a/services/cloud-agent-next/src/notifications/producer.ts b/services/cloud-agent-next/src/notifications/producer.ts new file mode 100644 index 0000000000..642e0c6f1a --- /dev/null +++ b/services/cloud-agent-next/src/notifications/producer.ts @@ -0,0 +1,32 @@ +import type { CloudAgentPushStatus } from './types.js'; + +const PUSH_SNIPPET_MAX_LENGTH = 100; +const ELLIPSIS = '...'; + +export function truncatePushSnippet(text: string, maxLength = PUSH_SNIPPET_MAX_LENGTH): string { + const singleLineText = text.trim().replace(/\s+/g, ' '); + if (singleLineText.length <= maxLength) return singleLineText; + if (maxLength <= ELLIPSIS.length) return ELLIPSIS.slice(0, maxLength); + return singleLineText.slice(0, maxLength - ELLIPSIS.length) + ELLIPSIS; +} + +export function buildCloudAgentPushBody( + status: CloudAgentPushStatus, + snippet?: string, + error?: string +): string { + const truncatedSnippet = snippet ? truncatePushSnippet(snippet) : undefined; + + if (status === 'completed') { + return truncatedSnippet ?? 'Task completed'; + } + + if (status === 'failed') { + const detail = + truncatedSnippet ?? (error ? truncatePushSnippet(error) : undefined) ?? 'Task failed'; + return `Failed: ${detail}`; + } + + const detail = truncatedSnippet ?? 'Task interrupted'; + return `Interrupted: ${detail}`; +} diff --git a/services/cloud-agent-next/src/notifications/types.ts b/services/cloud-agent-next/src/notifications/types.ts new file mode 100644 index 0000000000..0c36b98750 --- /dev/null +++ b/services/cloud-agent-next/src/notifications/types.ts @@ -0,0 +1 @@ +export type CloudAgentPushStatus = 'completed' | 'failed' | 'interrupted'; diff --git a/services/cloud-agent-next/src/persistence/CloudAgentSession.ts b/services/cloud-agent-next/src/persistence/CloudAgentSession.ts index 0db98ec77e..ae5f18b1f8 100644 --- a/services/cloud-agent-next/src/persistence/CloudAgentSession.ts +++ b/services/cloud-agent-next/src/persistence/CloudAgentSession.ts @@ -15,6 +15,8 @@ import { } from './schemas.js'; import type { EncryptedSecrets } from '../router/schemas.js'; import type { CallbackJob, CallbackTarget } from '../callbacks/index.js'; +import { buildCloudAgentPushBody } from '../notifications/producer.js'; +import type { CloudAgentPushStatus } from '../notifications/types.js'; import { drizzle } from 'drizzle-orm/durable-sqlite'; import { logger } from '../logger.js'; import { Limits } from '../schema.js'; @@ -46,7 +48,11 @@ import type { UpdateStatusError, SetActiveError, } from '../session/queries/executions.js'; -import { createStreamHandler, type StreamHandler } from '../websocket/stream.js'; +import { + createStreamHandler, + getConnectedStreamClientCount, + type StreamHandler, +} from '../websocket/stream.js'; import { createIngestHandler, type IngestHandler, @@ -152,15 +158,13 @@ export class CloudAgentSession extends DurableObject { private ingestHandlerSessionId?: SessionId; private sessionId?: SessionId; private orchestrator?: ExecutionOrchestrator; - private isTerminalStatus( - status: ExecutionStatus - ): status is 'completed' | 'failed' | 'interrupted' { + private isTerminalStatus(status: ExecutionStatus): status is CloudAgentPushStatus { return status === 'completed' || status === 'failed' || status === 'interrupted'; } private async enqueueCallbackNotification( executionId: ExecutionId, - status: 'completed' | 'failed' | 'interrupted', + status: CloudAgentPushStatus, error?: string, gateResult?: 'pass' | 'fail' ): Promise { @@ -171,13 +175,6 @@ export class CloudAgentSession extends DurableObject { return; } - logger.info('Enqueued callback job', { - cloudAgentSessionId: metadata.sessionId, - kiloSessionId: metadata.kiloSessionId, - executionId, - callbackUrl: metadata.callbackTarget.url, - }); - const resolvedSessionId = await this.resolveSessionId(metadata.sessionId as SessionId); const sessionId = resolvedSessionId ?? metadata.sessionId ?? ''; @@ -199,8 +196,9 @@ export class CloudAgentSession extends DurableObject { }, }; - // Fire-and-forget enqueue - don't block execution completion - callbackQueue.send(callbackJob).catch(err => { + try { + await callbackQueue.send(callbackJob); + } catch (err) { logger .withFields({ sessionId, @@ -208,7 +206,49 @@ export class CloudAgentSession extends DurableObject { error: err instanceof Error ? err.message : String(err), }) .error('Failed to enqueue callback job'); - }); + } + } + + private async dispatchPushNotification( + status: CloudAgentPushStatus, + error?: string + ): Promise { + if (getConnectedStreamClientCount(this.ctx) > 0) { + return; + } + + const notifications = this.env.NOTIFICATIONS; + + const metadata = await this.getMetadata(); + const cliSessionId = metadata?.kiloSessionId; + if (!metadata?.userId || !cliSessionId) { + return; + } + + const snippet = status === 'completed' ? await this.getLatestAssistantMessageText() : undefined; + const body = buildCloudAgentPushBody(status, snippet, error); + + try { + const result = await notifications.sendCloudAgentSessionNotification({ + userId: metadata.userId, + cliSessionId, + status, + body, + }); + if (!result.dispatched) { + logger + .withFields({ cliSessionId, status, reason: result.reason }) + .warn('Cloud-agent push notification skipped by notifications service'); + } + } catch (err) { + logger + .withFields({ + cliSessionId, + status, + error: err instanceof Error ? err.message : String(err), + }) + .error('Failed to dispatch cloud-agent push notification'); + } } constructor(ctx: DurableObjectState, env: WorkerEnv) { @@ -326,7 +366,7 @@ export class CloudAgentSession extends DurableObject { }, updateExecutionStatus: async ( executionId: string, - status: 'completed' | 'failed' | 'interrupted', + status: ExecutionStatus, error?: string, gateResult?: 'pass' | 'fail' ) => { @@ -589,7 +629,7 @@ export class CloudAgentSession extends DurableObject { * @returns Number of active WebSocket connections */ getConnectedClientCount(): number { - return this.streamHandler?.getConnectedClientCount() ?? 0; + return getConnectedStreamClientCount(this.ctx); } // --------------------------------------------------------------------------- @@ -620,7 +660,7 @@ export class CloudAgentSession extends DurableObject { } catch (err) { logger .withFields({ error: err instanceof Error ? err.message : String(err) }) - .warn('Failed to fetch latest assistant message for callback'); + .warn('Failed to fetch latest assistant message text'); return undefined; } } @@ -1694,8 +1734,6 @@ export class CloudAgentSession extends DurableObject { .withFields({ sessionId: this.sessionId, sandboxId, rpcElapsedMs: Date.now() - rpcStart }) .debug('stopKiloServer RPC completed'); - // Clear the activity timestamp since server is stopped - // Must merge with existing metadata since updateMetadata validates the full schema const updated = { ...metadata, kiloServerLastActivity: undefined, @@ -1707,7 +1745,6 @@ export class CloudAgentSession extends DurableObject { .withFields({ sessionId: this.sessionId, sandboxId }) .info('Idle kilo server stopped successfully'); } catch (error) { - // Log but don't fail - server may already be stopped or sandbox recycled logger .withFields({ sessionId: this.sessionId, @@ -1722,12 +1759,9 @@ export class CloudAgentSession extends DurableObject { * active execution. * * The wrapper heartbeat travels over an outbound WebSocket that bypasses - * `containerFetch()`, so it never calls `renewActivityTimeout()`. Calling + * `containerFetch()`, so it never calls `renewActivityTimeout()`. Calling * `setSleepAfter()` with the same value is a lightweight RPC that resets * the timer without changing the configuration. - * - * Called from the DO context's `updateHeartbeat` callback (debounced - * to every 30 s by the ingest handler) while an execution is running. */ private async keepContainerAlive(): Promise { try { @@ -1774,21 +1808,30 @@ export class CloudAgentSession extends DurableObject { * When `suppressCallback` is true the status is persisted but no callback * notification is enqueued. Used on the followup path where the caller * (orchestrator) handles the error synchronously and enqueuing a callback - * would race with a fallback session's callbacks. + * would race with a fallback session's callbacks. When `suppressPush` is + * true, no terminal push notification is enqueued. The two flags are + * independent: callers that want both suppressed must pass both. */ async updateExecutionStatus( params: UpdateExecutionStatusParams, - opts?: { suppressCallback?: boolean } + opts?: { suppressCallback?: boolean; suppressPush?: boolean } ): Promise> { const result = await this.executionQueries.updateStatus(params); - if (result.ok && this.isTerminalStatus(params.status) && !opts?.suppressCallback) { - await this.enqueueCallbackNotification( - params.executionId, - params.status, - params.error, - params.gateResult - ); + if (result.ok && this.isTerminalStatus(params.status)) { + // Enqueue notifications synchronously so callers (and tests) observe + // completed sends. Both helpers catch and log send failures internally. + if (!opts?.suppressCallback) { + await this.enqueueCallbackNotification( + params.executionId, + params.status, + params.error, + params.gateResult + ); + } + if (!opts?.suppressPush) { + await this.dispatchPushNotification(params.status, params.error); + } } return result; @@ -1924,7 +1967,7 @@ export class CloudAgentSession extends DurableObject { // decide whether to clean up the interrupt flag afterward. const wasActive = (await this.executionQueries.getActiveExecutionId()) === executionId; - // 1. Update status (enqueues callback notification on terminal unless suppressed) + // 1. Update status (enqueues callback/push notifications on terminal unless suppressed) const statusResult = await this.updateExecutionStatus( { executionId, @@ -1932,7 +1975,7 @@ export class CloudAgentSession extends DurableObject { error, completedAt: Date.now(), }, - { suppressCallback: params.suppressCallback } + { suppressCallback: params.suppressCallback, suppressPush: params.suppressCallback } ); if (!statusResult.ok) { @@ -2811,7 +2854,7 @@ export class CloudAgentSession extends DurableObject { */ async onExecutionComplete( executionId: ExecutionId, - status: 'completed' | 'failed' | 'interrupted', + status: CloudAgentPushStatus, error?: string ): Promise { const sessionId = await this.resolveSessionId(); diff --git a/services/cloud-agent-next/src/persistence/types.ts b/services/cloud-agent-next/src/persistence/types.ts index 550c867148..4bcb7c7c48 100644 --- a/services/cloud-agent-next/src/persistence/types.ts +++ b/services/cloud-agent-next/src/persistence/types.ts @@ -3,6 +3,7 @@ import type { Sandbox } from '@cloudflare/sandbox'; import type { CloudAgentSession } from './CloudAgentSession.js'; import type { EncryptedSecrets } from '../router/schemas.js'; import type { CallbackTarget } from '../callbacks/index.js'; +import type { NotificationsBinding } from '../notifications-binding.js'; import type { Images } from './schemas.js'; import type { SessionIngestBinding } from '../session-ingest-binding.js'; @@ -187,6 +188,9 @@ export type PersistenceEnv = { /** URL for session ingest service, injected into sandbox session env vars */ KILO_SESSION_INGEST_URL?: string; + /** Service binding for dispatching push notifications */ + NOTIFICATIONS: NotificationsBinding; + /** Shared secret for internal service-to-service authentication */ INTERNAL_API_SECRET_PROD: SecretsStoreSecret; diff --git a/services/cloud-agent-next/src/router.test.ts b/services/cloud-agent-next/src/router.test.ts index 1f61519712..ec018093e6 100644 --- a/services/cloud-agent-next/src/router.test.ts +++ b/services/cloud-agent-next/src/router.test.ts @@ -311,6 +311,7 @@ describe('router sessionId validation', () => { INTERNAL_API_SECRET_PROD: { get: vi.fn().mockResolvedValue('test-secret'), } as unknown as TRPCContext['env']['INTERNAL_API_SECRET_PROD'], + NOTIFICATIONS: {} as TRPCContext['env']['NOTIFICATIONS'], }, }; cloudAgentSession = mockContext.env.CLOUD_AGENT_SESSION as unknown as MockCAS; @@ -682,6 +683,7 @@ describe('router sessionId validation', () => { INTERNAL_API_SECRET_PROD: { get: vi.fn().mockResolvedValue('test-secret'), } as unknown as TRPCContext['env']['INTERNAL_API_SECRET_PROD'], + NOTIFICATIONS: {} as TRPCContext['env']['NOTIFICATIONS'], }, }; cloudAgentSession = mockContext.env.CLOUD_AGENT_SESSION as unknown as MockCAS; @@ -936,6 +938,7 @@ describe('router sessionId validation', () => { INTERNAL_API_SECRET_PROD: { get: vi.fn().mockResolvedValue('test-secret'), } as unknown as TRPCContext['env']['INTERNAL_API_SECRET_PROD'], + NOTIFICATIONS: {} as TRPCContext['env']['NOTIFICATIONS'], }, }; cloudAgentSession = mockContext.env.CLOUD_AGENT_SESSION as unknown as MockCAS; diff --git a/services/cloud-agent-next/src/session-service.test.ts b/services/cloud-agent-next/src/session-service.test.ts index aaf5d48813..8b610b26e8 100644 --- a/services/cloud-agent-next/src/session-service.test.ts +++ b/services/cloud-agent-next/src/session-service.test.ts @@ -108,6 +108,7 @@ describe('SessionService', () => { INTERNAL_API_SECRET_PROD: { get: vi.fn().mockResolvedValue('test-secret'), } as unknown as PersistenceEnv['INTERNAL_API_SECRET_PROD'], + NOTIFICATIONS: {} as unknown as PersistenceEnv['NOTIFICATIONS'], }; const createMetadataEnv = ( diff --git a/services/cloud-agent-next/src/session/ingest-handlers/execution-lifecycle.ts b/services/cloud-agent-next/src/session/ingest-handlers/execution-lifecycle.ts index 6a54afc02e..a62715d91c 100644 --- a/services/cloud-agent-next/src/session/ingest-handlers/execution-lifecycle.ts +++ b/services/cloud-agent-next/src/session/ingest-handlers/execution-lifecycle.ts @@ -1,4 +1,4 @@ -export type ExecutionStatus = 'completed' | 'failed' | 'interrupted'; +import type { ExecutionStatus } from '../../core/execution.js'; export type ExecutionLifecycleContext = { updateExecutionStatus: ( diff --git a/services/cloud-agent-next/src/session/ingest-handlers/index.ts b/services/cloud-agent-next/src/session/ingest-handlers/index.ts index 2d96ec272d..b67e3b1f5f 100644 --- a/services/cloud-agent-next/src/session/ingest-handlers/index.ts +++ b/services/cloud-agent-next/src/session/ingest-handlers/index.ts @@ -4,9 +4,5 @@ export { type KiloSessionCaptureState, } from './kilo-session-capture.js'; export { handleBranchCapture, type BranchCaptureContext } from './branch-capture.js'; -export { - handleExecutionComplete, - type ExecutionLifecycleContext, - type ExecutionStatus, -} from './execution-lifecycle.js'; +export { handleExecutionComplete, type ExecutionLifecycleContext } from './execution-lifecycle.js'; export { extractEntityId } from './entity-id.js'; diff --git a/services/cloud-agent-next/src/types.ts b/services/cloud-agent-next/src/types.ts index a17fa35b5d..c1e2dbc316 100644 --- a/services/cloud-agent-next/src/types.ts +++ b/services/cloud-agent-next/src/types.ts @@ -1,6 +1,7 @@ import type { getSandbox, ExecutionSession, Sandbox } from '@cloudflare/sandbox'; import type { CloudAgentSession } from './persistence/CloudAgentSession.js'; import type { CallbackJob } from './callbacks/index.js'; +import type { NotificationsBinding } from './notifications-binding.js'; import type { SessionIngestBinding } from './session-ingest-binding.js'; import * as z from 'zod'; import { Limits } from './schema.js'; @@ -140,6 +141,8 @@ export type Env = { CALLBACK_QUEUE?: Queue; /** Service binding for centralized git token generation */ GIT_TOKEN_SERVICE: GitTokenService; + /** Service binding for dispatching push notifications */ + NOTIFICATIONS: NotificationsBinding; /** GitHub Lite App slug for git commit attribution (e.g., 'kiloconnect-lite') */ GITHUB_LITE_APP_SLUG?: string; /** GitHub Lite App bot user ID for git commit email */ diff --git a/services/cloud-agent-next/src/websocket/ingest.test.ts b/services/cloud-agent-next/src/websocket/ingest.test.ts index a6e6b19692..1c4182370c 100644 --- a/services/cloud-agent-next/src/websocket/ingest.test.ts +++ b/services/cloud-agent-next/src/websocket/ingest.test.ts @@ -142,12 +142,13 @@ describe('createIngestHandler', () => { const eventQueries = createFakeEventQueries(); (eventQueries as unknown as Record).upsert = vi.fn().mockReturnValue(42); const broadcastFn = vi.fn(); + const doContext = createFakeDOContext(); const handler = createIngestHandler( createFakeState(), eventQueries, SESSION_ID, broadcastFn, - createFakeDOContext() + doContext ); const ws = createFakeWebSocket(makeAttachment()); @@ -176,12 +177,13 @@ describe('createIngestHandler', () => { ])('kilocode %s is plain-inserted', async eventName => { const eventQueries = createFakeEventQueries(); const broadcastFn = vi.fn(); + const doContext = createFakeDOContext(); const handler = createIngestHandler( createFakeState(), eventQueries, SESSION_ID, broadcastFn, - createFakeDOContext() + doContext ); const ws = createFakeWebSocket(makeAttachment()); @@ -204,12 +206,13 @@ describe('createIngestHandler', () => { ])('kilocode %s is broadcast-only', async eventName => { const eventQueries = createFakeEventQueries(); const broadcastFn = vi.fn(); + const doContext = createFakeDOContext(); const handler = createIngestHandler( createFakeState(), eventQueries, SESSION_ID, broadcastFn, - createFakeDOContext() + doContext ); const ws = createFakeWebSocket(makeAttachment()); diff --git a/services/cloud-agent-next/src/websocket/ingest.ts b/services/cloud-agent-next/src/websocket/ingest.ts index d7bfff422b..660ddc0bc7 100644 --- a/services/cloud-agent-next/src/websocket/ingest.ts +++ b/services/cloud-agent-next/src/websocket/ingest.ts @@ -26,6 +26,7 @@ import { type KiloSessionCaptureState, } from '../session/ingest-handlers/index.js'; import type { CompleteEventData, KilocodeEventData, CloudStatusData } from '../shared/protocol.js'; +import type { ExecutionStatus } from '../core/execution.js'; // --------------------------------------------------------------------------- // Ingest Attachment @@ -58,6 +59,11 @@ const errorEventSchema = z.object({ message: z.string().optional(), }); +function getObject(value: unknown): Record | undefined { + if (typeof value !== 'object' || value === null || Array.isArray(value)) return undefined; + return Object.fromEntries(Object.entries(value)); +} + // --------------------------------------------------------------------------- // Persistence Allowlists // --------------------------------------------------------------------------- @@ -96,7 +102,7 @@ const PERSISTED_KILO_EVENT_NAMES: ReadonlySet = new Set([ const createExecutionLifecycleContext = (doContext: IngestDOContext) => ({ updateExecutionStatus: ( id: string, - status: 'completed' | 'failed' | 'interrupted', + status: ExecutionStatus, err?: string, gateResult?: 'pass' | 'fail' ) => doContext.updateExecutionStatus(id, status, err, gateResult), @@ -168,7 +174,7 @@ export type IngestDOContext = { /** Update execution status when complete/failed/interrupted */ updateExecutionStatus: ( executionId: string, - status: 'completed' | 'failed' | 'interrupted', + status: ExecutionStatus, error?: string, gateResult?: 'pass' | 'fail' ) => Promise; @@ -353,30 +359,33 @@ export function createIngestHandler( // Only events in the allowlists are written to SQLite; // everything else is broadcast to /stream clients with eventId 0. if (eventType === 'kilocode') { - const kiloEventName = (ingestEvent.data as Record | undefined)?.event as - | string - | undefined; - const data = ingestEvent.data as Record; - const entityId = extractEntityId(kiloEventName ?? '', data); - if (entityId) { - eventId = eventQueries.upsert({ - executionId, - sessionId, - streamEventType: eventType, - payload, - timestamp, - entityId, - }); - } else if (kiloEventName && PERSISTED_KILO_EVENT_NAMES.has(kiloEventName)) { - eventId = eventQueries.insert({ - executionId, - sessionId, - streamEventType: eventType, - payload, - timestamp, - }); - } else { + const data = getObject(ingestEvent.data); + if (!data) { eventId = 0; + } else { + const eventName = data.event; + const kiloEventName = typeof eventName === 'string' ? eventName : undefined; + const entityId = extractEntityId(kiloEventName ?? '', data); + if (entityId) { + eventId = eventQueries.upsert({ + executionId, + sessionId, + streamEventType: eventType, + payload, + timestamp, + entityId, + }); + } else if (kiloEventName && PERSISTED_KILO_EVENT_NAMES.has(kiloEventName)) { + eventId = eventQueries.insert({ + executionId, + sessionId, + streamEventType: eventType, + payload, + timestamp, + }); + } else { + eventId = 0; + } } } else if (PERSISTED_STREAM_EVENT_TYPES.has(eventType)) { eventId = eventQueries.insert({ diff --git a/services/cloud-agent-next/src/websocket/stream.ts b/services/cloud-agent-next/src/websocket/stream.ts index ee16eb904c..3fdf770a79 100644 --- a/services/cloud-agent-next/src/websocket/stream.ts +++ b/services/cloud-agent-next/src/websocket/stream.ts @@ -90,6 +90,16 @@ export type StreamHandlerOptions = { * @param options - Optional derivation functions for the `connected` event * @returns Stream handler object with methods for WebSocket operations */ +/** + * Number of active /stream WebSocket connections. + * + * Stateless so callers can check the count without instantiating a + * StreamHandler (and without knowing the internal 'stream' tag). + */ +export function getConnectedStreamClientCount(state: DurableObjectState): number { + return state.getWebSockets('stream').length; +} + export function createStreamHandler( state: DurableObjectState, eventQueries: EventQueries, @@ -253,15 +263,6 @@ export function createStreamHandler( } } }, - - /** - * Get count of connected stream clients. - * - * @returns Number of active WebSocket connections with 'stream' tag - */ - getConnectedClientCount(): number { - return state.getWebSockets('stream').length; - }, }; } diff --git a/services/cloud-agent-next/test/env.d.ts b/services/cloud-agent-next/test/env.d.ts index 3a48a709b0..1b7176ee93 100644 --- a/services/cloud-agent-next/test/env.d.ts +++ b/services/cloud-agent-next/test/env.d.ts @@ -3,8 +3,13 @@ import type { Env } from '../src/types'; +type TestWorkerSelf = { + fetch(input: RequestInfo | URL, init?: RequestInit): Promise; +}; + declare module 'cloudflare:test' { // ProvidedEnv extends your worker's Env interface // This gives you typed access to bindings like env.CLOUD_AGENT_SESSION interface ProvidedEnv extends Env {} + export const SELF: TestWorkerSelf; } diff --git a/services/cloud-agent-next/test/integration/session/push-notifications.test.ts b/services/cloud-agent-next/test/integration/session/push-notifications.test.ts new file mode 100644 index 0000000000..064e9c7d12 --- /dev/null +++ b/services/cloud-agent-next/test/integration/session/push-notifications.test.ts @@ -0,0 +1,271 @@ +import { env, SELF, runInDurableObject } from 'cloudflare:test'; +import { describe, expect, it } from 'vitest'; + +import type { CloudAgentSessionState } from '../../../src/persistence/types.js'; + +const KILO_SESSION_ID = 'ses-root'; + +function createMetadata( + sessionId: string, + overrides: Partial = {} +): CloudAgentSessionState { + return { + version: 1, + sessionId, + userId: 'user_push', + timestamp: Date.now(), + kiloSessionId: KILO_SESSION_ID, + ...overrides, + }; +} + +async function createSession() { + const sessionId = `agent_${crypto.randomUUID()}`; + const id = env.CLOUD_AGENT_SESSION.idFromName(`user_push:${sessionId}`); + return { sessionId, stub: env.CLOUD_AGENT_SESSION.get(id) }; +} + +async function getNotificationJobs(): Promise { + const response = await SELF.fetch('http://test/test/notification-jobs'); + return response.json(); +} + +describe('CloudAgentSession push notification producer', () => { + it('enqueues a terminal push job with the last assistant text', async () => { + const clearResponse = await SELF.fetch('http://test/test/notification-jobs', { + method: 'DELETE', + }); + expect(clearResponse.ok).toBe(true); + const { sessionId, stub } = await createSession(); + + await runInDurableObject(stub, async instance => { + await instance.updateMetadata(createMetadata(sessionId)); + await instance.addExecution({ + executionId: 'exc_push_1', + mode: 'code', + streamingMode: 'websocket', + ingestToken: 'exc_push_1', + }); + + await instance.fetch( + new Request('http://do/ingest?executionId=exc_push_1', { + headers: { Upgrade: 'websocket' }, + }) + ); + + const ingestSocket = instance.ctx.getWebSockets('ingest:exc_push_1')[0]; + if (!ingestSocket) throw new Error('missing ingest socket'); + + await instance.webSocketMessage( + ingestSocket, + JSON.stringify({ + streamEventType: 'kilocode', + timestamp: new Date().toISOString(), + data: { + event: 'session.created', + properties: { info: { id: KILO_SESSION_ID } }, + }, + }) + ); + await instance.webSocketMessage( + ingestSocket, + JSON.stringify({ + streamEventType: 'kilocode', + timestamp: new Date().toISOString(), + data: { + event: 'message.updated', + properties: { + info: { id: 'msg-1', sessionID: KILO_SESSION_ID, role: 'assistant' }, + }, + }, + }) + ); + await instance.webSocketMessage( + ingestSocket, + JSON.stringify({ + streamEventType: 'kilocode', + timestamp: new Date().toISOString(), + data: { + event: 'message.part.updated', + properties: { + part: { + id: 'part-1', + messageID: 'msg-1', + sessionID: KILO_SESSION_ID, + type: 'text', + text: 'Assistant finished the task.', + }, + }, + }, + }) + ); + + await instance.updateExecutionStatus({ + executionId: 'exc_push_1', + status: 'completed', + }); + }); + + await expect + .poll(() => getNotificationJobs()) + .toEqual([ + { + userId: 'user_push', + cliSessionId: KILO_SESSION_ID, + status: 'completed', + body: 'Assistant finished the task.', + }, + ]); + }); + + it('captures assistant text from messages with mixed tool and text parts', async () => { + const clearResponse = await SELF.fetch('http://test/test/notification-jobs', { + method: 'DELETE', + }); + expect(clearResponse.ok).toBe(true); + const { sessionId, stub } = await createSession(); + + await runInDurableObject(stub, async instance => { + await instance.updateMetadata(createMetadata(sessionId)); + await instance.addExecution({ + executionId: 'exc_push_mixed', + mode: 'code', + streamingMode: 'websocket', + ingestToken: 'exc_push_mixed', + }); + + await instance.fetch( + new Request('http://do/ingest?executionId=exc_push_mixed', { + headers: { Upgrade: 'websocket' }, + }) + ); + + const ingestSocket = instance.ctx.getWebSockets('ingest:exc_push_mixed')[0]; + if (!ingestSocket) throw new Error('missing ingest socket'); + + await instance.webSocketMessage( + ingestSocket, + JSON.stringify({ + streamEventType: 'kilocode', + timestamp: new Date().toISOString(), + data: { + event: 'message.updated', + properties: { + info: { id: 'msg-mixed', sessionID: KILO_SESSION_ID, role: 'assistant' }, + }, + }, + }) + ); + // Tool part arrives BEFORE the text part. The previous tracker-based + // implementation would forget the message once a non-text part appeared + // and drop all later text parts for that message. + await instance.webSocketMessage( + ingestSocket, + JSON.stringify({ + streamEventType: 'kilocode', + timestamp: new Date().toISOString(), + data: { + event: 'message.part.updated', + properties: { + part: { + id: 'part-tool', + messageID: 'msg-mixed', + sessionID: KILO_SESSION_ID, + type: 'tool', + }, + }, + }, + }) + ); + await instance.webSocketMessage( + ingestSocket, + JSON.stringify({ + streamEventType: 'kilocode', + timestamp: new Date().toISOString(), + data: { + event: 'message.part.updated', + properties: { + part: { + id: 'part-text', + messageID: 'msg-mixed', + sessionID: KILO_SESSION_ID, + type: 'text', + text: 'Final answer after tool use.', + }, + }, + }, + }) + ); + + await instance.updateExecutionStatus({ + executionId: 'exc_push_mixed', + status: 'completed', + }); + }); + + await expect + .poll(() => getNotificationJobs()) + .toEqual([ + { + userId: 'user_push', + cliSessionId: KILO_SESSION_ID, + status: 'completed', + body: 'Final answer after tool use.', + }, + ]); + }); + + it('suppresses push jobs while a stream client is connected', async () => { + const clearResponse = await SELF.fetch('http://test/test/notification-jobs', { + method: 'DELETE', + }); + expect(clearResponse.ok).toBe(true); + const { sessionId, stub } = await createSession(); + + await runInDurableObject(stub, async instance => { + await instance.updateMetadata(createMetadata(sessionId)); + await instance.addExecution({ + executionId: 'exc_push_2', + mode: 'code', + streamingMode: 'websocket', + }); + const pair = new WebSocketPair(); + instance.ctx.acceptWebSocket(pair[1], ['stream']); + + await instance.updateExecutionStatus({ + executionId: 'exc_push_2', + status: 'completed', + }); + }); + + await expect.poll(() => getNotificationJobs()).toEqual([]); + }); + + it('suppresses push jobs when suppressPush is set', async () => { + const clearResponse = await SELF.fetch('http://test/test/notification-jobs', { + method: 'DELETE', + }); + expect(clearResponse.ok).toBe(true); + const { sessionId, stub } = await createSession(); + + await runInDurableObject(stub, async instance => { + await instance.updateMetadata(createMetadata(sessionId)); + await instance.addExecution({ + executionId: 'exc_push_3', + mode: 'code', + streamingMode: 'websocket', + }); + + await instance.updateExecutionStatus( + { + executionId: 'exc_push_3', + status: 'failed', + error: 'suppressed internal failure', + }, + { suppressCallback: true, suppressPush: true } + ); + }); + + await expect.poll(() => getNotificationJobs()).toEqual([]); + }); +}); diff --git a/services/cloud-agent-next/test/test-worker.ts b/services/cloud-agent-next/test/test-worker.ts index a7944fa78f..a719960093 100644 --- a/services/cloud-agent-next/test/test-worker.ts +++ b/services/cloud-agent-next/test/test-worker.ts @@ -9,10 +9,49 @@ * to avoid the @cloudflare/sandbox import chain. */ -import type { CloudAgentSession } from '../src/persistence/CloudAgentSession.js'; +import { CloudAgentSession as RealCloudAgentSession } from '../src/persistence/CloudAgentSession'; +import type { + NotificationsBinding, + SendCloudAgentSessionNotificationParams, + SendCloudAgentSessionNotificationResult, +} from '../src/notifications-binding.js'; -// Re-export CloudAgentSession for DO binding -export { CloudAgentSession } from '../src/persistence/CloudAgentSession'; +type RecordedPushCall = SendCloudAgentSessionNotificationParams; + +const recordedNotificationCalls: RecordedPushCall[] = []; + +// In the Workers test runtime, we don't want to actually provision the real +// notifications service binding. Swap it with an in-memory stub that records +// every RPC call so integration tests can assert on dispatches. +function createNotificationsStub(): NotificationsBinding { + const noopFetcher: Fetcher = { + // Minimal Fetcher surface — tests never invoke fetch() on this stub. + fetch: () => Promise.resolve(new Response('', { status: 501 })), + connect: () => { + throw new Error('connect not implemented on test notifications stub'); + }, + } as Fetcher; + + return { + ...noopFetcher, + async sendCloudAgentSessionNotification( + params: SendCloudAgentSessionNotificationParams + ): Promise { + recordedNotificationCalls.push(params); + return { dispatched: true }; + }, + } satisfies NotificationsBinding; +} + +const notificationsStub = createNotificationsStub(); + +// Re-export CloudAgentSession with the service binding replaced by the stub +// so tests observe push dispatches without requiring the real Worker. +export class CloudAgentSession extends RealCloudAgentSession { + constructor(ctx: DurableObjectState, env: Env) { + super(ctx, { ...env, NOTIFICATIONS: notificationsStub }); + } +} // Minimal Env type for tests type TestEnv = { @@ -43,6 +82,14 @@ export default { return stub.fetch(request); } + if (url.pathname === '/test/notification-jobs') { + if (request.method === 'DELETE') { + recordedNotificationCalls.length = 0; + return Response.json({ ok: true }); + } + return Response.json([...recordedNotificationCalls]); + } + return new Response('Not Found', { status: 404 }); }, }; diff --git a/services/cloud-agent-next/wrangler.jsonc b/services/cloud-agent-next/wrangler.jsonc index 6ca6a91568..1191453b8a 100644 --- a/services/cloud-agent-next/wrangler.jsonc +++ b/services/cloud-agent-next/wrangler.jsonc @@ -90,6 +90,11 @@ "service": "git-token-service", "entrypoint": "GitTokenRPCEntrypoint", }, + { + "binding": "NOTIFICATIONS", + "service": "notifications", + "entrypoint": "NotificationsService", + }, ], "secrets_store_secrets": [ { @@ -260,6 +265,11 @@ "service": "git-token-service-dev", "entrypoint": "GitTokenRPCEntrypoint", }, + { + "binding": "NOTIFICATIONS", + "service": "notifications-dev", + "entrypoint": "NotificationsService", + }, ], "secrets_store_secrets": [ { diff --git a/services/notifications/package.json b/services/notifications/package.json index 952f310a92..6b9fc135f3 100644 --- a/services/notifications/package.json +++ b/services/notifications/package.json @@ -4,9 +4,9 @@ "private": true, "scripts": { "deploy": "wrangler deploy", - "dev": "wrangler dev", - "start": "wrangler dev", - "test": "vitest", + "dev": "wrangler dev --env dev", + "start": "wrangler dev --env dev", + "test": "vitest --config vitest.config.ts", "cf-typegen": "wrangler types", "typecheck": "tsgo --noEmit", "lint": "pnpm -w exec oxlint --config .oxlintrc.json services/notifications/src" diff --git a/services/notifications/src/dos/NotificationChannelDO.ts b/services/notifications/src/dos/NotificationChannelDO.ts index a58bfda8bc..7f5e2e94a4 100644 --- a/services/notifications/src/dos/NotificationChannelDO.ts +++ b/services/notifications/src/dos/NotificationChannelDO.ts @@ -1,15 +1,10 @@ import { DurableObject } from 'cloudflare:workers'; import { getWorkerDb } from '@kilocode/db/client'; -import { channel_badge_counts, kiloclaw_instances, user_push_tokens } from '@kilocode/db/schema'; -import { and, eq, inArray, isNull, sql, sum } from 'drizzle-orm'; +import { kiloclaw_instances } from '@kilocode/db/schema'; +import { and, eq, isNull } from 'drizzle-orm'; import type { Event } from 'stream-chat'; -import type { ExpoPushMessage, TicketTokenPair } from '../lib/expo-push'; -import { sendPushNotifications } from '../lib/expo-push'; - -type ReceiptCheckMessage = { - ticketTokenPairs: TicketTokenPair[]; -}; +import { sendChannelPush } from '../lib/channel-push'; type PendingMessage = { messageId: string; @@ -135,58 +130,16 @@ export class NotificationChannelDO extends DurableObject { return; } - // Increment the badge count for this channel and return the new total across all channels. - // Done before the token guard so unread state is always persisted even if the user - // temporarily has no registered push tokens (e.g. between reinstalls). - // Uses UPSERT so the row is created on first notification for this channel. - await db - .insert(channel_badge_counts) - .values({ user_id: instance.user_id, channel_id: sandboxId, badge_count: 1 }) - .onConflictDoUpdate({ - target: [channel_badge_counts.user_id, channel_badge_counts.channel_id], - set: { badge_count: sql`${channel_badge_counts.badge_count} + 1` }, - }); - - const [totals] = await db - .select({ total: sum(channel_badge_counts.badge_count) }) - .from(channel_badge_counts) - .where(eq(channel_badge_counts.user_id, instance.user_id)); - - const badgeCount = Number(totals?.total ?? 0); - - const tokens = await db - .select({ token: user_push_tokens.token }) - .from(user_push_tokens) - .where(eq(user_push_tokens.user_id, instance.user_id)); - - if (tokens.length === 0) { - return; - } - const truncatedMessage = msg.text.length > 100 ? msg.text.slice(0, 97) + '...' : msg.text; - const messages: ExpoPushMessage[] = tokens.map(({ token }) => ({ - to: token, + await sendChannelPush({ + env: this.env, + userId: instance.user_id, + channelId: sandboxId, title: instance.name ?? 'KiloClaw', body: truncatedMessage, - // Keep in sync with NotificationData in apps/mobile/src/lib/notifications.ts data: { type: 'chat', instanceId: sandboxId }, - badge: badgeCount, - sound: 'default' as const, - priority: 'high' as const, - })); - - const accessToken = await this.env.EXPO_ACCESS_TOKEN.get(); - const { ticketTokenPairs, staleTokens } = await sendPushNotifications(messages, accessToken); - - if (staleTokens.length > 0) { - await db.delete(user_push_tokens).where(inArray(user_push_tokens.token, staleTokens)); - } - - if (ticketTokenPairs.length > 0) { - const receiptMsg: ReceiptCheckMessage = { ticketTokenPairs }; - await this.env.RECEIPTS_QUEUE.send(receiptMsg, { delaySeconds: 900 }); - } + }); } private async markWebhookSeen(webhookId: string): Promise { @@ -199,10 +152,7 @@ export class NotificationChannelDO extends DurableObject { } } -export function getNotificationChannelDO( - env: Env, - channelId: string -): DurableObjectStub { +export function getNotificationChannelDO(env: Env, channelId: string) { const id = env.NOTIFICATION_CHANNEL_DO.idFromName(channelId); - return env.NOTIFICATION_CHANNEL_DO.get(id) as DurableObjectStub; + return env.NOTIFICATION_CHANNEL_DO.get(id); } diff --git a/services/notifications/src/lib/channel-push.test.ts b/services/notifications/src/lib/channel-push.test.ts new file mode 100644 index 0000000000..a9409f87e8 --- /dev/null +++ b/services/notifications/src/lib/channel-push.test.ts @@ -0,0 +1,245 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +const { mockGetWorkerDb, mockSendPushNotifications } = vi.hoisted(() => ({ + mockGetWorkerDb: vi.fn(), + mockSendPushNotifications: vi.fn(), +})); + +vi.mock('@kilocode/db/client', () => ({ + getWorkerDb: mockGetWorkerDb, +})); + +vi.mock('./expo-push', () => ({ + sendPushNotifications: mockSendPushNotifications, +})); + +import { sendChannelPush } from './channel-push'; + +type TestMock = ReturnType; + +type QueryBuilder = { + values: TestMock; + onConflictDoUpdate: TestMock; + from: TestMock; + where: TestMock; +}; + +type DbMock = ReturnType['db']; + +function useDbMock(db: DbMock): void { + mockGetWorkerDb.mockImplementation(() => db); +} + +type TestEnv = Env & { + EXPO_ACCESS_TOKEN: SecretsStoreSecret & { get: TestMock }; + RECEIPTS_QUEUE: Queue & { + send: TestMock; + }; +}; + +function createUnusedSocket(): Socket { + return { + readable: new ReadableStream(), + writable: new WritableStream(), + closed: Promise.resolve(), + opened: new Promise(() => undefined), + upgraded: false, + secureTransport: 'off', + close: vi.fn(async () => undefined), + startTls() { + return this; + }, + }; +} + +function createHyperdrive(): Hyperdrive { + return { + connectionString: 'postgres://test', + connect() { + return createUnusedSocket(); + }, + host: 'localhost', + port: 5432, + user: 'postgres', + password: 'postgres', + database: 'postgres', + }; +} + +function createNotificationChannelNamespace(): Env['NOTIFICATION_CHANNEL_DO'] { + return { + newUniqueId() { + throw new Error('unused'); + }, + idFromName() { + throw new Error('unused'); + }, + idFromString() { + throw new Error('unused'); + }, + get() { + throw new Error('unused'); + }, + getByName() { + throw new Error('unused'); + }, + jurisdiction() { + return createNotificationChannelNamespace(); + }, + }; +} + +function createEnv(): TestEnv { + return { + HYPERDRIVE: createHyperdrive(), + EXPO_ACCESS_TOKEN: { get: vi.fn(async () => 'expo-token') }, + RECEIPTS_QUEUE: { + send: vi.fn(async () => undefined), + sendBatch: vi.fn(async () => undefined), + }, + NOTIFICATION_CHANNEL_DO: createNotificationChannelNamespace(), + STREAM_CHAT_API_SECRET: { get: vi.fn(async () => 'stream-secret') }, + } satisfies TestEnv; +} + +function createDbMock(options: { tokens: { token: string }[] }) { + const insertBuilder: QueryBuilder = { + values: vi.fn(), + onConflictDoUpdate: vi.fn(async () => undefined), + from: vi.fn(), + where: vi.fn(), + }; + insertBuilder.values.mockReturnValue(insertBuilder); + + const totalBuilder: QueryBuilder = { + values: vi.fn(), + onConflictDoUpdate: vi.fn(), + from: vi.fn(), + where: vi.fn(async () => [{ total: '4' }]), + }; + totalBuilder.from.mockReturnValue(totalBuilder); + + const tokensBuilder: QueryBuilder = { + values: vi.fn(), + onConflictDoUpdate: vi.fn(), + from: vi.fn(), + where: vi.fn(async () => options.tokens), + }; + tokensBuilder.from.mockReturnValue(tokensBuilder); + + const deleteBuilder = { + where: vi.fn(async () => undefined), + }; + + const db = { + insert: vi.fn(() => insertBuilder), + select: vi.fn().mockReturnValueOnce(totalBuilder).mockReturnValueOnce(tokensBuilder), + delete: vi.fn(() => deleteBuilder), + }; + + return { db, insertBuilder, tokensBuilder, deleteBuilder }; +} + +describe('sendChannelPush', () => { + beforeEach(() => { + vi.resetAllMocks(); + }); + + it('updates badge, sends Expo push messages, removes stale tokens, and enqueues receipt checks', async () => { + const env = createEnv(); + const { db, insertBuilder, deleteBuilder } = createDbMock({ + tokens: [{ token: 'ExponentPushToken[ok]' }, { token: 'ExponentPushToken[stale]' }], + }); + useDbMock(db); + mockSendPushNotifications.mockResolvedValue({ + ticketTokenPairs: [{ ticketId: 'ticket-1', token: 'ExponentPushToken[ok]' }], + staleTokens: ['ExponentPushToken[stale]'], + }); + + await sendChannelPush({ + env, + userId: 'user-1', + channelId: 'channel-1', + title: 'Title', + body: 'Body', + data: { type: 'cloud_agent_session', cliSessionId: 'channel-1' }, + }); + + expect(insertBuilder.values).toHaveBeenCalledWith({ + user_id: 'user-1', + channel_id: 'channel-1', + badge_count: 1, + }); + expect(mockSendPushNotifications).toHaveBeenCalledWith( + [ + { + to: 'ExponentPushToken[ok]', + title: 'Title', + body: 'Body', + data: { type: 'cloud_agent_session', cliSessionId: 'channel-1' }, + badge: 4, + sound: 'default', + priority: 'high', + }, + { + to: 'ExponentPushToken[stale]', + title: 'Title', + body: 'Body', + data: { type: 'cloud_agent_session', cliSessionId: 'channel-1' }, + badge: 4, + sound: 'default', + priority: 'high', + }, + ], + 'expo-token' + ); + expect(deleteBuilder.where).toHaveBeenCalledOnce(); + expect(env.RECEIPTS_QUEUE.send).toHaveBeenCalledWith( + { ticketTokenPairs: [{ ticketId: 'ticket-1', token: 'ExponentPushToken[ok]' }] }, + { delaySeconds: 900 } + ); + }); + + it('swallows Expo send failures so queue retries do not re-increment the badge', async () => { + const env = createEnv(); + const { db, insertBuilder, deleteBuilder } = createDbMock({ + tokens: [{ token: 'ExponentPushToken[ok]' }], + }); + useDbMock(db); + mockSendPushNotifications.mockRejectedValue(new Error('expo down')); + + await expect( + sendChannelPush({ + env, + userId: 'user-1', + channelId: 'channel-1', + title: 'Title', + body: 'Body', + data: { type: 'chat', instanceId: 'channel-1' }, + }) + ).resolves.toBeUndefined(); + + expect(insertBuilder.onConflictDoUpdate).toHaveBeenCalledOnce(); + expect(deleteBuilder.where).not.toHaveBeenCalled(); + expect(env.RECEIPTS_QUEUE.send).not.toHaveBeenCalled(); + }); + + it('updates badge and skips Expo send when the user has no push tokens', async () => { + const env = createEnv(); + const { db, insertBuilder } = createDbMock({ tokens: [] }); + useDbMock(db); + + await sendChannelPush({ + env, + userId: 'user-1', + channelId: 'channel-1', + title: 'Title', + body: 'Body', + data: { type: 'chat', instanceId: 'channel-1' }, + }); + + expect(insertBuilder.onConflictDoUpdate).toHaveBeenCalledOnce(); + expect(mockSendPushNotifications).not.toHaveBeenCalled(); + expect(env.RECEIPTS_QUEUE.send).not.toHaveBeenCalled(); + }); +}); diff --git a/services/notifications/src/lib/channel-push.ts b/services/notifications/src/lib/channel-push.ts new file mode 100644 index 0000000000..9b958f8855 --- /dev/null +++ b/services/notifications/src/lib/channel-push.ts @@ -0,0 +1,113 @@ +import { getWorkerDb } from '@kilocode/db/client'; +import { channel_badge_counts, user_push_tokens } from '@kilocode/db/schema'; +import { eq, inArray, sql, sum } from 'drizzle-orm'; + +import type { ExpoPushMessage, TicketTokenPair } from './expo-push'; +import { sendPushNotifications } from './expo-push'; + +type ReceiptCheckMessage = { + ticketTokenPairs: TicketTokenPair[]; +}; + +type SendChannelPushEnv = Env & { + RECEIPTS_QUEUE: Queue; +}; + +export type ChannelPushData = + | { type: 'chat'; instanceId: string } + | { type: 'cloud_agent_session'; cliSessionId: string }; + +export type SendChannelPushOptions = { + env: SendChannelPushEnv; + userId: string; + channelId: string; + title: string; + body: string; + data: ChannelPushData; +}; + +export async function sendChannelPush({ + env, + userId, + channelId, + title, + body, + data, +}: SendChannelPushOptions): Promise { + const db = getWorkerDb(env.HYPERDRIVE.connectionString); + + await db + .insert(channel_badge_counts) + .values({ user_id: userId, channel_id: channelId, badge_count: 1 }) + .onConflictDoUpdate({ + target: [channel_badge_counts.user_id, channel_badge_counts.channel_id], + set: { badge_count: sql`${channel_badge_counts.badge_count} + 1` }, + }); + + const [totals] = await db + .select({ total: sum(channel_badge_counts.badge_count) }) + .from(channel_badge_counts) + .where(eq(channel_badge_counts.user_id, userId)); + + const badgeCount = Number(totals?.total ?? 0); + + const tokens = await db + .select({ token: user_push_tokens.token }) + .from(user_push_tokens) + .where(eq(user_push_tokens.user_id, userId)); + + if (tokens.length === 0) { + return; + } + + // Everything after the badge increment is best-effort: once we've mutated + // the badge count, letting a downstream failure (Expo send, stale-token + // cleanup, receipt enqueue) bubble up would trigger a queue retry that + // re-increments the badge and may re-send pushes that partially succeeded + // in the failing attempt. Log and swallow instead. + try { + const accessToken = await env.EXPO_ACCESS_TOKEN.get(); + const messages = tokens.map( + ({ token }) => + ({ + to: token, + title, + body, + data, + badge: badgeCount, + sound: 'default', + priority: 'high', + }) satisfies ExpoPushMessage + ); + const { ticketTokenPairs, staleTokens } = await sendPushNotifications(messages, accessToken); + + if (staleTokens.length > 0) { + try { + await db.delete(user_push_tokens).where(inArray(user_push_tokens.token, staleTokens)); + } catch (err) { + console.error('Failed to clean up stale push tokens', { + staleCount: staleTokens.length, + error: err instanceof Error ? err.message : String(err), + }); + } + } + + if (ticketTokenPairs.length > 0) { + try { + await env.RECEIPTS_QUEUE.send({ ticketTokenPairs }, { delaySeconds: 900 }); + } catch (err) { + console.error('Failed to enqueue Expo receipt check', { + ticketCount: ticketTokenPairs.length, + error: err instanceof Error ? err.message : String(err), + }); + } + } + } catch (err) { + console.error('Failed to send channel push', { + userId, + channelId, + tokenCount: tokens.length, + error: err instanceof Error ? err.message : String(err), + }); + } +} diff --git a/services/notifications/src/lib/notifications-service-cloud-agent.test.ts b/services/notifications/src/lib/notifications-service-cloud-agent.test.ts new file mode 100644 index 0000000000..9a2ee7e42c --- /dev/null +++ b/services/notifications/src/lib/notifications-service-cloud-agent.test.ts @@ -0,0 +1,146 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +const { mockGetWorkerDb, mockSendChannelPush } = vi.hoisted(() => ({ + mockGetWorkerDb: vi.fn(), + mockSendChannelPush: vi.fn(async () => undefined), +})); + +vi.mock('cloudflare:workers', () => ({ + WorkerEntrypoint: class WorkerEntrypoint { + ctx: ExecutionContext; + env: Env; + constructor(ctx: ExecutionContext, env: Env) { + this.ctx = ctx; + this.env = env; + } + }, +})); + +vi.mock('@kilocode/db/client', () => ({ + getWorkerDb: mockGetWorkerDb, +})); + +vi.mock('./channel-push', () => ({ + sendChannelPush: mockSendChannelPush, +})); + +import { NotificationsService } from './notifications-service'; + +type TestMock = ReturnType; + +function createQueryBuilder(result: unknown[]) { + const builder = { + from: vi.fn(), + where: vi.fn(), + limit: vi.fn(async () => result), + }; + builder.from.mockReturnValue(builder); + builder.where.mockReturnValue(builder); + return builder; +} + +function createDbMock(options: { userRows?: unknown[]; sessionRows?: unknown[] } = {}) { + return { + select: vi + .fn() + .mockReturnValueOnce(createQueryBuilder(options.userRows ?? [{ id: 'user-1' }])) + .mockReturnValueOnce( + createQueryBuilder(options.sessionRows ?? [{ title: 'Resolved title' }]) + ), + }; +} + +type ServiceEnv = Parameters[0]; + +function makeService(env: { + HYPERDRIVE: { connectionString: string }; + EXPO_ACCESS_TOKEN?: { get: TestMock }; +}): NotificationsService { + return new NotificationsService({} as ExecutionContext, env as unknown as Env); +} + +function createEnv(): ServiceEnv { + return { + HYPERDRIVE: { connectionString: 'postgres://test' }, + EXPO_ACCESS_TOKEN: { get: vi.fn(async () => 'expo-token') }, + }; +} + +describe('NotificationsService.sendCloudAgentSessionNotification', () => { + beforeEach(() => { + vi.resetAllMocks(); + }); + + it('dispatches the push via sendChannelPush with the resolved session title', async () => { + const db = createDbMock(); + mockGetWorkerDb.mockReturnValue(db); + const service = makeService(createEnv()); + + const result = await service.sendCloudAgentSessionNotification({ + userId: 'user-1', + cliSessionId: 'ses_1', + status: 'completed', + body: 'Finished', + }); + + expect(result).toEqual({ dispatched: true }); + expect(mockSendChannelPush).toHaveBeenCalledWith( + expect.objectContaining({ + userId: 'user-1', + channelId: 'ses_1', + title: 'Resolved title', + body: 'Finished', + data: { type: 'cloud_agent_session', cliSessionId: 'ses_1' }, + }) + ); + }); + + it('returns missing_user without dispatching when the user row is absent', async () => { + const db = createDbMock({ userRows: [] }); + mockGetWorkerDb.mockReturnValue(db); + const service = makeService(createEnv()); + + const result = await service.sendCloudAgentSessionNotification({ + userId: 'user-1', + cliSessionId: 'ses_1', + status: 'completed', + body: 'Finished', + }); + + expect(result).toEqual({ dispatched: false, reason: 'missing_user' }); + expect(mockSendChannelPush).not.toHaveBeenCalled(); + }); + + it('returns missing_session without dispatching when the session row is absent', async () => { + const db = createDbMock({ sessionRows: [] }); + mockGetWorkerDb.mockReturnValue(db); + const service = makeService(createEnv()); + + const result = await service.sendCloudAgentSessionNotification({ + userId: 'user-1', + cliSessionId: 'ses_missing', + status: 'completed', + body: 'Finished', + }); + + expect(result).toEqual({ dispatched: false, reason: 'missing_session' }); + expect(mockSendChannelPush).not.toHaveBeenCalled(); + }); + + it('rejects invalid params before touching the db', async () => { + mockGetWorkerDb.mockImplementation(() => { + throw new Error('should not be called'); + }); + const service = makeService(createEnv()); + + await expect( + service.sendCloudAgentSessionNotification({ + userId: '', + cliSessionId: 'ses_1', + status: 'completed', + body: 'Finished', + }) + ).rejects.toThrow(); + expect(mockGetWorkerDb).not.toHaveBeenCalled(); + }); +}); diff --git a/services/notifications/src/lib/notifications-service.ts b/services/notifications/src/lib/notifications-service.ts index 8178d5529f..4d3b145de1 100644 --- a/services/notifications/src/lib/notifications-service.ts +++ b/services/notifications/src/lib/notifications-service.ts @@ -1,8 +1,10 @@ import { WorkerEntrypoint } from 'cloudflare:workers'; import { getWorkerDb } from '@kilocode/db/client'; -import { user_push_tokens } from '@kilocode/db/schema'; -import { eq, inArray } from 'drizzle-orm'; +import { cli_sessions_v2, kilocode_users, user_push_tokens } from '@kilocode/db/schema'; +import { and, eq, inArray } from 'drizzle-orm'; +import { z } from 'zod'; +import { sendChannelPush } from './channel-push'; import type { TicketTokenPair } from './expo-push'; import { sendPushNotifications } from './expo-push'; import { @@ -21,6 +23,28 @@ type ReceiptCheckMessage = { ticketTokenPairs: TicketTokenPair[]; }; +export type CloudAgentSessionPushStatus = 'completed' | 'failed' | 'interrupted'; + +export type SendCloudAgentSessionNotificationParams = { + userId: string; + cliSessionId: string; + status: CloudAgentSessionPushStatus; + body: string; +}; + +export type SendCloudAgentSessionNotificationResult = { + dispatched: boolean; + /** Reason the dispatch was skipped. Useful for producer-side logging. */ + reason?: 'missing_user' | 'missing_session'; +}; + +const CloudAgentSessionParamsSchema = z.object({ + userId: z.string().min(1), + cliSessionId: z.string().min(1), + status: z.enum(['completed', 'failed', 'interrupted']), + body: z.string(), +}) satisfies z.ZodType; + /** * RPC entrypoint for other Workers to send non-chat push notifications. * @@ -60,4 +84,47 @@ export class NotificationsService extends WorkerEntrypoint { return result; } + + async sendCloudAgentSessionNotification( + params: SendCloudAgentSessionNotificationParams + ): Promise { + const parsed = CloudAgentSessionParamsSchema.parse(params); + const db = getWorkerDb(this.env.HYPERDRIVE.connectionString); + + const [user] = await db + .select({ id: kilocode_users.id }) + .from(kilocode_users) + .where(eq(kilocode_users.id, parsed.userId)) + .limit(1); + + if (!user) { + return { dispatched: false, reason: 'missing_user' }; + } + + const [session] = await db + .select({ title: cli_sessions_v2.title }) + .from(cli_sessions_v2) + .where( + and( + eq(cli_sessions_v2.session_id, parsed.cliSessionId), + eq(cli_sessions_v2.kilo_user_id, parsed.userId) + ) + ) + .limit(1); + + if (!session) { + return { dispatched: false, reason: 'missing_session' }; + } + + await sendChannelPush({ + env: this.env, + userId: parsed.userId, + channelId: parsed.cliSessionId, + title: session.title ?? 'Agent session', + body: parsed.body, + data: { type: 'cloud_agent_session', cliSessionId: parsed.cliSessionId }, + }); + + return { dispatched: true }; + } } diff --git a/services/notifications/tsconfig.json b/services/notifications/tsconfig.json index 635e98f321..a07d241e3b 100644 --- a/services/notifications/tsconfig.json +++ b/services/notifications/tsconfig.json @@ -16,6 +16,6 @@ "skipLibCheck": true, "types": ["./worker-configuration.d.ts", "node"] }, - "exclude": ["test"], - "include": ["worker-configuration.d.ts", "src/**/*.ts"] + "exclude": ["test", "vitest.config.mts"], + "include": ["worker-configuration.d.ts", "src/**/*.ts", "vitest.config.ts"] } diff --git a/services/notifications/vitest.config.mts b/services/notifications/vitest.config.mts index d9430c7554..8b1c6bf799 100644 --- a/services/notifications/vitest.config.mts +++ b/services/notifications/vitest.config.mts @@ -2,6 +2,7 @@ import { defineWorkersConfig } from '@cloudflare/vitest-pool-workers/config'; export default defineWorkersConfig({ test: { + include: ['src/**/*.workers.test.ts'], poolOptions: { workers: { wrangler: { configPath: './wrangler.jsonc' }, diff --git a/services/notifications/vitest.config.ts b/services/notifications/vitest.config.ts new file mode 100644 index 0000000000..0231bd17ce --- /dev/null +++ b/services/notifications/vitest.config.ts @@ -0,0 +1,11 @@ +/// + +import { defineConfig } from 'vitest/config'; + +export default defineConfig({ + test: { + globals: true, + environment: 'node', + include: ['src/**/*.test.ts'], + }, +}); diff --git a/services/notifications/worker-configuration.d.ts b/services/notifications/worker-configuration.d.ts index cbc9201506..37b72d2660 100644 --- a/services/notifications/worker-configuration.d.ts +++ b/services/notifications/worker-configuration.d.ts @@ -1,17 +1,24 @@ /* eslint-disable */ -// Generated by Wrangler by running `wrangler types` (hash: b336c1c1e874405e99f5e26c8c9319df) +// Generated by Wrangler by running `wrangler types` (hash: 1088871fafeed99c9a08ee0a832c2dcb) // Runtime types generated with workerd@1.20260312.1 2026-02-01 nodejs_compat declare namespace Cloudflare { interface GlobalProps { mainModule: typeof import("./src/index"); durableNamespaces: "NotificationChannelDO"; } + interface DevEnv { + HYPERDRIVE: Hyperdrive; + RECEIPTS_QUEUE: Queue; + STREAM_CHAT_API_SECRET: SecretsStoreSecret; + EXPO_ACCESS_TOKEN: SecretsStoreSecret; + NOTIFICATION_CHANNEL_DO: DurableObjectNamespace; + } interface Env { HYPERDRIVE: Hyperdrive; RECEIPTS_QUEUE: Queue; STREAM_CHAT_API_SECRET: SecretsStoreSecret; EXPO_ACCESS_TOKEN: SecretsStoreSecret; - NOTIFICATION_CHANNEL_DO: DurableObjectNamespace /* NotificationChannelDO */; + NOTIFICATION_CHANNEL_DO: DurableObjectNamespace; } } interface Env extends Cloudflare.Env {} diff --git a/services/notifications/wrangler.jsonc b/services/notifications/wrangler.jsonc index 943bd8176a..84f45aae05 100644 --- a/services/notifications/wrangler.jsonc +++ b/services/notifications/wrangler.jsonc @@ -69,4 +69,53 @@ "new_classes": ["NotificationChannelDO"], }, ], + + "env": { + "dev": { + "name": "notifications-dev", + "hyperdrive": [ + { + "binding": "HYPERDRIVE", + "id": "624ec80650dd414199349f4e217ddb10", + "localConnectionString": "postgres://postgres:postgres@localhost:5432/postgres", + }, + ], + "durable_objects": { + "bindings": [ + { + "name": "NOTIFICATION_CHANNEL_DO", + "class_name": "NotificationChannelDO", + }, + ], + }, + "queues": { + "producers": [ + { + "binding": "RECEIPTS_QUEUE", + "queue": "notifications-receipts", + }, + ], + "consumers": [ + { + "queue": "notifications-receipts", + "max_retries": 3, + "dead_letter_queue": "notifications-receipts-dlq", + "retry_delay": 60, + }, + ], + }, + "secrets_store_secrets": [ + { + "binding": "STREAM_CHAT_API_SECRET", + "store_id": "342a86d9e3a94da698e82d0c6e2a36f0", + "secret_name": "STREAM_CHAT_API_SECRET", + }, + { + "binding": "EXPO_ACCESS_TOKEN", + "store_id": "342a86d9e3a94da698e82d0c6e2a36f0", + "secret_name": "EXPO_ACCESS_TOKEN", + }, + ], + }, + }, }