Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions integration/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ ordered-float = "4.2"
tokio-rustls = "0.26"
libc = "0.2"
rand = "0.9"
bytes = "*"
93 changes: 93 additions & 0 deletions integration/rust/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use super::setup::admin_sqlx;
use bytes::{BufMut, Bytes, BytesMut};
use sqlx::{Executor, Row};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream,
};

pub async fn assert_setting_str(name: &str, expected: &str) {
let admin = admin_sqlx().await;
Expand All @@ -17,3 +22,91 @@ pub async fn assert_setting_str(name: &str, expected: &str) {

assert!(found);
}

/// Standard Postgres protocol message.
#[derive(Debug, Clone)]
pub struct Message {
pub code: char,
pub payload: Bytes,
}

impl Message {
pub async fn read(stream: &mut (impl AsyncRead + Unpin)) -> Result<Self, std::io::Error> {
let code = stream.read_u8().await? as char;
let len = stream.read_i32().await?;
let mut payload = vec![0u8; len as usize - 4];
stream.read_exact(&mut payload).await?;

Ok(Self {
code,
payload: Bytes::from(payload),
})
}

pub async fn send(&self, stream: &mut (impl AsyncWrite + Unpin)) -> Result<(), std::io::Error> {
let mut payload = BytesMut::new();
payload.put_u8(self.code as u8);
payload.put_i32(self.payload.len() as i32 + 4);
payload.put(self.payload.clone());

stream.write_all(&payload).await?;
stream.flush().await?;

Ok(())
}

pub fn new_parse(name: &str, sql: &str) -> Self {
let mut payload = BytesMut::new();
payload.put(name.as_bytes());
payload.put_u8(0);
payload.put(sql.as_bytes());
payload.put_u8(0);
payload.put_i16(0);

Self {
payload: payload.freeze(),
code: 'P',
}
}

pub fn new_flush() -> Self {
Self {
code: 'H',
payload: Bytes::new(),
}
}
}

/// Create a startup message.
pub fn startup(user: &str, database: &str) -> Bytes {
let mut payload = BytesMut::new();
payload.put_i32(196608);
payload.put_slice("user\0".as_bytes());
payload.put_slice(user.as_bytes());
payload.put_u8(0);
payload.put_slice("database\0".as_bytes());
payload.put_slice(database.as_bytes());
payload.put_u8(0);
payload.put_u8(0);

let mut bytes = BytesMut::new();
bytes.put_i32(payload.len() as i32);
bytes.put(payload);

bytes.freeze()
}

pub async fn connect() -> TcpStream {
let mut stream = TcpStream::connect("127.0.0.1:6432").await.unwrap();
stream.write_all(&startup("pgdog", "pgdog")).await.unwrap();

loop {
let message = Message::read(&mut stream).await.unwrap();

if message.code == 'Z' {
break;
}
}

stream
}
1 change: 1 addition & 0 deletions integration/rust/tests/integration/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod max;
pub mod multi_set;
pub mod notify;
pub mod offset;
pub mod partial_req;
pub mod per_stmt_routing;
pub mod prepared;
pub mod reload;
Expand Down
63 changes: 63 additions & 0 deletions integration/rust/tests/integration/partial_req.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::time::Duration;

use rust::{
setup::{admin_sqlx, connection_sqlx_direct},
utils::{Message, connect},
};
use sqlx::{Executor, Pool, Postgres};
use tokio::{spawn, time::sleep};

#[tokio::test]
async fn test_partial_request_disconnect() {
let admin = admin_sqlx().await;
let direct = connection_sqlx_direct().await;

admin.execute("SET auth_type TO 'trust'").await.unwrap();

multiple_clients!(
{
let mut stream = connect().await;

Message::new_parse("test", "SELECT $1")
.send(&mut stream)
.await
.unwrap();
// Message::new_flush().send(&mut stream).await.unwrap();

drop(stream);
},
50
);

sleep(Duration::from_millis(100)).await;

let acr = active_client_read(&direct).await;
assert_eq!(acr, 0);
}

macro_rules! multiple_clients {
($code:block, $times:expr) => {{
let mut handles = vec![];

for _ in 0..$times {
let handle = spawn(async move {
$code;
});
handles.push(handle);
}

for handle in handles {
handle.await.unwrap();
}
}};
}

use multiple_clients;

async fn active_client_read(pool: &Pool<Postgres>) -> i32 {
let count: i32 = sqlx::query_scalar(
"SELECT COUNT(*)::integer FROM pg_stat_activity WHERE state != 'idle' AND wait_event = 'ClientRead'"
).fetch_one(pool).await.unwrap();

count
}
8 changes: 6 additions & 2 deletions pgdog/src/frontend/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,12 @@ impl Client {

buffer = self.buffer(client_state) => {
let event = buffer?;
if !self.client_request.messages.is_empty() {
self.client_messages(&mut query_engine).await?;

// Only send requests to the backend if they are complete.
if self.client_request.is_complete() {
if !self.client_request.messages.is_empty() {
self.client_messages(&mut query_engine).await?;
}
}

match event {
Expand Down
Loading