From 19836ed7ae30ca1f8bd1ab8de308c6a6a153d4d9 Mon Sep 17 00:00:00 2001 From: Tyler Longwell Date: Wed, 8 Apr 2026 16:42:43 -0400 Subject: [PATCH] feat(sprout-proxy): add /public read-only WebSocket endpoint --- Cargo.lock | 2 + crates/sprout-proxy/Cargo.toml | 4 + crates/sprout-proxy/src/main.rs | 69 +- crates/sprout-proxy/src/server.rs | 1071 ++++++++++++++++++++++++++- crates/sprout-proxy/src/upstream.rs | 21 +- 5 files changed, 1137 insertions(+), 30 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e8681fa7..7abe9534 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3502,6 +3502,7 @@ dependencies = [ "dashmap", "futures-util", "hmac", + "http-body-util", "moka", "nostr", "rand 0.8.5", @@ -3513,6 +3514,7 @@ dependencies = [ "thiserror", "tokio", "tokio-tungstenite 0.26.2", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/crates/sprout-proxy/Cargo.toml b/crates/sprout-proxy/Cargo.toml index 26d1ef8e..a81500c1 100644 --- a/crates/sprout-proxy/Cargo.toml +++ b/crates/sprout-proxy/Cargo.toml @@ -30,6 +30,10 @@ axum = { workspace = true } tower-http = { workspace = true } tracing-subscriber = { workspace = true } +[dev-dependencies] +http-body-util = "0.1" +tower = { version = "0.5", features = ["util"] } + [[bin]] name = "sprout-proxy" path = "src/main.rs" diff --git a/crates/sprout-proxy/src/main.rs b/crates/sprout-proxy/src/main.rs index a293eef0..654c0ae0 100644 --- a/crates/sprout-proxy/src/main.rs +++ b/crates/sprout-proxy/src/main.rs @@ -1,10 +1,11 @@ //! sprout-proxy binary — NIP-28 guest relay proxy for standard Nostr clients. +use std::sync::atomic::AtomicUsize; use std::sync::Arc; use nostr::prelude::*; use tokio::sync::{broadcast, mpsc}; -use tracing::{error, info}; +use tracing::{error, info, warn}; use sprout_proxy::channel_map::ChannelMap; use sprout_proxy::guest_store::GuestStore; @@ -97,6 +98,69 @@ async fn main() { ); info!(channels = channel_map.len(), "channel map ready"); + // ── Parse public channel list (optional) ────────────────────────────────── + + let public_channels: Vec = std::env::var("SPROUT_PROXY_PUBLIC_CHANNELS") + .unwrap_or_default() + .split(',') + .filter(|s| !s.trim().is_empty()) + .filter_map(|s| { + let trimmed = s.trim(); + match trimmed.parse::() { + Ok(uuid) => { + if channel_map.lookup_by_uuid(&uuid).is_some() { + Some(uuid) + } else { + warn!(uuid = %trimmed, "SPROUT_PROXY_PUBLIC_CHANNELS: unknown channel UUID — skipping"); + None + } + } + Err(_) => { + warn!(value = %trimmed, "SPROUT_PROXY_PUBLIC_CHANNELS: invalid UUID — skipping"); + None + } + } + }) + .collect(); + + if public_channels.is_empty() { + info!("no public channels configured — /public route disabled"); + } else { + info!( + count = public_channels.len(), + "public channels configured for read-only access" + ); + } + + // ── Parse public connection lifetime cap ────────────────────────────────── + + let public_lifetime_secs: u64 = match std::env::var("SPROUT_PROXY_PUBLIC_LIFETIME_SECS") { + Ok(raw) => match raw.trim().parse::() { + Ok(0) => { + eprintln!( + "error: SPROUT_PROXY_PUBLIC_LIFETIME_SECS cannot be 0 \ + (would disconnect every public client immediately)" + ); + std::process::exit(1); + } + Ok(v) => v, + Err(_) => { + warn!( + value = %raw.trim(), + default = server::DEFAULT_PUBLIC_LIFETIME_SECS, + "SPROUT_PROXY_PUBLIC_LIFETIME_SECS is not a valid integer — using default" + ); + server::DEFAULT_PUBLIC_LIFETIME_SECS + } + }, + Err(_) => server::DEFAULT_PUBLIC_LIFETIME_SECS, + }; + + info!( + seconds = public_lifetime_secs, + "public connection lifetime cap" + ); + // ── Init translator ─────────────────────────────────────────────────────── let translator = Arc::new(Translator::new( @@ -188,6 +252,9 @@ async fn main() { upstream_events: upstream_events_tx.clone(), admin_secret, relay_url, + public_channels: Arc::new(public_channels), + public_connection_count: Arc::new(AtomicUsize::new(0)), + public_lifetime_secs, }; // ── Build router ────────────────────────────────────────────────────────── diff --git a/crates/sprout-proxy/src/server.rs b/crates/sprout-proxy/src/server.rs index 160f4641..2378b14a 100644 --- a/crates/sprout-proxy/src/server.rs +++ b/crates/sprout-proxy/src/server.rs @@ -6,6 +6,7 @@ //! from the local [`ChannelMap`]. use std::collections::{HashMap, HashSet}; +use std::sync::atomic::Ordering; use std::sync::Arc; use axum::{ @@ -54,6 +55,15 @@ pub struct ProxyState { /// This proxy's own WebSocket URL (e.g. "ws://0.0.0.0:4869"). /// Used for NIP-42 relay tag validation. pub relay_url: String, + /// UUIDs of channels accessible via the `/public` read-only endpoint. + /// Configured via `SPROUT_PROXY_PUBLIC_CHANNELS` env var. + pub public_channels: Arc>, + /// Active public (unauthenticated) connection count for rate limiting. + pub public_connection_count: Arc, + /// Maximum lifetime (seconds) for a public WebSocket connection. + /// Configured via `SPROUT_PROXY_PUBLIC_LIFETIME_SECS` env var. + /// Falls back to `DEFAULT_PUBLIC_LIFETIME_SECS` (3600) if not set. + pub public_lifetime_secs: u64, } // ─── Router ────────────────────────────────────────────────────────────────── @@ -75,8 +85,11 @@ pub struct WsParams { /// - `GET /admin/guests` — List all registered guests /// /// All `/admin/*` routes are protected by `SPROUT_PROXY_ADMIN_SECRET` if set. +/// +/// `/public` is only registered when at least one public channel is configured. pub fn router(state: ProxyState) -> Router { - Router::new() + let has_public = !state.public_channels.is_empty(); + let mut app = Router::new() .route("/", get(root_handler)) .route("/admin/invite", axum::routing::post(create_invite)) .route( @@ -84,8 +97,11 @@ pub fn router(state: ProxyState) -> Router { axum::routing::post(register_guest) .delete(revoke_guest) .get(list_guests), - ) - .with_state(state) + ); + if has_public { + app = app.route("/public", get(public_handler)); + } + app.with_state(state) } // ─── Root handler (NIP-11 / WebSocket) ─────────────────────────────────────── @@ -122,7 +138,7 @@ async fn root_handler( } } -fn nip11_response() -> impl IntoResponse { +fn nip11_response() -> Response { let nip11 = serde_json::json!({ "name": "sprout-proxy", "description": "Sprout NIP-28 guest proxy for standard Nostr clients", @@ -133,13 +149,17 @@ fn nip11_response() -> impl IntoResponse { "auth_required": true } }); - ( - [ - (axum::http::header::CONTENT_TYPE, "application/nostr+json"), - (axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"), - ], - serde_json::to_string_pretty(&nip11).unwrap(), - ) + match serde_json::to_string_pretty(&nip11) { + Ok(body) => ( + [ + (axum::http::header::CONTENT_TYPE, "application/nostr+json"), + (axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"), + ], + body, + ) + .into_response(), + Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } } // ─── Constant-time string comparison ───────────────────────────────────────── @@ -351,9 +371,11 @@ async fn handle_ws(mut socket: WebSocket, state: ProxyState, token: String) { }; // FIX 1: pending_oks maps upstream_event_id_hex → client_original_event_id - // FIX 5: active_subs tracks prefixed sub IDs sent upstream for cleanup on disconnect + // FIX 5: active_subs tracks ALL prefixed sub IDs (for cap counting and cleanup). + // upstream_subs tracks only subs that were forwarded upstream (for CLOSE routing). let mut pending_oks: HashMap = HashMap::new(); let mut active_subs: HashSet = HashSet::new(); + let mut upstream_subs: HashSet = HashSet::new(); // ── 4. Subscribe to upstream broadcast ──────────────────────────────── let mut upstream_rx = state.upstream_events.subscribe(); @@ -368,12 +390,13 @@ async fn handle_ws(mut socket: WebSocket, state: ProxyState, token: String) { handle_client_message( &mut socket, &state, - &text.to_string(), + &text, &allowed_channels, &client_pubkey, &conn_prefix, &mut pending_oks, &mut active_subs, + &mut upstream_subs, ) .await; } @@ -429,6 +452,10 @@ async fn handle_ws(mut socket: WebSocket, state: ProxyState, token: String) { Ok(RelayMessage::Closed { ref subscription_id, ref message }) => { let sub_str = subscription_id.to_string(); if sub_str.starts_with(&conn_prefix) { + // Clean up tracking — upstream killed this sub, + // so free the slot before forwarding to client. + active_subs.remove(&sub_str); + upstream_subs.remove(&sub_str); let client_sub_id = SubscriptionId::new(&sub_str[conn_prefix.len() + 1..]); let out = RelayMessage::closed(client_sub_id, message.clone()); if socket.send(Message::Text(out.as_json().into())).await.is_err() { @@ -480,8 +507,8 @@ async fn handle_ws(mut socket: WebSocket, state: ProxyState, token: String) { } } - // FIX 5: On disconnect, send CLOSE for all active upstream subscriptions. - for prefixed_sub in active_subs { + // On disconnect, send CLOSE only for subs that were forwarded upstream. + for prefixed_sub in upstream_subs { let sub_id = SubscriptionId::new(prefixed_sub); if let Err(e) = state.upstream.send_close(sub_id).await { warn!("upstream send_close on disconnect failed: {e}"); @@ -503,6 +530,7 @@ async fn handle_client_message( conn_prefix: &str, pending_oks: &mut HashMap, active_subs: &mut HashSet, + upstream_subs: &mut HashSet, ) { let msg = match ClientMessage::from_json(raw_msg) { Ok(m) => m, @@ -525,6 +553,7 @@ async fn handle_client_message( allowed_channels, conn_prefix, active_subs, + upstream_subs, ) .await; } @@ -589,11 +618,13 @@ async fn handle_client_message( } ClientMessage::Close(sub_id) => { let prefixed = format!("{conn_prefix}:{}", sub_id); - // FIX 5: Remove from active_subs tracking. active_subs.remove(&prefixed); - let prefixed_sub_id = SubscriptionId::new(prefixed); - if let Err(e) = state.upstream.send_close(prefixed_sub_id).await { - warn!("upstream send_close failed: {e}"); + // Only send upstream CLOSE for subs that had upstream REQs. + if upstream_subs.remove(&prefixed) { + let prefixed_sub_id = SubscriptionId::new(prefixed); + if let Err(e) = state.upstream.send_close(prefixed_sub_id).await { + warn!("upstream send_close failed: {e}"); + } } } // AUTH after initial handshake is silently ignored. @@ -753,6 +784,7 @@ fn collect_local_events( // ─── REQ handler ───────────────────────────────────────────────────────────── +#[allow(clippy::too_many_arguments)] async fn handle_req( socket: &mut WebSocket, state: &ProxyState, @@ -761,7 +793,20 @@ async fn handle_req( allowed_channels: &[Uuid], conn_prefix: &str, active_subs: &mut HashSet, + upstream_subs: &mut HashSet, ) { + // Nostr replacement semantics: a new REQ with the same sub ID replaces + // the previous one. Tear down any existing upstream subscription first. + let prefixed_existing = format!("{conn_prefix}:{}", sub_id); + if upstream_subs.remove(&prefixed_existing) { + let old_sub_id = SubscriptionId::new(&prefixed_existing); + if let Err(e) = state.upstream.send_close(old_sub_id).await { + warn!("upstream send_close for replaced sub failed: {e}"); + } + } + // Remove from active_subs too — it will be re-added if the new REQ succeeds. + active_subs.remove(&prefixed_existing); + let (owned_local_filters, owned_upstream_filters) = split_filters(&filters); // Serve local filters from ChannelMap via the extracted pure function. @@ -774,6 +819,10 @@ async fn handle_req( if owned_upstream_filters.is_empty() { // Only local filters — send EOSE immediately after serving them. + // Track the sub even for local-only REQs so the per-connection cap + // is accurate and cleanup on disconnect is complete. + let prefixed_local = format!("{conn_prefix}:{}", sub_id); + active_subs.insert(prefixed_local); let _ = send_relay_msg(socket, RelayMessage::eose(sub_id.clone())).await; return; } @@ -792,15 +841,26 @@ async fn handle_req( let prefixed_sub_id_str = format!("{conn_prefix}:{}", sub_id); let prefixed_sub_id = SubscriptionId::new(prefixed_sub_id_str.clone()); - // FIX 5: Track this subscription for cleanup on disconnect. - active_subs.insert(prefixed_sub_id_str); - - if let Err(e) = state + // Only track the sub after send_req succeeds. If the upstream send + // fails, the sub was never established — don't consume a slot. + match state .upstream .send_req(prefixed_sub_id, translated_filters) .await { - warn!("upstream send_req failed: {e}"); + Ok(()) => { + active_subs.insert(prefixed_sub_id_str.clone()); + upstream_subs.insert(prefixed_sub_id_str); + } + Err(e) => { + warn!("upstream send_req failed: {e}"); + // Notify client that the subscription couldn't be established. + let _ = send_relay_msg( + socket, + RelayMessage::closed(sub_id, "error: upstream unavailable"), + ) + .await; + } } } @@ -1034,6 +1094,381 @@ async fn list_guests(State(state): State, headers: HeaderMap) -> imp .into_response() } +// ─── Public read-only endpoint ──────────────────────────────────────────────── + +/// Resource limits for unauthenticated `/public` connections. +/// These prevent DoS from anonymous internet traffic. +const MAX_PUBLIC_CONNECTIONS: usize = 100; +const MAX_PUBLIC_SUBS_PER_CONN: usize = 5; +/// Maximum number of filters allowed in a single REQ from an anonymous client. +/// Prevents expensive multi-filter fan-out queries from unauthenticated callers. +const MAX_PUBLIC_FILTERS_PER_REQ: usize = 3; +/// Maximum raw message size (bytes) accepted from an anonymous client. +/// Rejects oversized payloads before JSON parsing to limit CPU exposure. +const MAX_PUBLIC_MSG_BYTES: usize = 4096; +/// Maximum number of events returned per subscription for anonymous clients. +/// Prevents unbounded backfill queries from unauthenticated callers. +const MAX_PUBLIC_LIMIT: usize = 200; +const PUBLIC_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); +/// Default lifetime cap for public connections (1 hour). +/// Used when `ProxyState::public_lifetime_secs` is 0 or not configured. +pub const DEFAULT_PUBLIC_LIFETIME_SECS: u64 = 3600; + +/// Content-negotiate between public NIP-11 JSON and WebSocket upgrade. +/// Mirrors [`root_handler`] but uses the public NIP-11 document and +/// the read-only WebSocket handler. +async fn public_handler( + State(state): State, + headers: HeaderMap, + req: axum::extract::Request, +) -> Response { + let wants_nip11 = headers + .get("accept") + .and_then(|v| v.to_str().ok()) + .map(|v| v.contains("application/nostr+json")) + .unwrap_or(false); + + if wants_nip11 { + return public_nip11_response().into_response(); + } + + match WebSocketUpgrade::from_request(req, &state).await { + Ok(ws) => ws + .max_message_size(MAX_PUBLIC_MSG_BYTES) + .max_frame_size(MAX_PUBLIC_MSG_BYTES) + .on_upgrade(move |socket| handle_public_ws(socket, state)), + Err(_) => public_nip11_response().into_response(), + } +} + +fn public_nip11_response() -> Response { + let nip11 = serde_json::json!({ + "name": "sprout-proxy (public)", + "description": "Sprout NIP-28 public read-only relay — no authentication required", + "supported_nips": [1, 11, 28], + "software": "sprout-proxy", + "version": env!("CARGO_PKG_VERSION"), + "limitation": { + "auth_required": false, + "max_subscriptions": MAX_PUBLIC_SUBS_PER_CONN, + } + }); + match serde_json::to_string_pretty(&nip11) { + Ok(body) => ( + [ + (axum::http::header::CONTENT_TYPE, "application/nostr+json"), + (axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*"), + ], + body, + ) + .into_response(), + Err(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } +} + +/// Read-only WebSocket handler for the `/public` endpoint. +/// +/// Structurally separate from [`handle_ws`] — there is no EVENT branch, +/// no AUTH challenge, no shadow key derivation for the reader, no guest +/// or invite store interaction. The write path does not exist. +async fn handle_public_ws(mut socket: WebSocket, state: ProxyState) { + // ── Connection cap ──────────────────────────────────────────────────── + // Soft cap — Relaxed ordering is intentional; brief overshoot + // by concurrent arrivals is acceptable. + let current = state + .public_connection_count + .fetch_add(1, Ordering::Relaxed); + if current >= MAX_PUBLIC_CONNECTIONS { + state + .public_connection_count + .fetch_sub(1, Ordering::Relaxed); + let _ = send_relay_msg( + &mut socket, + RelayMessage::notice("error: too many public connections — try again later"), + ) + .await; + return; + } + + // Ensure counter is decremented on all exit paths. + let _guard = PublicConnGuard(state.public_connection_count.clone()); + + let conn_prefix = uuid::Uuid::new_v4().simple().to_string()[..8].to_string(); + let prefix_with_sep = format!("{conn_prefix}:"); + let allowed_channels: &[Uuid] = &state.public_channels; + let mut active_subs: HashSet = HashSet::new(); + let mut upstream_subs: HashSet = HashSet::new(); + let mut upstream_rx = state.upstream_events.subscribe(); + + let connected_at = tokio::time::Instant::now(); + let mut last_activity = tokio::time::Instant::now(); + + debug!("public client connected"); + + let lifetime_secs = if state.public_lifetime_secs > 0 { + state.public_lifetime_secs + } else { + DEFAULT_PUBLIC_LIFETIME_SECS + }; + let lifetime_cap = std::time::Duration::from_secs(lifetime_secs); + + // ── Main read-only message loop ─────────────────────────────────────── + loop { + tokio::select! { + // Lifetime cap — hard deadline enforced inside select! so it fires + // even when other branches are blocking. + _ = tokio::time::sleep_until(connected_at + lifetime_cap) => { + let _ = send_relay_msg( + &mut socket, + RelayMessage::notice("connection lifetime exceeded — reconnect to continue"), + ) + .await; + break; + } + + // Bidirectional idle timeout — disconnect if no inbound or outbound activity. + _ = tokio::time::sleep_until(last_activity + PUBLIC_IDLE_TIMEOUT) => { + let _ = send_relay_msg( + &mut socket, + RelayMessage::notice("idle timeout — disconnecting"), + ) + .await; + break; + } + + // Inbound from client (read-only: REQ and CLOSE only). + msg = socket.recv() => { + match msg { + Some(Ok(Message::Text(text))) => { + last_activity = tokio::time::Instant::now(); + handle_public_client_message( + &mut socket, + &state, + &text, + allowed_channels, + &conn_prefix, + &mut active_subs, + &mut upstream_subs, + ) + .await; + } + Some(Ok(Message::Close(_))) | None => break, + // Ping/Pong frames reset the idle timer — many Nostr clients + // use WebSocket keepalives to maintain quiet subscriptions. + Some(Ok(Message::Ping(_) | Message::Pong(_))) => { + last_activity = tokio::time::Instant::now(); + } + // Binary frames are not used in the Nostr protocol. + Some(Ok(Message::Binary(_))) => { + let _ = send_relay_msg( + &mut socket, + RelayMessage::notice("error: binary frames not supported"), + ) + .await; + } + _ => {} + } + } + + // Outbound from upstream relay — translate and filter. + upstream = upstream_rx.recv() => { + match upstream { + Ok(text) => { + match RelayMessage::from_json(&text) { + Ok(RelayMessage::Event { subscription_id, event }) => { + let sub_str = subscription_id.to_string(); + let Some(client_sub) = sub_str.strip_prefix(&prefix_with_sep) else { + continue; + }; + let client_sub_id = SubscriptionId::new(client_sub); + match state + .translator + .translate_outbound(&event, allowed_channels) + .await + { + Ok(Some(translated)) => { + let out = RelayMessage::event(client_sub_id, translated); + if socket.send(Message::Text(out.as_json().into())).await.is_ok() { + last_activity = tokio::time::Instant::now(); + } else { + break; + } + } + Ok(None) => {} + Err(e) => { + debug!(error = %e, "dropping upstream event (not in public scope)"); + } + } + } + Ok(RelayMessage::EndOfStoredEvents(ref sub_id)) => { + let sub_str = sub_id.to_string(); + if let Some(client_sub) = sub_str.strip_prefix(&prefix_with_sep) { + let client_sub_id = SubscriptionId::new(client_sub); + let out = RelayMessage::eose(client_sub_id); + if socket.send(Message::Text(out.as_json().into())).await.is_ok() { + last_activity = tokio::time::Instant::now(); + } else { + break; + } + } + } + Ok(RelayMessage::Closed { ref subscription_id, ref message }) => { + let sub_str = subscription_id.to_string(); + if let Some(client_sub) = sub_str.strip_prefix(&prefix_with_sep) { + // Clean up tracking — upstream killed this sub, + // so free the slot before forwarding to client. + active_subs.remove(&sub_str); + upstream_subs.remove(&sub_str); + let client_sub_id = SubscriptionId::new(client_sub); + let out = RelayMessage::closed(client_sub_id, message.clone()); + if socket.send(Message::Text(out.as_json().into())).await.is_ok() { + last_activity = tokio::time::Instant::now(); + } else { + break; + } + } + } + // Public readers never send events, so no pending_oks to route. + // Drop NOTICE, AUTH, OK, and other control-plane messages. + Ok(_) => {} + Err(_) => {} + } + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + warn!(skipped = n, "public client: upstream broadcast lagged"); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + error!("upstream broadcast channel closed"); + break; + } + } + } + } + } + + // Clean up upstream subscriptions on disconnect. + for prefixed_sub in upstream_subs { + let sub_id = SubscriptionId::new(prefixed_sub); + if let Err(e) = state.upstream.send_close(sub_id).await { + warn!("upstream send_close on public disconnect failed: {e}"); + } + } + + debug!("public client disconnected"); +} + +/// Read-only client message handler. Handles REQ and CLOSE only. +/// EVENT and all other message types are rejected. +async fn handle_public_client_message( + socket: &mut WebSocket, + state: &ProxyState, + raw_msg: &str, + allowed_channels: &[Uuid], + conn_prefix: &str, + active_subs: &mut HashSet, + upstream_subs: &mut HashSet, +) { + // Reject oversized messages before JSON parsing to limit CPU exposure. + if raw_msg.len() > MAX_PUBLIC_MSG_BYTES { + let _ = send_relay_msg(socket, RelayMessage::notice("error: message too large")).await; + return; + } + + let msg = match ClientMessage::from_json(raw_msg) { + Ok(m) => m, + Err(_) => { + let _ = send_relay_msg(socket, RelayMessage::notice("error: invalid message")).await; + return; + } + }; + + match msg { + ClientMessage::Req { + subscription_id, + filters, + } => { + // Enforce per-connection subscription limit. + // Track all sub IDs (including local-only) to prevent bypass via kind:40/41 REQs. + let prefixed = format!("{conn_prefix}:{}", subscription_id); + if active_subs.len() >= MAX_PUBLIC_SUBS_PER_CONN && !active_subs.contains(&prefixed) { + let _ = send_relay_msg( + socket, + RelayMessage::closed( + subscription_id, + "error: too many subscriptions — close one first", + ), + ) + .await; + return; + } + // Reject REQs with too many filters to prevent expensive fan-out queries. + if filters.len() > MAX_PUBLIC_FILTERS_PER_REQ { + let _ = send_relay_msg( + socket, + RelayMessage::closed( + subscription_id, + "error: too many filters — max 3 per REQ", + ), + ) + .await; + return; + } + // Clamp each filter's limit to prevent unbounded backfill from + // anonymous clients. + let clamped_filters: Vec = filters + .into_iter() + .map(|mut f| { + let current = f.limit.unwrap_or(MAX_PUBLIC_LIMIT); + f.limit = Some(current.min(MAX_PUBLIC_LIMIT)); + f + }) + .collect(); + handle_req( + socket, + state, + subscription_id, + clamped_filters, + allowed_channels, + conn_prefix, + active_subs, + upstream_subs, + ) + .await; + } + ClientMessage::Close(sub_id) => { + let prefixed = format!("{conn_prefix}:{}", sub_id); + active_subs.remove(&prefixed); + // Only send upstream CLOSE for subs that had upstream REQs. + if upstream_subs.remove(&prefixed) { + let prefixed_sub_id = SubscriptionId::new(prefixed); + if let Err(e) = state.upstream.send_close(prefixed_sub_id).await { + warn!("upstream send_close failed: {e}"); + } + } + } + ClientMessage::Event(event) => { + // Structurally reject — this endpoint is read-only. + let _ = send_relay_msg( + socket, + RelayMessage::ok(event.id, false, "restricted: read-only access"), + ) + .await; + } + // AUTH on a public endpoint is meaningless — ignore silently. + ClientMessage::Auth(_) => {} + _ => {} + } +} + +/// RAII guard that decrements the public connection counter on drop. +struct PublicConnGuard(Arc); + +impl Drop for PublicConnGuard { + fn drop(&mut self) { + self.0.fetch_sub(1, Ordering::Relaxed); + } +} + // ─── Tests ──────────────────────────────────────────────────────────────────── #[cfg(test)] @@ -1069,9 +1504,62 @@ mod tests { upstream_events, admin_secret: None, relay_url: "ws://127.0.0.1:4869".to_string(), + public_channels: Arc::new(Vec::new()), + public_connection_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)), + public_lifetime_secs: 3600, } } + /// Like `make_state()` but registers one channel and includes it in + /// `public_channels`, so the `/public` route is registered by `router()`. + fn make_state_with_public_channel() -> (ProxyState, Uuid) { + let keys = Keys::generate(); + let channel_map = Arc::new(crate::channel_map::ChannelMap::new(keys.clone())); + let guest_store = Arc::new(GuestStore::new()); + let invite_store = Arc::new(InviteStore::new()); + let (upstream_events, _) = broadcast::channel(16); + let shadow_keys = Arc::new( + crate::shadow_keys::ShadowKeyManager::new(b"test-salt-server-tests") + .expect("shadow key manager"), + ); + let translator = Arc::new(crate::translate::Translator::new( + shadow_keys, + channel_map.clone(), + "http://localhost:3000", + "sprout_test", + "deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef", + )); + let upstream = Arc::new(UpstreamClient::new("ws://localhost:3000", "sprout_test")); + + // Register a test channel so the map has something to serve. + let dto = crate::channel_map::ChannelDto { + id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + name: "test-public-channel".to_string(), + created_at: "2026-01-15T12:00:00Z".to_string(), + visibility: "open".to_string(), + description: "A test public channel".to_string(), + created_by: "0101010101010101010101010101010101010101010101010101010101010101" + .to_string(), + }; + channel_map.register(&dto).expect("register must succeed"); + let uuid: Uuid = "550e8400-e29b-41d4-a716-446655440000".parse().unwrap(); + + let state = ProxyState { + channel_map, + guest_store, + invite_store, + translator, + upstream, + upstream_events, + admin_secret: None, + relay_url: "ws://127.0.0.1:4869".to_string(), + public_channels: Arc::new(vec![uuid]), + public_connection_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)), + public_lifetime_secs: 3600, + }; + (state, uuid) + } + #[test] fn router_builds() { let state = make_state(); @@ -1085,6 +1573,226 @@ mod tests { let _ = response.into_response(); } + #[test] + fn public_nip11_json_is_valid() { + // Verify the public NIP-11 response serializes without panic + // and contains expected fields. + let resp = public_nip11_response().into_response(); + assert_eq!( + resp.headers().get("content-type").unwrap(), + "application/nostr+json" + ); + } + + #[test] + fn public_nip11_contains_expected_fields() { + let resp = public_nip11_response().into_response(); + let body = resp.into_body(); + let bytes = tokio::runtime::Runtime::new().unwrap().block_on(async { + use http_body_util::BodyExt; + body.collect().await.unwrap().to_bytes() + }); + let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(json["name"], "sprout-proxy (public)"); + assert_eq!(json["supported_nips"], serde_json::json!([1, 11, 28])); + assert_eq!(json["limitation"]["auth_required"], false); + assert_eq!( + json["limitation"]["max_subscriptions"], + MAX_PUBLIC_SUBS_PER_CONN + ); + assert!(json["version"].is_string()); + } + + #[test] + fn public_conn_guard_decrements() { + let counter = Arc::new(std::sync::atomic::AtomicUsize::new(5)); + { + let _guard = PublicConnGuard(counter.clone()); + assert_eq!(counter.load(std::sync::atomic::Ordering::Relaxed), 5); + } + // Guard dropped — counter should be decremented + assert_eq!(counter.load(std::sync::atomic::Ordering::Relaxed), 4); + } + + #[test] + #[allow(clippy::assertions_on_constants)] + fn public_constants_are_sane() { + // Verify DoS limits are within reasonable bounds + assert!(MAX_PUBLIC_CONNECTIONS > 0 && MAX_PUBLIC_CONNECTIONS <= 1000); + assert!(MAX_PUBLIC_SUBS_PER_CONN > 0 && MAX_PUBLIC_SUBS_PER_CONN <= 50); + assert!(MAX_PUBLIC_FILTERS_PER_REQ > 0 && MAX_PUBLIC_FILTERS_PER_REQ <= 20); + assert!(MAX_PUBLIC_MSG_BYTES >= 1024 && MAX_PUBLIC_MSG_BYTES <= 65536); + assert!(PUBLIC_IDLE_TIMEOUT.as_secs() >= 30); + } + + #[test] + fn public_router_includes_public_route() { + // Verify the /public route is registered when public_channels is non-empty + let (state, _uuid) = make_state_with_public_channel(); + let app = router(state); + // Build a NIP-11 request to /public + let req = axum::http::Request::builder() + .uri("/public") + .header("accept", "application/nostr+json") + .body(axum::body::Body::empty()) + .unwrap(); + let resp = tokio::runtime::Runtime::new().unwrap().block_on(async { + use tower::util::ServiceExt; + app.oneshot(req).await.unwrap() + }); + assert_eq!(resp.status(), 200); + assert_eq!( + resp.headers().get("content-type").unwrap(), + "application/nostr+json" + ); + } + + #[test] + #[allow(clippy::assertions_on_constants)] + fn public_msg_size_limit_is_reasonable() { + // MAX_PUBLIC_MSG_BYTES should be large enough for valid REQ/CLOSE + // messages but small enough to reject abuse. + // A typical REQ: ["REQ","sub1",{"kinds":[42],"limit":100}] is ~45 bytes. + // 4096 bytes is generous for 3 filters with complex conditions. + assert!(MAX_PUBLIC_MSG_BYTES >= 1024, "too small for valid REQs"); + assert!( + MAX_PUBLIC_MSG_BYTES <= 8192, + "too large — increases CPU exposure" + ); + } + + #[test] + #[allow(clippy::assertions_on_constants)] + fn public_max_limit_is_bounded() { + // Verify the replay cap exists and is reasonable. + // MAX_PUBLIC_LIMIT prevents unbounded backfill from anonymous clients. + assert!(MAX_PUBLIC_LIMIT > 0, "must allow some results"); + assert!( + MAX_PUBLIC_LIMIT <= 500, + "too high — enables expensive backfills" + ); + } + + #[test] + fn public_endpoint_rejects_non_websocket_post() { + let (state, _uuid) = make_state_with_public_channel(); + let app = router(state); + // POST to /public should get rejected (not a GET) + let req = axum::http::Request::builder() + .method("POST") + .uri("/public") + .body(axum::body::Body::empty()) + .unwrap(); + let resp = tokio::runtime::Runtime::new().unwrap().block_on(async { + use tower::util::ServiceExt; + app.oneshot(req).await.unwrap() + }); + // Should return 405 Method Not Allowed (axum rejects non-GET on get() routes) + assert_eq!(resp.status().as_u16(), 405); + } + + #[test] + fn public_nip11_cors_header() { + let (state, _uuid) = make_state_with_public_channel(); + let app = router(state); + let req = axum::http::Request::builder() + .uri("/public") + .header("accept", "application/nostr+json") + .body(axum::body::Body::empty()) + .unwrap(); + let resp = tokio::runtime::Runtime::new().unwrap().block_on(async { + use tower::util::ServiceExt; + app.oneshot(req).await.unwrap() + }); + assert_eq!(resp.status(), 200); + assert_eq!( + resp.headers().get("access-control-allow-origin").unwrap(), + "*" + ); + } + + #[test] + fn public_route_not_registered_when_no_channels() { + // When public_channels is empty, /public should not be registered. + let state = make_state(); + let app = router(state); + let req = axum::http::Request::builder() + .uri("/public") + .header("accept", "application/nostr+json") + .body(axum::body::Body::empty()) + .unwrap(); + let resp = tokio::runtime::Runtime::new().unwrap().block_on(async { + use tower::util::ServiceExt; + app.oneshot(req).await.unwrap() + }); + // /public is not registered → 404 + assert_eq!(resp.status().as_u16(), 404); + } + + #[test] + fn channel_isolation_public_scope() { + // Core security invariant: events from non-public channels must not + // be visible through the public endpoint's allowed_channels scope. + let keys = Keys::generate(); + let map = crate::channel_map::ChannelMap::new(keys); + + // Register two channels — one public, one private. + let dto_public = crate::channel_map::ChannelDto { + id: "550e8400-e29b-41d4-a716-446655440000".to_string(), + name: "public-channel".to_string(), + created_at: "2026-01-15T12:00:00Z".to_string(), + visibility: "open".to_string(), + description: "Public".to_string(), + created_by: "0101010101010101010101010101010101010101010101010101010101010101" + .to_string(), + }; + let dto_private = crate::channel_map::ChannelDto { + id: "660e8400-e29b-41d4-a716-446655440001".to_string(), + name: "private-channel".to_string(), + created_at: "2026-01-15T12:00:00Z".to_string(), + visibility: "open".to_string(), + description: "Private".to_string(), + created_by: "0202020202020202020202020202020202020202020202020202020202020202" + .to_string(), + }; + map.register(&dto_public).expect("register public"); + map.register(&dto_private).expect("register private"); + + let public_uuid: Uuid = "550e8400-e29b-41d4-a716-446655440000".parse().unwrap(); + let private_uuid: Uuid = "660e8400-e29b-41d4-a716-446655440001".parse().unwrap(); + let map = Arc::new(map); + + // Query with only the public channel in scope (simulating /public endpoint). + let filter = Filter::new().kind(Kind::ChannelCreation); + let public_scope = vec![public_uuid]; + let events = collect_local_events(&filter, &map, &public_scope); + + // Must see exactly one channel — the public one. + assert_eq!(events.len(), 1, "public scope must yield exactly 1 channel"); + + // Verify it's the public channel, not the private one. + let content: serde_json::Value = serde_json::from_str(&events[0].content).unwrap(); + // kind:40 uses the UUID as the "name" field (display name is in kind:41). + assert_eq!( + content["name"], "550e8400-e29b-41d4-a716-446655440000", + "public scope must only expose the public channel" + ); + + // Query with both channels in scope (simulating authenticated access). + let full_scope = vec![public_uuid, private_uuid]; + let all_events = collect_local_events(&filter, &map, &full_scope); + assert_eq!( + all_events.len(), + 2, + "authenticated scope must see both channels" + ); + } + + #[test] + fn public_default_lifetime_is_one_hour() { + assert_eq!(DEFAULT_PUBLIC_LIFETIME_SECS, 3600); + } + #[test] fn default_hours_and_max_uses() { assert_eq!(default_hours(), 24); @@ -1285,4 +1993,319 @@ mod tests { let events = collect_local_events(&filter, &map, &[uuid]); assert!(events.is_empty()); } + + // ── WebSocket behavioral tests for /public endpoint ───────────────── + + /// Start the proxy router on a random port and return the bound address. + /// Uses `make_state_with_public_channel()` so the `/public` route is registered. + async fn start_test_server() -> std::net::SocketAddr { + let (state, _uuid) = make_state_with_public_channel(); + let app = router(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind to random port"); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.ok(); + }); + addr + } + + /// Connect to the /public WebSocket endpoint and return the stream. + async fn connect_public( + addr: std::net::SocketAddr, + ) -> tokio_tungstenite::WebSocketStream> + { + let url = format!("ws://{addr}/public"); + let (ws, _resp) = tokio_tungstenite::connect_async(&url) + .await + .expect("WebSocket connect to /public"); + ws + } + + /// Read the next text frame, with a timeout to prevent hanging tests. + async fn read_text( + ws: &mut tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + ) -> String { + use futures_util::StreamExt; + let msg = tokio::time::timeout(std::time::Duration::from_secs(5), ws.next()) + .await + .expect("read timed out") + .expect("stream ended") + .expect("read error"); + match msg { + tokio_tungstenite::tungstenite::Message::Text(t) => t.to_string(), + other => panic!("expected Text frame, got {other:?}"), + } + } + + /// Send a text frame. + async fn send_text( + ws: &mut tokio_tungstenite::WebSocketStream< + tokio_tungstenite::MaybeTlsStream, + >, + msg: &str, + ) { + use futures_util::SinkExt; + ws.send(tokio_tungstenite::tungstenite::Message::Text(msg.into())) + .await + .expect("send failed"); + } + + #[tokio::test] + async fn public_ws_rejects_event() { + let addr = start_test_server().await; + let mut ws = connect_public(addr).await; + + // Build a valid EVENT message. + let keys = Keys::generate(); + let event = EventBuilder::text_note("hello", []) + .sign_with_keys(&keys) + .unwrap(); + let client_msg = ClientMessage::event(event); + send_text(&mut ws, &client_msg.as_json()).await; + + let resp = read_text(&mut ws).await; + let relay_msg: serde_json::Value = serde_json::from_str(&resp).unwrap(); + // Expect ["OK", , false, "restricted: read-only access"] + assert_eq!(relay_msg[0], "OK"); + assert_eq!(relay_msg[2], false); + assert!( + relay_msg[3].as_str().unwrap().contains("read-only"), + "expected read-only rejection, got: {resp}" + ); + } + + #[tokio::test] + async fn public_ws_rejects_oversized_message() { + use futures_util::StreamExt; + + let addr = start_test_server().await; + let mut ws = connect_public(addr).await; + + // Send a message larger than MAX_PUBLIC_MSG_BYTES (4096). + // The WebSocket layer now enforces max_message_size / max_frame_size, + // so the server resets the connection before the app layer sees the + // payload — no NOTICE is sent; the stream simply closes. + let oversized = "x".repeat(MAX_PUBLIC_MSG_BYTES + 1); + send_text(&mut ws, &oversized).await; + + let result = tokio::time::timeout(std::time::Duration::from_secs(5), ws.next()) + .await + .expect("read timed out"); + + // The connection must be terminated — either a clean Close frame or a + // protocol-level reset. Any non-error outcome that is not a Close + // frame is unexpected. + match result { + None => {} // stream ended cleanly + Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) => {} // clean close + Some(Err(_)) => {} // protocol reset — acceptable + Some(Ok(other)) => panic!("expected connection close, got: {other:?}"), + } + } + + #[tokio::test] + async fn public_ws_enforces_subscription_limit() { + let addr = start_test_server().await; + let mut ws = connect_public(addr).await; + + // Use kind:40 (ChannelCreation) — handled locally, so EOSE comes back + // immediately without needing an upstream relay. + for i in 0..MAX_PUBLIC_SUBS_PER_CONN { + let req = format!(r#"["REQ","sub-{i}",{{"kinds":[40],"limit":0}}]"#,); + send_text(&mut ws, &req).await; + // Drain the EOSE response. + let eose = read_text(&mut ws).await; + assert!( + eose.contains("EOSE"), + "expected EOSE for sub-{i}, got: {eose}" + ); + } + + // The next subscription should be rejected before it reaches handle_req. + let overflow_req = r#"["REQ","sub-overflow",{"kinds":[40],"limit":0}]"#; + send_text(&mut ws, overflow_req).await; + + let resp = read_text(&mut ws).await; + let relay_msg: serde_json::Value = serde_json::from_str(&resp).unwrap(); + // Expect ["CLOSED", "sub-overflow", "error: too many subscriptions ..."] + assert_eq!(relay_msg[0], "CLOSED"); + assert_eq!(relay_msg[1], "sub-overflow"); + assert!( + relay_msg[2] + .as_str() + .unwrap() + .contains("too many subscriptions"), + "expected sub limit rejection, got: {resp}" + ); + } + + #[tokio::test] + async fn public_ws_enforces_filter_limit() { + let addr = start_test_server().await; + let mut ws = connect_public(addr).await; + + // Build a REQ with more filters than MAX_PUBLIC_FILTERS_PER_REQ (3). + // Uses kind:40 (local-only) so no upstream dependency. + let filters: Vec = (0..=MAX_PUBLIC_FILTERS_PER_REQ) + .map(|_| r#"{"kinds":[40],"limit":0}"#.to_string()) + .collect(); + let req = format!(r#"["REQ","too-many-filters",{}]"#, filters.join(",")); + send_text(&mut ws, &req).await; + + let resp = read_text(&mut ws).await; + let relay_msg: serde_json::Value = serde_json::from_str(&resp).unwrap(); + // Expect ["CLOSED", "too-many-filters", "error: too many filters ..."] + assert_eq!(relay_msg[0], "CLOSED"); + assert_eq!(relay_msg[1], "too-many-filters"); + assert!( + relay_msg[2].as_str().unwrap().contains("too many filters"), + "expected filter limit rejection, got: {resp}" + ); + } + + #[tokio::test] + async fn public_ws_close_frees_subscription_slot() { + let addr = start_test_server().await; + let mut ws = connect_public(addr).await; + + // Fill all subscription slots with kind:40 (local-only, no upstream needed). + for i in 0..MAX_PUBLIC_SUBS_PER_CONN { + let req = format!(r#"["REQ","sub-{i}",{{"kinds":[40],"limit":0}}]"#); + send_text(&mut ws, &req).await; + let eose = read_text(&mut ws).await; + assert!(eose.contains("EOSE"), "expected EOSE, got: {eose}"); + } + + // Close one subscription. + send_text(&mut ws, r#"["CLOSE","sub-0"]"#).await; + + // Small delay to let the server process the CLOSE. + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + // Now a new subscription should succeed. + let req = r#"["REQ","sub-replacement",{"kinds":[40],"limit":0}]"#; + send_text(&mut ws, req).await; + + let resp = read_text(&mut ws).await; + // Should get EOSE, not CLOSED error. + assert!( + resp.contains("EOSE"), + "expected EOSE after freeing a slot, got: {resp}" + ); + } + + #[tokio::test] + async fn public_ws_invalid_json_returns_notice() { + let addr = start_test_server().await; + let mut ws = connect_public(addr).await; + + send_text(&mut ws, "this is not json").await; + + let resp = read_text(&mut ws).await; + let relay_msg: serde_json::Value = serde_json::from_str(&resp).unwrap(); + assert_eq!(relay_msg[0], "NOTICE"); + assert!( + relay_msg[1].as_str().unwrap().contains("invalid"), + "expected invalid message notice, got: {resp}" + ); + } + + // ── Regression test infrastructure ────────────────────────────────── + + /// Start a test server and return the address, broadcast sender, and + /// upstream client. The broadcast sender lets tests inject upstream relay + /// messages (EVENT, EOSE, CLOSED). The upstream client lets tests discover + /// prefixed subscription IDs via `active_sub_ids()`. + async fn start_test_server_with_upstream() -> ( + std::net::SocketAddr, + broadcast::Sender, + Arc, + ) { + let (state, _uuid) = make_state_with_public_channel(); + let events_tx = state.upstream_events.clone(); + let upstream = state.upstream.clone(); + let app = router(state); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind to random port"); + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + axum::serve(listener, app).await.ok(); + }); + (addr, events_tx, upstream) + } + + /// Regression test for the upstream CLOSED subscription slot leak. + /// + /// Bug: when upstream sends CLOSED for a subscription, the handler forwarded + /// it to the client but didn't remove the sub from `active_subs` or + /// `upstream_subs`. The dead sub permanently counted against the 5-slot limit. + /// + /// Fix: remove from both tracking sets before forwarding CLOSED to client. + #[tokio::test] + async fn public_ws_upstream_closed_frees_subscription_slot() { + let (addr, events_tx, upstream) = start_test_server_with_upstream().await; + let mut ws = connect_public(addr).await; + + // Fill 4 of 5 slots with local-only subs (kind:40 → immediate EOSE). + for i in 0..4 { + let req = format!(r#"["REQ","local-{i}",{{"kinds":[40],"limit":0}}]"#); + send_text(&mut ws, &req).await; + let eose = read_text(&mut ws).await; + assert!( + eose.contains("EOSE"), + "expected EOSE for local-{i}, got: {eose}" + ); + } + + // Fill the 5th slot with an upstream sub (kind:42). + // This goes to the upstream relay, so no EOSE comes back in tests. + send_text( + &mut ws, + r#"["REQ","upstream-sub",{"kinds":[42],"limit":0}]"#, + ) + .await; + + // Wait briefly for the server to process the REQ and register the upstream sub. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Discover the prefixed sub ID from the upstream client's tracking. + let sub_ids = upstream.active_sub_ids(); + let prefixed = sub_ids + .iter() + .find(|id| id.ends_with(":upstream-sub")) + .expect("upstream sub should be tracked") + .clone(); + + // Inject a CLOSED from "upstream" through the broadcast channel. + let closed_msg = RelayMessage::closed( + SubscriptionId::new(&prefixed), + "subscription closed by relay", + ); + events_tx + .send(closed_msg.as_json()) + .expect("broadcast send"); + + // The client should receive the CLOSED (with prefix stripped). + let resp = read_text(&mut ws).await; + let relay_msg: serde_json::Value = serde_json::from_str(&resp).unwrap(); + assert_eq!(relay_msg[0], "CLOSED", "expected CLOSED, got: {resp}"); + assert_eq!(relay_msg[1], "upstream-sub"); + + // Now the 5th slot should be freed. Open a new local sub — should succeed. + send_text( + &mut ws, + r#"["REQ","replacement-sub",{"kinds":[40],"limit":0}]"#, + ) + .await; + let resp = read_text(&mut ws).await; + assert!( + resp.contains("EOSE"), + "expected EOSE after upstream CLOSED freed the slot, got: {resp}" + ); + } } diff --git a/crates/sprout-proxy/src/upstream.rs b/crates/sprout-proxy/src/upstream.rs index 44b83f6a..f8806b7d 100644 --- a/crates/sprout-proxy/src/upstream.rs +++ b/crates/sprout-proxy/src/upstream.rs @@ -130,13 +130,13 @@ impl UpstreamClient { filters: Vec, ) -> Result<(), crate::ProxyError> { let msg = ClientMessage::req(sub_id.clone(), filters).as_json(); - self.inner - .active_subs - .insert(sub_id.to_string(), msg.clone()); self.outbound_tx - .send(msg) + .send(msg.clone()) .await - .map_err(|_| crate::ProxyError::Upstream("outbound channel closed".into())) + .map_err(|_| crate::ProxyError::Upstream("outbound channel closed".into()))?; + // Only record the subscription for reconnect replay after successful send. + self.inner.active_subs.insert(sub_id.to_string(), msg); + Ok(()) } /// Send a CLOSE to the upstream relay. @@ -179,6 +179,17 @@ impl UpstreamClient { self.inner.active_subs.len() } + /// Returns the prefixed subscription IDs currently tracked by the upstream client. + /// Used by tests to discover the server-assigned prefix for injecting upstream messages. + #[allow(dead_code)] // Used in tests + pub(crate) fn active_sub_ids(&self) -> Vec { + self.inner + .active_subs + .iter() + .map(|r| r.key().clone()) + .collect() + } + // ── Run loop ────────────────────────────────────────────────────────────── /// Run the upstream connection loop. Reconnects on disconnect with exponential