diff --git a/src/distributed_planner/distributed_query_planner.rs b/src/distributed_planner/distributed_query_planner.rs index 9ffb04bb..2fd92ab7 100644 --- a/src/distributed_planner/distributed_query_planner.rs +++ b/src/distributed_planner/distributed_query_planner.rs @@ -1,4 +1,6 @@ -use crate::distributed_planner::inject_network_boundaries::inject_network_boundaries; +use crate::distributed_planner::inject_network_boundaries::{ + CardinalityBasedNetworkBoundaryBuilder, inject_network_boundaries, +}; use crate::distributed_planner::insert_broadcast::insert_broadcast_execs; use crate::distributed_planner::partial_reduce_below_network_shuffles::partial_reduce_below_network_shuffles; use crate::distributed_planner::prepare_network_boundaries::prepare_network_boundaries; @@ -93,7 +95,7 @@ impl QueryPlanner for DistributedQueryPlanner { plan = insert_broadcast_execs(plan, cfg)?; - plan = inject_network_boundaries(plan, cfg).await?; + plan = inject_network_boundaries(plan, CardinalityBasedNetworkBoundaryBuilder, cfg).await?; plan = prepare_network_boundaries(plan)?; if !plan.exists(|plan| Ok(plan.is_network_boundary()))? { diff --git a/src/distributed_planner/inject_network_boundaries.rs b/src/distributed_planner/inject_network_boundaries.rs index 2804c22b..1c2c5f72 100644 --- a/src/distributed_planner/inject_network_boundaries.rs +++ b/src/distributed_planner/inject_network_boundaries.rs @@ -2,9 +2,10 @@ use crate::TaskCountAnnotation::{Desired, Maximum}; use crate::execution_plans::{ChildWeight, ChildrenIsolatorUnionExec}; use crate::stage::LocalStage; use crate::{ - BroadcastExec, DistributedConfig, NetworkBoundaryExt, NetworkBroadcastExec, + BroadcastExec, DistributedConfig, NetworkBoundary, NetworkBoundaryExt, NetworkBroadcastExec, NetworkCoalesceExec, NetworkShuffleExec, TaskCountAnnotation, TaskEstimator, }; +use async_trait::async_trait; use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion::common::{HashMap, Result, plan_err}; use datafusion::config::ConfigOptions; @@ -131,13 +132,15 @@ use uuid::Uuid; /// boundary injection, so the head stage is closed by running one final Phase 2 pass over /// the whole plan. This guarantees every node (including head-stage nodes that never sat /// directly above a boundary) has a task count recorded. -pub(super) async fn inject_network_boundaries( +pub(crate) async fn inject_network_boundaries( plan: Arc, + nb_builder: impl NetworkBoundaryBuilder + Send + Sync, cfg: &ConfigOptions, ) -> Result> { let ctx = Context { cfg, d_cfg: DistributedConfig::from_config_options(cfg)?, + nb_builder: &nb_builder, task_counts: &Mutex::new(HashMap::new()), query_id: Uuid::new_v4(), stage_id: &AtomicUsize::new(1), @@ -150,6 +153,7 @@ pub(super) async fn inject_network_boundaries( struct Context<'a> { cfg: &'a ConfigOptions, d_cfg: &'a DistributedConfig, + nb_builder: &'a (dyn NetworkBoundaryBuilder + Send + Sync), task_counts: &'a Mutex>, query_id: Uuid, stage_id: &'a AtomicUsize, @@ -204,6 +208,13 @@ impl<'a> Context<'a> { fn fetch_add_stage_id(&self) -> usize { self.stage_id.fetch_add(1, Ordering::Acquire) } + + async fn apply_nb_builder( + &self, + nb: Arc, + ) -> Result { + self.nb_builder.build(nb, self.cfg).await + } } /// Identity key for a plan node. The pointer is only used as a hash-map key, never dereferenced, @@ -287,16 +298,14 @@ async fn _inject_network_boundaries( // count down so every node in that stage has it recorded. let plan = propagate_task_count_until_network_boundaries(&plan, task_count, ctx)?; - let f = calculate_scale_factor(&plan, ctx); - let input_stage = LocalStage { + let plan = NetworkShuffleExec::from_stage(LocalStage { query_id: ctx.query_id, num: ctx.fetch_add_stage_id(), plan, tasks: task_count.as_usize(), - }; - let plan = Arc::new(NetworkShuffleExec::from_stage(input_stage)); - let task_count = Desired((f * task_count.as_usize() as f64).ceil() as usize); - return Ok(ctx.plan_with_task_count(plan, task_count)); + }); + let result = ctx.apply_nb_builder(Arc::new(plan)).await?; + return Ok(ctx.plan_with_task_count(result.network_boundary, result.task_count_above)); } // If the parent of the current node is either a `CoalescePartitionsExec` or a // `SortPreservingMergeExec`, a network boundary below it is necessary. @@ -315,31 +324,37 @@ async fn _inject_network_boundaries( // count down so every node in that stage has it recorded. let plan = propagate_task_count_until_network_boundaries(&plan, task_count, ctx)?; - let f = calculate_scale_factor(&plan, ctx); - let input_stage = LocalStage { + let plan = NetworkBroadcastExec::from_stage(LocalStage { query_id: ctx.query_id, num: ctx.fetch_add_stage_id(), plan, tasks: task_count.as_usize(), - }; - let plan = Arc::new(NetworkBroadcastExec::from_stage(input_stage)); - let task_count = Desired((f * task_count.as_usize() as f64).ceil() as usize); - Ok(ctx.plan_with_task_count(plan, task_count)) + }); + let result = ctx.apply_nb_builder(Arc::new(plan)).await?; + return Ok(ctx.plan_with_task_count(result.network_boundary, result.task_count_above)); } else { // The subtree below this point belongs to one stage. Propagate the chosen task // count down so every node in that stage has it recorded. let plan = propagate_task_count_until_network_boundaries(&plan, task_count, ctx)?; - let input_stage = LocalStage { - query_id: ctx.query_id, - num: ctx.fetch_add_stage_id(), - plan, - tasks: task_count.as_usize(), - }; - let plan = Arc::new(NetworkCoalesceExec::from_stage(input_stage, 1)); + let plan = NetworkCoalesceExec::from_stage( + LocalStage { + query_id: ctx.query_id, + num: ctx.fetch_add_stage_id(), + plan, + tasks: task_count.as_usize(), + }, + 1, + ); + let result = ctx.apply_nb_builder(Arc::new(plan)).await?; + if !matches!(result.task_count_above, Maximum(1)) { + return plan_err!( + "A NetworkCoalesceExec must return exactly a Maximum(1) annotation above" + ); + } // The parent that triggered this branch is a `CoalescePartitionsExec` or // `SortPreservingMergeExec`, both of which fold all partitions into one — so the // stage above this boundary must run in exactly one task. - Ok(ctx.plan_with_task_count(plan, Maximum(1))) + Ok(ctx.plan_with_task_count(result.network_boundary, Maximum(1))) }; } @@ -473,6 +488,37 @@ fn propagate_task_count_until_network_boundaries( } } +pub(crate) struct NetworkBoundaryBuilderResult { + pub(crate) task_count_above: TaskCountAnnotation, + pub(crate) network_boundary: Arc, +} + +#[async_trait] +pub(crate) trait NetworkBoundaryBuilder { + async fn build( + &self, + nb: Arc, + cfg: &ConfigOptions, + ) -> Result; +} + +#[async_trait] +impl NetworkBoundaryBuilder for T +where + T: Fn(Arc, &ConfigOptions) -> Result, + T: Send + Sync, + F: Future>, + F: Send, +{ + async fn build( + &self, + nb: Arc, + cfg: &ConfigOptions, + ) -> Result { + self(nb, cfg)?.await + } +} + /// Returns a multiplicative factor describing how the data volume changes between the bottom of /// `plan` (at a network boundary or a leaf) and `plan` itself. The walk descends into `plan`'s /// children, stops at any node that is itself a network boundary (returning `1.0` there — that @@ -506,24 +552,57 @@ fn propagate_task_count_until_network_boundaries( /// /// With `cardinality_task_count_factor = 1.5`, the example above yields `sf ≈ 0.44`. The /// boundary's recorded task count above this stage will be `ceil(T_producer × sf)`. -fn calculate_scale_factor(plan: &Arc, ctx: &Context) -> f64 { - if plan.is_network_boundary() { - return 1.0; - }; +pub(crate) struct CardinalityBasedNetworkBoundaryBuilder; - let mut sf = None; - for plan in plan.children() { - sf = match sf { - None => Some(calculate_scale_factor(plan, ctx)), - Some(sf) => Some(sf.max(calculate_scale_factor(plan, ctx))), +#[async_trait] +impl NetworkBoundaryBuilder for CardinalityBasedNetworkBoundaryBuilder { + async fn build( + &self, + nb: Arc, + cfg: &ConfigOptions, + ) -> Result { + if nb.as_any().is::() { + return Ok(NetworkBoundaryBuilderResult { + task_count_above: Maximum(1), + network_boundary: nb, + }); } - } + let d_cfg = DistributedConfig::from_config_options(cfg)?; + + fn calculate_scale_factor(plan: &Arc, d_cfg: &DistributedConfig) -> f64 { + if plan.is_network_boundary() { + return 1.0; + }; - let sf = sf.unwrap_or(1.0); - match plan.cardinality_effect() { - CardinalityEffect::LowerEqual => sf / ctx.d_cfg.cardinality_task_count_factor, - CardinalityEffect::GreaterEqual => sf * ctx.d_cfg.cardinality_task_count_factor, - _ => sf, + let mut sf = None; + for plan in plan.children() { + sf = match sf { + None => Some(calculate_scale_factor(plan, d_cfg)), + Some(sf) => Some(sf.max(calculate_scale_factor(plan, d_cfg))), + } + } + + let sf = sf.unwrap_or(1.0); + match plan.cardinality_effect() { + CardinalityEffect::LowerEqual => sf / d_cfg.cardinality_task_count_factor, + CardinalityEffect::GreaterEqual => sf * d_cfg.cardinality_task_count_factor, + _ => sf, + } + } + + let input_stage = nb.input_stage(); + let Some(input_plan) = input_stage.local_plan() else { + return plan_err!( + "input_stage plan needs to be in local mode for cardinality calculation" + ); + }; + + let f = calculate_scale_factor(input_plan, d_cfg); + + Ok(NetworkBoundaryBuilderResult { + task_count_above: Desired((f * input_stage.task_count() as f64).ceil() as usize), + network_boundary: nb, + }) } } @@ -1118,6 +1197,7 @@ mod tests { task_counts: &Mutex::new(HashMap::new()), query_id: Uuid::new_v4(), stage_id: &AtomicUsize::new(1), + nb_builder: &CardinalityBasedNetworkBoundaryBuilder, }; let annotated = _inject_network_boundaries(plan, None, &ctx)