Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod children_helpers;
mod map_last_stream;
mod on_drop_stream;
mod once_lock;
mod task_context_helpers;
mod uuid;

pub(crate) use children_helpers::require_one_child;
pub(crate) use map_last_stream::map_last_stream;
pub(crate) use on_drop_stream::on_drop_stream;
pub(crate) use once_lock::OnceLockResult;
pub(crate) use task_context_helpers::task_ctx_with_extension;
pub(crate) use uuid::{deserialize_uuid, serialize_uuid};
5 changes: 5 additions & 0 deletions src/common/once_lock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use datafusion::error::DataFusionError;
use std::sync::{Arc, OnceLock};

/// A [OnceLock] that holds a clonable result.
pub(crate) type OnceLockResult<T> = OnceLock<Result<T, Arc<DataFusionError>>>;
4 changes: 2 additions & 2 deletions src/execution_plans/broadcast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::common::require_one_child;
use crate::common::{OnceLockResult, require_one_child};
use crossbeam_queue::SegQueue;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::runtime::SpawnedTask;
Expand Down Expand Up @@ -73,7 +73,7 @@ pub struct BroadcastExec {
input: Arc<dyn ExecutionPlan>,
consumer_task_count: usize,
properties: Arc<PlanProperties>,
queues: Vec<OnceLock<Result<StreamAndTask, Arc<DataFusionError>>>>,
queues: Vec<OnceLockResult<StreamAndTask>>,
}

type StreamAndTask = (SegQueue<SendableRecordBatchStream>, Arc<SpawnedTask<()>>);
Expand Down
2 changes: 1 addition & 1 deletion src/execution_plans/network_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl ExecutionPlan for NetworkBroadcastExec {
&context,
)?;

let stream = worker_connection.stream_partition(off + partition, |_meta| {})?;
let stream = worker_connection.execute(off + partition)?;
streams.push(stream);
}

Expand Down
2 changes: 1 addition & 1 deletion src/execution_plans/network_coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ impl ExecutionPlan for NetworkCoalesceExec {
&context,
)?;

let stream = worker_connection.stream_partition(target_partition, |_meta| {})?;
let stream = worker_connection.execute(target_partition)?;

Ok(Box::pin(RecordBatchStreamAdapter::new(
self.schema(),
Expand Down
2 changes: 1 addition & 1 deletion src/execution_plans/network_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl ExecutionPlan for NetworkShuffleExec {
&context,
)?;

let stream = worker_connection.stream_partition(off + partition, |_meta| {})?;
let stream = worker_connection.execute(off + partition)?;
streams.push(stream);
}

Expand Down
1 change: 1 addition & 0 deletions src/worker/impl_coordinator_channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl Worker {
task_count: request.task_count as usize,
}))
.with_extension(Arc::new(LocalWorkerContext {
task_data_entries: Arc::clone(&self.task_data_entries),
self_url: Url::parse(&request.target_worker_url)
.map_err(|e| DataFusionError::External(Box::new(e)))?,
}))
Expand Down
278 changes: 153 additions & 125 deletions src/worker/impl_execute_task.rs

Large diffs are not rendered by default.

174 changes: 145 additions & 29 deletions src/worker/worker_connection_pool.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use crate::common::{on_drop_stream, serialize_uuid};
use crate::common::{OnceLockResult, on_drop_stream, serialize_uuid};
use crate::metrics::LatencyMetricExt;
use crate::networking::get_distributed_channel_resolver;
use crate::passthrough_headers::get_passthrough_headers;
use crate::protobuf::{datafusion_error_to_tonic_status, map_flight_to_datafusion_error};
use crate::stage::RemoteStage;
use crate::worker::generated::worker::FlightAppMetadata;
use crate::worker::generated::worker::{ExecuteTaskRequest, TaskKey};
use crate::worker::impl_execute_task::execute_local_task;
use crate::worker::worker_service::TaskDataEntries;
use crate::{BytesMetricExt, ChannelResolver, DistributedConfig};
use arrow_flight::FlightData;
use arrow_flight::decode::FlightRecordBatchStream;
Expand All @@ -14,12 +16,13 @@ use dashmap::DashMap;
use datafusion::arrow::array::RecordBatch;
use datafusion::common::instant::Instant;
use datafusion::common::runtime::SpawnedTask;
use datafusion::common::{DataFusionError, Result, internal_err};
use datafusion::common::{DataFusionError, Result, internal_datafusion_err, internal_err};
use datafusion::execution::TaskContext;
use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion::physical_expr_common::metrics::{ExecutionPlanMetricsSet, MetricValue};
use datafusion::physical_plan::metrics::{MetricBuilder, Time};
use futures::{Stream, TryStreamExt};
use futures::stream::BoxStream;
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
use http::Extensions;
use pin_project::{pin_project, pinned_drop};
use prost::Message;
Expand All @@ -28,12 +31,11 @@ use std::fmt::{Debug, Formatter};
use std::ops::Range;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::Notify;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_stream::StreamExt;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::metadata::MetadataMap;
Expand All @@ -47,9 +49,10 @@ use url::Url;
/// This information can be used for executing tasks locally bypassing gRPC comms if the tasks that
/// needs to be remotely executed happens to be owned by this same worker.
pub(crate) struct LocalWorkerContext {
/// The registry of in-flight tasks the [crate::Worker] in the current scope owns.
pub(crate) task_data_entries: Arc<TaskDataEntries>,
/// The URL of the [crate::Worker] in scope. When trying to reach to a target URL that happens
/// to be the same as this one, local comms are preferred instead.
#[allow(dead_code)]
pub(crate) self_url: Url,
}

Expand All @@ -59,7 +62,7 @@ pub(crate) struct LocalWorkerContext {
/// it will initialize the corresponding position in the vector matching the provided `target_task`
/// index.
pub(crate) struct WorkerConnectionPool {
connections: Vec<OnceLock<Result<WorkerConnection, Arc<DataFusionError>>>>,
connections: Vec<OnceLockResult<Box<dyn WorkerConnection + Sync + Send>>>,
pub(crate) metrics: ExecutionPlanMetricsSet,
}

Expand All @@ -86,35 +89,60 @@ impl WorkerConnectionPool {
target_partitions: Range<usize>,
target_task: usize,
ctx: &Arc<TaskContext>,
) -> Result<&WorkerConnection> {
) -> Result<&(dyn WorkerConnection + Sync + Send)> {
let Some(worker_connection) = self.connections.get(target_task) else {
return internal_err!(
"WorkerConnections: Task index {target_task} not found, only have {} tasks",
self.connections.len()
);
};
ctx.session_config().get_extension::<LocalWorkerContext>();

let conn = worker_connection.get_or_init(|| {
WorkerConnection::init(
input_stage,
target_partitions,
target_task,
ctx,
&self.metrics,
)
.map_err(Arc::new)
let Some(target_url) = input_stage.workers.get(target_task) else {
internal_err!("input_stage.workers[{target_task}] out of range.")?
};
if let Some(lw_ctx) = ctx.session_config().get_extension::<LocalWorkerContext>()
&& &lw_ctx.self_url == target_url
{
// Instead of making a gRPC call to ourselves, better to just use local comms.
Ok(Box::new(LocalWorkerConnection::init(
input_stage,
target_partitions,
target_task,
lw_ctx,
&self.metrics,
)) as Box<_>)
} else {
// We are trying to reach a URL different from ours, so use normal gRPC streams.
RemoteWorkerConnection::init(
input_stage,
target_partitions,
target_task,
ctx,
&self.metrics,
)
.map(|v| Box::new(v) as Box<_>)
.map_err(Arc::new)
}
});

match conn {
Ok(v) => Ok(v),
Ok(v) => Ok(v.as_ref()),
Err(err) => Err(DataFusionError::Shared(Arc::clone(err))),
}
}
}

type WorkerMsg = Result<(FlightData, FlightAppMetadata), Status>;

/// Abstraction that allows treating remote and local comms as equal. Network boundaries do not
/// care if the stream comes over the wire or locally.
pub(crate) trait WorkerConnection {
/// Streams the specified partition. Consumers do not care if the implementation pulls data
/// from in-memory or from local comms.
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>>;
}

/// Represents a connection to one [Worker]. Network boundaries will use this for streaming
/// data from single partitions while the actual network communication is handling all the partitions
/// under the hood.
Expand All @@ -126,7 +154,7 @@ type WorkerMsg = Result<(FlightData, FlightAppMetadata), Status>;
/// the same underlying TCP connection, there do is some overhead in having one gRPC stream per
/// partition VS a single gRPC stream interleaving multiple partitions. The whole serialized plan
/// needs to be sent over the wire on every gRPC call, so the less gRPC calls we do the better.
pub(crate) struct WorkerConnection {
struct RemoteWorkerConnection {
task: Arc<SpawnedTask<()>>,
not_consumed_streams: Arc<AtomicUsize>,
cancel_token: CancellationToken,
Expand All @@ -140,7 +168,7 @@ pub(crate) struct WorkerConnection {
elapsed_compute: Time,
}

impl WorkerConnection {
impl RemoteWorkerConnection {
fn init(
input_stage: &RemoteStage,
target_partition_range: Range<usize>,
Expand Down Expand Up @@ -337,7 +365,9 @@ impl WorkerConnection {
elapsed_compute: elapsed_compute_clone,
})
}
}

impl WorkerConnection for RemoteWorkerConnection {
/// Streams the provided `partition` from the remote worker.
///
/// Note that this does not issue a network request, the actual network request happened before
Expand All @@ -348,11 +378,7 @@ impl WorkerConnection {
///
/// When the returned stream is dropped (e.g., due to query cancellation), the background task
/// pulling from the Flight stream will be cancelled promptly.
pub(crate) fn stream_partition(
&self,
partition: usize,
on_metadata: impl Fn(FlightAppMetadata) + Send + Sync + 'static,
) -> Result<impl Stream<Item = Result<RecordBatch>> + 'static> {
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>> {
let Some((_, partition_receiver)) = self.per_partition_rx.remove(&partition) else {
return internal_err!(
"WorkerConnection has no stream for target partition {partition}. Was it already consumed?"
Expand All @@ -365,12 +391,11 @@ impl WorkerConnection {
let stream = stream.map_err(|err| FlightError::Tonic(Box::new(err)));
let reservation = Arc::clone(&self.memory_reservation);
let mem_available_notify = Arc::clone(&self.mem_available_notify);
let stream = stream.map_ok(move |(data, meta)| {
let stream = stream.map_ok(move |(data, _meta)| {
reservation.shrink(data.encoded_len());
// Wake the demux task in case it is blocked on the byte budget.
mem_available_notify.notify_one();
let _ = &task; // <- keep the task that polls data from the network alive.
on_metadata(meta);
data
});
let stream = FlightRecordBatchStream::new_from_flight_data(stream);
Expand All @@ -384,7 +409,98 @@ impl WorkerConnection {
if remaining_streams == 0 {
cancel_token.cancel();
}
}))
})
.boxed())
}
}

/// Equivalent to [RemoteWorkerConnection], but that pulls data from the local registry of tasks
/// rather than doing it across a gRPC interface.
pub(crate) struct LocalWorkerConnection {
partition_start: usize,
local_streams: Vec<Mutex<Option<BoxStream<'static, Result<RecordBatch>>>>>,
}

impl LocalWorkerConnection {
fn init(
input_stage: &RemoteStage,
target_partition_range: Range<usize>,
target_task: usize,
lw_ctx: Arc<LocalWorkerContext>,
metrics: &ExecutionPlanMetricsSet,
) -> Self {
MetricBuilder::new(metrics)
.global_counter("local_connections_used")
.add(1);

let task_key = TaskKey {
query_id: serialize_uuid(&input_stage.query_id),
stage_id: input_stage.num as u64,
task_number: target_task as u64,
};

let partition_start = target_partition_range.start;
let mut local_streams = Vec::with_capacity(target_partition_range.len());
for partition_i in target_partition_range {
let request = ExecuteTaskRequest {
task_key: Some(task_key.clone()),
target_partition_start: partition_i as u64,
target_partition_end: (partition_i + 1) as u64,
};

let task_data_entries = Arc::clone(&lw_ctx.task_data_entries);

// The relevant entry from `task_data_entries` needs to be eagerly retrieved, it cannot be
// left for until someone decides to start polling the returned `BoxStream`, otherwise,
// there's risk that the entry is evicted by Moka's TTL, and by the time the returned stream
// is polled, the entry might not be there.
//
// Note that this does not start polling the returned streams, it just instantiates them.
let streams_future = SpawnedTask::spawn(async move {
let (streams, _) = execute_local_task(&task_data_entries, request).await?;
Ok::<_, DataFusionError>(streams)
});

let stream = async move {
let mut streams = streams_future
.await
.map_err(|err| internal_datafusion_err!("{err}"))??;
if streams.len() != 1 {
return internal_err!("Expected exactly 1 local stream");
}
Ok(streams.swap_remove(0))
}
.try_flatten_stream()
.boxed();

local_streams.push(Mutex::new(Some(stream)));
}

Self {
partition_start,
local_streams,
}
}
}

impl WorkerConnection for LocalWorkerConnection {
fn execute(&self, partition: usize) -> Result<BoxStream<'static, Result<RecordBatch>>> {
let Some(relative_i) = partition.checked_sub(self.partition_start) else {
return internal_err!(
"LocalWorkerConnection received an invalid partition {partition}, the starting partition is {}",
self.partition_start
);
};
let Some(slot) = self.local_streams.get(relative_i) else {
return internal_err!(
"LocalWorkerConnection has no stream for partition {partition}. Was it already consumed?"
);
};
slot.lock().unwrap().take().ok_or_else(|| {
internal_datafusion_err!(
"LocalWorkerConnection stream for partition {partition} was already consumed"
)
})
}
}

Expand All @@ -408,7 +524,7 @@ impl Clone for WorkerConnectionPool {
}
}

impl Debug for WorkerConnection {
impl Debug for RemoteWorkerConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WorkerConnection").finish()
}
Expand Down
Loading
Loading