Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions crates/api/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pub struct StatsResponse {
pub active_flows: u64,
}

/// Node status response.
/// Node info response.
#[derive(Debug, Serialize)]
pub struct StatusResponse {
pub struct InfoResponse {
pub name: String,
pub version: String,
pub role: String,
Expand Down Expand Up @@ -133,9 +133,9 @@ pub struct PingResponseBody {
pub role: String,
}

/// Set hint request body.
/// Hint set request body.
#[derive(Debug, Deserialize)]
pub struct SetHintRequestBody {
pub struct HintSetRequestBody {
pub level: String,
pub role: String,
}
Expand Down Expand Up @@ -183,7 +183,7 @@ pub async fn events(
)
}

pub async fn info(State(state): State<ApiState>) -> Result<Json<StatusResponse>, StatusCode> {
pub async fn info(State(state): State<ApiState>) -> Result<Json<InfoResponse>, StatusCode> {
let resp = state
.ipc
.lock()
Expand All @@ -195,7 +195,7 @@ pub async fn info(State(state): State<ApiState>) -> Result<Json<StatusResponse>,
match resp.response {
Some(management_response::Response::Info(s)) => {
let role = s.role().to_string();
Ok(Json(StatusResponse {
Ok(Json(InfoResponse {
name: s.package_name,
version: s.version,
role,
Expand Down Expand Up @@ -280,7 +280,7 @@ pub async fn peers(State(state): State<ApiState>) -> Result<Json<PeersResponse>,
}
}

pub async fn disconnect_peer(
pub async fn peer_disconnect(
State(state): State<ApiState>,
Path(name): Path<String>,
) -> (StatusCode, Json<SuccessResponse>) {
Expand Down Expand Up @@ -644,7 +644,7 @@ pub async fn ping(State(state): State<ApiState>) -> Result<Json<PingResponseBody
}
}

pub async fn ping_peer(
pub async fn peer_ping(
State(state): State<ApiState>,
Path(peer): Path<String>,
) -> Result<Json<PingResponseBody>, StatusCode> {
Expand Down Expand Up @@ -704,9 +704,9 @@ pub async fn shutdown(State(state): State<ApiState>) -> (StatusCode, Json<Succes
}
}

pub async fn set_hint(
pub async fn hint_set(
State(state): State<ApiState>,
Json(req): Json<SetHintRequestBody>,
Json(req): Json<HintSetRequestBody>,
) -> (StatusCode, Json<SuccessResponse>) {
let level = match req.level.as_str() {
"prefer" => HintLevel::Prefer,
Expand Down Expand Up @@ -787,7 +787,7 @@ pub async fn set_hint(
}
}

pub async fn clear_hints(State(state): State<ApiState>) -> (StatusCode, Json<SuccessResponse>) {
pub async fn hint_set_auto(State(state): State<ApiState>) -> (StatusCode, Json<SuccessResponse>) {
let resp = state
.ipc
.lock()
Expand Down
6 changes: 3 additions & 3 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub fn router(state: State) -> Router {
.route("/info", get(handlers::info))
.route("/stats", get(handlers::stats))
.route("/peers", get(handlers::peers))
.route("/peers/{name}", delete(handlers::disconnect_peer))
.route("/peers/{name}", delete(handlers::peer_disconnect))
.route(
"/routes",
get(handlers::list_routes).post(handlers::add_route),
Expand All @@ -76,11 +76,11 @@ pub fn router(state: State) -> Router {
.route("/listen", post(handlers::listen))
.route("/disconnect", post(handlers::disconnect))
.route("/ping", get(handlers::ping))
.route("/ping/{peer}", get(handlers::ping_peer))
.route("/ping/{peer}", get(handlers::peer_ping))
.route("/shutdown", post(handlers::shutdown))
.route(
"/hints",
put(handlers::set_hint).delete(handlers::clear_hints),
put(handlers::hint_set).delete(handlers::hint_set_auto),
)
.layer(middleware::from_fn(move |req, next| {
let auth = auth.clone();
Expand Down
88 changes: 81 additions & 7 deletions crates/core/src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ pub struct ConnectResult<T: wallhack_transport::Transport + ?Sized> {
control_tx: mpsc::Sender<ControlMessage>,
/// Receiver for the server's `Handshake` (delivered via the control loop).
peer_handshake_rx: Option<oneshot::Receiver<Handshake>>,
/// Pong-derived latency measurements from the control loop (milliseconds).
latency_rx: Option<mpsc::Receiver<f64>>,
}

impl<T: wallhack_transport::Transport + ?Sized> ConnectResult<T> {
Expand All @@ -57,7 +55,6 @@ impl<T: wallhack_transport::Transport + ?Sized> ConnectResult<T> {
tasks: ConnectionTasks,
control_tx: mpsc::Sender<ControlMessage>,
peer_handshake_rx: Option<oneshot::Receiver<Handshake>>,
latency_rx: Option<mpsc::Receiver<f64>>,
) -> Self {
Self {
channels,
Expand All @@ -66,7 +63,6 @@ impl<T: wallhack_transport::Transport + ?Sized> ConnectResult<T> {
transport,
control_tx,
peer_handshake_rx,
latency_rx,
}
}

Expand Down Expand Up @@ -110,8 +106,6 @@ pub struct ErasedConnectResult {
pub control_tx: mpsc::Sender<ControlMessage>,
pub peer_handshake_rx: Option<oneshot::Receiver<Handshake>>,
pub peer_addr: String,
/// Pong-derived latency measurements from the control loop (milliseconds).
pub latency_rx: Option<mpsc::Receiver<f64>>,
}

impl<T> ConnectResult<T>
Expand All @@ -135,11 +129,91 @@ where
channels: self.channels,
tasks: self.tasks,
control_tx: self.control_tx,
latency_rx: self.latency_rx,
}
}
}

/// Spawn the control and data-in tasks shared by all client transports.
///
/// Called after the transport is established and the handshake has been
/// queued on `control_tx`. Creates the handshake oneshot, control loop
/// task, incoming data task, and returns a fully wired `ConnectResult`.
pub fn spawn_client_tasks<T: wallhack_transport::Transport + 'static>(
transport: Arc<T>,
control_tx: mpsc::Sender<ControlMessage>,
control_rx: mpsc::Receiver<ControlMessage>,
peer_registry: Option<std::sync::Arc<crate::control::peers::Registry>>,
remote_addr: String,
) -> ConnectResult<T>
where
T::SendStream: 'static,
T::RecvStream: 'static,
T::BiStream: Send + 'static,
{
use crate::transport::protocol;

let (handshake_tx, handshake_rx) = oneshot::channel::<Handshake>();

let control_handle = {
let transport = Arc::clone(&transport);
tokio::spawn(async move {
let mut channels = protocol::ControlChannels {
outgoing_rx: control_rx,
handshake_tx: Some(handshake_tx),
control_response_tx: None,
peer_registry,
peer_name: None,
};
match protocol::run_control_stream_initiator(
&*transport,
&mut channels,
None,
std::time::Duration::from_secs(30),
)
.await
{
Ok(exit) => tracing::debug!("Control stream finished: {exit:?}"),
Err(e) => tracing::debug!("Control stream error: {e}"),
}
})
};

let channels = DataChannels::new();

let incoming_handle = {
let transport = Arc::clone(&transport);
let instructions_tx = channels.instructions_tx.clone();
let responses_tx = channels.responses_tx.clone();
tokio::spawn(async move {
match transport.accept_uni().await {
Ok(Some(mut recv)) => {
if let Err(e) =
protocol::run_data_in(&mut recv, &instructions_tx, &responses_tx).await
{
tracing::debug!("Data-in handler finished: {e}");
}
}
Ok(None) => tracing::debug!("Transport closed before data-in stream accepted"),
Err(e) => tracing::debug!("Failed to accept data-in stream: {e}"),
}
})
};

let tasks = ConnectionTasks {
incoming: incoming_handle,
control: control_handle,
};

ConnectResult::new(
transport,
channels,
remote_addr,
tasks,
control_tx,
Some(handshake_rx),
)
}

pub trait Client {
type Error: std::error::Error + std::fmt::Debug + Send + Sync + 'static;
type Transport: wallhack_transport::Transport;
Expand Down
86 changes: 11 additions & 75 deletions crates/core/src/client/quic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,16 @@ use std::sync::Arc;

use quinn::{IdleTimeout, VarInt, crypto::rustls::QuicClientConfig};
use tokio::time::Instant;
use wallhack_transport::Transport;

use crate::{
ClientConfig, NodeRole,
client::tls_config,
psk::HandshakeExt,
server::server::DataChannels,
transport::{protocol, quic::QuicTransport},
ClientConfig, NodeRole, client::tls_config, psk::HandshakeExt, transport::quic::QuicTransport,
};
use wallhack_wire::{
control::{ControlMessage, control_message},
data::Handshake,
};

use super::client::{Client, ConnectResult, ConnectionTasks};
use super::client::{Client, ConnectResult};

#[derive(thiserror::Error, Debug)]
pub enum Error {
Expand Down Expand Up @@ -64,6 +59,9 @@ pub struct QuicClient {
name: Option<String>,
psk: Option<zeroize::Zeroizing<String>>,
local_handshake: Option<Handshake>,
/// Peer registry for direct latency updates in the control loop.
/// Set by the daemon mode before calling `connect()`.
pub peer_registry: Option<std::sync::Arc<crate::control::peers::Registry>>,
}

impl Client for QuicClient {
Expand Down Expand Up @@ -101,6 +99,7 @@ impl Client for QuicClient {
name: config.name,
psk: config.psk,
local_handshake: config.local_handshake,
peer_registry: None,
})
}

Expand Down Expand Up @@ -168,75 +167,12 @@ impl Client for QuicClient {
})?;
}

// Create oneshot for receiving server's Handshake via the control loop.
let (handshake_tx, handshake_rx) = tokio::sync::oneshot::channel::<Handshake>();
let (latency_tx, latency_rx) = tokio::sync::mpsc::channel::<f64>(4);

// Spawn control stream task
let control_handle = {
let transport = Arc::clone(&transport);
tokio::spawn(async move {
let mut channels = protocol::ControlChannels {
outgoing_rx: control_rx,
handshake_tx: Some(handshake_tx),
latency_tx: Some(latency_tx),
control_response_tx: None,
role_transition_tx: None,
peer_registry: None,
};
match protocol::run_control_stream_initiator(
&*transport,
&mut channels,
None, // client doesn't handle ControlRequests
std::time::Duration::from_secs(30),
)
.await
{
Ok(exit) => tracing::debug!("Control stream finished: {exit:?}"),
Err(e) => tracing::debug!("Control stream error: {e}"),
}
})
};

let channels = DataChannels::new();

// Incoming data task: accept uni stream from peer, dispatch messages.
let incoming_handle = {
let transport = Arc::clone(&transport);
let instructions_tx = channels.instructions_tx.clone();
let responses_tx = channels.responses_tx.clone();
tokio::spawn(async move {
match transport.accept_uni().await {
Ok(Some(mut recv)) => {
if let Err(e) =
protocol::run_data_in(&mut recv, &instructions_tx, &responses_tx).await
{
tracing::debug!("Data-in handler finished: {e}");
}
}
Ok(None) => tracing::debug!("Transport closed before data-in stream accepted"),
Err(e) => tracing::debug!("Failed to accept data-in stream: {e}"),
}
})
};

// Outgoing data task is NOT spawned here; the caller opens the uni stream
// and drives run_send_instructions / run_send_responses as appropriate for
// its role, consuming the receiver from DataChannels.

let tasks = ConnectionTasks {
incoming: incoming_handle,
control: control_handle,
};

Ok(ConnectResult::new(
Arc::clone(&transport),
channels,
remote_addr,
tasks,
Ok(super::client::spawn_client_tasks(
transport,
control_tx,
Some(handshake_rx),
Some(latency_rx),
control_rx,
self.peer_registry.clone(),
remote_addr,
))
}

Expand Down
Loading