diff --git a/Cargo.lock b/Cargo.lock index 4f21108cd..3612c56cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3693,6 +3693,7 @@ dependencies = [ name = "rust" version = "0.1.0" dependencies = [ + "bytes", "chrono", "futures-util", "libc", diff --git a/integration/rust/Cargo.toml b/integration/rust/Cargo.toml index ef1ab54d3..8c15eaf20 100644 --- a/integration/rust/Cargo.toml +++ b/integration/rust/Cargo.toml @@ -21,3 +21,4 @@ ordered-float = "4.2" tokio-rustls = "0.26" libc = "0.2" rand = "0.9" +bytes = "*" diff --git a/integration/rust/src/utils.rs b/integration/rust/src/utils.rs index c05185344..c1b10784d 100644 --- a/integration/rust/src/utils.rs +++ b/integration/rust/src/utils.rs @@ -1,5 +1,10 @@ use super::setup::admin_sqlx; +use bytes::{BufMut, Bytes, BytesMut}; use sqlx::{Executor, Row}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpStream, +}; pub async fn assert_setting_str(name: &str, expected: &str) { let admin = admin_sqlx().await; @@ -17,3 +22,91 @@ pub async fn assert_setting_str(name: &str, expected: &str) { assert!(found); } + +/// Standard Postgres protocol message. +#[derive(Debug, Clone)] +pub struct Message { + pub code: char, + pub payload: Bytes, +} + +impl Message { + pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result { + let code = stream.read_u8().await? as char; + let len = stream.read_i32().await?; + let mut payload = vec![0u8; len as usize - 4]; + stream.read_exact(&mut payload).await?; + + Ok(Self { + code, + payload: Bytes::from(payload), + }) + } + + pub async fn send(&self, stream: &mut (impl AsyncWrite + Unpin)) -> Result<(), std::io::Error> { + let mut payload = BytesMut::new(); + payload.put_u8(self.code as u8); + payload.put_i32(self.payload.len() as i32 + 4); + payload.put(self.payload.clone()); + + stream.write_all(&payload).await?; + stream.flush().await?; + + Ok(()) + } + + pub fn new_parse(name: &str, sql: &str) -> Self { + let mut payload = BytesMut::new(); + payload.put(name.as_bytes()); + payload.put_u8(0); + payload.put(sql.as_bytes()); + payload.put_u8(0); + payload.put_i16(0); + + Self { + payload: payload.freeze(), + code: 'P', + } + } + + pub fn new_flush() -> Self { + Self { + code: 'H', + payload: Bytes::new(), + } + } +} + +/// Create a startup message. +pub fn startup(user: &str, database: &str) -> Bytes { + let mut payload = BytesMut::new(); + payload.put_i32(196608); + payload.put_slice("user\0".as_bytes()); + payload.put_slice(user.as_bytes()); + payload.put_u8(0); + payload.put_slice("database\0".as_bytes()); + payload.put_slice(database.as_bytes()); + payload.put_u8(0); + payload.put_u8(0); + + let mut bytes = BytesMut::new(); + bytes.put_i32(payload.len() as i32); + bytes.put(payload); + + bytes.freeze() +} + +pub async fn connect() -> TcpStream { + let mut stream = TcpStream::connect("127.0.0.1:6432").await.unwrap(); + stream.write_all(&startup("pgdog", "pgdog")).await.unwrap(); + + loop { + let message = Message::read(&mut stream).await.unwrap(); + + if message.code == 'Z' { + break; + } + } + + stream +} diff --git a/integration/rust/tests/integration/mod.rs b/integration/rust/tests/integration/mod.rs index fb506ba72..cd9baa510 100644 --- a/integration/rust/tests/integration/mod.rs +++ b/integration/rust/tests/integration/mod.rs @@ -16,6 +16,7 @@ pub mod max; pub mod multi_set; pub mod notify; pub mod offset; +pub mod partial_req; pub mod per_stmt_routing; pub mod prepared; pub mod reload; diff --git a/integration/rust/tests/integration/partial_req.rs b/integration/rust/tests/integration/partial_req.rs new file mode 100644 index 000000000..94f2a4929 --- /dev/null +++ b/integration/rust/tests/integration/partial_req.rs @@ -0,0 +1,63 @@ +use std::time::Duration; + +use rust::{ + setup::{admin_sqlx, connection_sqlx_direct}, + utils::{Message, connect}, +}; +use sqlx::{Executor, Pool, Postgres}; +use tokio::{spawn, time::sleep}; + +#[tokio::test] +async fn test_partial_request_disconnect() { + let admin = admin_sqlx().await; + let direct = connection_sqlx_direct().await; + + admin.execute("SET auth_type TO 'trust'").await.unwrap(); + + multiple_clients!( + { + let mut stream = connect().await; + + Message::new_parse("test", "SELECT $1") + .send(&mut stream) + .await + .unwrap(); + // Message::new_flush().send(&mut stream).await.unwrap(); + + drop(stream); + }, + 50 + ); + + sleep(Duration::from_millis(100)).await; + + let acr = active_client_read(&direct).await; + assert_eq!(acr, 0); +} + +macro_rules! multiple_clients { + ($code:block, $times:expr) => {{ + let mut handles = vec![]; + + for _ in 0..$times { + let handle = spawn(async move { + $code; + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + }}; +} + +use multiple_clients; + +async fn active_client_read(pool: &Pool) -> i32 { + let count: i32 = sqlx::query_scalar( + "SELECT COUNT(*)::integer FROM pg_stat_activity WHERE state != 'idle' AND wait_event = 'ClientRead'" + ).fetch_one(pool).await.unwrap(); + + count +} diff --git a/pgdog/src/frontend/client/mod.rs b/pgdog/src/frontend/client/mod.rs index b868a6c1d..cd6fb5f9a 100644 --- a/pgdog/src/frontend/client/mod.rs +++ b/pgdog/src/frontend/client/mod.rs @@ -418,8 +418,12 @@ impl Client { buffer = self.buffer(client_state) => { let event = buffer?; - if !self.client_request.messages.is_empty() { - self.client_messages(&mut query_engine).await?; + + // Only send requests to the backend if they are complete. + if self.client_request.is_complete() { + if !self.client_request.messages.is_empty() { + self.client_messages(&mut query_engine).await?; + } } match event {