Skip to content
Open
35 changes: 30 additions & 5 deletions desktop/src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,24 @@ fn setup_menu(app: &mut App) -> Result<(), DynError> {
Ok(())
}

/// Restore input focus to the main webview after a native GTK dialog
/// is dismissed. On Linux/WebKitGTK, native dialogs can leave the
/// webview in a frozen state where it renders but does not process
/// input events.
fn restore_webview_focus(handle: &AppHandle) {
let handle = handle.clone();
// Delay focus restoration so the native GTK dialog has time to
// fully close and release window focus. Without this, set_focus()
// fires while the dialog still owns focus and the webview stays
// unresponsive.
std::thread::spawn(move || {
std::thread::sleep(Duration::from_millis(100));
if let Some(window) = handle.get_webview_window("main") {
let _ = window.set_focus();
}
});
}

static UPDATE_CHECK_ACTIVE: AtomicBool = AtomicBool::new(false);

// Guard that clears UPDATE_CHECK_ACTIVE on drop, ensuring the
Expand Down Expand Up @@ -654,11 +672,12 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) {
.is_err()
{
if !silent {
let h = handle.clone();
handle
.dialog()
.message("An update check is already in progress.")
.title("Update Check")
.show(|_| {});
.show(move |_| restore_webview_focus(&h));
}
return;
}
Expand All @@ -669,11 +688,12 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) {
Err(err) => {
eprintln!("[agentsview] updater unavailable: {err}");
if !silent {
let h = handle.clone();
handle
.dialog()
.message("Could not check for updates. The updater is not configured.")
.title("Update Check")
.show(|_| {});
.show(move |_| restore_webview_focus(&h));
}
return;
}
Expand All @@ -684,23 +704,25 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) {
Err(err) => {
eprintln!("[agentsview] update check failed: {err}");
if !silent {
let h = handle.clone();
handle
.dialog()
.message("Could not check for updates. Please try again later.")
.title("Update Check")
.show(|_| {});
.show(move |_| restore_webview_focus(&h));
}
return;
}
};

let Some(update) = update else {
if !silent {
let h = handle.clone();
handle
.dialog()
.message("You're running the latest version.")
.title("No Updates Available")
.show(|_| {});
.show(move |_| restore_webview_focus(&h));
}
return;
};
Expand All @@ -722,14 +744,15 @@ async fn check_for_updates(handle: &AppHandle, silent: bool) {

if let Err(err) = update.download_and_install(|_, _| {}, || {}).await {
eprintln!("[agentsview] update install failed: {err}");
let h = handle.clone();
handle
.dialog()
.message(
"Failed to install the update. \
Please try downloading manually from the releases page.",
)
.title("Update Failed")
.show(|_| {});
.show(move |_| restore_webview_focus(&h));
return;
}

Expand All @@ -752,12 +775,14 @@ async fn dialog_confirm(
message: &str,
) -> bool {
let (tx, rx) = tokio::sync::oneshot::channel();
let h = handle.clone();
handle
.dialog()
.message(message)
.title(title)
.buttons(MessageDialogButtons::OkCancel)
.show(move |confirmed| {
restore_webview_focus(&h);
let _ = tx.send(confirmed);
});
rx.await.unwrap_or(false)
Expand Down
2 changes: 1 addition & 1 deletion desktop/src-tauri/tauri.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
}
],
"security": {
"csp": "default-src 'self'; connect-src 'self' http://127.0.0.1:* ws://127.0.0.1:*; img-src 'self' data:; style-src 'self' 'unsafe-inline'; font-src 'self' data:; object-src 'none'; frame-ancestors 'none'; base-uri 'none';"
"csp": "default-src 'self' http://127.0.0.1:* http://localhost:*; script-src 'self' http://127.0.0.1:* http://localhost:*; connect-src 'self' http://127.0.0.1:* http://localhost:* ws://127.0.0.1:* ws://localhost:*; img-src 'self' data: http://127.0.0.1:* http://localhost:*; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com http://127.0.0.1:* http://localhost:*; font-src 'self' data: https://fonts.gstatic.com http://127.0.0.1:* http://localhost:*; object-src 'none'; base-uri 'none'; frame-ancestors 'none'"
}
},
"bundle": {
Expand Down
154 changes: 154 additions & 0 deletions internal/server/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,160 @@ func TestMiddlewareTimeout(t *testing.T) {
}
}

func TestCSPMiddlewareSetsHeaderOnNonAPIRoutes(t *testing.T) {
t.Parallel()

tests := []struct {
name string
path string
host string
port int
publicOrigins []string
bindAllIPs map[string]bool
wantCSP bool
wantParts []string
wantAbsent []string // substrings that must NOT appear
}{
{
name: "SPA_root_gets_CSP_with_pinned_origin",
path: "/",
host: "127.0.0.1",
port: 8081,
wantCSP: true,
wantParts: []string{
"script-src 'self' http://127.0.0.1:8081",
"default-src 'self' http://127.0.0.1:8081",
"connect-src 'self' http://127.0.0.1:8081",
"ws://127.0.0.1:8081",
"style-src 'self' http://127.0.0.1:8081 'unsafe-inline' https://fonts.googleapis.com",
"font-src 'self' http://127.0.0.1:8081 data: https://fonts.gstatic.com",
"frame-ancestors 'none'",
},
},
{
name: "SPA_subpath_gets_CSP",
path: "/sessions/abc",
host: "127.0.0.1",
port: 9090,
wantCSP: true,
wantParts: []string{
"http://127.0.0.1:9090",
"ws://127.0.0.1:9090",
},
},
{
name: "API_route_no_CSP",
path: "/api/v1/sessions",
host: "127.0.0.1",
port: 8081,
wantCSP: false,
},
{
name: "API_subpath_no_CSP",
path: "/api/v1/stats",
host: "127.0.0.1",
port: 8081,
wantCSP: false,
},
{
name: "IPv6_loopback_brackets",
path: "/",
host: "::1",
port: 8081,
wantCSP: true,
wantParts: []string{
"script-src 'self' http://[::1]:8081",
"connect-src",
"ws://[::1]:8081",
"http://127.0.0.1:8081",
},
},
{
name: "BindAll_connect_src_includes_LAN_IPs",
path: "/",
host: "0.0.0.0",
port: 8080,
bindAllIPs: map[string]bool{
"127.0.0.1": true,
"::1": true,
"192.168.1.5": true,
},
wantCSP: true,
wantParts: []string{
// Pinned origin in all directives
"script-src 'self' http://0.0.0.0:8080",
// LAN IPs in connect-src
"http://192.168.1.5:8080",
"ws://192.168.1.5:8080",
"http://127.0.0.1:8080",
"http://localhost:8080",
},
wantAbsent: []string{
// LAN IPs must NOT be in script-src
"script-src 'self' http://0.0.0.0:8080 http://192",
},
},
{
name: "PublicOrigin_in_connect_src_only",
path: "/",
host: "127.0.0.1",
port: 8081,
publicOrigins: []string{"https://view.example.com"},
wantCSP: true,
wantParts: []string{
// Pinned origin in script-src
"script-src 'self' http://127.0.0.1:8081",
// Public origin in connect-src
"https://view.example.com",
"wss://view.example.com",
},
wantAbsent: []string{
// Public origin must NOT be in script-src
"script-src 'self' http://127.0.0.1:8081 https://view",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := cspMiddleware(tt.host, tt.port, tt.publicOrigins, tt.bindAllIPs, inner)

req := httptest.NewRequest(http.MethodGet, tt.path, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

csp := w.Header().Get("Content-Security-Policy")
if tt.wantCSP {
if csp == "" {
t.Fatal("expected CSP header, got empty")
}
for _, part := range tt.wantParts {
if !strings.Contains(csp, part) {
t.Errorf("CSP missing %q; got %q", part, csp)
}
}
for _, absent := range tt.wantAbsent {
if strings.Contains(csp, absent) {
t.Errorf("CSP should not contain %q; got %q", absent, csp)
}
}
xfo := w.Header().Get("X-Frame-Options")
if xfo != "DENY" {
t.Errorf("expected X-Frame-Options DENY, got %q", xfo)
}
} else {
if csp != "" {
t.Errorf("expected no CSP header on API route, got %q", csp)
}
}
})
}
}

func TestCORSMiddlewareMergesVaryHeader(t *testing.T) {
t.Parallel()

Expand Down
Loading