diff --git a/packages/cubejs-backend-native/js/index.ts b/packages/cubejs-backend-native/js/index.ts index b5dbf3cdc4ad2..d6f5a919b0566 100644 --- a/packages/cubejs-backend-native/js/index.ts +++ b/packages/cubejs-backend-native/js/index.ts @@ -15,6 +15,8 @@ export interface BaseMeta { apiType: string, // Application name, for example Metabase appName?: string, + // Database name from the client startup message (e.g. psql dbname parameter) + database?: string, } export interface LoadRequestMeta extends BaseMeta { diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 75bd04812e53f..7db5fcd79f2dc 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -35,6 +35,8 @@ pub mod test_cube_join_grouped; #[cfg(test)] pub mod test_cube_scan; #[cfg(test)] +pub mod test_database_meta; +#[cfg(test)] pub mod test_df_execution; #[cfg(test)] pub mod test_filters; diff --git a/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs b/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs new file mode 100644 index 0000000000000..b22a532ec890a --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs @@ -0,0 +1,59 @@ +//! Tests that check database name propagation through LoadRequestMeta + +use pretty_assertions::assert_eq; + +use crate::compile::{ + test::{init_testing_logger, TestContext}, + DatabaseProtocol, Rewriter, +}; +use crate::transport::LoadRequestMeta; + +#[tokio::test] +async fn test_database_propagates_through_load_request_meta() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let context = TestContext::new(DatabaseProtocol::PostgreSQL).await; + + context + .execute_query( + // language=PostgreSQL + r#" +SELECT + COALESCE(customer_gender, 'N/A'), + AVG(avgPrice) +FROM + KibanaSampleDataEcommerce +WHERE + LOWER(customer_gender) = 'test' +GROUP BY 1 +; + "# + .to_string(), + ) + .await + .expect_err("Test transport does not support load with SQL"); + + let load_calls = context.load_calls().await; + assert_eq!(load_calls.len(), 1); + assert_eq!(load_calls[0].meta.database(), Some("cubedb".to_string())); +} + +#[test] +fn test_load_request_meta_database_serialization() { + let mut meta = LoadRequestMeta::new( + "postgres".to_string(), + "sql".to_string(), + Some("test-app".to_string()), + ); + + let json = serde_json::to_value(&meta).unwrap(); + assert!(json.get("database").is_none()); + + meta.set_database(Some("mydb".to_string())); + let json = serde_json::to_value(&meta).unwrap(); + assert_eq!(json["database"], "mydb"); + assert_eq!(meta.database(), Some("mydb".to_string())); +} diff --git a/rust/cubesql/cubesql/src/sql/session.rs b/rust/cubesql/cubesql/src/sql/session.rs index 610d4c77c5f19..eb2c3167bc70d 100644 --- a/rust/cubesql/cubesql/src/sql/session.rs +++ b/rust/cubesql/cubesql/src/sql/session.rs @@ -409,11 +409,13 @@ impl SessionState { None }; - LoadRequestMeta::new( + let mut meta = LoadRequestMeta::new( self.protocol.get_name().to_string(), api_type.to_string(), application_name, - ) + ); + meta.set_database(self.database()); + meta } } diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index 364306be5298d..f7e431fdbd682 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -51,6 +51,8 @@ pub struct LoadRequestMeta { // Optional fields #[serde(rename = "changeUser", skip_serializing_if = "Option::is_none")] change_user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + database: Option, } impl LoadRequestMeta { @@ -61,6 +63,7 @@ impl LoadRequestMeta { api_type, app_name, change_user: None, + database: None, } } @@ -71,6 +74,14 @@ impl LoadRequestMeta { pub fn set_change_user(&mut self, change_user: Option) { self.change_user = change_user; } + + pub fn database(&self) -> Option { + self.database.clone() + } + + pub fn set_database(&mut self, database: Option) { + self.database = database; + } } #[derive(Debug, Deserialize)]