diff --git a/src/config.rs b/src/config.rs index 523c0286..31dc91d1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -293,6 +293,12 @@ pub struct Config { /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. pub push_timeout_ms: u64, + /// The size of a batch of status updates. Only active in push mode. + pub status_flush_batch_size: usize, + + /// Maximum milliseconds to wait before flushing a batch of status updates. + pub status_flush_interval_ms: u64, + /// The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, @@ -383,6 +389,8 @@ impl Default for Config { push_queue_size: 1, push_queue_timeout_ms: 5000, push_timeout_ms: 30000, + status_flush_batch_size: 1, + status_flush_interval_ms: 100, callback_addr: "0.0.0.0".into(), callback_port: 50051, worker_map: [("sentry".into(), "http://127.0.0.1:50052".into())].into(), diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 9fe35be6..716220f0 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -109,13 +109,19 @@ impl FetchPool { } _ = async { + let start = Instant::now(); + debug!("Fetching next batch of pending activations..."); metrics::counter!("fetch.loop.count").increment(1); - let start = Instant::now(); let mut backoff = false; - match store.claim_activations_for_push(limit, bucket).await { + let start_claim = Instant::now(); + let result = store.claim_activations_for_push(limit, bucket).await; + metrics::histogram!("fetch.claim_activations_for_push.duration") + .record(start_claim.elapsed()); + + match result { Ok(activations) if activations.is_empty() => { metrics::counter!("fetch.empty").increment(1); debug!("No pending activations"); diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 0d11da63..6a5117c6 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -122,6 +122,18 @@ impl InflightActivationStore for MockStore { unimplemented!() } + async fn set_status_batch( + &self, + _ids: &[String], + _status: InflightActivationStatus, + ) -> Result<(), Error> { + Ok(()) + } + + async fn delete_activation_batch(&self, _ids: &[String]) -> Result { + Ok(0) + } + async fn set_processing_deadline( &self, _id: &str, diff --git a/src/flusher.rs b/src/flusher.rs new file mode 100644 index 00000000..63ebe685 --- /dev/null +++ b/src/flusher.rs @@ -0,0 +1,62 @@ +use std::future::Future; +use std::pin::Pin; +use std::time::Duration; + +use anyhow::Result; +use tokio::sync::mpsc::Receiver; + +/// Run flusher that receives values of type T from a channel and flushes +/// them using the provided async `flush` function either when the batch is +/// full or when the max flush interval has elapsed. +pub async fn run_flusher( + mut rx: Receiver, + batch_size: usize, + interval_ms: u64, + mut flush: F, +) -> Result<()> +where + F: for<'a> FnMut(&'a mut Vec) -> Pin + Send + 'a>>, +{ + let batch_size = batch_size.max(1); + let interval_ms = interval_ms.max(1); + + let period = Duration::from_millis(interval_ms); + let mut interval = tokio::time::interval(period); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let mut buffer: Vec = Vec::with_capacity(batch_size); + + loop { + tokio::select! { + msg = rx.recv() => { + match msg { + Some(v) => { + buffer.push(v); + + while let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + if buffer.len() >= batch_size { + flush(&mut buffer).await; + } + } + + None => { + // Channel closed (shutdown), flush remaining and exit + flush(&mut buffer).await; + break; + } + } + } + + _ = interval.tick() => { + if !buffer.is_empty() { + flush(&mut buffer).await; + } + } + } + } + + Ok(()) +} diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index abaa773f..6cb960f6 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -1,5 +1,7 @@ pub mod auth_middleware; pub mod metrics_middleware; pub mod server; +pub mod status_flusher; + #[cfg(test)] mod server_tests; diff --git a/src/grpc/server.rs b/src/grpc/server.rs index cf4a6bdd..af612aee 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -1,6 +1,8 @@ +use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; +use anyhow::Result; use chrono::Utc; use prost::Message; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerService; @@ -8,8 +10,9 @@ use sentry_protos::taskbroker::v1::{ FetchNextTask, GetTaskRequest, GetTaskResponse, SetTaskStatusRequest, SetTaskStatusResponse, TaskActivation, TaskActivationStatus, }; +use tokio::sync::mpsc; use tonic::{Request, Response, Status}; -use tracing::{error, instrument, warn}; +use tracing::{debug, error, instrument, warn}; use crate::config::{Config, DeliveryMode}; use crate::store::activation::InflightActivationStatus; @@ -18,6 +21,7 @@ use crate::store::traits::InflightActivationStore; pub struct TaskbrokerServer { pub store: Arc, pub config: Arc, + pub status_tx: Option>, } #[tonic::async_trait] @@ -97,10 +101,20 @@ impl ConsumerService for TaskbrokerServer { "Invalid status, expects 3 (Failure), 4 (Retry), or 5 (Complete), but got: {status:?}" ))); } + if status == InflightActivationStatus::Failure { metrics::counter!("grpc_server.set_status.failure").increment(1); } + if let Some(ref tx) = self.status_tx { + tx.send((id, status)) + .await + .map_err(|_| Status::internal("Status update channel closed"))?; + + metrics::histogram!("grpc_server.set_status.duration").record(start_time.elapsed()); + return Ok(Response::new(SetTaskStatusResponse { task: None })); + } + match self.store.set_status(&id, status).await { Ok(Some(_)) => metrics::counter!( "grpc_server.set_status", @@ -194,3 +208,58 @@ impl ConsumerService for TaskbrokerServer { res } } + +pub type StatusUpdate = (String, InflightActivationStatus); + +pub async fn flush_status_updates( + store: Arc, + buffer: &mut Vec, +) { + if buffer.is_empty() { + return; + } + + let updates = std::mem::take(buffer); + let mut by_status: HashMap> = HashMap::new(); + + for (id, status) in updates { + by_status.entry(status).or_default().push(id); + } + + let mut success = 0; + let mut fail = 0; + + for (status, ids) in by_status { + let count = ids.len() as u64; + + match store + .set_status_batch(&ids, status) + .await + .map(|()| ids.len() as u64) + { + Ok(count) => { + success += count; + debug!(?status, ?count, "Flushed status batch"); + } + + Err(e) => { + fail += count; + + error!( + ?status, + ?count, + error = ?e, + "Failed to flush status batch" + ); + + // Push failed updates back into the buffer so they can be retried on next flush + for id in ids { + buffer.push((id, status)); + } + } + } + } + + metrics::gauge!(format!("grpc_server.flush_status_updates.success")).set(success as f64); + metrics::gauge!(format!("grpc_server.flush_status_updates.fail")).set(fail as f64); +} diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index 2b986d66..0c52ce4d 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -21,11 +21,17 @@ async fn test_get_task_push_mode_returns_permission_denied() { ..Config::default() }); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); @@ -42,11 +48,17 @@ async fn test_get_task(#[case] adapter: &str) { let store = create_test_store(adapter).await; let config = create_config(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); let e = response.unwrap_err(); @@ -63,12 +75,18 @@ async fn test_set_task_status(#[case] adapter: &str) { let store = create_test_store(adapter).await; let config = create_config(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete fetch_next_task: None, }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -84,12 +102,18 @@ async fn test_set_task_status_invalid(#[case] adapter: &str) { let store = create_test_store(adapter).await; let config = create_config(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid fetch_next_task: None, }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_err()); let e = response.unwrap_err(); @@ -115,11 +139,14 @@ async fn test_get_task_success(#[case] adapter: &str) { let service = TaskbrokerServer { store: store.clone(), config, + status_tx: None, }; + let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -149,11 +176,17 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = GetTaskRequest { namespace: None, application: Some("hammers".into()), }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -177,11 +210,17 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = GetTaskRequest { namespace: Some(namespace), application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); @@ -201,12 +240,17 @@ async fn test_set_task_status_success(#[case] adapter: &str) { let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -248,7 +292,12 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -257,6 +306,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { namespace: None, }), }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); @@ -287,7 +337,12 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + // Request a task from an application without any activations. let request = SetTaskStatusRequest { id: "id_0".to_string(), @@ -297,6 +352,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { namespace: None, }), }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); assert!(response.unwrap().get_ref().task.is_none()); @@ -316,7 +372,12 @@ async fn test_set_task_status_with_namespace_requires_application(#[case] adapte store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + status_tx: None, + }; + let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -325,6 +386,7 @@ async fn test_set_task_status_with_namespace_requires_application(#[case] adapte namespace: Some(namespace), }), }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); assert!( diff --git a/src/grpc/status_flusher.rs b/src/grpc/status_flusher.rs new file mode 100644 index 00000000..11d08ba2 --- /dev/null +++ b/src/grpc/status_flusher.rs @@ -0,0 +1,136 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use anyhow::Result; +use tokio::sync::mpsc::Receiver; +use tracing::{debug, error}; + +use crate::config::Config; +use crate::store::activation::InflightActivationStatus; +use crate::store::traits::InflightActivationStore; + +pub type StatusUpdate = (String, InflightActivationStatus); + +/// Run the status flusher task. Receives (id, status) from the channel and +/// flushes to the store in batches, either when the batch is full or when +/// the max flush interval has elapsed. +pub async fn run_status_flusher( + mut rx: Receiver, + store: Arc, + config: Arc, +) -> Result<()> { + let batch_size = config.status_flush_batch_size.max(1); + let flush_interval = Duration::from_millis(config.status_flush_interval_ms); + let mut interval = tokio::time::interval(flush_interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let mut buffer: Vec = Vec::with_capacity(batch_size); + + loop { + tokio::select! { + msg = rx.recv() => { + match msg { + Some((id, status)) => { + buffer.push((id, status)); + + while let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + if buffer.len() >= batch_size { + metrics::histogram!("status_flush.batch_size").record(buffer.len() as f64); + flush_buffer(&store, &mut buffer).await; + } + } + + None => { + // Channel closed (shutdown), flush remaining and exit. + if !buffer.is_empty() { + flush_buffer(&store, &mut buffer).await; + } + + debug!("Status flusher shutting down..."); + break; + } + } + } + + _ = interval.tick() => { + if !buffer.is_empty() { + flush_buffer(&store, &mut buffer).await; + } + } + } + } + + Ok(()) +} + +async fn flush_buffer(store: &Arc, buffer: &mut Vec) { + if buffer.is_empty() { + return; + } + + let start = Instant::now(); + + let updates = std::mem::take(buffer); + let mut by_status: HashMap> = HashMap::new(); + + for (id, status) in updates { + by_status.entry(status).or_default().push(id); + } + + for (status, ids) in by_status { + let count = ids.len() as u64; + + let result = if status == InflightActivationStatus::Complete { + let start = Instant::now(); + let result = store.delete_activation_batch(&ids).await; + + metrics::histogram!("status_flush.delete_activation_batch.duration") + .record(start.elapsed()); + + if result.is_err() { + metrics::counter!("status_flush.delete_activation_batch.error").increment(1); + } + + result + } else { + let start = Instant::now(); + let result = store + .set_status_batch(&ids, status) + .await + .map(|()| ids.len() as u64); + + metrics::histogram!("status_flush.set_status_batch.duration").record(start.elapsed()); + + if result.is_err() { + metrics::counter!("status_flush.set_status_batch.error").increment(1); + } + + result + }; + + if let Err(e) = result { + error!( + ?status, + ?count, + error = ?e, + "Failed to flush status batch" + ); + + // Push failed updates back into the buffer so they can be retried on next flush + for id in ids { + buffer.push((id, status)); + } + } else { + metrics::counter!("status_flush.activations", "status" => status.to_string()) + .increment(count); + + debug!(?status, ?count, "Flushed status batch"); + } + } + + metrics::histogram!("status_flush.flush.duration").record(start.elapsed()); +} diff --git a/src/lib.rs b/src/lib.rs index 6ce53cd1..89a17421 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use std::fs; pub mod config; pub mod fetch; +pub mod flusher; pub mod grpc; pub mod kafka; pub mod logging; diff --git a/src/main.rs b/src/main.rs index 4c174420..00f30557 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,12 +12,11 @@ use tonic::transport::Server; use tonic_health::ServingStatus; use tracing::{debug, error, info, warn}; -use taskbroker::SERVICE_NAME; use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::TaskbrokerServer; +use taskbroker::grpc::server::{TaskbrokerServer, flush_status_updates}; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize_activation; @@ -41,6 +40,7 @@ use taskbroker::store::adapters::sqlite::{InflightActivationStoreConfig, SqliteA use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; +use taskbroker::{SERVICE_NAME, flusher}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -192,10 +192,33 @@ async fn main() -> Result<(), Error> { } }); + // Status flusher + let (status_tx, status_flush_task) = if config.delivery_mode == DeliveryMode::Push { + let (tx, rx) = tokio::sync::mpsc::channel(config.status_flush_batch_size); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.status_flush_batch_size, + flusher_config.status_flush_interval_ms, + move |buffer| Box::pin(flush_status_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + // GRPC server let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); + let grpc_status_tx = status_tx.clone(); async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) @@ -212,6 +235,7 @@ async fn main() -> Result<(), Error> { .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, config: grpc_config, + status_tx: grpc_status_tx, })) .add_service(health_service.clone()) .serve(addr); @@ -278,6 +302,10 @@ async fn main() -> Result<(), Error> { departure = departure.on_completion(log_task_completion("fetch_task", task)); } + if let Some(task) = status_flush_task { + departure = departure.on_completion(log_task_completion("status_flush_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs index 5696ce77..a2d2bcb8 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -168,7 +168,7 @@ impl PushPool { Ok(a) => a, // Channel closed - Err(_) => break + Err(_) => break, }; let id = activation.id.clone(); @@ -183,7 +183,7 @@ impl PushPool { "Task application has no worker pool mapping" ); - continue + continue; }; match push_task( @@ -211,7 +211,12 @@ impl PushPool { } } - if let Err(e) = store.mark_activation_processing(&id).await { + let start = Instant::now(); + let result = store.mark_activation_processing(&id).await; + metrics::histogram!("push.mark_activation_processing.duration") + .record(start.elapsed()); + + if let Err(e) = result { metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); error!( @@ -267,7 +272,12 @@ impl PushPool { metrics::counter!("push.delivery", "result" => "ok").increment(1); debug!(task_id = %id, "Activation sent to worker"); - if let Err(e) = store.mark_activation_processing(&id).await { + let start = Instant::now(); + let result = store.mark_activation_processing(&id).await; + metrics::histogram!("push.mark_activation_processing.duration") + .record(start.elapsed()); + + if let Err(e) = result { metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); error!( diff --git a/src/store/activation.rs b/src/store/activation.rs index f0b5a411..d13560a4 100644 --- a/src/store/activation.rs +++ b/src/store/activation.rs @@ -8,7 +8,7 @@ use sqlx::Type; /// The members of this enum should be a superset of the members /// of `InflightActivationStatus` in `sentry_protos`. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Type)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Type, Hash)] pub enum InflightActivationStatus { /// Unused but necessary to align with sentry-protos Unspecified, diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 81d720fd..f1947dd1 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -637,6 +637,27 @@ impl InflightActivationStore for PostgresActivationStore { Ok(Some(row.into())) } + #[instrument(skip_all)] + async fn set_status_batch( + &self, + ids: &[String], + status: InflightActivationStatus, + ) -> Result<(), Error> { + if ids.is_empty() { + return Ok(()); + } + + let mut conn = self.acquire_write_conn_metric("set_status_batch").await?; + + sqlx::query("UPDATE inflight_taskactivations SET status = $1 WHERE id = ANY($2)") + .bind(status.to_string()) + .bind(ids) + .execute(&mut *conn) + .await?; + + Ok(()) + } + #[instrument(skip_all)] async fn set_processing_deadline( &self, @@ -664,6 +685,28 @@ impl InflightActivationStore for PostgresActivationStore { Ok(()) } + #[instrument(skip_all)] + async fn delete_activation_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("delete_activations_by_id") + .await?; + + let mut query_builder = + QueryBuilder::new("DELETE FROM inflight_taskactivations WHERE id = ANY("); + + query_builder.push_bind(ids); + query_builder.push(")"); + + self.add_partition_condition(&mut query_builder, false); + + let result = query_builder.build().execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + #[instrument(skip_all)] async fn get_retry_activations(&self) -> Result, Error> { let mut query_builder = QueryBuilder::new( diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index 8692ac0c..05098cdb 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -706,6 +706,33 @@ impl InflightActivationStore for SqliteActivationStore { Ok(Some(row.into())) } + #[instrument(skip_all)] + async fn set_status_batch( + &self, + ids: &[String], + status: InflightActivationStatus, + ) -> Result<(), Error> { + if ids.is_empty() { + return Ok(()); + } + + let mut conn = self.acquire_write_conn_metric("set_status_batch").await?; + + let placeholders: Vec = (0..ids.len()).map(|i| format!("?{}", i + 2)).collect(); + let sql = format!( + "UPDATE inflight_taskactivations SET status = ?1 WHERE id IN ({})", + placeholders.join(", ") + ); + + let mut q = sqlx::query(&sql).bind(status); + for id in ids { + q = q.bind(id); + } + + q.execute(&mut *conn).await?; + Ok(()) + } + #[instrument(skip_all)] async fn set_processing_deadline( &self, @@ -733,6 +760,32 @@ impl InflightActivationStore for SqliteActivationStore { Ok(()) } + #[instrument(skip_all)] + async fn delete_activation_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("delete_activations_by_id") + .await?; + + let placeholders: Vec = (0..ids.len()).map(|i| format!("?{}", i + 1)).collect(); + + let sql = format!( + "DELETE FROM inflight_taskactivations WHERE id IN ({})", + placeholders.join(", ") + ); + + let mut q = sqlx::query(&sql); + for id in ids { + q = q.bind(id); + } + + let result = q.execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + #[instrument(skip_all)] async fn get_retry_activations(&self) -> Result, Error> { Ok(sqlx::query_as( diff --git a/src/store/traits.rs b/src/store/traits.rs index c21a9d41..86def9f0 100644 --- a/src/store/traits.rs +++ b/src/store/traits.rs @@ -78,6 +78,13 @@ pub trait InflightActivationStore: Send + Sync { status: InflightActivationStatus, ) -> Result, Error>; + /// Update the status of multiple activations in one batch. + async fn set_status_batch( + &self, + ids: &[String], + status: InflightActivationStatus, + ) -> Result<(), Error>; + /// COUNT OPERATIONS /// Get the age of the oldest pending activation in seconds async fn pending_activation_max_lag(&self, now: &DateTime) -> f64; @@ -126,6 +133,9 @@ pub trait InflightActivationStore: Send + Sync { /// Delete an activation by id async fn delete_activation(&self, id: &str) -> Result<(), Error>; + /// Delete several activations by ID. + async fn delete_activation_batch(&self, ids: &[String]) -> Result; + /// DATABASE OPERATIONS /// Trigger incremental vacuum to reclaim free pages in the database async fn vacuum_db(&self) -> Result<(), Error>;