diff --git a/desktop/src-tauri/src/lib.rs b/desktop/src-tauri/src/lib.rs index ba614573..185326ab 100644 --- a/desktop/src-tauri/src/lib.rs +++ b/desktop/src-tauri/src/lib.rs @@ -622,6 +622,24 @@ fn setup_menu(app: &mut App) -> Result<(), DynError> { Ok(()) } +/// Restore input focus to the main webview after a native GTK dialog +/// is dismissed. On Linux/WebKitGTK, native dialogs can leave the +/// webview in a frozen state where it renders but does not process +/// input events. +fn restore_webview_focus(handle: &AppHandle) { + let handle = handle.clone(); + // Delay focus restoration so the native GTK dialog has time to + // fully close and release window focus. Without this, set_focus() + // fires while the dialog still owns focus and the webview stays + // unresponsive. + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(100)); + if let Some(window) = handle.get_webview_window("main") { + let _ = window.set_focus(); + } + }); +} + static UPDATE_CHECK_ACTIVE: AtomicBool = AtomicBool::new(false); // Guard that clears UPDATE_CHECK_ACTIVE on drop, ensuring the @@ -654,11 +672,12 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) { .is_err() { if !silent { + let h = handle.clone(); handle .dialog() .message("An update check is already in progress.") .title("Update Check") - .show(|_| {}); + .show(move |_| restore_webview_focus(&h)); } return; } @@ -669,11 +688,12 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) { Err(err) => { eprintln!("[agentsview] updater unavailable: {err}"); if !silent { + let h = handle.clone(); handle .dialog() .message("Could not check for updates. The updater is not configured.") .title("Update Check") - .show(|_| {}); + .show(move |_| restore_webview_focus(&h)); } return; } @@ -684,11 +704,12 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) { Err(err) => { eprintln!("[agentsview] update check failed: {err}"); if !silent { + let h = handle.clone(); handle .dialog() .message("Could not check for updates. Please try again later.") .title("Update Check") - .show(|_| {}); + .show(move |_| restore_webview_focus(&h)); } return; } @@ -696,11 +717,12 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) { let Some(update) = update else { if !silent { + let h = handle.clone(); handle .dialog() .message("You're running the latest version.") .title("No Updates Available") - .show(|_| {}); + .show(move |_| restore_webview_focus(&h)); } return; }; @@ -722,6 +744,7 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) { if let Err(err) = update.download_and_install(|_, _| {}, || {}).await { eprintln!("[agentsview] update install failed: {err}"); + let h = handle.clone(); handle .dialog() .message( @@ -729,7 +752,7 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) { Please try downloading manually from the releases page.", ) .title("Update Failed") - .show(|_| {}); + .show(move |_| restore_webview_focus(&h)); return; } @@ -752,12 +775,14 @@ async fn dialog_confirm( message: &str, ) -> bool { let (tx, rx) = tokio::sync::oneshot::channel(); + let h = handle.clone(); handle .dialog() .message(message) .title(title) .buttons(MessageDialogButtons::OkCancel) .show(move |confirmed| { + restore_webview_focus(&h); let _ = tx.send(confirmed); }); rx.await.unwrap_or(false) diff --git a/desktop/src-tauri/tauri.conf.json b/desktop/src-tauri/tauri.conf.json index d6de4651..e6d682df 100644 --- a/desktop/src-tauri/tauri.conf.json +++ b/desktop/src-tauri/tauri.conf.json @@ -20,7 +20,7 @@ } ], "security": { - "csp": "default-src 'self'; connect-src 'self' http://127.0.0.1:* ws://127.0.0.1:*; img-src 'self' data:; style-src 'self' 'unsafe-inline'; font-src 'self' data:; object-src 'none'; frame-ancestors 'none'; base-uri 'none';" + "csp": "default-src 'self' http://127.0.0.1:* http://localhost:*; script-src 'self' http://127.0.0.1:* http://localhost:*; connect-src 'self' http://127.0.0.1:* http://localhost:* ws://127.0.0.1:* ws://localhost:*; img-src 'self' data: http://127.0.0.1:* http://localhost:*; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com http://127.0.0.1:* http://localhost:*; font-src 'self' data: https://fonts.gstatic.com http://127.0.0.1:* http://localhost:*; object-src 'none'; base-uri 'none'; frame-ancestors 'none'" } }, "bundle": { diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go index 456ff636..f84cf230 100644 --- a/internal/server/middleware_test.go +++ b/internal/server/middleware_test.go @@ -156,6 +156,160 @@ func TestMiddlewareTimeout(t *testing.T) { } } +func TestCSPMiddlewareSetsHeaderOnNonAPIRoutes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + path string + host string + port int + publicOrigins []string + bindAllIPs map[string]bool + wantCSP bool + wantParts []string + wantAbsent []string // substrings that must NOT appear + }{ + { + name: "SPA_root_gets_CSP_with_pinned_origin", + path: "/", + host: "127.0.0.1", + port: 8081, + wantCSP: true, + wantParts: []string{ + "script-src 'self' http://127.0.0.1:8081", + "default-src 'self' http://127.0.0.1:8081", + "connect-src 'self' http://127.0.0.1:8081", + "ws://127.0.0.1:8081", + "style-src 'self' http://127.0.0.1:8081 'unsafe-inline' https://fonts.googleapis.com", + "font-src 'self' http://127.0.0.1:8081 data: https://fonts.gstatic.com", + "frame-ancestors 'none'", + }, + }, + { + name: "SPA_subpath_gets_CSP", + path: "/sessions/abc", + host: "127.0.0.1", + port: 9090, + wantCSP: true, + wantParts: []string{ + "http://127.0.0.1:9090", + "ws://127.0.0.1:9090", + }, + }, + { + name: "API_route_no_CSP", + path: "/api/v1/sessions", + host: "127.0.0.1", + port: 8081, + wantCSP: false, + }, + { + name: "API_subpath_no_CSP", + path: "/api/v1/stats", + host: "127.0.0.1", + port: 8081, + wantCSP: false, + }, + { + name: "IPv6_loopback_brackets", + path: "/", + host: "::1", + port: 8081, + wantCSP: true, + wantParts: []string{ + "script-src 'self' http://[::1]:8081", + "connect-src", + "ws://[::1]:8081", + "http://127.0.0.1:8081", + }, + }, + { + name: "BindAll_connect_src_includes_LAN_IPs", + path: "/", + host: "0.0.0.0", + port: 8080, + bindAllIPs: map[string]bool{ + "127.0.0.1": true, + "::1": true, + "192.168.1.5": true, + }, + wantCSP: true, + wantParts: []string{ + // Pinned origin in all directives + "script-src 'self' http://0.0.0.0:8080", + // LAN IPs in connect-src + "http://192.168.1.5:8080", + "ws://192.168.1.5:8080", + "http://127.0.0.1:8080", + "http://localhost:8080", + }, + wantAbsent: []string{ + // LAN IPs must NOT be in script-src + "script-src 'self' http://0.0.0.0:8080 http://192", + }, + }, + { + name: "PublicOrigin_in_connect_src_only", + path: "/", + host: "127.0.0.1", + port: 8081, + publicOrigins: []string{"https://view.example.com"}, + wantCSP: true, + wantParts: []string{ + // Pinned origin in script-src + "script-src 'self' http://127.0.0.1:8081", + // Public origin in connect-src + "https://view.example.com", + "wss://view.example.com", + }, + wantAbsent: []string{ + // Public origin must NOT be in script-src + "script-src 'self' http://127.0.0.1:8081 https://view", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := cspMiddleware(tt.host, tt.port, tt.publicOrigins, tt.bindAllIPs, inner) + + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + csp := w.Header().Get("Content-Security-Policy") + if tt.wantCSP { + if csp == "" { + t.Fatal("expected CSP header, got empty") + } + for _, part := range tt.wantParts { + if !strings.Contains(csp, part) { + t.Errorf("CSP missing %q; got %q", part, csp) + } + } + for _, absent := range tt.wantAbsent { + if strings.Contains(csp, absent) { + t.Errorf("CSP should not contain %q; got %q", absent, csp) + } + } + xfo := w.Header().Get("X-Frame-Options") + if xfo != "DENY" { + t.Errorf("expected X-Frame-Options DENY, got %q", xfo) + } + } else { + if csp != "" { + t.Errorf("expected no CSP header on API route, got %q", csp) + } + } + }) + } +} + func TestCORSMiddlewareMergesVaryHeader(t *testing.T) { t.Parallel() diff --git a/internal/server/server.go b/internal/server/server.go index dce1eeb3..b6f5c848 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -358,11 +358,13 @@ func (s *Server) Handler() http.Handler { if bindAll { bindAllIPs = localInterfaceIPs() } - h := s.authMiddleware( - hostCheckMiddleware( - allowedHosts, bindAll, s.cfg.Port, bindAllIPs, - corsMiddleware( - allowedOrigins, bindAll, s.cfg.Port, bindAllIPs, logMiddleware(s.mux), + h := cspMiddleware(s.cfg.Host, s.cfg.Port, s.cfg.PublicOrigins, bindAllIPs, + s.authMiddleware( + hostCheckMiddleware( + allowedHosts, bindAll, s.cfg.Port, bindAllIPs, + corsMiddleware( + allowedOrigins, bindAll, s.cfg.Port, bindAllIPs, logMiddleware(s.mux), + ), ), ), ) @@ -392,6 +394,108 @@ func (s *Server) Handler() http.Handler { return h } +// cspMiddleware sets a Content-Security-Policy header on non-API +// responses. The policy pins the exact host:port origin so that +// even if Tauri's compile-time CSP uses a wildcard port, the +// intersection narrows to the actual runtime port. +func cspMiddleware(host string, port int, publicOrigins []string, bindAllIPs map[string]bool, next http.Handler) http.Handler { + policy := buildCSPPolicy(host, port, publicOrigins, bindAllIPs) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/api/") { + w.Header().Set("Content-Security-Policy", policy) + w.Header().Set("X-Frame-Options", "DENY") + } + next.ServeHTTP(w, r) + }) +} + +// buildCSPPolicy constructs the Content-Security-Policy string. +// It uses the same loopback/bind-all logic as buildAllowedOrigins +// to handle IPv6 bracketing, 0.0.0.0/:: normalization, and +// public origins (proxy/TLS). +// +// The server's own origin (host:port) is included explicitly in +// all directives because WebKitGTK in a Tauri webview may not +// resolve 'self' to the Go server origin after navigating from +// tauri://localhost. Public origins and LAN IPs are restricted +// to connect-src only to limit the script execution surface. +func buildCSPPolicy(host string, port int, publicOrigins []string, bindAllIPs map[string]bool) string { + // serverOrigin is the pinned http origin for the configured + // host:port, used in all directives so resources load + // correctly regardless of how the webview resolves 'self'. + serverOrigin := "http://" + net.JoinHostPort(host, strconv.Itoa(port)) + + // connectSrcs collects additional origins for connect-src + // (fetch, SSE, WebSocket) — loopback variants, LAN IPs, + // and public/proxy origins. + connectHTTP := []string{} + connectWS := []string{} + + addConnectOrigin := func(h string) { + for _, o := range httpOrigin(h, port) { + connectHTTP = append(connectHTTP, o) + connectWS = append(connectWS, strings.Replace(o, "http://", "ws://", 1)) + } + } + + // Mirror buildAllowedOrigins: when binding to loopback, + // include the other loopback variant. When binding to all + // interfaces, include all loopback origins plus every + // concrete interface IP. + switch host { + case "127.0.0.1": + addConnectOrigin("localhost") + case "localhost": + addConnectOrigin("127.0.0.1") + case "0.0.0.0", "::": + addConnectOrigin("127.0.0.1") + addConnectOrigin("localhost") + addConnectOrigin("::1") + for ip := range bindAllIPs { + if ip != "127.0.0.1" && ip != "::1" { + addConnectOrigin(ip) + } + } + case "::1": + addConnectOrigin("127.0.0.1") + addConnectOrigin("localhost") + } + + for _, origin := range publicOrigins { + connectHTTP = append(connectHTTP, origin) + connectWS = append(connectWS, + strings.NewReplacer( + "https://", "wss://", + "http://", "ws://", + ).Replace(origin), + ) + } + + // resource-src: 'self' + pinned server origin (for all resource types) + resourceSrc := "'self' " + serverOrigin + + // connect-src: resource-src + loopback/LAN/public origins + ws variants + connectParts := []string{resourceSrc} + wsOrigin := "ws://" + net.JoinHostPort(host, strconv.Itoa(port)) + connectParts = append(connectParts, wsOrigin) + connectParts = append(connectParts, connectHTTP...) + connectParts = append(connectParts, connectWS...) + connectSrc := strings.Join(connectParts, " ") + + return fmt.Sprintf( + "default-src %[1]s; "+ + "script-src %[1]s; "+ + "connect-src %[2]s; "+ + "img-src %[1]s data:; "+ + "style-src %[1]s 'unsafe-inline' https://fonts.googleapis.com; "+ + "font-src %[1]s data: https://fonts.gstatic.com; "+ + "object-src 'none'; "+ + "base-uri 'none'; "+ + "frame-ancestors 'none'", + resourceSrc, connectSrc, + ) +} + // buildAllowedHosts returns the set of Host header values that // are legitimate for this server. This defends against DNS // rebinding attacks where an attacker's domain resolves to