From a455d0a8cd64d27468b9b1e88c383a7775a2304c Mon Sep 17 00:00:00 2001 From: Kurt Volmar <39112522+kurtvolmar@users.noreply.github.com> Date: Wed, 14 Jan 2026 12:01:17 +0200 Subject: [PATCH] implement extensibility with NetworkBoundaryStategy --- Cargo.toml | 5 + src/distributed_ext.rs | 55 ++ src/distributed_planner/distributed_config.rs | 16 + .../distributed_physical_optimizer_rule.rs | 13 + src/distributed_planner/mod.rs | 12 + .../network_boundary_strategy.rs | 622 ++++++++++++++++++ src/distributed_planner/plan_annotator.rs | 50 +- src/flight_service/do_get.rs | 2 +- src/flight_service/mod.rs | 2 + src/lib.rs | 13 +- src/protobuf/mod.rs | 17 +- 11 files changed, 790 insertions(+), 17 deletions(-) create mode 100644 src/distributed_planner/network_boundary_strategy.rs diff --git a/Cargo.toml b/Cargo.toml index 159823d0..6630e8dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,3 +86,8 @@ pretty_assertions = "1.4" reqwest = "0.12" zip = "6.0" test-case = "3.3.1" + +# @NetworkBoundaryStrategy: custom execution plan flag +[[example]] +name = "custom_execution_plan" +required-features = ["integration"] diff --git a/src/distributed_ext.rs b/src/distributed_ext.rs index 7efea431..fbefb15b 100644 --- a/src/distributed_ext.rs +++ b/src/distributed_ext.rs @@ -16,6 +16,9 @@ use delegate::delegate; use http::HeaderMap; use std::sync::Arc; +// @NetworkBoundaryStrategy: separte import for set_distributed_network_boundary_strategy +use crate::distributed_planner::set_distributed_network_boundary_strategy; + /// Extends DataFusion with distributed capabilities. pub trait DistributedExt: Sized { /// Adds the provided [ConfigExtension] to the distributed context. The [ConfigExtension] will @@ -325,6 +328,25 @@ pub trait DistributedExt: Sized { estimator: T, ); + // @NetworkBoundaryStrategy: trait method to register a custom network boundary strategy. + /// Adds a distributed network boundary strategy. [NetworkBoundaryStrategy]s are executed on each node + /// sequentially until one returns an annotation for a network boundary. + + fn with_distributed_network_boundary_strategy< + T: crate::distributed_planner::NetworkBoundaryStrategy + 'static, + >( + self, + strategy: T, + ) -> Self; + + /// Same as [DistributedExt::with_distributed_network_boundary_strategy] but with an in-place mutation. + fn set_distributed_network_boundary_strategy< + T: crate::distributed_planner::NetworkBoundaryStrategy + 'static, + >( + &mut self, + strategy: T, + ); + /// Sets the maximum number of files each task in a stage with a FileScanConfig node will /// handle. Reducing this number will increment the amount of tasks. By default, this /// is close to the number of cores in the machine. @@ -564,6 +586,16 @@ impl DistributedExt for SessionConfig { set_distributed_task_estimator(self, estimator) } + // @NetworkBoundaryStrategy: SessionConfig impl of set_distributed_network_boundary_strategy. + fn set_distributed_network_boundary_strategy< + T: crate::distributed_planner::NetworkBoundaryStrategy + 'static, + >( + &mut self, + strategy: T, + ) { + set_distributed_network_boundary_strategy(self, strategy) + } + fn set_distributed_files_per_task( &mut self, files_per_task: usize, @@ -662,6 +694,11 @@ impl DistributedExt for SessionConfig { #[expr($;self)] fn with_distributed_task_estimator(mut self, estimator: T) -> Self; + // @NetworkBoundaryStrategy: SessionStateBuilder delegate for with_distributed_network_boundary_strategy. + #[call(set_distributed_network_boundary_strategy)] + #[expr($;self)] + fn with_distributed_network_boundary_strategy(mut self, strategy: T) -> Self; + #[call(set_distributed_files_per_task)] #[expr($?;Ok(self))] fn with_distributed_files_per_task(mut self, files_per_task: usize) -> Result; @@ -735,6 +772,12 @@ impl DistributedExt for SessionStateBuilder { #[expr($;self)] fn with_distributed_task_estimator(mut self, estimator: T) -> Self; + // @NetworkBoundaryStrategy: SessionStateBuilder delegate for with_distributed_network_boundary_strategy. + fn set_distributed_network_boundary_strategy(&mut self, strategy: T); + #[call(set_distributed_network_boundary_strategy)] + #[expr($;self)] + fn with_distributed_network_boundary_strategy(mut self, strategy: T) -> Self; + fn set_distributed_files_per_task(&mut self, files_per_task: usize) -> Result<(), DataFusionError>; #[call(set_distributed_files_per_task)] #[expr($?;Ok(self))] @@ -816,6 +859,12 @@ impl DistributedExt for SessionState { #[expr($;self)] fn with_distributed_task_estimator(mut self, estimator: T) -> Self; + // @NetworkBoundaryStrategy: SessionState delegate for with_distributed_network_boundary_strategy. + fn set_distributed_network_boundary_strategy(&mut self, strategy: T); + #[call(set_distributed_network_boundary_strategy)] + #[expr($;self)] + fn with_distributed_network_boundary_strategy(mut self, strategy: T) -> Self; + fn set_distributed_files_per_task(&mut self, files_per_task: usize) -> Result<(), DataFusionError>; #[call(set_distributed_files_per_task)] #[expr($?;Ok(self))] @@ -897,6 +946,12 @@ impl DistributedExt for SessionContext { #[expr($;self)] fn with_distributed_task_estimator(self, estimator: T) -> Self; + // @NetworkBoundaryStrategy: SessionContext delegate for with_distributed_network_boundary_strategy. + fn set_distributed_network_boundary_strategy(&mut self, strategy: T); + #[call(set_distributed_network_boundary_strategy)] + #[expr($;self)] + fn with_distributed_network_boundary_strategy(self, strategy: T) -> Self; + fn set_distributed_files_per_task(&mut self, files_per_task: usize) -> Result<(), DataFusionError>; #[call(set_distributed_files_per_task)] #[expr($?;Ok(self))] diff --git a/src/distributed_planner/distributed_config.rs b/src/distributed_planner/distributed_config.rs index f3c3f8c6..e7ca7e55 100644 --- a/src/distributed_planner/distributed_config.rs +++ b/src/distributed_planner/distributed_config.rs @@ -57,6 +57,9 @@ extensions_options! { /// [WorkerResolver] implementation that tells the distributed planner information about /// the available workers ready to execute distributed tasks. pub(crate) __private_worker_resolver: WorkerResolverExtension, default = WorkerResolverExtension::not_implemented() + /// @NetworkBoundaryStrategy: Collection of [NetworkBoundaryStrategy]s that will be applied to plan nodes to + /// determine if a network boundary is needed and what type it should be. + pub(crate) __private_network_boundary_strategy: crate::distributed_planner::network_boundary_strategy::CombinedNetworkBoundaryStrategy, default = crate::distributed_planner::network_boundary_strategy::CombinedNetworkBoundaryStrategy::default() } } @@ -166,3 +169,16 @@ impl Debug for CombinedTaskEstimator { write!(f, "TaskEstimators") } } + +// @NetworkBoundaryStrategy: ConfigField impl required so CombinedNetworkBoundaryStrategy can be stored in ConfigOptions extensions. +impl ConfigField + for crate::distributed_planner::network_boundary_strategy::CombinedNetworkBoundaryStrategy +{ + fn visit(&self, _: &mut V, _: &str, _: &'static str) { + // nothing to do. + } + + fn set(&mut self, _: &str, _: &str) -> datafusion::common::Result<()> { + not_impl_err!("Not implemented") + } +} diff --git a/src/distributed_planner/distributed_physical_optimizer_rule.rs b/src/distributed_planner/distributed_physical_optimizer_rule.rs index 1f31e171..f81ee5ba 100644 --- a/src/distributed_planner/distributed_physical_optimizer_rule.rs +++ b/src/distributed_planner/distributed_physical_optimizer_rule.rs @@ -164,6 +164,19 @@ fn distribute_plan( stage_id.add_assign(1); Ok(node) } + // @NetworkBoundaryStrategy: Extension boundaries are handled by custom NetworkBoundaryStrategy implementations. + PlanOrNetworkBoundary::Extension(_) => { + crate::distributed_planner::network_boundary_strategy::apply_extension_boundary( + d_cfg, + &annotated_plan.plan_or_nb, + new_children, + query_id, + stage_id, + task_count, + max_child_task_count, + cfg, + ) + } } } diff --git a/src/distributed_planner/mod.rs b/src/distributed_planner/mod.rs index 2f9b2895..a56e60db 100644 --- a/src/distributed_planner/mod.rs +++ b/src/distributed_planner/mod.rs @@ -12,3 +12,15 @@ pub use distributed_physical_optimizer_rule::DistributedPhysicalOptimizerRule; pub use network_boundary::{NetworkBoundary, NetworkBoundaryExt}; pub(crate) use task_estimator::set_distributed_task_estimator; pub use task_estimator::{TaskCountAnnotation, TaskEstimation, TaskEstimator}; + +// @NetworkBoundaryStrategy: new module and re-exports for pluggable network boundary strategies. +mod network_boundary_strategy; + +pub(crate) use network_boundary_strategy::set_distributed_network_boundary_strategy; + +#[rustfmt::skip] +pub use network_boundary_strategy::{ + CombinedNetworkBoundaryStrategy, NetworkBoundaryAnnotation, NetworkBoundaryContext, + NetworkBoundaryStrategy, +}; +pub use plan_annotator::PlanOrNetworkBoundary; diff --git a/src/distributed_planner/network_boundary_strategy.rs b/src/distributed_planner/network_boundary_strategy.rs new file mode 100644 index 00000000..a4121091 --- /dev/null +++ b/src/distributed_planner/network_boundary_strategy.rs @@ -0,0 +1,622 @@ +// @NetworkBoundaryStrategy: new module for pluggable network boundary strategies so custom strategies can override or extend default boundary placement. +use crate::DistributedConfig; +use crate::common::require_one_child; +use crate::distributed_planner::plan_annotator::PlanOrNetworkBoundary; +use datafusion::common::plan_err; +use datafusion::config::ConfigOptions; +use datafusion::error::{DataFusionError, Result}; +use datafusion::physical_plan::ExecutionPlan; +use std::fmt::Debug; +use std::ops::AddAssign; +use std::sync::Arc; +use uuid::Uuid; + +/// Annotation metadata about a network boundary for a plan node. +/// +/// Returned by [`NetworkBoundaryStrategy::annotate_network_boundary`] to describe +/// what kind of network boundary (if any) is needed and optional hints about +/// the output task count. +#[derive(Debug, Clone)] +pub struct NetworkBoundaryAnnotation { + /// The type of network boundary required (Shuffle, Coalesce, Broadcast, Extension). + /// If None, no network boundary is needed for this plan node. + pub required_network_boundary: Option, + + /// Optional hint for the output task count after this boundary is applied. + /// + /// - If `None`, DFD will calculate the output task count using default cardinality scaling. + /// - If `Some(n)`, DFD will use `n` as the output task count for this stage. + /// + /// This allows strategies which know their output task count to override the generic + /// cardinality-based calculation in the plan annotator. + pub output_tasks: Option, +} + +/// Context provided to network boundary strategies when deciding how to place boundaries. +/// +/// This struct contains all the information a strategy needs to make decisions about +/// how to transform the plan at a network boundary point. +#[derive(Debug)] +pub struct NetworkBoundaryContext<'a> { + /// The type of network boundary required at this point. + pub boundary_type: &'a PlanOrNetworkBoundary, + /// The already-distributed children to be wrapped by the network boundary. + pub new_children: Arc, + /// The query ID for this execution. + pub query_id: Uuid, + /// The stage ID for the boundary being created. + pub stage_id: usize, + /// Number of tasks in the current stage (above the boundary). + pub task_count: usize, + /// Number of tasks in the input stage (below the boundary). + pub input_task_count: usize, + /// The DataFusion configuration options. + pub config: &'a ConfigOptions, +} + +/// Strategy for placing network boundaries in a distributed execution plan. +/// +/// When a network boundary is needed (e.g., after hash repartition or before coalesce), +/// strategies are invoked in order. The first strategy to return annotation with a boundary wins. +/// +/// Strategies should return `None` to defer to the next strategy in the chain. +/// Custom strategies can be registered to override default behavior. +pub trait NetworkBoundaryStrategy: Debug + Send + Sync { + /// Annotates a plan node with network boundary metadata. + /// + /// Returns `Some(NetworkBoundaryAnnotation)` if this strategy detects a boundary is needed, + /// or `None` to defer to the next strategy. + /// + /// The annotation can optionally include an `output_tasks` hint to override DFD's + /// default task count calculation. + fn annotate_network_boundary( + &self, + plan: &dyn ExecutionPlan, + ) -> Option; + + /// Apply this strategy to place a network boundary. Return `Ok(None)` to defer to next strategy. + fn apply_boundary( + &self, + context: &NetworkBoundaryContext<'_>, + ) -> Result>>; +} + +/// Combines multiple [`NetworkBoundaryStrategy`] implementations. +/// +/// Strategies are tried in order for both annotation and boundary application. +/// The first strategy to return an annotation with a boundary will be used. +#[derive(Clone)] +pub struct CombinedNetworkBoundaryStrategy { + pub(crate) strategies: Vec>, +} + +impl From>> for CombinedNetworkBoundaryStrategy { + fn from(strategies: Vec>) -> Self { + Self { strategies } + } +} + +impl Default for CombinedNetworkBoundaryStrategy { + fn default() -> Self { + Self { strategies: vec![] } + } +} + +impl Debug for CombinedNetworkBoundaryStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CombinedNetworkBoundaryStrategy") + .field("strategies_count", &self.strategies.len()) + .finish() + } +} + +impl NetworkBoundaryStrategy for CombinedNetworkBoundaryStrategy { + fn annotate_network_boundary( + &self, + plan: &dyn ExecutionPlan, + ) -> Option { + for strategy in &self.strategies { + if let Some(annotation) = strategy.annotate_network_boundary(plan) { + return Some(annotation); + } + } + None + } + + fn apply_boundary( + &self, + context: &NetworkBoundaryContext<'_>, + ) -> Result>> { + for strategy in &self.strategies { + if let Some(result) = strategy.apply_boundary(context)? { + return Ok(Some(result)); + } + } + Ok(None) + } +} + +/// Applies an extension network boundary by building context and delegating to the configured +/// strategy. Used by the distributed physical optimizer rule when it encounters +/// `PlanOrNetworkBoundary::Extension`. +pub(crate) fn apply_extension_boundary( + d_cfg: &DistributedConfig, + boundary_type: &PlanOrNetworkBoundary, + new_children: Vec>, + query_id: Uuid, + stage_id: &mut usize, + task_count: usize, + max_child_task_count: Option, + cfg: &ConfigOptions, +) -> Result, DataFusionError> { + let context = NetworkBoundaryContext { + boundary_type, + new_children: require_one_child(new_children)?, + query_id, + stage_id: *stage_id, + task_count, + input_task_count: max_child_task_count.unwrap_or(1), + config: cfg, + }; + match d_cfg + .__private_network_boundary_strategy + .apply_boundary(&context)? + { + Some(custom_plan) => { + // TODO: revisit the stage incrementation, since a strategy can insert 0 to many stages. + stage_id.add_assign(1); + Ok(custom_plan) + } + None => plan_err!("Extension boundary not handled by any strategy"), + } +} + +/// Runs network boundary strategies after default detection. Returns an optional boundary type +/// and an optional task count. When the task count is Some, the annotator should use it and return early. +pub(crate) fn apply_network_boundary_strategy( + d_cfg: &DistributedConfig, + plan: &Arc, +) -> ( + Option, + Option, +) { + let strategy_annotation = d_cfg + .__private_network_boundary_strategy + .annotate_network_boundary(plan.as_ref()); + let boundary = strategy_annotation + .as_ref() + .and_then(|a| a.required_network_boundary.clone()); + let task_count = + strategy_annotation.and_then(|a| a.output_tasks.map(crate::TaskCountAnnotation::Desired)); + (boundary, task_count) +} + +/// Helper function to add a network boundary strategy to the session config. +/// This is used by the DistributedExt trait implementation. +pub(crate) fn set_distributed_network_boundary_strategy( + cfg: &mut datafusion::prelude::SessionConfig, + strategy: impl NetworkBoundaryStrategy + 'static, +) { + use crate::config_extension_ext::set_distributed_option_extension; + use crate::distributed_planner::DistributedConfig; + + let opts = cfg.options_mut(); + if let Some(distributed_cfg) = opts.extensions.get_mut::() { + distributed_cfg + .__private_network_boundary_strategy + .strategies + .push(Arc::new(strategy)); + } else { + let mut combined = CombinedNetworkBoundaryStrategy::default(); + combined.strategies.push(Arc::new(strategy)); + set_distributed_option_extension( + cfg, + DistributedConfig { + __private_network_boundary_strategy: combined, + ..Default::default() + }, + ); + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::distributed_planner::insert_broadcast::insert_broadcast_execs; + use crate::distributed_planner::plan_annotator::{PlanOrNetworkBoundary, annotate_plan}; + use crate::test_utils::in_memory_channel_resolver::InMemoryWorkerResolver; + use crate::test_utils::plans::{TestPlanOptions, base_session_builder, context_with_query}; + use crate::{ + DistributedConfig, DistributedExt, DistributedPhysicalOptimizerRule, assert_snapshot, + display_plan_ascii, + }; + use datafusion::arrow::datatypes::{DataType, Field, Schema}; + use datafusion::execution::SessionStateBuilder; + use datafusion::physical_plan::displayable; + use datafusion::physical_plan::empty::EmptyExec; + + #[test] + fn test_combined_network_strategy_first_strategy_wins() { + use datafusion::physical_plan::Partitioning; + use datafusion::physical_plan::repartition::RepartitionExec; + + let mut combined = CombinedNetworkBoundaryStrategy::default(); + // Add two PassthroughStrategy instances with different extension names + combined.strategies.insert( + 0, + Arc::new(PassthroughStrategy::new( + |plan| { + plan.as_any() + .downcast_ref::() + .map(|repartition| { + matches!(repartition.partitioning(), Partitioning::Hash(_, _)) + }) + .unwrap_or(false) + }, + "wrap_hash_repartition_0", + 3, + )), + ); + combined.strategies.insert( + 1, + Arc::new(PassthroughStrategy::new( + |plan| { + plan.as_any() + .downcast_ref::() + .map(|repartition| { + matches!(repartition.partitioning(), Partitioning::Hash(_, _)) + }) + .unwrap_or(false) + }, + "wrap_hash_repartition_1", + 3, + )), + ); + + // Test with a Hash RepartitionExec plan + let plan = hash_repartition_plan(); + let result = combined.annotate_network_boundary(plan.as_ref()); + + // First strategy should win and return Extension("wrap_hash_repartition_0") + assert!(matches!( + result + .as_ref() + .and_then(|a| a.required_network_boundary.as_ref()), + Some(PlanOrNetworkBoundary::Extension("wrap_hash_repartition_0")) + )); + assert_eq!(result.as_ref().and_then(|a| a.output_tasks), Some(3)); + } + + #[test] + fn test_combined_network_strategy_continues_until_match() { + use datafusion::physical_plan::Partitioning; + use datafusion::physical_plan::repartition::RepartitionExec; + + let mut combined = CombinedNetworkBoundaryStrategy::default(); + // First strategy matches EmptyExec (won't match Hash RepartitionExec) + combined.strategies.insert( + 0, + Arc::new(PassthroughStrategy::new( + |plan| plan.as_any().is::(), + "wrap_empty_exec", + 3, + )), + ); + // Second strategy matches Hash RepartitionExec + combined.strategies.insert( + 1, + Arc::new(PassthroughStrategy::new( + |plan| { + plan.as_any() + .downcast_ref::() + .map(|repartition| { + matches!(repartition.partitioning(), Partitioning::Hash(_, _)) + }) + .unwrap_or(false) + }, + "wrap_hash_repartition", + 3, + )), + ); + + // Test with Hash RepartitionExec - first strategy won't match, second strategy should match + let plan = hash_repartition_plan(); + let result = combined.annotate_network_boundary(plan.as_ref()); + assert!(matches!( + result + .as_ref() + .and_then(|a| a.required_network_boundary.as_ref()), + Some(PlanOrNetworkBoundary::Extension("wrap_hash_repartition")) + )); + } + + #[tokio::test] + async fn test_extension_boundary_strategy() { + use datafusion::physical_plan::Partitioning; + use datafusion::physical_plan::repartition::RepartitionExec; + + let query = r#" + SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*) + "#; + let annotated = annotate_test_plan(query, TestPlanOptions::default(), |b| { + let mut state = b.build(); + let config = state.config_mut(); + let d_cfg = DistributedConfig::from_config_options_mut(config.options_mut()).unwrap(); + d_cfg.__private_network_boundary_strategy.strategies = + vec![Arc::new(PassthroughStrategy::new( + |plan| { + plan.as_any() + .downcast_ref::() + .map(|repartition| { + matches!(repartition.partitioning(), Partitioning::Hash(_, _)) + }) + .unwrap_or(false) + }, + "wrap_hash_repartition", + 3, + )) as Arc]; + SessionStateBuilder::new_from_existing(state) + }) + .await; + assert_snapshot!(annotated, @r" + ProjectionExec: task_count=Maximum(1) + SortPreservingMergeExec: task_count=Maximum(1) + [NetworkBoundary] Coalesce: task_count=Maximum(1) + SortExec: task_count=Desired(3) + ProjectionExec: task_count=Desired(3) + AggregateExec: task_count=Desired(3) + [NetworkBoundary] Extension(wrap_hash_repartition): task_count=Desired(3) + RepartitionExec: task_count=Desired(3) + AggregateExec: task_count=Desired(3) + DataSourceExec: task_count=Desired(3) + ") + } + + #[tokio::test] + async fn test_extension_boundary_with_custom_strategy() { + use datafusion::physical_plan::Partitioning; + use datafusion::physical_plan::repartition::RepartitionExec; + + let query = r#" + SELECT count(*), "RainToday" FROM weather GROUP BY "RainToday" ORDER BY count(*) + "#; + let plan = explain_test_plan(query, TestPlanOptions::default(), true, |b| { + let mut state = b.build(); + let config = state.config_mut(); + let d_cfg = DistributedConfig::from_config_options_mut(config.options_mut()).unwrap(); + d_cfg.__private_network_boundary_strategy.strategies = + vec![Arc::new(PassthroughStrategy::new( + |plan| { + plan.as_any() + .downcast_ref::() + .map(|repartition| { + matches!(repartition.partitioning(), Partitioning::Hash(_, _)) + }) + .unwrap_or(false) + }, + "wrap_hash_repartition", + 3, + )) as Arc]; + SessionStateBuilder::new_from_existing(state) + .with_distributed_worker_resolver(InMemoryWorkerResolver::new(3)) + }) + .await; + + assert_snapshot!(plan, @r" + ┌───── DistributedExec ── Tasks: t0:[p0] + │ ProjectionExec: expr=[count(*)@0 as count(*), RainToday@1 as RainToday] + │ SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] + │ [Stage 2] => NetworkCoalesceExec: output_partitions=12, input_tasks=3 + └────────────────────────────────────────────────── + ┌───── Stage 2 ── Tasks: t0:[p0..p3] t1:[p0..p3] t2:[p0..p3] + │ SortExec: expr=[count(*)@0 ASC NULLS LAST], preserve_partitioning=[true] + │ ProjectionExec: expr=[count(Int64(1))@1 as count(*), RainToday@0 as RainToday, count(Int64(1))@1 as count(Int64(1))] + │ AggregateExec: mode=FinalPartitioned, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │ PassthroughExec + │ RepartitionExec: partitioning=Hash([RainToday@0], 4), input_partitions=1 + │ AggregateExec: mode=Partial, gby=[RainToday@0 as RainToday], aggr=[count(Int64(1))] + │ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0] + │ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[RainToday], file_type=parquet + └────────────────────────────────────────────────── + "); + } + + // --- Helpers (duplicated from plan_annotator and distributed_physical_optimizer_rule + // test modules to avoid tail-of-file merge conflicts when upstream adds tests there) --- + + fn hash_repartition_plan() -> Arc { + use datafusion::physical_expr::expressions::Column as PhysicalColumn; + use datafusion::physical_plan::Partitioning; + use datafusion::physical_plan::repartition::RepartitionExec; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let child = Arc::new(EmptyExec::new(schema.clone())); + let partitioning = Partitioning::Hash( + vec![Arc::new(PhysicalColumn::new("a", 0)) + as Arc], + 4, + ); + Arc::new(RepartitionExec::try_new(child, partitioning).unwrap()) + } + + /// Test strategy that wraps specific ExecutionPlan nodes with a PassthroughExec. + /// Useful for testing custom boundary insertion without transformation. + #[derive(Clone)] + pub struct PassthroughStrategy { + /// Function to check if a plan matches the target type + matcher: Arc bool + Send + Sync>, + /// Extension name to use for the boundary + extension_name: &'static str, + /// Number of output tasks + output_tasks: usize, + } + + impl std::fmt::Debug for PassthroughStrategy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PassthroughStrategy") + .field("extension_name", &self.extension_name) + .field("output_tasks", &self.output_tasks) + .finish() + } + } + + impl PassthroughStrategy { + /// Creates a new PassthroughStrategy with a custom matcher function + pub fn new(matcher: F, extension_name: &'static str, output_tasks: usize) -> Self + where + F: Fn(&dyn ExecutionPlan) -> bool + Send + Sync + 'static, + { + Self { + matcher: Arc::new(matcher), + extension_name, + output_tasks, + } + } + } + + impl NetworkBoundaryStrategy for PassthroughStrategy { + fn annotate_network_boundary( + &self, + plan: &dyn ExecutionPlan, + ) -> Option { + if (self.matcher)(plan) { + Some(NetworkBoundaryAnnotation { + required_network_boundary: Some(PlanOrNetworkBoundary::Extension( + self.extension_name, + )), + output_tasks: Some(self.output_tasks), + }) + } else { + None + } + } + + fn apply_boundary( + &self, + context: &NetworkBoundaryContext<'_>, + ) -> Result>> { + if let PlanOrNetworkBoundary::Extension(name) = context.boundary_type { + if *name == self.extension_name { + // Wrap with PassthroughExec to demonstrate custom boundary insertion + return Ok(Some(Arc::new(PassthroughExec::new( + context.new_children.clone(), + )))); + } + } + Ok(None) + } + } + + /// A no-op ExecutionPlan that passes through data from its child unchanged. + /// Useful for testing scenarios where you need a wrapper node that doesn't transform data. + #[derive(Debug)] + pub struct PassthroughExec { + child: Arc, + } + + impl PassthroughExec { + pub fn new(child: Arc) -> Self { + Self { child } + } + } + + impl ExecutionPlan for PassthroughExec { + fn name(&self) -> &str { + "PassthroughExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + self.child.properties() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self::new(children[0].clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + // Simply delegate to the child + self.child.execute(partition, context) + } + } + + impl std::fmt::Display for PassthroughExec { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "PassthroughExec") + } + } + + impl datafusion::physical_plan::DisplayAs for PassthroughExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "PassthroughExec") + } + } + + async fn annotate_test_plan( + query: &str, + options: TestPlanOptions, + configure: impl FnOnce(SessionStateBuilder) -> SessionStateBuilder, + ) -> String { + let builder = base_session_builder( + options.target_partitions, + options.num_workers, + options.broadcast_enabled, + ); + let builder = configure(builder); + let (ctx, query) = context_with_query(builder, query).await; + let df = ctx.sql(&query).await.unwrap(); + let mut plan = df.create_physical_plan().await.unwrap(); + + plan = insert_broadcast_execs(plan, ctx.state_ref().read().config_options().as_ref()) + .expect("failed to insert broadcasts"); + + let annotated = annotate_plan(plan, ctx.state_ref().read().config_options().as_ref()) + .expect("failed to annotate plan"); + format!("{annotated:?}") + } + + async fn explain_test_plan( + query: &str, + options: TestPlanOptions, + use_optimizer: bool, + configure: impl FnOnce(SessionStateBuilder) -> SessionStateBuilder, + ) -> String { + let mut builder = base_session_builder( + options.target_partitions, + options.num_workers, + options.broadcast_enabled, + ); + if use_optimizer { + builder = + builder.with_physical_optimizer_rule(Arc::new(DistributedPhysicalOptimizerRule)); + } + let builder = configure(builder); + let (ctx, query) = context_with_query(builder, query).await; + let df = ctx.sql(&query).await.unwrap(); + let physical_plan = df.create_physical_plan().await.unwrap(); + + if use_optimizer { + display_plan_ascii(physical_plan.as_ref(), false) + } else { + format!("{}", displayable(physical_plan.as_ref()).indent(true)) + } + } +} diff --git a/src/distributed_planner/plan_annotator.rs b/src/distributed_planner/plan_annotator.rs index f54b4271..f9d07c50 100644 --- a/src/distributed_planner/plan_annotator.rs +++ b/src/distributed_planner/plan_annotator.rs @@ -16,11 +16,14 @@ use std::sync::Arc; /// Annotation attached to a single [ExecutionPlan] that determines the kind of network boundary /// needed just below itself. -pub(super) enum PlanOrNetworkBoundary { +// @NetworkBoundaryStrategy: Added Clone derive, made pub, added Extension variant +#[derive(Clone)] +pub enum PlanOrNetworkBoundary { Plan(Arc), Shuffle, Coalesce, Broadcast, + Extension(&'static str), } impl Debug for PlanOrNetworkBoundary { @@ -30,13 +33,18 @@ impl Debug for PlanOrNetworkBoundary { Self::Shuffle => write!(f, "[NetworkBoundary] Shuffle"), Self::Coalesce => write!(f, "[NetworkBoundary] Coalesce"), Self::Broadcast => write!(f, "[NetworkBoundary] Broadcast"), + Self::Extension(name) => write!(f, "[NetworkBoundary] Extension({})", name), // @NetworkBoundaryStrategy: Extension variant } } } impl PlanOrNetworkBoundary { - fn is_network_boundary(&self) -> bool { - matches!(self, Self::Shuffle | Self::Coalesce | Self::Broadcast) + // @NetworkBoundaryStrategy: Made pub(super) and added Extension to matches + pub(super) fn is_network_boundary(&self) -> bool { + matches!( + self, + Self::Shuffle | Self::Coalesce | Self::Broadcast | Self::Extension(_) // @NetworkBoundaryStrategy: Added Extension + ) } } @@ -252,7 +260,7 @@ fn _annotate_plan( annotation = AnnotatedPlan { plan_or_nb: PlanOrNetworkBoundary::Shuffle, children: vec![annotation], - task_count, + task_count: task_count.clone(), }; } } else if let Some(parent) = parent @@ -269,17 +277,46 @@ fn _annotate_plan( annotation = AnnotatedPlan { plan_or_nb: PlanOrNetworkBoundary::Broadcast, children: vec![annotation], - task_count, + task_count: task_count.clone(), }; } else { annotation = AnnotatedPlan { plan_or_nb: PlanOrNetworkBoundary::Coalesce, children: vec![annotation], - task_count, + task_count: task_count.clone(), }; } } + // @NetworkBoundaryStrategy: strategies last—overwrite annotation when a strategy matches + let (strategy_boundary, strategy_task_count) = + crate::distributed_planner::network_boundary_strategy::apply_network_boundary_strategy( + d_cfg, &plan, + ); + + // If a strategy matches, overwrite the annotation. + if let Some(boundary_type) = strategy_boundary { + match strategy_task_count { + // Strategy set an explicit task count; return early so we don't propagate into children. + Some(task_count_from_strategy) => { + annotation = AnnotatedPlan { + plan_or_nb: boundary_type, + children: annotation.children, + task_count: task_count_from_strategy, + }; + return Ok(annotation); + } + // Strategy did not set task count; fall through so propagation runs below. + None => { + annotation = AnnotatedPlan { + plan_or_nb: boundary_type, + children: annotation.children, + task_count: task_count.clone(), + }; + } + } + } + // The plan needs a NetworkBoundary. At this point we have all the info we need for choosing // the right size for the stage below, so what we need to do is take the calculated final // task count and propagate to all the children that will eventually be part of the stage. @@ -308,6 +345,7 @@ fn _annotate_plan( // assigned a task count, we do not want to overwrite it. PlanOrNetworkBoundary::Shuffle => return Ok(()), PlanOrNetworkBoundary::Coalesce => return Ok(()), + PlanOrNetworkBoundary::Extension(_) => return Ok(()), // @NetworkBoundaryStrategy: Extension variant }; if d_cfg.children_isolator_unions && plan.as_any().is::() { diff --git a/src/flight_service/do_get.rs b/src/flight_service/do_get.rs index 34cc861a..de4c26b4 100644 --- a/src/flight_service/do_get.rs +++ b/src/flight_service/do_get.rs @@ -73,7 +73,7 @@ pub struct TaskData { } impl Worker { - pub(super) async fn get( + pub async fn get( &self, request: Request, ) -> Result::DoGetStream>, Status> { diff --git a/src/flight_service/mod.rs b/src/flight_service/mod.rs index 565633ed..b5ee0d68 100644 --- a/src/flight_service/mod.rs +++ b/src/flight_service/mod.rs @@ -6,6 +6,8 @@ mod worker_connection_pool; pub(crate) use worker_connection_pool::WorkerConnectionPool; +pub use do_get::DoGet; + pub use session_builder::{ DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, WorkerQueryContext, WorkerSessionBuilder, diff --git a/src/lib.rs b/src/lib.rs index b15c743f..b489deeb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,7 @@ mod stage; mod distributed_planner; mod networking; mod observability; -mod protobuf; +pub mod protobuf; #[cfg(any(feature = "integration", test))] pub mod test_utils; @@ -22,19 +22,26 @@ pub use distributed_planner::{ DistributedConfig, DistributedPhysicalOptimizerRule, NetworkBoundary, NetworkBoundaryExt, TaskCountAnnotation, TaskEstimation, TaskEstimator, }; +// @NetworkBoundaryStrategy: import in separate block to avoid conflict with other imports +#[rustfmt::skip] +pub use distributed_planner::{ + CombinedNetworkBoundaryStrategy, NetworkBoundaryAnnotation, NetworkBoundaryContext, + NetworkBoundaryStrategy, PlanOrNetworkBoundary, +}; pub use execution_plans::{ BroadcastExec, DistributedExec, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, PartitionIsolatorExec, }; pub use flight_service::{ - DefaultSessionBuilder, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, TaskData, - Worker, WorkerQueryContext, WorkerSessionBuilder, + DefaultSessionBuilder, DoGet, MappedWorkerSessionBuilder, MappedWorkerSessionBuilderExt, + TaskData, Worker, WorkerQueryContext, WorkerSessionBuilder, }; pub use metrics::{DistributedMetricsFormat, rewrite_distributed_plan_with_metrics}; pub use networking::{ BoxCloneSyncChannel, ChannelResolver, DefaultChannelResolver, WorkerResolver, create_flight_client, get_distributed_channel_resolver, get_distributed_worker_resolver, }; +pub use protobuf::{AppMetadata, DistributedCodec, FlightAppMetadata}; pub use stage::{ DistributedTaskContext, ExecutionTask, Stage, display_plan_ascii, display_plan_graphviz, explain_analyze, diff --git a/src/protobuf/mod.rs b/src/protobuf/mod.rs index 81b444f1..c9941dce 100644 --- a/src/protobuf/mod.rs +++ b/src/protobuf/mod.rs @@ -1,11 +1,14 @@ -mod app_metadata; -mod distributed_codec; -mod errors; -mod user_codec; +pub mod app_metadata; +pub mod distributed_codec; +pub mod errors; +pub mod user_codec; -pub(crate) use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; -pub(crate) use distributed_codec::{DistributedCodec, StageKey}; -pub(crate) use errors::{datafusion_error_to_tonic_status, map_flight_to_datafusion_error}; +pub use app_metadata::{AppMetadata, FlightAppMetadata, MetricsCollection, TaskMetrics}; +pub use distributed_codec::{DistributedCodec, StageKey}; +pub use errors::{ + datafusion_error_to_tonic_status, map_flight_to_datafusion_error, + map_status_to_datafusion_error, +}; pub(crate) use user_codec::{ get_distributed_user_codecs, set_distributed_user_codec, set_distributed_user_codec_arc, };