diff --git a/Cargo.toml b/Cargo.toml index 70de9a2..5460836 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ exclude = [ ] [dependencies] -duroxide = { version = "0.1.28", features = ["provider-test"] } +duroxide = { git = "https://github.com/microsoft/duroxide", branch = "waldemort/batched-worker-fetch", features = ["provider-test"] } async-trait = "0.1" tokio = { version = "1", features = ["full"] } sqlx = { version = "0.8", features = ["runtime-tokio-native-tls", "postgres", "chrono"], default-features = false } diff --git a/migrations/0022_add_batched_worker_fetch.sql b/migrations/0022_add_batched_worker_fetch.sql new file mode 100644 index 0000000..1f35e35 --- /dev/null +++ b/migrations/0022_add_batched_worker_fetch.sql @@ -0,0 +1,178 @@ +-- Migration: 0022_add_batched_worker_fetch.sql +-- Description: Adds provider-native batched worker fetch with bounded session claims. + +DO $$ +DECLARE + v_schema_name TEXT := current_schema(); +BEGIN + EXECUTE format(' + CREATE OR REPLACE FUNCTION %I.fetch_work_items( + p_now_ms BIGINT, + p_lock_timeout_ms BIGINT, + p_owner_id TEXT DEFAULT NULL, + p_session_lock_timeout_ms BIGINT DEFAULT NULL, + p_tag_filter TEXT[] DEFAULT NULL, + p_tag_mode TEXT DEFAULT ''default_only'', + p_max_items INTEGER DEFAULT 1, + p_max_new_sessions INTEGER DEFAULT 0 + ) + RETURNS TABLE( + out_work_item TEXT, + out_lock_token TEXT, + out_attempt_count INTEGER + ) AS $fetch_workers$ + DECLARE + v_candidate RECORD; + v_lock_token TEXT; + v_session_locked_until BIGINT; + v_rows_affected INTEGER; + v_returned INTEGER := 0; + v_new_sessions_claimed INTEGER := 0; + v_claimed_session_ids TEXT[] := ARRAY[]::TEXT[]; + v_result_ids BIGINT[] := ARRAY[]::BIGINT[]; + v_result_work_items TEXT[] := ARRAY[]::TEXT[]; + v_result_lock_tokens TEXT[] := ARRAY[]::TEXT[]; + v_result_attempt_counts INTEGER[] := ARRAY[]::INTEGER[]; + v_scan_limit INTEGER; + BEGIN + IF current_setting(''transaction_isolation'') <> ''read committed'' THEN + RAISE EXCEPTION ''fetch_work_items requires READ COMMITTED isolation''; + END IF; + + IF p_max_items IS NULL OR p_max_items <= 0 THEN + RETURN; + END IF; + + -- none mode: return immediately with no results + IF p_tag_mode = ''none'' THEN + RETURN; + END IF; + + v_scan_limit := LEAST(GREATEST(p_max_items * 4, p_max_items), 1024); + + FOR v_candidate IN + SELECT q.id, q.session_id, s.worker_id AS active_worker_id + FROM %I.worker_queue q + LEFT JOIN %I.sessions s ON s.session_id = q.session_id AND s.locked_until > p_now_ms + WHERE q.visible_at <= TO_TIMESTAMP(p_now_ms / 1000.0) + AND (q.lock_token IS NULL OR q.locked_until <= p_now_ms) + AND ( + (p_owner_id IS NOT NULL AND ( + q.session_id IS NULL + OR s.worker_id = p_owner_id + OR s.session_id IS NULL + )) + OR (p_owner_id IS NULL AND q.session_id IS NULL) + ) + AND ( + CASE p_tag_mode + WHEN ''default_only'' THEN q.tag IS NULL + WHEN ''tags'' THEN q.tag = ANY(p_tag_filter) + WHEN ''default_and'' THEN (q.tag IS NULL OR q.tag = ANY(p_tag_filter)) + WHEN ''any'' THEN TRUE + ELSE FALSE + END + ) + ORDER BY q.session_id NULLS FIRST, q.id + LIMIT v_scan_limit + FOR UPDATE OF q SKIP LOCKED + LOOP + EXIT WHEN v_returned >= p_max_items; + + IF v_candidate.session_id IS NOT NULL THEN + IF p_owner_id IS NULL THEN + CONTINUE; + END IF; + + IF v_candidate.active_worker_id IS DISTINCT FROM p_owner_id + AND NOT (v_candidate.session_id = ANY(v_claimed_session_ids)) THEN + IF v_new_sessions_claimed >= COALESCE(p_max_new_sessions, 0) THEN + CONTINUE; + END IF; + v_new_sessions_claimed := v_new_sessions_claimed + 1; + END IF; + + v_session_locked_until := p_now_ms + COALESCE(p_session_lock_timeout_ms, p_lock_timeout_ms); + + INSERT INTO %I.sessions (session_id, worker_id, locked_until, last_activity_at) + VALUES (v_candidate.session_id, p_owner_id, v_session_locked_until, p_now_ms) + ON CONFLICT (session_id) DO UPDATE + SET worker_id = p_owner_id, + locked_until = v_session_locked_until, + last_activity_at = p_now_ms + WHERE %I.sessions.locked_until <= p_now_ms OR %I.sessions.worker_id = p_owner_id; + + GET DIAGNOSTICS v_rows_affected = ROW_COUNT; + IF v_rows_affected = 0 THEN + IF v_candidate.active_worker_id IS DISTINCT FROM p_owner_id + AND NOT (v_candidate.session_id = ANY(v_claimed_session_ids)) THEN + v_new_sessions_claimed := GREATEST(0, v_new_sessions_claimed - 1); + END IF; + CONTINUE; + END IF; + IF NOT (v_candidate.session_id = ANY(v_claimed_session_ids)) THEN + v_claimed_session_ids := array_append(v_claimed_session_ids, v_candidate.session_id); + END IF; + END IF; + + v_lock_token := ''lock_'' || gen_random_uuid()::TEXT; + + UPDATE %I.worker_queue q + SET lock_token = v_lock_token, + locked_until = p_now_ms + p_lock_timeout_ms, + attempt_count = q.attempt_count + 1 + WHERE q.id = v_candidate.id + AND (q.lock_token IS NULL OR q.locked_until <= p_now_ms) + RETURNING q.work_item, q.attempt_count + INTO out_work_item, out_attempt_count; + + GET DIAGNOSTICS v_rows_affected = ROW_COUNT; + IF v_rows_affected = 0 THEN + CONTINUE; + END IF; + + v_result_ids := array_append(v_result_ids, v_candidate.id); + v_result_work_items := array_append(v_result_work_items, out_work_item); + v_result_lock_tokens := array_append(v_result_lock_tokens, v_lock_token); + v_result_attempt_counts := array_append(v_result_attempt_counts, out_attempt_count); + v_returned := v_returned + 1; + END LOOP; + + RETURN QUERY + SELECT u.work_item, u.lock_token, u.attempt_count + FROM unnest( + v_result_ids, + v_result_work_items, + v_result_lock_tokens, + v_result_attempt_counts + ) AS u(id, work_item, lock_token, attempt_count) + ORDER BY u.id; + END; + $fetch_workers$ LANGUAGE plpgsql; + ', v_schema_name, v_schema_name, v_schema_name, v_schema_name, v_schema_name, v_schema_name, v_schema_name); + + EXECUTE format(' + CREATE INDEX IF NOT EXISTS idx_worker_queue_fetch_default_unlocked + ON %I.worker_queue (visible_at, id) + WHERE session_id IS NULL AND lock_token IS NULL + ', v_schema_name); + + EXECUTE format(' + CREATE INDEX IF NOT EXISTS idx_worker_queue_fetch_session_unlocked + ON %I.worker_queue (session_id, visible_at, id) + WHERE lock_token IS NULL + ', v_schema_name); + + EXECUTE format(' + CREATE INDEX IF NOT EXISTS idx_worker_queue_fetch_tag_unlocked + ON %I.worker_queue (tag, visible_at, id) + WHERE lock_token IS NULL + ', v_schema_name); + + EXECUTE format(' + CREATE INDEX IF NOT EXISTS idx_worker_queue_lock_expiry + ON %I.worker_queue (locked_until) + WHERE lock_token IS NOT NULL + ', v_schema_name); +END; +$$; diff --git a/pg-stress/Cargo.toml b/pg-stress/Cargo.toml index 7ebd1e8..c704dfa 100644 --- a/pg-stress/Cargo.toml +++ b/pg-stress/Cargo.toml @@ -9,7 +9,7 @@ name = "pg-stress" path = "src/bin/pg-stress.rs" [dependencies] -duroxide = { version = "0.1.28", features = ["provider-test"] } +duroxide = { git = "https://github.com/microsoft/duroxide", branch = "waldemort/batched-worker-fetch", features = ["provider-test"] } duroxide-pg = { path = ".." } tokio = { version = "1", features = ["full"] } tracing = "0.1" diff --git a/src/provider.rs b/src/provider.rs index 4df0dbc..3dc585f 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -229,6 +229,10 @@ impl Provider for PostgresProvider { env!("CARGO_PKG_VERSION") } + fn supports_batched_work_item_fetch(&self) -> bool { + true + } + #[instrument(skip(self), target = "duroxide::providers::postgres")] async fn fetch_orchestration_item( &self, @@ -944,6 +948,71 @@ impl Provider for PostgresProvider { Ok(Some((work_item, lock_token, attempt_count as u32))) } + #[instrument(skip(self), target = "duroxide::providers::postgres")] + async fn fetch_work_items( + &self, + lock_timeout: Duration, + _poll_timeout: Duration, + session: Option<&SessionFetchConfig>, + tag_filter: &TagFilter, + max_items: usize, + max_new_sessions: usize, + ) -> Result, ProviderError> { + if max_items == 0 || matches!(tag_filter, TagFilter::None) { + return Ok(Vec::new()); + } + + let start = std::time::Instant::now(); + let lock_timeout_ms = lock_timeout.as_millis() as i64; + let (owner_id, session_lock_timeout_ms): (Option<&str>, Option) = match session { + Some(config) => ( + Some(&config.owner_id), + Some(config.lock_timeout.as_millis() as i64), + ), + None => (None, None), + }; + let (tag_mode, tag_names) = Self::tag_filter_to_sql(tag_filter); + + let rows = sqlx::query_as::<_, (String, String, i32)>(&format!( + "SELECT * FROM {}.fetch_work_items($1, $2, $3, $4, $5, $6, $7, $8)", + self.schema_name + )) + .bind(Self::now_millis()) + .bind(lock_timeout_ms) + .bind(owner_id) + .bind(session_lock_timeout_ms) + .bind(&tag_names) + .bind(tag_mode) + .bind(max_items.min(i32::MAX as usize) as i32) + .bind(max_new_sessions.min(i32::MAX as usize) as i32) + .fetch_all(&*self.pool) + .await + .map_err(|e| Self::sqlx_to_provider_error("fetch_work_items", e))?; + + let mut items = Vec::with_capacity(rows.len()); + for (work_item_json, lock_token, attempt_count) in rows { + let work_item: WorkItem = serde_json::from_str(&work_item_json).map_err(|e| { + ProviderError::permanent( + "fetch_work_items", + format!("Failed to deserialize worker item: {e}"), + ) + })?; + items.push((work_item, lock_token, attempt_count as u32)); + } + + debug!( + target = "duroxide::providers::postgres", + operation = "fetch_work_items", + returned = items.len(), + requested = max_items, + max_new_sessions = max_new_sessions, + duration_ms = start.elapsed().as_millis() as u64, + "Fetched activity work-item batch via stored procedure" + ); + + Ok(items) + } + #[instrument(skip(self), fields(token = %token), target = "duroxide::providers::postgres")] async fn ack_work_item( &self, @@ -1496,7 +1565,10 @@ impl Provider for PostgresProvider { } #[instrument(skip(self), fields(instance = %instance), target = "duroxide::providers::postgres")] - async fn get_instance_stats(&self, instance: &str) -> Result, ProviderError> { + async fn get_instance_stats( + &self, + instance: &str, + ) -> Result, ProviderError> { let row: Option<(bool, i64, i64, i64, i64, i64)> = sqlx::query_as(&format!( "SELECT * FROM {}.get_instance_stats($1)", self.schema_name @@ -1507,15 +1579,20 @@ impl Provider for PostgresProvider { .map_err(|e| Self::sqlx_to_provider_error("get_instance_stats", e))?; match row { - Some((true, history_event_count, history_size_bytes, queue_pending_count, kv_user_key_count, kv_total_value_bytes)) => { - Ok(Some(SystemStats { - history_event_count: history_event_count as u64, - history_size_bytes: history_size_bytes as u64, - queue_pending_count: queue_pending_count as u64, - kv_user_key_count: kv_user_key_count as u64, - kv_total_value_bytes: kv_total_value_bytes as u64, - })) - } + Some(( + true, + history_event_count, + history_size_bytes, + queue_pending_count, + kv_user_key_count, + kv_total_value_bytes, + )) => Ok(Some(SystemStats { + history_event_count: history_event_count as u64, + history_size_bytes: history_size_bytes as u64, + queue_pending_count: queue_pending_count as u64, + kv_user_key_count: kv_user_key_count as u64, + kv_total_value_bytes: kv_total_value_bytes as u64, + })), _ => Ok(None), } } diff --git a/tests/basic_tests.rs b/tests/basic_tests.rs index 655de71..1267aba 100644 --- a/tests/basic_tests.rs +++ b/tests/basic_tests.rs @@ -1,4 +1,4 @@ -use duroxide::providers::{ExecutionMetadata, Provider, TagFilter, WorkItem}; +use duroxide::providers::{ExecutionMetadata, Provider, SessionFetchConfig, TagFilter, WorkItem}; use duroxide::{Event, EventKind, INITIAL_EVENT_ID, INITIAL_EXECUTION_ID}; use duroxide_pg::PostgresProvider; use tracing_subscriber::EnvFilter; @@ -381,6 +381,83 @@ async fn test_fetch_orchestration_item_empty_queue() { provider.cleanup_schema().await.expect("Failed to cleanup"); } +#[tokio::test] +async fn test_batched_worker_fetch_respects_session_claim_quota() { + init_test_logging(); + let database_url = get_database_url(); + let schema_name = get_test_schema(); + + let provider = PostgresProvider::new_with_schema(&database_url, Some(&schema_name)) + .await + .expect("Failed to create provider"); + + for (id, session_id) in [(1, "session-a"), (2, "session-a"), (3, "session-b")] { + provider + .enqueue_for_worker(WorkItem::ActivityExecute { + instance: format!("batch-instance-{id}"), + execution_id: 1, + id, + name: "BatchedActivity".to_string(), + input: format!("input-{id}"), + session_id: Some(session_id.to_string()), + tag: None, + }) + .await + .expect("Failed to enqueue worker work"); + } + + let session = SessionFetchConfig { + owner_id: "batch-worker".to_string(), + lock_timeout: std::time::Duration::from_secs(30), + }; + + let first_batch = provider + .fetch_work_items( + std::time::Duration::from_secs(30), + std::time::Duration::ZERO, + Some(&session), + &TagFilter::default(), + 3, + 1, + ) + .await + .expect("batched fetch should succeed"); + + assert_eq!( + first_batch.len(), + 2, + "quota=1 should allow multiple rows from the one newly claimed session" + ); + assert!(first_batch.iter().all(|(item, _, attempt_count)| { + matches!( + item, + WorkItem::ActivityExecute { + session_id: Some(session_id), + .. + } if session_id == "session-a" + ) && *attempt_count == 1 + })); + + let second_batch = provider + .fetch_work_items( + std::time::Duration::from_secs(30), + std::time::Duration::ZERO, + Some(&session), + &TagFilter::default(), + 3, + 0, + ) + .await + .expect("batched fetch should succeed"); + + assert!( + second_batch.is_empty(), + "quota=0 should not claim a second session" + ); + + provider.cleanup_schema().await.expect("Failed to cleanup"); +} + #[tokio::test] async fn test_management_capability() { init_test_logging();