From dcb162bb5d39b9a25991bce4d5fb88866eac8e28 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:46:07 +0530 Subject: [PATCH 01/12] feat(workflows): add Rust workflow engine with PyO3 bindings New taskito-workflows crate with WorkflowDefinition, WorkflowRun, WorkflowNode types, SQLite storage, topology module, and state machine. PyO3 bindings: PyWorkflowBuilder, PyWorkflowHandle, PyWorkflowRunStatus, plus PyQueue methods for submit, mark result, cancel, fan-out expansion, deferred job creation, approval gates, sub-workflow support, incremental caching, and visualization DAG retrieval. --- Cargo.toml | 3 +- .../taskito-core/src/storage/postgres/mod.rs | 4 +- .../src/storage/redis_backend/mod.rs | 2 +- crates/taskito-core/src/storage/sqlite/mod.rs | 2 +- crates/taskito-python/Cargo.toml | 2 + crates/taskito-python/src/lib.rs | 8 + crates/taskito-python/src/py_queue/mod.rs | 2 + .../src/py_queue/workflow_ops.rs | 820 ++++++++++++++++++ crates/taskito-python/src/py_workflow/mod.rs | 193 +++++ crates/taskito-workflows/Cargo.toml | 13 + crates/taskito-workflows/src/definition.rs | 40 + crates/taskito-workflows/src/error.rs | 51 ++ crates/taskito-workflows/src/lib.rs | 20 + crates/taskito-workflows/src/node.rs | 75 ++ crates/taskito-workflows/src/run.rs | 20 + crates/taskito-workflows/src/sqlite_store.rs | 688 +++++++++++++++ crates/taskito-workflows/src/state.rs | 69 ++ crates/taskito-workflows/src/storage.rs | 94 ++ crates/taskito-workflows/src/tests.rs | 703 +++++++++++++++ crates/taskito-workflows/src/topology.rs | 112 +++ pyproject.toml | 2 +- 21 files changed, 2916 insertions(+), 7 deletions(-) create mode 100644 crates/taskito-python/src/py_queue/workflow_ops.rs create mode 100644 crates/taskito-python/src/py_workflow/mod.rs create mode 100644 crates/taskito-workflows/Cargo.toml create mode 100644 crates/taskito-workflows/src/definition.rs create mode 100644 crates/taskito-workflows/src/error.rs create mode 100644 crates/taskito-workflows/src/lib.rs create mode 100644 crates/taskito-workflows/src/node.rs create mode 100644 crates/taskito-workflows/src/run.rs create mode 100644 crates/taskito-workflows/src/sqlite_store.rs create mode 100644 crates/taskito-workflows/src/state.rs create mode 100644 crates/taskito-workflows/src/storage.rs create mode 100644 crates/taskito-workflows/src/tests.rs create mode 100644 crates/taskito-workflows/src/topology.rs diff --git a/Cargo.toml b/Cargo.toml index 9c60993..9fdc42a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/taskito-core", "crates/taskito-python", "crates/taskito-async"] +members = ["crates/taskito-core", "crates/taskito-python", "crates/taskito-async", "crates/taskito-workflows"] resolver = "2" [workspace.dependencies] @@ -21,3 +21,4 @@ redis = { version = "0.27", features = ["script"] } openssl-sys = { version = "0.9", features = ["vendored"] } pyo3 = { version = "0.22", features = ["multiple-pymethods"] } async-trait = "0.1" +dagron-core = { path = "../dagron/crates/dagron-core" } diff --git a/crates/taskito-core/src/storage/postgres/mod.rs b/crates/taskito-core/src/storage/postgres/mod.rs index 6999063..57f7975 100644 --- a/crates/taskito-core/src/storage/postgres/mod.rs +++ b/crates/taskito-core/src/storage/postgres/mod.rs @@ -98,9 +98,7 @@ impl PostgresStorage { Ok(storage) } - pub(crate) fn conn( - &self, - ) -> Result>> { + pub fn conn(&self) -> Result>> { let mut conn = self.pool.get()?; diesel::sql_query(format!("SET search_path TO {}", self.schema)) .execute(&mut conn) diff --git a/crates/taskito-core/src/storage/redis_backend/mod.rs b/crates/taskito-core/src/storage/redis_backend/mod.rs index faa771f..cbe6039 100644 --- a/crates/taskito-core/src/storage/redis_backend/mod.rs +++ b/crates/taskito-core/src/storage/redis_backend/mod.rs @@ -58,7 +58,7 @@ impl RedisStorage { } /// Get a Redis connection. - fn conn(&self) -> Result { + pub fn conn(&self) -> Result { self.client .get_connection() .map_err(|e| QueueError::Other(format!("Redis connection error: {e}"))) diff --git a/crates/taskito-core/src/storage/sqlite/mod.rs b/crates/taskito-core/src/storage/sqlite/mod.rs index dc847e1..1e94ce8 100644 --- a/crates/taskito-core/src/storage/sqlite/mod.rs +++ b/crates/taskito-core/src/storage/sqlite/mod.rs @@ -95,7 +95,7 @@ impl SqliteStorage { Ok(storage) } - pub(crate) fn conn( + pub fn conn( &self, ) -> Result>> { Ok(self.pool.get()?) diff --git a/crates/taskito-python/Cargo.toml b/crates/taskito-python/Cargo.toml index 7854df9..3fbf334 100644 --- a/crates/taskito-python/Cargo.toml +++ b/crates/taskito-python/Cargo.toml @@ -9,6 +9,7 @@ extension-module = ["pyo3/extension-module"] postgres = ["taskito-core/postgres"] redis = ["taskito-core/redis"] native-async = ["dep:taskito-async"] +workflows = ["dep:taskito-workflows"] [lib] name = "_taskito" @@ -22,6 +23,7 @@ crossbeam-channel = { workspace = true } uuid = { workspace = true } async-trait = { workspace = true } taskito-async = { path = "../taskito-async", optional = true } +taskito-workflows = { path = "../taskito-workflows", optional = true } serde_json = { workspace = true } serde = { workspace = true } base64 = "0.22" diff --git a/crates/taskito-python/src/lib.rs b/crates/taskito-python/src/lib.rs index 8864afe..c3b0b34 100644 --- a/crates/taskito-python/src/lib.rs +++ b/crates/taskito-python/src/lib.rs @@ -7,6 +7,8 @@ mod py_config; mod py_job; mod py_queue; pub mod py_worker; +#[cfg(feature = "workflows")] +mod py_workflow; use py_config::PyTaskConfig; use py_job::PyJob; @@ -19,5 +21,11 @@ fn _taskito(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; #[cfg(feature = "native-async")] m.add_class::()?; + #[cfg(feature = "workflows")] + { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + } Ok(()) } diff --git a/crates/taskito-python/src/py_queue/mod.rs b/crates/taskito-python/src/py_queue/mod.rs index c3f53bb..22ed33f 100644 --- a/crates/taskito-python/src/py_queue/mod.rs +++ b/crates/taskito-python/src/py_queue/mod.rs @@ -3,6 +3,8 @@ mod inspection; mod worker; +#[cfg(feature = "workflows")] +mod workflow_ops; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; diff --git a/crates/taskito-python/src/py_queue/workflow_ops.rs b/crates/taskito-python/src/py_queue/workflow_ops.rs new file mode 100644 index 0000000..31e070a --- /dev/null +++ b/crates/taskito-python/src/py_queue/workflow_ops.rs @@ -0,0 +1,820 @@ +//! Workflow operations on `PyQueue`. +//! +//! Compiled only when the `workflows` feature is enabled. Adds +//! workflow-specific methods to `PyQueue` via a separate `#[pymethods]` +//! impl block (enabled by pyo3's `multiple-pymethods` feature). + +use std::collections::HashMap; + +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; + +use taskito_core::job::{now_millis, NewJob}; +use taskito_core::storage::{Storage, StorageBackend}; +use taskito_workflows::{ + topological_order, StepMetadata, WorkflowNode, WorkflowNodeStatus, WorkflowRun, + WorkflowSqliteStorage, WorkflowState, WorkflowStorage, +}; + +use crate::py_queue::PyQueue; +use crate::py_workflow::{PyWorkflowHandle, PyWorkflowRunStatus}; + +/// Build a `WorkflowSqliteStorage` from a `PyQueue`'s backend. +/// +/// Currently only SQLite is supported for workflows. Migrations run on +/// construction; repeated calls are cheap because the migrations use +/// `CREATE TABLE IF NOT EXISTS`. +fn workflow_storage(queue: &PyQueue) -> PyResult { + match &queue.storage { + StorageBackend::Sqlite(s) => WorkflowSqliteStorage::new(s.clone()) + .map_err(|e| PyRuntimeError::new_err(e.to_string())), + #[cfg(feature = "postgres")] + StorageBackend::Postgres(_) => Err(PyRuntimeError::new_err( + "workflows are currently only supported on the SQLite backend", + )), + #[cfg(feature = "redis")] + StorageBackend::Redis(_) => Err(PyRuntimeError::new_err( + "workflows are currently only supported on the SQLite backend", + )), + } +} + +fn parse_step_metadata(json: &str) -> PyResult> { + serde_json::from_str(json) + .map_err(|e| PyValueError::new_err(format!("invalid step_metadata JSON: {e}"))) +} + +fn build_metadata_json(run_id: &str, node_name: &str) -> String { + format!( + r#"{{"workflow_run_id":"{}","workflow_node_name":"{}"}}"#, + run_id.replace('"', "\\\""), + node_name.replace('"', "\\\""), + ) +} + +fn status_to_py(status: WorkflowState) -> String { + status.as_str().to_string() +} + +#[pymethods] +impl PyQueue { + /// Submit a workflow for execution. + /// + /// Creates (or reuses) a `WorkflowDefinition` with the given name + version, + /// inserts a `WorkflowRun`, pre-enqueues all step jobs in topological order + /// with `depends_on` chains so taskito's existing scheduler runs them in the + /// correct order. Nodes listed in `deferred_node_names` get a + /// `WorkflowNode` only (no job) — their jobs are created at runtime by the + /// Python tracker (fan-out / fan-in orchestration). + /// + /// Returns a `PyWorkflowHandle` carrying the run id. + #[pyo3(signature = ( + name, version, dag_bytes, step_metadata_json, node_payloads, + queue_default="default", params_json=None, deferred_node_names=None, + parent_run_id=None, parent_node_name=None, cache_hit_nodes=None + ))] + #[allow(clippy::too_many_arguments)] + pub fn submit_workflow( + &self, + name: &str, + version: i32, + dag_bytes: Vec, + step_metadata_json: &str, + node_payloads: HashMap>, + queue_default: &str, + params_json: Option, + deferred_node_names: Option>, + parent_run_id: Option, + parent_node_name: Option, + cache_hit_nodes: Option>, + ) -> PyResult { + let wf_storage = workflow_storage(self)?; + let step_meta = parse_step_metadata(step_metadata_json)?; + let ordered = + topological_order(&dag_bytes).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let deferred: std::collections::HashSet = deferred_node_names + .unwrap_or_default() + .into_iter() + .collect(); + let cached: HashMap = cache_hit_nodes.unwrap_or_default(); + + let definition_id = match wf_storage + .get_workflow_definition(name, Some(version)) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + { + Some(existing) => existing.id, + None => { + let def = taskito_workflows::WorkflowDefinition { + id: uuid::Uuid::now_v7().to_string(), + name: name.to_string(), + version, + dag_data: dag_bytes.clone(), + step_metadata: step_meta.clone(), + created_at: now_millis(), + }; + let def_id = def.id.clone(); + wf_storage + .create_workflow_definition(&def) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + def_id + } + }; + + let run_id = uuid::Uuid::now_v7().to_string(); + let now = now_millis(); + let run = WorkflowRun { + id: run_id.clone(), + definition_id: definition_id.clone(), + params: params_json, + state: WorkflowState::Pending, + started_at: Some(now), + completed_at: None, + error: None, + parent_run_id, + parent_node_name, + created_at: now, + }; + wf_storage + .create_workflow_run(&run) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let mut job_ids: HashMap = HashMap::new(); + for topo in &ordered { + // Deferred nodes get a WorkflowNode only — no job. + // Cache-hit nodes: copy result_hash from base run, no job. + if let Some(rh) = cached.get(&topo.name) { + let wf_node = WorkflowNode { + id: uuid::Uuid::now_v7().to_string(), + run_id: run_id.clone(), + node_name: topo.name.clone(), + job_id: None, + status: WorkflowNodeStatus::CacheHit, + result_hash: Some(rh.clone()), + fan_out_count: None, + fan_in_data: None, + started_at: None, + completed_at: Some(now), + error: None, + }; + wf_storage + .create_workflow_node(&wf_node) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + continue; + } + + // Deferred nodes: WorkflowNode only, no job. + if deferred.contains(&topo.name) { + let wf_node = WorkflowNode { + id: uuid::Uuid::now_v7().to_string(), + run_id: run_id.clone(), + node_name: topo.name.clone(), + job_id: None, + status: WorkflowNodeStatus::Pending, + result_hash: None, + fan_out_count: None, + fan_in_data: None, + started_at: None, + completed_at: None, + error: None, + }; + wf_storage + .create_workflow_node(&wf_node) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + continue; + } + + let meta = step_meta.get(&topo.name).ok_or_else(|| { + PyValueError::new_err(format!("step '{}' missing from step_metadata", topo.name)) + })?; + let payload = node_payloads.get(&topo.name).cloned().ok_or_else(|| { + PyValueError::new_err(format!("step '{}' missing from node_payloads", topo.name)) + })?; + + // Only resolve depends_on for non-deferred predecessors. + let depends_on: Vec = topo + .predecessors + .iter() + .filter(|p| !deferred.contains(*p)) + .map(|p| { + job_ids.get(p).cloned().ok_or_else(|| { + PyValueError::new_err(format!( + "predecessor '{}' of step '{}' has no job id", + p, topo.name + )) + }) + }) + .collect::>>()?; + + let timeout_ms = meta.timeout_ms.unwrap_or(self.default_timeout * 1000); + let new_job = NewJob { + queue: meta + .queue + .clone() + .unwrap_or_else(|| queue_default.to_string()), + task_name: meta.task_name.clone(), + payload, + priority: meta.priority.unwrap_or(self.default_priority), + scheduled_at: now, + max_retries: meta.max_retries.unwrap_or(self.default_retry), + timeout_ms, + unique_key: None, + metadata: Some(build_metadata_json(&run_id, &topo.name)), + depends_on, + expires_at: None, + result_ttl_ms: self.result_ttl_ms, + namespace: self.namespace.clone(), + }; + + let job = self + .storage + .enqueue(new_job) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + job_ids.insert(topo.name.clone(), job.id.clone()); + + let wf_node = WorkflowNode { + id: uuid::Uuid::now_v7().to_string(), + run_id: run_id.clone(), + node_name: topo.name.clone(), + job_id: Some(job.id), + status: WorkflowNodeStatus::Pending, + result_hash: None, + fan_out_count: None, + fan_in_data: None, + started_at: None, + completed_at: None, + error: None, + }; + wf_storage + .create_workflow_node(&wf_node) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } + + wf_storage + .update_workflow_run_state(&run_id, WorkflowState::Running, None) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(PyWorkflowHandle { + run_id, + name: name.to_string(), + definition_id, + }) + } + + /// Fetch a snapshot of a workflow run's state and per-node status. + pub fn get_workflow_run_status(&self, run_id: &str) -> PyResult { + let wf_storage = workflow_storage(self)?; + let run = wf_storage + .get_workflow_run(run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .ok_or_else(|| PyValueError::new_err(format!("workflow run '{run_id}' not found")))?; + + let nodes = wf_storage + .get_workflow_nodes(run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let node_rows = nodes + .into_iter() + .map(|n| { + ( + n.node_name, + n.status.as_str().to_string(), + n.job_id, + n.error, + ) + }) + .collect(); + + Ok(PyWorkflowRunStatus { + run_id: run.id, + state: status_to_py(run.state), + started_at: run.started_at, + completed_at: run.completed_at, + error: run.error, + nodes: node_rows, + }) + } + + /// Cancel a workflow run. + /// + /// Marks the run `Cancelled`, skips any pending nodes, and cancels + /// their underlying jobs. Nodes already running are left alone + /// (consistent with taskito's existing cancel semantics). + pub fn cancel_workflow_run(&self, run_id: &str) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + let nodes = wf_storage + .get_workflow_nodes(run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + for node in &nodes { + if matches!( + node.status, + WorkflowNodeStatus::Pending | WorkflowNodeStatus::Ready + ) { + if let Some(job_id) = &node.job_id { + let _ = self.storage.cancel_job(job_id); + } + let _ = wf_storage.update_workflow_node_status( + run_id, + &node.node_name, + WorkflowNodeStatus::Skipped, + ); + } + } + + wf_storage + .update_workflow_run_state(run_id, WorkflowState::Cancelled, None) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + wf_storage + .set_workflow_run_completed(run_id, now_millis()) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + // Cascade cancellation to child workflow runs (sub-workflows). + if let Ok(children) = wf_storage.get_child_workflow_runs(run_id) { + for child in children { + if !child.state.is_terminal() { + let _ = self.cancel_workflow_run(&child.id); + } + } + } + + Ok(()) + } + + /// Record the terminal outcome of a workflow node's job. + /// + /// Called by the Python workflow tracker in response to + /// `JOB_COMPLETED`/`JOB_FAILED`/`JOB_DEAD`/`JOB_CANCELLED` events. + /// On failure, walks remaining pending nodes in the run and skips them + /// (fail-fast semantics) unless ``skip_cascade`` is true. When the run + /// transitions to a terminal state, returns + /// `(run_id, node_name, final_state_str)`. Otherwise returns + /// `(run_id, node_name, None)` so the tracker can decide whether to + /// trigger fan-out expansion or fan-in collection. + /// + /// Set ``skip_cascade=True`` for tracker-managed runs (those with + /// conditions or ``on_failure="continue"``) so the Python tracker can + /// handle selective skip/create decisions. + #[pyo3(signature = (job_id, succeeded, error=None, skip_cascade=false, result_hash=None))] + pub fn mark_workflow_node_result( + &self, + job_id: &str, + succeeded: bool, + error: Option, + skip_cascade: bool, + result_hash: Option, + ) -> PyResult)>> { + let wf_storage = workflow_storage(self)?; + let job = self + .storage + .get_job(job_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .ok_or_else(|| PyValueError::new_err(format!("job '{job_id}' not found")))?; + + let metadata_json = match &job.metadata { + Some(m) => m, + None => return Ok(None), + }; + let parsed: serde_json::Value = match serde_json::from_str(metadata_json) { + Ok(v) => v, + Err(_) => return Ok(None), + }; + let run_id = match parsed.get("workflow_run_id").and_then(|v| v.as_str()) { + Some(id) => id.to_string(), + None => return Ok(None), + }; + let node_name = match parsed.get("workflow_node_name").and_then(|v| v.as_str()) { + Some(n) => n.to_string(), + None => return Ok(None), + }; + + let now = now_millis(); + if succeeded { + wf_storage + .set_workflow_node_completed(&run_id, &node_name, now, result_hash.as_deref()) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } else { + let err_msg = error.clone().unwrap_or_else(|| "failed".to_string()); + wf_storage + .set_workflow_node_error(&run_id, &node_name, &err_msg) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } + + // Fail-fast: cascade failure to pending/ready nodes. + // Skipped when the Python tracker manages cascade (conditions / continue mode). + if !succeeded && !skip_cascade { + let nodes = wf_storage + .get_workflow_nodes(&run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + for n in &nodes { + if matches!( + n.status, + WorkflowNodeStatus::Pending | WorkflowNodeStatus::Ready + ) { + if let Some(j) = &n.job_id { + let _ = self.storage.cancel_job(j); + } + let _ = wf_storage.update_workflow_node_status( + &run_id, + &n.node_name, + WorkflowNodeStatus::Skipped, + ); + } + } + } + + // Note: fan-out parent status is NOT updated here. The Python + // tracker calls `check_fan_out_completion` which atomically marks + // the parent and triggers fan-in. Doing it here would race. + + // Check if the entire run is terminal. + let nodes = wf_storage + .get_workflow_nodes(&run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + let all_terminal = nodes.iter().all(|n| n.status.is_terminal()); + if !all_terminal { + return Ok(Some((run_id, node_name, None))); + } + + let any_failed = nodes.iter().any(|n| n.status == WorkflowNodeStatus::Failed); + let final_state = if any_failed || !succeeded { + WorkflowState::Failed + } else { + WorkflowState::Completed + }; + + wf_storage + .update_workflow_run_state( + &run_id, + final_state, + if final_state == WorkflowState::Failed { + error.as_deref() + } else { + None + }, + ) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + wf_storage + .set_workflow_run_completed(&run_id, now) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Some(( + run_id, + node_name, + Some(final_state.as_str().to_string()), + ))) + } + + // ── Fan-out / Fan-in helpers ──────────────────────────────── + + /// Expand a fan-out node into N child nodes + jobs. + /// + /// Creates one `WorkflowNode` and one job per child. Sets the parent + /// node's `fan_out_count` and transitions it to `Running`. If the + /// children list is empty (fan-out over empty result), the parent is + /// marked `Completed` immediately. + #[pyo3(signature = ( + run_id, parent_node_name, child_names, child_payloads, + task_name, queue, max_retries, timeout_ms, priority + ))] + #[allow(clippy::too_many_arguments)] + pub fn expand_fan_out( + &self, + run_id: &str, + parent_node_name: &str, + child_names: Vec, + child_payloads: Vec>, + task_name: &str, + queue: &str, + max_retries: i32, + timeout_ms: i64, + priority: i32, + ) -> PyResult> { + if child_names.len() != child_payloads.len() { + return Err(PyValueError::new_err( + "child_names and child_payloads must have the same length", + )); + } + + let wf_storage = workflow_storage(self)?; + let now = now_millis(); + let count = child_names.len() as i32; + + // Empty fan-out: mark parent completed immediately. + if count == 0 { + wf_storage + .set_workflow_node_fan_out_count(run_id, parent_node_name, 0) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + wf_storage + .set_workflow_node_completed(run_id, parent_node_name, now, None) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + return Ok(vec![]); + } + + let mut child_job_ids = Vec::with_capacity(child_names.len()); + + for (child_name, payload) in child_names.iter().zip(child_payloads.into_iter()) { + let new_job = NewJob { + queue: queue.to_string(), + task_name: task_name.to_string(), + payload, + priority, + scheduled_at: now, + max_retries, + timeout_ms, + unique_key: None, + metadata: Some(build_metadata_json(run_id, child_name)), + depends_on: vec![], + expires_at: None, + result_ttl_ms: self.result_ttl_ms, + namespace: self.namespace.clone(), + }; + + let job = self + .storage + .enqueue(new_job) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + child_job_ids.push(job.id.clone()); + + let wf_node = WorkflowNode { + id: uuid::Uuid::now_v7().to_string(), + run_id: run_id.to_string(), + node_name: child_name.clone(), + job_id: Some(job.id), + status: WorkflowNodeStatus::Pending, + result_hash: None, + fan_out_count: None, + fan_in_data: None, + started_at: None, + completed_at: None, + error: None, + }; + wf_storage + .create_workflow_node(&wf_node) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } + + wf_storage + .set_workflow_node_fan_out_count(run_id, parent_node_name, count) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(child_job_ids) + } + + /// Create a job for a deferred workflow node. + /// + /// Used after fan-in collects results, or for static nodes downstream of + /// dynamic nodes whose predecessors are now all complete. + #[pyo3(signature = (run_id, node_name, payload, task_name, queue, max_retries, timeout_ms, priority))] + #[allow(clippy::too_many_arguments)] + pub fn create_deferred_job( + &self, + run_id: &str, + node_name: &str, + payload: Vec, + task_name: &str, + queue: &str, + max_retries: i32, + timeout_ms: i64, + priority: i32, + ) -> PyResult { + let wf_storage = workflow_storage(self)?; + let now = now_millis(); + + let new_job = NewJob { + queue: queue.to_string(), + task_name: task_name.to_string(), + payload, + priority, + scheduled_at: now, + max_retries, + timeout_ms, + unique_key: None, + metadata: Some(build_metadata_json(run_id, node_name)), + depends_on: vec![], + expires_at: None, + result_ttl_ms: self.result_ttl_ms, + namespace: self.namespace.clone(), + }; + + let job = self + .storage + .enqueue(new_job) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + wf_storage + .set_workflow_node_job(run_id, node_name, &job.id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(job.id) + } + + /// Check whether all fan-out children of a parent node are terminal. + /// + /// If all children are terminal, atomically marks the parent node as + /// `Completed` (all succeeded) or `Failed` (any failed) and returns + /// `Some((all_succeeded, child_job_ids))`. Returns `None` if not all + /// children are done yet or if the parent was already finalized by a + /// concurrent call. + pub fn check_fan_out_completion( + &self, + run_id: &str, + parent_node_name: &str, + ) -> PyResult)>> { + let wf_storage = workflow_storage(self)?; + + // Guard: if the parent is already terminal, another call beat us. + let parent = wf_storage + .get_workflow_node(run_id, parent_node_name) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .ok_or_else(|| { + PyValueError::new_err(format!( + "workflow node '{parent_node_name}' not found in run '{run_id}'" + )) + })?; + if parent.status.is_terminal() { + return Ok(None); + } + + let children = wf_storage + .get_workflow_nodes_by_prefix(run_id, &format!("{parent_node_name}[")) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + if !children.iter().all(|n| n.status.is_terminal()) { + return Ok(None); + } + + let any_failed = children + .iter() + .any(|n| n.status == WorkflowNodeStatus::Failed); + let child_job_ids: Vec = children.iter().filter_map(|n| n.job_id.clone()).collect(); + + let now = now_millis(); + if any_failed { + wf_storage + .set_workflow_node_error(run_id, parent_node_name, "fan-out child failed") + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } else { + wf_storage + .set_workflow_node_completed(run_id, parent_node_name, now, None) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } + + Ok(Some((!any_failed, child_job_ids))) + } + + /// Check whether all workflow nodes are terminal and finalize the run. + /// + /// Called by the Python tracker after updating the fan-out parent status + /// (e.g., after a failed fan-out). If all nodes are terminal, transitions + /// the run to `Completed` or `Failed` and returns the final state string. + /// Returns `None` if not all nodes are terminal yet. + pub fn finalize_run_if_terminal(&self, run_id: &str) -> PyResult> { + let wf_storage = workflow_storage(self)?; + let nodes = wf_storage + .get_workflow_nodes(run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + if !nodes.iter().all(|n| n.status.is_terminal()) { + return Ok(None); + } + + let any_failed = nodes.iter().any(|n| n.status == WorkflowNodeStatus::Failed); + let final_state = if any_failed { + WorkflowState::Failed + } else { + WorkflowState::Completed + }; + + let now = now_millis(); + wf_storage + .update_workflow_run_state( + run_id, + final_state, + if final_state == WorkflowState::Failed { + Some("fan-out child failed") + } else { + None + }, + ) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + wf_storage + .set_workflow_run_completed(run_id, now) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + Ok(Some(final_state.as_str().to_string())) + } + + /// Transition a workflow node to `WaitingApproval` status. + /// + /// Used by the Python tracker when a gate node becomes evaluable. + /// Sets `started_at` without overriding the status (unlike + /// `set_workflow_node_started` which forces `running`). + pub fn set_workflow_node_waiting_approval( + &self, + run_id: &str, + node_name: &str, + ) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + wf_storage + .update_workflow_node_status(run_id, node_name, WorkflowNodeStatus::WaitingApproval) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + } + + /// Fetch node data from a prior run for incremental caching. + /// + /// Returns a list of ``(node_name, status, result_hash)`` tuples. + pub fn get_base_run_node_data( + &self, + base_run_id: &str, + ) -> PyResult)>> { + let wf_storage = workflow_storage(self)?; + let nodes = wf_storage + .get_workflow_nodes(base_run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(nodes + .into_iter() + .map(|n| (n.node_name, n.status.as_str().to_string(), n.result_hash)) + .collect()) + } + + /// Return the DAG JSON bytes for a workflow run's definition. + /// + /// Used by the Python visualization layer to render diagrams. + pub fn get_workflow_definition_dag(&self, run_id: &str) -> PyResult> { + let wf_storage = workflow_storage(self)?; + let run = wf_storage + .get_workflow_run(run_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .ok_or_else(|| PyValueError::new_err(format!("run '{run_id}' not found")))?; + let def = wf_storage + .get_workflow_definition_by_id(&run.definition_id) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))? + .ok_or_else(|| { + PyRuntimeError::new_err(format!("definition '{}' not found", run.definition_id)) + })?; + Ok(def.dag_data) + } + + /// Set a node's fan_out_count and transition to Running. + /// + /// Also used by the tracker to mark sub-workflow parent nodes as Running. + pub fn set_workflow_node_fan_out_count( + &self, + run_id: &str, + node_name: &str, + count: i32, + ) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + wf_storage + .set_workflow_node_fan_out_count(run_id, node_name, count) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + } + + /// Approve or reject an approval gate node. + /// + /// Approved gates transition to `Completed`; rejected gates to `Failed`. + #[pyo3(signature = (run_id, node_name, approved, error=None))] + pub fn resolve_workflow_gate( + &self, + run_id: &str, + node_name: &str, + approved: bool, + error: Option, + ) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + let now = now_millis(); + if approved { + wf_storage + .set_workflow_node_completed(run_id, node_name, now, None) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } else { + let err_msg = error.unwrap_or_else(|| "rejected".to_string()); + wf_storage + .set_workflow_node_error(run_id, node_name, &err_msg) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } + Ok(()) + } + + /// Mark a single workflow node as `Skipped` and cancel its job. + /// + /// Used by the Python tracker for condition-based skip propagation. + pub fn skip_workflow_node(&self, run_id: &str, node_name: &str) -> PyResult<()> { + let wf_storage = workflow_storage(self)?; + let node = wf_storage + .get_workflow_node(run_id, node_name) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + if let Some(node) = node { + if let Some(job_id) = &node.job_id { + let _ = self.storage.cancel_job(job_id); + } + wf_storage + .update_workflow_node_status(run_id, node_name, WorkflowNodeStatus::Skipped) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + } + Ok(()) + } +} diff --git a/crates/taskito-python/src/py_workflow/mod.rs b/crates/taskito-python/src/py_workflow/mod.rs new file mode 100644 index 0000000..c6428c8 --- /dev/null +++ b/crates/taskito-python/src/py_workflow/mod.rs @@ -0,0 +1,193 @@ +#![allow(clippy::useless_conversion)] +//! Python bindings for the `taskito-workflows` crate. +//! +//! Compiled only when the `workflows` feature is enabled. Exposes three +//! `#[pyclass]` types: +//! +//! * [`PyWorkflowBuilder`] — construct a DAG from Python and serialize it +//! for storage. +//! * [`PyWorkflowHandle`] — opaque handle returned from `PyQueue::submit_workflow` +//! carrying the run id. +//! * [`PyWorkflowRunStatus`] — snapshot of a workflow run, returned by +//! `PyQueue::get_workflow_run_status`. + +use std::collections::HashMap; + +use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::prelude::*; +use pyo3::types::PyDict; + +use taskito_workflows::dagron_core::DAG; +use taskito_workflows::StepMetadata; + +/// Builder for a workflow DAG. +/// +/// Construct in Python, add steps, then call `serialize()` to produce +/// `(dag_json_bytes, step_metadata_json)` for submission. +#[pyclass] +pub struct PyWorkflowBuilder { + dag: DAG<()>, + step_metadata: HashMap, + step_order: Vec, +} + +#[pymethods] +impl PyWorkflowBuilder { + #[new] + fn new() -> Self { + Self { + dag: DAG::new(), + step_metadata: HashMap::new(), + step_order: Vec::new(), + } + } + + /// Add a step to the workflow. + /// + /// `after` lists the names of predecessor steps that must complete before + /// this step runs. All predecessors must already have been added. + #[pyo3(signature = ( + name, task_name, after=None, queue=None, max_retries=None, + timeout_ms=None, priority=None, args_template=None, kwargs_template=None, + fan_out=None, fan_in=None, condition=None + ))] + #[allow(clippy::too_many_arguments)] + fn add_step( + &mut self, + name: &str, + task_name: &str, + after: Option>, + queue: Option, + max_retries: Option, + timeout_ms: Option, + priority: Option, + args_template: Option, + kwargs_template: Option, + fan_out: Option, + fan_in: Option, + condition: Option, + ) -> PyResult<()> { + self.dag + .add_node(name.to_string(), ()) + .map_err(|e| PyValueError::new_err(format!("add_node failed: {e}")))?; + + if let Some(preds) = after { + for pred in preds { + self.dag + .add_edge(&pred, name, None, None) + .map_err(|e| PyValueError::new_err(format!("add_edge failed: {e}")))?; + } + } + + self.step_metadata.insert( + name.to_string(), + StepMetadata { + task_name: task_name.to_string(), + queue, + args_template, + kwargs_template, + max_retries, + timeout_ms, + priority, + fan_out, + fan_in, + condition, + }, + ); + self.step_order.push(name.to_string()); + Ok(()) + } + + /// Return the number of steps added. + fn step_count(&self) -> usize { + self.step_order.len() + } + + /// Return the names of steps in insertion order. + fn step_names(&self) -> Vec { + self.step_order.clone() + } + + /// Serialize the DAG and its step metadata for storage. + /// + /// Returns `(dag_bytes, step_metadata_json)` where: + /// * `dag_bytes` is the UTF-8 encoded JSON of the dagron + /// `SerializableGraph` (no payloads). + /// * `step_metadata_json` is the JSON-encoded + /// `HashMap`. + fn serialize(&self) -> PyResult<(Vec, String)> { + let dag_json = self + .dag + .to_json(|_| None) + .map_err(|e| PyRuntimeError::new_err(format!("DAG to_json failed: {e}")))?; + let meta_json = serde_json::to_string(&self.step_metadata) + .map_err(|e| PyRuntimeError::new_err(format!("step_metadata serialize failed: {e}")))?; + Ok((dag_json.into_bytes(), meta_json)) + } +} + +/// Opaque handle returned from `PyQueue::submit_workflow`. +#[pyclass] +#[derive(Clone)] +pub struct PyWorkflowHandle { + #[pyo3(get)] + pub run_id: String, + #[pyo3(get)] + pub name: String, + #[pyo3(get)] + pub definition_id: String, +} + +#[pymethods] +impl PyWorkflowHandle { + fn __repr__(&self) -> String { + format!( + "PyWorkflowHandle(run_id='{}', name='{}')", + self.run_id, self.name + ) + } +} + +/// Snapshot of a workflow run's state and per-node status. +#[pyclass] +#[derive(Clone)] +pub struct PyWorkflowRunStatus { + #[pyo3(get)] + pub run_id: String, + #[pyo3(get)] + pub state: String, + #[pyo3(get)] + pub started_at: Option, + #[pyo3(get)] + pub completed_at: Option, + #[pyo3(get)] + pub error: Option, + pub nodes: Vec<(String, String, Option, Option)>, +} + +#[pymethods] +impl PyWorkflowRunStatus { + /// Return per-node status as a dict keyed by node name. + /// + /// Each value is a dict with keys `status`, `job_id`, and `error`. + fn node_statuses<'py>(&self, py: Python<'py>) -> PyResult> { + let out = PyDict::new_bound(py); + for (name, status, job_id, error) in &self.nodes { + let entry = PyDict::new_bound(py); + entry.set_item("status", status)?; + entry.set_item("job_id", job_id.clone())?; + entry.set_item("error", error.clone())?; + out.set_item(name, entry)?; + } + Ok(out) + } + + fn __repr__(&self) -> String { + format!( + "PyWorkflowRunStatus(run_id='{}', state='{}', nodes={})", + self.run_id, + self.state, + self.nodes.len() + ) + } +} diff --git a/crates/taskito-workflows/Cargo.toml b/crates/taskito-workflows/Cargo.toml new file mode 100644 index 0000000..0c943c6 --- /dev/null +++ b/crates/taskito-workflows/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "taskito-workflows" +version = "0.10.1" +edition = "2021" + +[dependencies] +taskito-core = { path = "../taskito-core" } +dagron-core = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +uuid = { workspace = true } +diesel = { workspace = true } +log = { workspace = true } diff --git a/crates/taskito-workflows/src/definition.rs b/crates/taskito-workflows/src/definition.rs new file mode 100644 index 0000000..16de2c6 --- /dev/null +++ b/crates/taskito-workflows/src/definition.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +/// Metadata for a single step in a workflow definition. +/// +/// Stored alongside the DAG structure to map node names to task queue details. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StepMetadata { + pub task_name: String, + #[serde(default)] + pub queue: Option, + #[serde(default)] + pub args_template: Option, + #[serde(default)] + pub kwargs_template: Option, + #[serde(default)] + pub max_retries: Option, + #[serde(default)] + pub timeout_ms: Option, + #[serde(default)] + pub priority: Option, + #[serde(default)] + pub fan_out: Option, + #[serde(default)] + pub fan_in: Option, + #[serde(default)] + pub condition: Option, +} + +/// A persisted workflow definition: the DAG structure plus per-step metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowDefinition { + pub id: String, + pub name: String, + pub version: i32, + /// The serialized dagron DAG (JSON via `SerializableGraph`). + pub dag_data: Vec, + /// Per-node metadata mapping node names to task configuration. + pub step_metadata: std::collections::HashMap, + pub created_at: i64, +} diff --git a/crates/taskito-workflows/src/error.rs b/crates/taskito-workflows/src/error.rs new file mode 100644 index 0000000..bf3243e --- /dev/null +++ b/crates/taskito-workflows/src/error.rs @@ -0,0 +1,51 @@ +use std::fmt; + +use crate::state::WorkflowState; + +#[derive(Debug)] +pub enum WorkflowError { + /// The requested workflow definition was not found. + DefinitionNotFound(String), + /// The requested workflow run was not found. + RunNotFound(String), + /// A node with this name was not found in the workflow. + NodeNotFound { run_id: String, node_name: String }, + /// Invalid state transition for a workflow run. + InvalidTransition { + from: WorkflowState, + to: WorkflowState, + }, + /// The DAG is structurally invalid (e.g. cycle detected). + InvalidDag(String), + /// The workflow definition already exists (name + version conflict). + DuplicateDefinition { name: String, version: i32 }, +} + +impl fmt::Display for WorkflowError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::DefinitionNotFound(name) => { + write!(f, "workflow definition not found: {name}") + } + Self::RunNotFound(id) => write!(f, "workflow run not found: {id}"), + Self::NodeNotFound { run_id, node_name } => { + write!(f, "node '{node_name}' not found in workflow run {run_id}") + } + Self::InvalidTransition { from, to } => { + write!(f, "invalid workflow state transition: {from} → {to}") + } + Self::InvalidDag(msg) => write!(f, "invalid workflow DAG: {msg}"), + Self::DuplicateDefinition { name, version } => { + write!(f, "workflow definition already exists: {name} v{version}") + } + } + } +} + +impl std::error::Error for WorkflowError {} + +impl From for taskito_core::error::QueueError { + fn from(e: WorkflowError) -> Self { + taskito_core::error::QueueError::Other(e.to_string()) + } +} diff --git a/crates/taskito-workflows/src/lib.rs b/crates/taskito-workflows/src/lib.rs new file mode 100644 index 0000000..72b8479 --- /dev/null +++ b/crates/taskito-workflows/src/lib.rs @@ -0,0 +1,20 @@ +mod definition; +mod error; +mod node; +mod run; +pub mod sqlite_store; +mod state; +pub mod storage; +#[cfg(test)] +mod tests; +pub mod topology; + +pub use dagron_core; +pub use definition::{StepMetadata, WorkflowDefinition}; +pub use error::WorkflowError; +pub use node::{WorkflowNode, WorkflowNodeStatus}; +pub use run::WorkflowRun; +pub use sqlite_store::WorkflowSqliteStorage; +pub use state::WorkflowState; +pub use storage::WorkflowStorage; +pub use topology::{topological_order, TopologicalNode}; diff --git a/crates/taskito-workflows/src/node.rs b/crates/taskito-workflows/src/node.rs new file mode 100644 index 0000000..afcc162 --- /dev/null +++ b/crates/taskito-workflows/src/node.rs @@ -0,0 +1,75 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// Status of a single node within a workflow run. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkflowNodeStatus { + Pending, + Ready, + Running, + Completed, + Failed, + Skipped, + WaitingApproval, + CacheHit, +} + +impl WorkflowNodeStatus { + pub fn as_str(&self) -> &'static str { + match self { + Self::Pending => "pending", + Self::Ready => "ready", + Self::Running => "running", + Self::Completed => "completed", + Self::Failed => "failed", + Self::Skipped => "skipped", + Self::WaitingApproval => "waiting_approval", + Self::CacheHit => "cache_hit", + } + } + + pub fn from_str_val(s: &str) -> Option { + match s { + "pending" => Some(Self::Pending), + "ready" => Some(Self::Ready), + "running" => Some(Self::Running), + "completed" => Some(Self::Completed), + "failed" => Some(Self::Failed), + "skipped" => Some(Self::Skipped), + "waiting_approval" => Some(Self::WaitingApproval), + "cache_hit" => Some(Self::CacheHit), + _ => None, + } + } + + pub fn is_terminal(&self) -> bool { + matches!( + self, + Self::Completed | Self::Failed | Self::Skipped | Self::CacheHit + ) + } +} + +impl fmt::Display for WorkflowNodeStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +/// A single node instance within a workflow run. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowNode { + pub id: String, + pub run_id: String, + pub node_name: String, + pub job_id: Option, + pub status: WorkflowNodeStatus, + pub result_hash: Option, + pub fan_out_count: Option, + pub fan_in_data: Option, + pub started_at: Option, + pub completed_at: Option, + pub error: Option, +} diff --git a/crates/taskito-workflows/src/run.rs b/crates/taskito-workflows/src/run.rs new file mode 100644 index 0000000..d40598e --- /dev/null +++ b/crates/taskito-workflows/src/run.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +use crate::state::WorkflowState; + +/// A single execution of a workflow definition. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorkflowRun { + pub id: String, + pub definition_id: String, + pub params: Option, + pub state: WorkflowState, + pub started_at: Option, + pub completed_at: Option, + pub error: Option, + /// For sub-workflows: the parent run that spawned this one. + pub parent_run_id: Option, + /// For sub-workflows: the node in the parent that triggered this run. + pub parent_node_name: Option, + pub created_at: i64, +} diff --git a/crates/taskito-workflows/src/sqlite_store.rs b/crates/taskito-workflows/src/sqlite_store.rs new file mode 100644 index 0000000..3a38b03 --- /dev/null +++ b/crates/taskito-workflows/src/sqlite_store.rs @@ -0,0 +1,688 @@ +use diesel::prelude::*; +use diesel::sql_types::Text; +use diesel::sqlite::SqliteConnection; + +use taskito_core::error::Result; +use taskito_core::storage::sqlite::SqliteStorage; + +use crate::storage::WorkflowStorage; +use crate::{ + StepMetadata, WorkflowDefinition, WorkflowNode, WorkflowNodeStatus, WorkflowRun, WorkflowState, +}; + +// ── Row types for sql_query results ────────────────────────────────── + +#[derive(QueryableByName)] +struct DefinitionRow { + #[diesel(sql_type = Text)] + id: String, + #[diesel(sql_type = Text)] + name: String, + #[diesel(sql_type = diesel::sql_types::Integer)] + version: i32, + #[diesel(sql_type = diesel::sql_types::Binary)] + dag_data: Vec, + #[diesel(sql_type = Text)] + step_metadata: String, + #[diesel(sql_type = diesel::sql_types::BigInt)] + created_at: i64, +} + +#[derive(QueryableByName)] +struct RunRow { + #[diesel(sql_type = Text)] + id: String, + #[diesel(sql_type = Text)] + definition_id: String, + #[diesel(sql_type = diesel::sql_types::Nullable)] + params: Option, + #[diesel(sql_type = Text)] + state: String, + #[diesel(sql_type = diesel::sql_types::Nullable)] + started_at: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + completed_at: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + error: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + parent_run_id: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + parent_node_name: Option, + #[diesel(sql_type = diesel::sql_types::BigInt)] + created_at: i64, +} + +#[derive(QueryableByName)] +struct NodeRow { + #[diesel(sql_type = Text)] + id: String, + #[diesel(sql_type = Text)] + run_id: String, + #[diesel(sql_type = Text)] + node_name: String, + #[diesel(sql_type = diesel::sql_types::Nullable)] + job_id: Option, + #[diesel(sql_type = Text)] + status: String, + #[diesel(sql_type = diesel::sql_types::Nullable)] + result_hash: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + fan_out_count: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + fan_in_data: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + started_at: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + completed_at: Option, + #[diesel(sql_type = diesel::sql_types::Nullable)] + error: Option, +} + +// ── Conversions ────────────────────────────────────────────────────── + +fn definition_from_row(row: DefinitionRow) -> Result { + let step_metadata: std::collections::HashMap = + serde_json::from_str(&row.step_metadata).map_err(|e| { + taskito_core::error::QueueError::Serialization(format!( + "failed to deserialize step_metadata: {e}" + )) + })?; + Ok(WorkflowDefinition { + id: row.id, + name: row.name, + version: row.version, + dag_data: row.dag_data, + step_metadata, + created_at: row.created_at, + }) +} + +fn run_from_row(row: RunRow) -> WorkflowRun { + WorkflowRun { + id: row.id, + definition_id: row.definition_id, + params: row.params, + state: WorkflowState::from_str_val(&row.state).unwrap_or(WorkflowState::Pending), + started_at: row.started_at, + completed_at: row.completed_at, + error: row.error, + parent_run_id: row.parent_run_id, + parent_node_name: row.parent_node_name, + created_at: row.created_at, + } +} + +fn node_from_row(row: NodeRow) -> WorkflowNode { + WorkflowNode { + id: row.id, + run_id: row.run_id, + node_name: row.node_name, + job_id: row.job_id, + status: WorkflowNodeStatus::from_str_val(&row.status) + .unwrap_or(WorkflowNodeStatus::Pending), + result_hash: row.result_hash, + fan_out_count: row.fan_out_count, + fan_in_data: row.fan_in_data, + started_at: row.started_at, + completed_at: row.completed_at, + error: row.error, + } +} + +// ── Migrations ─────────────────────────────────────────────────────── + +fn run_workflow_migrations(conn: &mut SqliteConnection) -> Result<()> { + diesel::sql_query( + "CREATE TABLE IF NOT EXISTS workflow_definitions ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + version INTEGER NOT NULL DEFAULT 1, + dag_data BLOB NOT NULL, + step_metadata TEXT NOT NULL, + created_at INTEGER NOT NULL, + UNIQUE(name, version) + )", + ) + .execute(conn)?; + + diesel::sql_query("CREATE INDEX IF NOT EXISTS idx_wf_def_name ON workflow_definitions(name)") + .execute(conn)?; + + diesel::sql_query( + "CREATE TABLE IF NOT EXISTS workflow_runs ( + id TEXT PRIMARY KEY, + definition_id TEXT NOT NULL, + params TEXT, + state TEXT NOT NULL DEFAULT 'pending', + started_at INTEGER, + completed_at INTEGER, + error TEXT, + parent_run_id TEXT, + parent_node_name TEXT, + created_at INTEGER NOT NULL + )", + ) + .execute(conn)?; + + diesel::sql_query("CREATE INDEX IF NOT EXISTS idx_wf_run_def ON workflow_runs(definition_id)") + .execute(conn)?; + + diesel::sql_query("CREATE INDEX IF NOT EXISTS idx_wf_run_state ON workflow_runs(state)") + .execute(conn)?; + + diesel::sql_query( + "CREATE TABLE IF NOT EXISTS workflow_nodes ( + id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + node_name TEXT NOT NULL, + job_id TEXT, + status TEXT NOT NULL DEFAULT 'pending', + result_hash TEXT, + fan_out_count INTEGER, + fan_in_data TEXT, + started_at INTEGER, + completed_at INTEGER, + error TEXT, + UNIQUE(run_id, node_name) + )", + ) + .execute(conn)?; + + diesel::sql_query("CREATE INDEX IF NOT EXISTS idx_wf_node_run ON workflow_nodes(run_id)") + .execute(conn)?; + + diesel::sql_query( + "CREATE INDEX IF NOT EXISTS idx_wf_node_status ON workflow_nodes(run_id, status)", + ) + .execute(conn)?; + + Ok(()) +} + +// ── Wrapper that runs migrations ───────────────────────────────────── + +/// Workflow-aware wrapper around `SqliteStorage`. +/// +/// Runs workflow table migrations on construction, then delegates +/// all `WorkflowStorage` operations to the underlying connection pool. +#[derive(Clone)] +pub struct WorkflowSqliteStorage { + inner: SqliteStorage, +} + +impl WorkflowSqliteStorage { + /// Wrap an existing `SqliteStorage` and ensure workflow tables exist. + pub fn new(storage: SqliteStorage) -> Result { + let mut conn = storage.conn()?; + run_workflow_migrations(&mut conn)?; + Ok(Self { inner: storage }) + } + + /// Access the underlying `SqliteStorage`. + pub fn inner(&self) -> &SqliteStorage { + &self.inner + } +} + +// ── WorkflowStorage impl ──────────────────────────────────────────── + +impl WorkflowStorage for WorkflowSqliteStorage { + fn create_workflow_definition(&self, def: &WorkflowDefinition) -> Result<()> { + let mut conn = self.inner.conn()?; + let meta_json = serde_json::to_string(&def.step_metadata).map_err(|e| { + taskito_core::error::QueueError::Serialization(format!( + "failed to serialize step_metadata: {e}" + )) + })?; + + diesel::sql_query( + "INSERT INTO workflow_definitions (id, name, version, dag_data, step_metadata, created_at) + VALUES (?, ?, ?, ?, ?, ?)", + ) + .bind::(&def.id) + .bind::(&def.name) + .bind::(def.version) + .bind::(&def.dag_data) + .bind::(&meta_json) + .bind::(def.created_at) + .execute(&mut conn)?; + + Ok(()) + } + + fn get_workflow_definition( + &self, + name: &str, + version: Option, + ) -> Result> { + let mut conn = self.inner.conn()?; + + let rows: Vec = if let Some(v) = version { + diesel::sql_query( + "SELECT id, name, version, dag_data, step_metadata, created_at + FROM workflow_definitions WHERE name = ? AND version = ?", + ) + .bind::(name) + .bind::(v) + .load(&mut conn)? + } else { + diesel::sql_query( + "SELECT id, name, version, dag_data, step_metadata, created_at + FROM workflow_definitions WHERE name = ? ORDER BY version DESC LIMIT 1", + ) + .bind::(name) + .load(&mut conn)? + }; + + match rows.into_iter().next() { + Some(row) => Ok(Some(definition_from_row(row)?)), + None => Ok(None), + } + } + + fn get_workflow_definition_by_id(&self, id: &str) -> Result> { + let mut conn = self.inner.conn()?; + let rows: Vec = diesel::sql_query( + "SELECT id, name, version, dag_data, step_metadata, created_at + FROM workflow_definitions WHERE id = ?", + ) + .bind::(id) + .load(&mut conn)?; + + match rows.into_iter().next() { + Some(row) => Ok(Some(definition_from_row(row)?)), + None => Ok(None), + } + } + + fn create_workflow_run(&self, run: &WorkflowRun) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "INSERT INTO workflow_runs + (id, definition_id, params, state, started_at, completed_at, error, + parent_run_id, parent_node_name, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ) + .bind::(&run.id) + .bind::(&run.definition_id) + .bind::, _>(&run.params) + .bind::(run.state.as_str()) + .bind::, _>(run.started_at) + .bind::, _>(run.completed_at) + .bind::, _>(&run.error) + .bind::, _>(&run.parent_run_id) + .bind::, _>(&run.parent_node_name) + .bind::(run.created_at) + .execute(&mut conn)?; + + Ok(()) + } + + fn get_workflow_run(&self, run_id: &str) -> Result> { + let mut conn = self.inner.conn()?; + let rows: Vec = diesel::sql_query( + "SELECT id, definition_id, params, state, started_at, completed_at, error, + parent_run_id, parent_node_name, created_at + FROM workflow_runs WHERE id = ?", + ) + .bind::(run_id) + .load(&mut conn)?; + + Ok(rows.into_iter().next().map(run_from_row)) + } + + fn update_workflow_run_state( + &self, + run_id: &str, + state: WorkflowState, + error: Option<&str>, + ) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query("UPDATE workflow_runs SET state = ?, error = ? WHERE id = ?") + .bind::(state.as_str()) + .bind::, _>(error) + .bind::(run_id) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_run_started(&self, run_id: &str, started_at: i64) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_runs SET state = 'running', started_at = ? WHERE id = ?", + ) + .bind::(started_at) + .bind::(run_id) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_run_completed(&self, run_id: &str, completed_at: i64) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query("UPDATE workflow_runs SET completed_at = ? WHERE id = ?") + .bind::(completed_at) + .bind::(run_id) + .execute(&mut conn)?; + Ok(()) + } + + fn list_workflow_runs( + &self, + definition_name: Option<&str>, + state: Option, + limit: i64, + offset: i64, + ) -> Result> { + let mut conn = self.inner.conn()?; + + let rows: Vec = match (definition_name, state) { + (Some(name), Some(st)) => diesel::sql_query( + "SELECT r.id, r.definition_id, r.params, r.state, r.started_at, + r.completed_at, r.error, r.parent_run_id, r.parent_node_name, + r.created_at + FROM workflow_runs r + JOIN workflow_definitions d ON r.definition_id = d.id + WHERE d.name = ? AND r.state = ? + ORDER BY r.created_at DESC LIMIT ? OFFSET ?", + ) + .bind::(name) + .bind::(st.as_str()) + .bind::(limit) + .bind::(offset) + .load(&mut conn)?, + (Some(name), None) => diesel::sql_query( + "SELECT r.id, r.definition_id, r.params, r.state, r.started_at, + r.completed_at, r.error, r.parent_run_id, r.parent_node_name, + r.created_at + FROM workflow_runs r + JOIN workflow_definitions d ON r.definition_id = d.id + WHERE d.name = ? + ORDER BY r.created_at DESC LIMIT ? OFFSET ?", + ) + .bind::(name) + .bind::(limit) + .bind::(offset) + .load(&mut conn)?, + (None, Some(st)) => diesel::sql_query( + "SELECT id, definition_id, params, state, started_at, completed_at, + error, parent_run_id, parent_node_name, created_at + FROM workflow_runs WHERE state = ? + ORDER BY created_at DESC LIMIT ? OFFSET ?", + ) + .bind::(st.as_str()) + .bind::(limit) + .bind::(offset) + .load(&mut conn)?, + (None, None) => diesel::sql_query( + "SELECT id, definition_id, params, state, started_at, completed_at, + error, parent_run_id, parent_node_name, created_at + FROM workflow_runs + ORDER BY created_at DESC LIMIT ? OFFSET ?", + ) + .bind::(limit) + .bind::(offset) + .load(&mut conn)?, + }; + + Ok(rows.into_iter().map(run_from_row).collect()) + } + + fn create_workflow_node(&self, node: &WorkflowNode) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "INSERT INTO workflow_nodes + (id, run_id, node_name, job_id, status, result_hash, + fan_out_count, fan_in_data, started_at, completed_at, error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ) + .bind::(&node.id) + .bind::(&node.run_id) + .bind::(&node.node_name) + .bind::, _>(&node.job_id) + .bind::(node.status.as_str()) + .bind::, _>(&node.result_hash) + .bind::, _>(node.fan_out_count) + .bind::, _>(&node.fan_in_data) + .bind::, _>(node.started_at) + .bind::, _>(node.completed_at) + .bind::, _>(&node.error) + .execute(&mut conn)?; + Ok(()) + } + + fn create_workflow_nodes_batch(&self, nodes: &[WorkflowNode]) -> Result<()> { + let mut conn = self.inner.conn()?; + for node in nodes { + diesel::sql_query( + "INSERT INTO workflow_nodes + (id, run_id, node_name, job_id, status, result_hash, + fan_out_count, fan_in_data, started_at, completed_at, error) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + ) + .bind::(&node.id) + .bind::(&node.run_id) + .bind::(&node.node_name) + .bind::, _>(&node.job_id) + .bind::(node.status.as_str()) + .bind::, _>(&node.result_hash) + .bind::, _>(node.fan_out_count) + .bind::, _>(&node.fan_in_data) + .bind::, _>(node.started_at) + .bind::, _>(node.completed_at) + .bind::, _>(&node.error) + .execute(&mut conn)?; + } + Ok(()) + } + + fn get_workflow_node(&self, run_id: &str, node_name: &str) -> Result> { + let mut conn = self.inner.conn()?; + let rows: Vec = diesel::sql_query( + "SELECT id, run_id, node_name, job_id, status, result_hash, + fan_out_count, fan_in_data, started_at, completed_at, error + FROM workflow_nodes WHERE run_id = ? AND node_name = ?", + ) + .bind::(run_id) + .bind::(node_name) + .load(&mut conn)?; + + Ok(rows.into_iter().next().map(node_from_row)) + } + + fn get_workflow_nodes(&self, run_id: &str) -> Result> { + let mut conn = self.inner.conn()?; + let rows: Vec = diesel::sql_query( + "SELECT id, run_id, node_name, job_id, status, result_hash, + fan_out_count, fan_in_data, started_at, completed_at, error + FROM workflow_nodes WHERE run_id = ?", + ) + .bind::(run_id) + .load(&mut conn)?; + + Ok(rows.into_iter().map(node_from_row).collect()) + } + + fn update_workflow_node_status( + &self, + run_id: &str, + node_name: &str, + status: WorkflowNodeStatus, + ) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET status = ? WHERE run_id = ? AND node_name = ?", + ) + .bind::(status.as_str()) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_node_job(&self, run_id: &str, node_name: &str, job_id: &str) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET job_id = ? WHERE run_id = ? AND node_name = ?", + ) + .bind::(job_id) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_node_started( + &self, + run_id: &str, + node_name: &str, + started_at: i64, + ) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET status = 'running', started_at = ? + WHERE run_id = ? AND node_name = ?", + ) + .bind::(started_at) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_node_completed( + &self, + run_id: &str, + node_name: &str, + completed_at: i64, + result_hash: Option<&str>, + ) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET status = 'completed', completed_at = ?, result_hash = ? + WHERE run_id = ? AND node_name = ?", + ) + .bind::(completed_at) + .bind::, _>(result_hash) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_node_error(&self, run_id: &str, node_name: &str, error: &str) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET status = 'failed', error = ? + WHERE run_id = ? AND node_name = ?", + ) + .bind::(error) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn set_workflow_node_fan_out_count( + &self, + run_id: &str, + node_name: &str, + count: i32, + ) -> Result<()> { + let mut conn = self.inner.conn()?; + diesel::sql_query( + "UPDATE workflow_nodes SET fan_out_count = ?, status = 'running' + WHERE run_id = ? AND node_name = ?", + ) + .bind::(count) + .bind::(run_id) + .bind::(node_name) + .execute(&mut conn)?; + Ok(()) + } + + fn get_workflow_nodes_by_prefix( + &self, + run_id: &str, + prefix: &str, + ) -> Result> { + let mut conn = self.inner.conn()?; + let pattern = format!("{prefix}%"); + let rows: Vec = diesel::sql_query( + "SELECT id, run_id, node_name, job_id, status, result_hash, + fan_out_count, fan_in_data, started_at, completed_at, error + FROM workflow_nodes WHERE run_id = ? AND node_name LIKE ?", + ) + .bind::(run_id) + .bind::(&pattern) + .load(&mut conn)?; + + Ok(rows.into_iter().map(node_from_row).collect()) + } + + fn get_ready_workflow_nodes(&self, run_id: &str, dag_json: &str) -> Result> { + // Parse the DAG to find edges (predecessor relationships). + let graph: dagron_core::SerializableGraph = + serde_json::from_str(dag_json).map_err(|e| { + taskito_core::error::QueueError::Serialization(format!( + "failed to deserialize workflow DAG: {e}" + )) + })?; + + // Build predecessor map: node_name → set of predecessor names. + let mut predecessors: std::collections::HashMap> = + std::collections::HashMap::new(); + for node in &graph.nodes { + predecessors.entry(node.name.clone()).or_default(); + } + for edge in &graph.edges { + predecessors + .entry(edge.to.clone()) + .or_default() + .push(edge.from.clone()); + } + + // Load all nodes for this run. + let all_nodes = self.get_workflow_nodes(run_id)?; + + // Build an owned status lookup. + let status_map: std::collections::HashMap = all_nodes + .iter() + .map(|n| (n.node_name.clone(), n.status)) + .collect(); + + // A node is ready if: + // 1. Its status is Pending + // 2. All its DAG predecessors have status Completed + let ready: Vec = all_nodes + .into_iter() + .filter(|node| { + if node.status != WorkflowNodeStatus::Pending { + return false; + } + let preds = match predecessors.get(&node.node_name) { + Some(p) => p, + None => return true, // no predecessors → root node → ready + }; + preds.iter().all(|pred_name| { + status_map.get(pred_name.as_str()).copied() + == Some(WorkflowNodeStatus::Completed) + }) + }) + .collect(); + + Ok(ready) + } + + fn get_child_workflow_runs(&self, parent_run_id: &str) -> Result> { + let mut conn = self.inner.conn()?; + let rows: Vec = diesel::sql_query( + "SELECT id, definition_id, params, state, started_at, completed_at, + error, parent_run_id, parent_node_name, created_at + FROM workflow_runs WHERE parent_run_id = ?", + ) + .bind::(parent_run_id) + .load(&mut conn)?; + + Ok(rows.into_iter().map(run_from_row).collect()) + } +} diff --git a/crates/taskito-workflows/src/state.rs b/crates/taskito-workflows/src/state.rs new file mode 100644 index 0000000..93724e7 --- /dev/null +++ b/crates/taskito-workflows/src/state.rs @@ -0,0 +1,69 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// State machine for a workflow run. +/// +/// Transitions: +/// Pending → Running +/// Running → Completed | Failed | Cancelled | Paused +/// Paused → Running | Cancelled +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum WorkflowState { + Pending, + Running, + Paused, + Completed, + Failed, + Cancelled, +} + +impl WorkflowState { + pub fn as_str(&self) -> &'static str { + match self { + Self::Pending => "pending", + Self::Running => "running", + Self::Paused => "paused", + Self::Completed => "completed", + Self::Failed => "failed", + Self::Cancelled => "cancelled", + } + } + + pub fn from_str_val(s: &str) -> Option { + match s { + "pending" => Some(Self::Pending), + "running" => Some(Self::Running), + "paused" => Some(Self::Paused), + "completed" => Some(Self::Completed), + "failed" => Some(Self::Failed), + "cancelled" => Some(Self::Cancelled), + _ => None, + } + } + + pub fn is_terminal(&self) -> bool { + matches!(self, Self::Completed | Self::Failed | Self::Cancelled) + } + + /// Check whether transitioning from `self` to `target` is valid. + pub fn can_transition_to(&self, target: Self) -> bool { + matches!( + (self, target), + (Self::Pending, Self::Running) + | (Self::Running, Self::Completed) + | (Self::Running, Self::Failed) + | (Self::Running, Self::Cancelled) + | (Self::Running, Self::Paused) + | (Self::Paused, Self::Running) + | (Self::Paused, Self::Cancelled) + ) + } +} + +impl fmt::Display for WorkflowState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} diff --git a/crates/taskito-workflows/src/storage.rs b/crates/taskito-workflows/src/storage.rs new file mode 100644 index 0000000..c88a649 --- /dev/null +++ b/crates/taskito-workflows/src/storage.rs @@ -0,0 +1,94 @@ +use taskito_core::error::Result; + +use crate::{WorkflowDefinition, WorkflowNode, WorkflowNodeStatus, WorkflowRun, WorkflowState}; + +/// Storage operations for workflows. +/// +/// Kept as a separate trait from `Storage` so that the workflow feature +/// doesn't bloat the core trait or require feature-gated methods everywhere. +pub trait WorkflowStorage: Send + Sync { + // ── Definitions ──────────────────────────────────────────────── + + fn create_workflow_definition(&self, def: &WorkflowDefinition) -> Result<()>; + fn get_workflow_definition( + &self, + name: &str, + version: Option, + ) -> Result>; + fn get_workflow_definition_by_id(&self, id: &str) -> Result>; + + // ── Runs ─────────────────────────────────────────────────────── + + fn create_workflow_run(&self, run: &WorkflowRun) -> Result<()>; + fn get_workflow_run(&self, run_id: &str) -> Result>; + fn update_workflow_run_state( + &self, + run_id: &str, + state: WorkflowState, + error: Option<&str>, + ) -> Result<()>; + fn set_workflow_run_started(&self, run_id: &str, started_at: i64) -> Result<()>; + fn set_workflow_run_completed(&self, run_id: &str, completed_at: i64) -> Result<()>; + fn list_workflow_runs( + &self, + definition_name: Option<&str>, + state: Option, + limit: i64, + offset: i64, + ) -> Result>; + + // ── Nodes ────────────────────────────────────────────────────── + + fn create_workflow_node(&self, node: &WorkflowNode) -> Result<()>; + fn create_workflow_nodes_batch(&self, nodes: &[WorkflowNode]) -> Result<()>; + fn get_workflow_node(&self, run_id: &str, node_name: &str) -> Result>; + fn get_workflow_nodes(&self, run_id: &str) -> Result>; + fn update_workflow_node_status( + &self, + run_id: &str, + node_name: &str, + status: WorkflowNodeStatus, + ) -> Result<()>; + fn set_workflow_node_job(&self, run_id: &str, node_name: &str, job_id: &str) -> Result<()>; + fn set_workflow_node_started( + &self, + run_id: &str, + node_name: &str, + started_at: i64, + ) -> Result<()>; + fn set_workflow_node_completed( + &self, + run_id: &str, + node_name: &str, + completed_at: i64, + result_hash: Option<&str>, + ) -> Result<()>; + fn set_workflow_node_error(&self, run_id: &str, node_name: &str, error: &str) -> Result<()>; + + /// Return nodes whose status is `Pending` and all DAG predecessors are `Completed`. + /// + /// The `dag_json` parameter is the serialized DAG from the workflow definition. + /// The implementation uses it to determine which predecessor nodes must be + /// complete before a given node becomes ready. + fn get_ready_workflow_nodes(&self, run_id: &str, dag_json: &str) -> Result>; + + // ── Fan-out / Fan-in ────────────────────────────────────────── + + /// Set a node's `fan_out_count` and transition its status to `Running`. + fn set_workflow_node_fan_out_count( + &self, + run_id: &str, + node_name: &str, + count: i32, + ) -> Result<()>; + + /// Return all nodes whose `node_name` starts with `prefix`. + /// + /// Used to find fan-out children (e.g., prefix `"process["` returns + /// `process[0]`, `process[1]`, etc.). + fn get_workflow_nodes_by_prefix(&self, run_id: &str, prefix: &str) + -> Result>; + + /// Return all child workflow runs of a parent run. + fn get_child_workflow_runs(&self, parent_run_id: &str) -> Result>; +} diff --git a/crates/taskito-workflows/src/tests.rs b/crates/taskito-workflows/src/tests.rs new file mode 100644 index 0000000..677954d --- /dev/null +++ b/crates/taskito-workflows/src/tests.rs @@ -0,0 +1,703 @@ +use taskito_core::job::now_millis; +use taskito_core::storage::sqlite::SqliteStorage; + +use crate::sqlite_store::WorkflowSqliteStorage; +use crate::storage::WorkflowStorage; +use crate::*; + +fn test_storage() -> WorkflowSqliteStorage { + let base = SqliteStorage::in_memory().unwrap(); + WorkflowSqliteStorage::new(base).unwrap() +} + +fn make_definition(name: &str) -> WorkflowDefinition { + let mut step_metadata = std::collections::HashMap::new(); + step_metadata.insert( + "a".to_string(), + StepMetadata { + task_name: "task_a".to_string(), + queue: None, + args_template: None, + kwargs_template: None, + max_retries: None, + timeout_ms: None, + priority: None, + fan_out: None, + fan_in: None, + condition: None, + }, + ); + step_metadata.insert( + "b".to_string(), + StepMetadata { + task_name: "task_b".to_string(), + queue: None, + args_template: None, + kwargs_template: None, + max_retries: None, + timeout_ms: None, + priority: None, + fan_out: None, + fan_in: None, + condition: None, + }, + ); + step_metadata.insert( + "c".to_string(), + StepMetadata { + task_name: "task_c".to_string(), + queue: None, + args_template: None, + kwargs_template: None, + max_retries: None, + timeout_ms: None, + priority: None, + fan_out: None, + fan_in: None, + condition: None, + }, + ); + + // Simple DAG: a → b → c + let dag_json = serde_json::json!({ + "nodes": [ + {"name": "a"}, + {"name": "b"}, + {"name": "c"} + ], + "edges": [ + {"from": "a", "to": "b", "weight": 1.0}, + {"from": "b", "to": "c", "weight": 1.0} + ] + }); + + WorkflowDefinition { + id: uuid::Uuid::now_v7().to_string(), + name: name.to_string(), + version: 1, + dag_data: serde_json::to_vec(&dag_json).unwrap(), + step_metadata, + created_at: now_millis(), + } +} + +fn make_run(definition_id: &str) -> WorkflowRun { + WorkflowRun { + id: uuid::Uuid::now_v7().to_string(), + definition_id: definition_id.to_string(), + params: Some(r#"{"region":"eu"}"#.to_string()), + state: WorkflowState::Pending, + started_at: None, + completed_at: None, + error: None, + parent_run_id: None, + parent_node_name: None, + created_at: now_millis(), + } +} + +fn make_node(run_id: &str, name: &str) -> WorkflowNode { + WorkflowNode { + id: uuid::Uuid::now_v7().to_string(), + run_id: run_id.to_string(), + node_name: name.to_string(), + job_id: None, + status: WorkflowNodeStatus::Pending, + result_hash: None, + fan_out_count: None, + fan_in_data: None, + started_at: None, + completed_at: None, + error: None, + } +} + +// ── Definition tests ───────────────────────────────────────── + +#[test] +fn test_create_and_get_definition() { + let storage = test_storage(); + let def = make_definition("my_pipeline"); + + storage.create_workflow_definition(&def).unwrap(); + + let fetched = storage + .get_workflow_definition("my_pipeline", None) + .unwrap() + .unwrap(); + assert_eq!(fetched.name, "my_pipeline"); + assert_eq!(fetched.version, 1); + assert_eq!(fetched.step_metadata.len(), 3); + assert!(fetched.step_metadata.contains_key("a")); +} + +#[test] +fn test_get_definition_by_version() { + let storage = test_storage(); + + let mut def_v1 = make_definition("versioned"); + def_v1.version = 1; + storage.create_workflow_definition(&def_v1).unwrap(); + + let mut def_v2 = make_definition("versioned"); + def_v2.id = uuid::Uuid::now_v7().to_string(); + def_v2.version = 2; + storage.create_workflow_definition(&def_v2).unwrap(); + + // Latest version (no version specified) returns v2 + let latest = storage + .get_workflow_definition("versioned", None) + .unwrap() + .unwrap(); + assert_eq!(latest.version, 2); + + // Specific version + let v1 = storage + .get_workflow_definition("versioned", Some(1)) + .unwrap() + .unwrap(); + assert_eq!(v1.version, 1); +} + +#[test] +fn test_get_definition_by_id() { + let storage = test_storage(); + let def = make_definition("by_id_test"); + let def_id = def.id.clone(); + storage.create_workflow_definition(&def).unwrap(); + + let fetched = storage + .get_workflow_definition_by_id(&def_id) + .unwrap() + .unwrap(); + assert_eq!(fetched.id, def_id); +} + +#[test] +fn test_definition_not_found() { + let storage = test_storage(); + let result = storage + .get_workflow_definition("nonexistent", None) + .unwrap(); + assert!(result.is_none()); +} + +// ── Run tests ──────────────────────────────────────────────── + +#[test] +fn test_create_and_get_run() { + let storage = test_storage(); + let def = make_definition("run_test"); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + let fetched = storage.get_workflow_run(&run_id).unwrap().unwrap(); + assert_eq!(fetched.id, run_id); + assert_eq!(fetched.state, WorkflowState::Pending); + assert_eq!(fetched.params, Some(r#"{"region":"eu"}"#.to_string())); +} + +#[test] +fn test_update_run_state() { + let storage = test_storage(); + let def = make_definition("state_test"); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + // Pending → Running + storage + .update_workflow_run_state(&run_id, WorkflowState::Running, None) + .unwrap(); + let fetched = storage.get_workflow_run(&run_id).unwrap().unwrap(); + assert_eq!(fetched.state, WorkflowState::Running); + + // Running → Failed with error + storage + .update_workflow_run_state(&run_id, WorkflowState::Failed, Some("node X blew up")) + .unwrap(); + let fetched = storage.get_workflow_run(&run_id).unwrap().unwrap(); + assert_eq!(fetched.state, WorkflowState::Failed); + assert_eq!(fetched.error, Some("node X blew up".to_string())); +} + +#[test] +fn test_set_run_started_and_completed() { + let storage = test_storage(); + let def = make_definition("timing_test"); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + let start_time = now_millis(); + storage + .set_workflow_run_started(&run_id, start_time) + .unwrap(); + let fetched = storage.get_workflow_run(&run_id).unwrap().unwrap(); + assert_eq!(fetched.state, WorkflowState::Running); + assert_eq!(fetched.started_at, Some(start_time)); + + let end_time = now_millis(); + storage + .set_workflow_run_completed(&run_id, end_time) + .unwrap(); + let fetched = storage.get_workflow_run(&run_id).unwrap().unwrap(); + assert_eq!(fetched.completed_at, Some(end_time)); +} + +#[test] +fn test_list_runs() { + let storage = test_storage(); + let def = make_definition("list_test"); + storage.create_workflow_definition(&def).unwrap(); + + for _ in 0..5 { + storage.create_workflow_run(&make_run(&def.id)).unwrap(); + } + + let all = storage.list_workflow_runs(None, None, 100, 0).unwrap(); + assert_eq!(all.len(), 5); + + let limited = storage.list_workflow_runs(None, None, 2, 0).unwrap(); + assert_eq!(limited.len(), 2); + + let offset = storage.list_workflow_runs(None, None, 100, 3).unwrap(); + assert_eq!(offset.len(), 2); +} + +#[test] +fn test_list_runs_by_state() { + let storage = test_storage(); + let def = make_definition("state_filter"); + storage.create_workflow_definition(&def).unwrap(); + + let run1 = make_run(&def.id); + let run1_id = run1.id.clone(); + storage.create_workflow_run(&run1).unwrap(); + + let run2 = make_run(&def.id); + storage.create_workflow_run(&run2).unwrap(); + + storage + .update_workflow_run_state(&run1_id, WorkflowState::Running, None) + .unwrap(); + + let running = storage + .list_workflow_runs(None, Some(WorkflowState::Running), 100, 0) + .unwrap(); + assert_eq!(running.len(), 1); + assert_eq!(running[0].id, run1_id); + + let pending = storage + .list_workflow_runs(None, Some(WorkflowState::Pending), 100, 0) + .unwrap(); + assert_eq!(pending.len(), 1); +} + +// ── Node tests ─────────────────────────────────────────────── + +#[test] +fn test_create_and_get_node() { + let storage = test_storage(); + let def = make_definition("node_test"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + let node = make_node(&run_id, "a"); + storage.create_workflow_node(&node).unwrap(); + + let fetched = storage.get_workflow_node(&run_id, "a").unwrap().unwrap(); + assert_eq!(fetched.node_name, "a"); + assert_eq!(fetched.status, WorkflowNodeStatus::Pending); + assert!(fetched.job_id.is_none()); +} + +#[test] +fn test_create_nodes_batch() { + let storage = test_storage(); + let def = make_definition("batch_test"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + let nodes = vec![ + make_node(&run_id, "a"), + make_node(&run_id, "b"), + make_node(&run_id, "c"), + ]; + storage.create_workflow_nodes_batch(&nodes).unwrap(); + + let all = storage.get_workflow_nodes(&run_id).unwrap(); + assert_eq!(all.len(), 3); +} + +#[test] +fn test_update_node_status() { + let storage = test_storage(); + let def = make_definition("status_test"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "a")) + .unwrap(); + + storage + .update_workflow_node_status(&run_id, "a", WorkflowNodeStatus::Running) + .unwrap(); + let fetched = storage.get_workflow_node(&run_id, "a").unwrap().unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Running); +} + +#[test] +fn test_set_node_job_and_timing() { + let storage = test_storage(); + let def = make_definition("job_test"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "a")) + .unwrap(); + + storage + .set_workflow_node_job(&run_id, "a", "job-123") + .unwrap(); + let fetched = storage.get_workflow_node(&run_id, "a").unwrap().unwrap(); + assert_eq!(fetched.job_id, Some("job-123".to_string())); + + let start = now_millis(); + storage + .set_workflow_node_started(&run_id, "a", start) + .unwrap(); + let fetched = storage.get_workflow_node(&run_id, "a").unwrap().unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Running); + assert_eq!(fetched.started_at, Some(start)); + + let end = now_millis(); + storage + .set_workflow_node_completed(&run_id, "a", end, Some("abc123")) + .unwrap(); + let fetched = storage.get_workflow_node(&run_id, "a").unwrap().unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Completed); + assert_eq!(fetched.completed_at, Some(end)); + assert_eq!(fetched.result_hash, Some("abc123".to_string())); +} + +#[test] +fn test_set_node_error() { + let storage = test_storage(); + let def = make_definition("error_test"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + storage + .create_workflow_node(&make_node(&run_id, "a")) + .unwrap(); + + storage + .set_workflow_node_error(&run_id, "a", "kaboom") + .unwrap(); + let fetched = storage.get_workflow_node(&run_id, "a").unwrap().unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Failed); + assert_eq!(fetched.error, Some("kaboom".to_string())); +} + +// ── Ready nodes (DAG-aware) ────────────────────────────────── + +#[test] +fn test_get_ready_nodes_roots_are_ready() { + let storage = test_storage(); + let def = make_definition("ready_test"); + let dag_json = String::from_utf8(def.dag_data.clone()).unwrap(); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + let nodes = vec![ + make_node(&run_id, "a"), + make_node(&run_id, "b"), + make_node(&run_id, "c"), + ]; + storage.create_workflow_nodes_batch(&nodes).unwrap(); + + // Only root node "a" should be ready (b depends on a, c depends on b) + let ready = storage + .get_ready_workflow_nodes(&run_id, &dag_json) + .unwrap(); + assert_eq!(ready.len(), 1); + assert_eq!(ready[0].node_name, "a"); +} + +#[test] +fn test_get_ready_nodes_after_completion() { + let storage = test_storage(); + let def = make_definition("completion_ready"); + let dag_json = String::from_utf8(def.dag_data.clone()).unwrap(); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + storage + .create_workflow_nodes_batch(&[ + make_node(&run_id, "a"), + make_node(&run_id, "b"), + make_node(&run_id, "c"), + ]) + .unwrap(); + + // Complete "a" → "b" becomes ready + storage + .set_workflow_node_completed(&run_id, "a", now_millis(), None) + .unwrap(); + + let ready = storage + .get_ready_workflow_nodes(&run_id, &dag_json) + .unwrap(); + assert_eq!(ready.len(), 1); + assert_eq!(ready[0].node_name, "b"); +} + +#[test] +fn test_get_ready_nodes_all_completed() { + let storage = test_storage(); + let def = make_definition("all_done"); + let dag_json = String::from_utf8(def.dag_data.clone()).unwrap(); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + storage + .create_workflow_nodes_batch(&[ + make_node(&run_id, "a"), + make_node(&run_id, "b"), + make_node(&run_id, "c"), + ]) + .unwrap(); + + // Complete all + for name in &["a", "b", "c"] { + storage + .set_workflow_node_completed(&run_id, name, now_millis(), None) + .unwrap(); + } + + let ready = storage + .get_ready_workflow_nodes(&run_id, &dag_json) + .unwrap(); + assert!(ready.is_empty()); +} + +#[test] +fn test_get_ready_nodes_diamond_dag() { + let storage = test_storage(); + + // Diamond: a → b, a → c, b → d, c → d + let dag_json = serde_json::json!({ + "nodes": [{"name": "a"}, {"name": "b"}, {"name": "c"}, {"name": "d"}], + "edges": [ + {"from": "a", "to": "b", "weight": 1.0}, + {"from": "a", "to": "c", "weight": 1.0}, + {"from": "b", "to": "d", "weight": 1.0}, + {"from": "c", "to": "d", "weight": 1.0} + ] + }); + + let mut step_meta = std::collections::HashMap::new(); + for name in &["a", "b", "c", "d"] { + step_meta.insert( + name.to_string(), + StepMetadata { + task_name: format!("task_{name}"), + queue: None, + args_template: None, + kwargs_template: None, + max_retries: None, + timeout_ms: None, + priority: None, + fan_out: None, + fan_in: None, + condition: None, + }, + ); + } + + let def = WorkflowDefinition { + id: uuid::Uuid::now_v7().to_string(), + name: "diamond".to_string(), + version: 1, + dag_data: serde_json::to_vec(&dag_json).unwrap(), + step_metadata: step_meta, + created_at: now_millis(), + }; + let dag_str = serde_json::to_string(&dag_json).unwrap(); + storage.create_workflow_definition(&def).unwrap(); + + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + for name in &["a", "b", "c", "d"] { + storage + .create_workflow_node(&make_node(&run_id, name)) + .unwrap(); + } + + // Initially only "a" is ready + let ready = storage.get_ready_workflow_nodes(&run_id, &dag_str).unwrap(); + assert_eq!(ready.len(), 1); + assert_eq!(ready[0].node_name, "a"); + + // Complete "a" → "b" and "c" become ready (parallel) + storage + .set_workflow_node_completed(&run_id, "a", now_millis(), None) + .unwrap(); + let ready = storage.get_ready_workflow_nodes(&run_id, &dag_str).unwrap(); + assert_eq!(ready.len(), 2); + let names: Vec<&str> = ready.iter().map(|n| n.node_name.as_str()).collect(); + assert!(names.contains(&"b")); + assert!(names.contains(&"c")); + + // Complete only "b" → "d" NOT ready yet (needs "c" too) + storage + .set_workflow_node_completed(&run_id, "b", now_millis(), None) + .unwrap(); + let ready = storage.get_ready_workflow_nodes(&run_id, &dag_str).unwrap(); + assert_eq!(ready.len(), 1); + assert_eq!(ready[0].node_name, "c"); // only c is pending+ready + + // Complete "c" → "d" becomes ready + storage + .set_workflow_node_completed(&run_id, "c", now_millis(), None) + .unwrap(); + let ready = storage.get_ready_workflow_nodes(&run_id, &dag_str).unwrap(); + assert_eq!(ready.len(), 1); + assert_eq!(ready[0].node_name, "d"); +} + +// ── WorkflowState transition tests ─────────────────────────── + +#[test] +fn test_state_transitions() { + assert!(WorkflowState::Pending.can_transition_to(WorkflowState::Running)); + assert!(WorkflowState::Running.can_transition_to(WorkflowState::Completed)); + assert!(WorkflowState::Running.can_transition_to(WorkflowState::Failed)); + assert!(WorkflowState::Running.can_transition_to(WorkflowState::Cancelled)); + assert!(WorkflowState::Running.can_transition_to(WorkflowState::Paused)); + assert!(WorkflowState::Paused.can_transition_to(WorkflowState::Running)); + assert!(WorkflowState::Paused.can_transition_to(WorkflowState::Cancelled)); + + // Invalid transitions + assert!(!WorkflowState::Pending.can_transition_to(WorkflowState::Completed)); + assert!(!WorkflowState::Completed.can_transition_to(WorkflowState::Running)); + assert!(!WorkflowState::Failed.can_transition_to(WorkflowState::Running)); +} + +#[test] +fn test_state_is_terminal() { + assert!(WorkflowState::Completed.is_terminal()); + assert!(WorkflowState::Failed.is_terminal()); + assert!(WorkflowState::Cancelled.is_terminal()); + assert!(!WorkflowState::Pending.is_terminal()); + assert!(!WorkflowState::Running.is_terminal()); + assert!(!WorkflowState::Paused.is_terminal()); +} + +// ── Fan-out storage tests ─────────────────────────────────── + +#[test] +fn test_set_fan_out_count() { + let storage = test_storage(); + let def = make_definition("fanout_pipe"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + let node = make_node(&run_id, "process"); + storage.create_workflow_node(&node).unwrap(); + + // Initially pending, no fan_out_count + let fetched = storage + .get_workflow_node(&run_id, "process") + .unwrap() + .unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Pending); + assert_eq!(fetched.fan_out_count, None); + + // Set fan_out_count → status transitions to Running + storage + .set_workflow_node_fan_out_count(&run_id, "process", 5) + .unwrap(); + + let fetched = storage + .get_workflow_node(&run_id, "process") + .unwrap() + .unwrap(); + assert_eq!(fetched.status, WorkflowNodeStatus::Running); + assert_eq!(fetched.fan_out_count, Some(5)); +} + +#[test] +fn test_get_nodes_by_prefix() { + let storage = test_storage(); + let def = make_definition("prefix_pipe"); + storage.create_workflow_definition(&def).unwrap(); + let run = make_run(&def.id); + let run_id = run.id.clone(); + storage.create_workflow_run(&run).unwrap(); + + // Create parent and children + for name in &[ + "process", + "process[0]", + "process[1]", + "process[2]", + "aggregate", + ] { + storage + .create_workflow_node(&make_node(&run_id, name)) + .unwrap(); + } + + // Prefix "process[" should return only the children + let children = storage + .get_workflow_nodes_by_prefix(&run_id, "process[") + .unwrap(); + assert_eq!(children.len(), 3); + let names: Vec<&str> = children.iter().map(|n| n.node_name.as_str()).collect(); + assert!(names.contains(&"process[0]")); + assert!(names.contains(&"process[1]")); + assert!(names.contains(&"process[2]")); + + // Prefix "aggregate" should match just the one node + let agg = storage + .get_workflow_nodes_by_prefix(&run_id, "aggregate") + .unwrap(); + assert_eq!(agg.len(), 1); + + // Prefix "nonexistent[" should return nothing + let empty = storage + .get_workflow_nodes_by_prefix(&run_id, "nonexistent[") + .unwrap(); + assert!(empty.is_empty()); +} diff --git a/crates/taskito-workflows/src/topology.rs b/crates/taskito-workflows/src/topology.rs new file mode 100644 index 0000000..c83b015 --- /dev/null +++ b/crates/taskito-workflows/src/topology.rs @@ -0,0 +1,112 @@ +use std::collections::HashMap; + +use taskito_core::error::{QueueError, Result}; + +use dagron_core::{SerializableGraph, DAG}; + +/// A node in a workflow DAG, returned in topological order. +#[derive(Debug, Clone)] +pub struct TopologicalNode { + pub name: String, + pub predecessors: Vec, +} + +/// Parse a JSON-encoded DAG and return nodes in topological order. +/// +/// Each returned entry carries the node name and the names of its direct +/// predecessors. Callers (e.g. the Python submit path) use this to create +/// jobs with the correct `depends_on` chain. +pub fn topological_order(dag_bytes: &[u8]) -> Result> { + let dag_json = std::str::from_utf8(dag_bytes) + .map_err(|e| QueueError::Serialization(format!("workflow DAG is not valid UTF-8: {e}")))?; + + let graph: SerializableGraph = serde_json::from_str(dag_json).map_err(|e| { + QueueError::Serialization(format!("failed to deserialize workflow DAG: {e}")) + })?; + + let mut predecessors: HashMap> = HashMap::new(); + for node in &graph.nodes { + predecessors.entry(node.name.clone()).or_default(); + } + for edge in &graph.edges { + predecessors + .entry(edge.to.clone()) + .or_default() + .push(edge.from.clone()); + } + + let dag: DAG<()> = DAG::from_serializable(graph, |_| ()) + .map_err(|e| QueueError::Serialization(format!("failed to build DAG: {e}")))?; + + let sorted = dag + .topological_sort() + .map_err(|e| QueueError::Serialization(format!("topological sort failed: {e}")))?; + + Ok(sorted + .into_iter() + .map(|node_id| TopologicalNode { + predecessors: predecessors.remove(&node_id.name).unwrap_or_default(), + name: node_id.name, + }) + .collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn dag(json: serde_json::Value) -> Vec { + serde_json::to_vec(&json).unwrap() + } + + #[test] + fn test_linear_chain() { + let bytes = dag(serde_json::json!({ + "nodes": [{"name": "a"}, {"name": "b"}, {"name": "c"}], + "edges": [ + {"from": "a", "to": "b", "weight": 1.0}, + {"from": "b", "to": "c", "weight": 1.0} + ] + })); + + let order = topological_order(&bytes).unwrap(); + let names: Vec<&str> = order.iter().map(|n| n.name.as_str()).collect(); + assert_eq!(names, vec!["a", "b", "c"]); + + assert!(order[0].predecessors.is_empty()); + assert_eq!(order[1].predecessors, vec!["a".to_string()]); + assert_eq!(order[2].predecessors, vec!["b".to_string()]); + } + + #[test] + fn test_diamond_topology() { + let bytes = dag(serde_json::json!({ + "nodes": [{"name": "a"}, {"name": "b"}, {"name": "c"}, {"name": "d"}], + "edges": [ + {"from": "a", "to": "b", "weight": 1.0}, + {"from": "a", "to": "c", "weight": 1.0}, + {"from": "b", "to": "d", "weight": 1.0}, + {"from": "c", "to": "d", "weight": 1.0} + ] + })); + + let order = topological_order(&bytes).unwrap(); + let names: Vec<&str> = order.iter().map(|n| n.name.as_str()).collect(); + let pos = |n: &str| names.iter().position(|&x| x == n).unwrap(); + assert!(pos("a") < pos("b")); + assert!(pos("a") < pos("c")); + assert!(pos("b") < pos("d")); + assert!(pos("c") < pos("d")); + + let d_preds: &[String] = &order[pos("d")].predecessors; + assert_eq!(d_preds.len(), 2); + assert!(d_preds.contains(&"b".to_string())); + assert!(d_preds.contains(&"c".to_string())); + } + + #[test] + fn test_invalid_json() { + let err = topological_order(b"not json").unwrap_err(); + assert!(err.to_string().contains("deserialize")); + } +} diff --git a/pyproject.toml b/pyproject.toml index 52c96b9..a2292a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ docs = ["zensical"] manifest-path = "crates/taskito-python/Cargo.toml" python-source = "py_src" module-name = "taskito._taskito" -features = ["extension-module", "postgres", "redis"] +features = ["extension-module", "postgres", "redis", "workflows"] [project.scripts] taskito = "taskito.cli:main" From 8ee7e1fe93efc35aff0ae9f23309eb55b689bbd3 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:46:22 +0530 Subject: [PATCH 02/12] feat(workflows): add Python workflow package with full feature set - Workflow builder with step(), gate(), fan_out/fan_in, conditions - WorkflowTracker with conditional orchestration, skip propagation - Fan-out strategy module (each), fan-in aggregation - Approval gate support with timeout and auto-resolve - Sub-workflow composition via WorkflowProxy.as_step() - WorkflowContext for callable conditions - Incremental caching with dirty-set computation - Graph analysis: critical path, execution plan, bottleneck - Mermaid and DOT visualization renderers - Cron-scheduled workflow support via bridge task pattern - Type stubs and event types updated --- py_src/taskito/_taskito.pyi | 115 +++ py_src/taskito/app.py | 44 ++ py_src/taskito/events.py | 5 + py_src/taskito/workflows/__init__.py | 25 + py_src/taskito/workflows/analysis.py | 189 +++++ py_src/taskito/workflows/builder.py | 492 +++++++++++++ py_src/taskito/workflows/context.py | 33 + py_src/taskito/workflows/fan_out.py | 41 ++ py_src/taskito/workflows/incremental.py | 69 ++ py_src/taskito/workflows/mixins.py | 163 +++++ py_src/taskito/workflows/run.py | 154 ++++ py_src/taskito/workflows/tracker.py | 820 ++++++++++++++++++++++ py_src/taskito/workflows/types.py | 63 ++ py_src/taskito/workflows/visualization.py | 133 ++++ 14 files changed, 2346 insertions(+) create mode 100644 py_src/taskito/workflows/__init__.py create mode 100644 py_src/taskito/workflows/analysis.py create mode 100644 py_src/taskito/workflows/builder.py create mode 100644 py_src/taskito/workflows/context.py create mode 100644 py_src/taskito/workflows/fan_out.py create mode 100644 py_src/taskito/workflows/incremental.py create mode 100644 py_src/taskito/workflows/mixins.py create mode 100644 py_src/taskito/workflows/run.py create mode 100644 py_src/taskito/workflows/tracker.py create mode 100644 py_src/taskito/workflows/types.py create mode 100644 py_src/taskito/workflows/visualization.py diff --git a/py_src/taskito/_taskito.pyi b/py_src/taskito/_taskito.pyi index 7a056d7..f3590c8 100644 --- a/py_src/taskito/_taskito.pyi +++ b/py_src/taskito/_taskito.pyi @@ -234,6 +234,121 @@ class PyQueue: ttl_ms: int = 30000, ) -> bool: ... def get_lock_info(self, lock_name: str) -> dict[str, Any] | None: ... + def submit_workflow( + self, + name: str, + version: int, + dag_bytes: bytes, + step_metadata_json: str, + node_payloads: dict[str, bytes], + queue_default: str = "default", + params_json: str | None = None, + deferred_node_names: list[str] | None = None, + parent_run_id: str | None = None, + parent_node_name: str | None = None, + cache_hit_nodes: dict[str, str] | None = None, + ) -> PyWorkflowHandle: ... + def get_workflow_run_status(self, run_id: str) -> PyWorkflowRunStatus: ... + def cancel_workflow_run(self, run_id: str) -> None: ... + def mark_workflow_node_result( + self, + job_id: str, + succeeded: bool, + error: str | None = None, + skip_cascade: bool = False, + result_hash: str | None = None, + ) -> tuple[str, str, str | None] | None: ... + def get_base_run_node_data(self, base_run_id: str) -> list[tuple[str, str, str | None]]: ... + def skip_workflow_node(self, run_id: str, node_name: str) -> None: ... + def expand_fan_out( + self, + run_id: str, + parent_node_name: str, + child_names: list[str], + child_payloads: list[bytes], + task_name: str, + queue: str, + max_retries: int, + timeout_ms: int, + priority: int, + ) -> list[str]: ... + def create_deferred_job( + self, + run_id: str, + node_name: str, + payload: bytes, + task_name: str, + queue: str, + max_retries: int, + timeout_ms: int, + priority: int, + ) -> str: ... + def check_fan_out_completion( + self, + run_id: str, + parent_node_name: str, + ) -> tuple[bool, list[str]] | None: ... + def finalize_run_if_terminal(self, run_id: str) -> str | None: ... + def set_workflow_node_waiting_approval(self, run_id: str, node_name: str) -> None: ... + def resolve_workflow_gate( + self, + run_id: str, + node_name: str, + approved: bool, + error: str | None = None, + ) -> None: ... + def get_workflow_definition_dag(self, run_id: str) -> bytes: ... + def set_workflow_node_fan_out_count(self, run_id: str, node_name: str, count: int) -> None: ... + +class PyWorkflowBuilder: + """Rust-side workflow DAG builder. + + Only available when built with the ``workflows`` feature. + """ + + def __init__(self) -> None: ... + def add_step( + self, + name: str, + task_name: str, + after: list[str] | None = None, + queue: str | None = None, + max_retries: int | None = None, + timeout_ms: int | None = None, + priority: int | None = None, + args_template: str | None = None, + kwargs_template: str | None = None, + fan_out: str | None = None, + fan_in: str | None = None, + condition: str | None = None, + ) -> None: ... + def step_count(self) -> int: ... + def step_names(self) -> list[str]: ... + def serialize(self) -> tuple[bytes, str]: ... + +class PyWorkflowHandle: + """Opaque handle returned from ``PyQueue.submit_workflow``. + + Only available when built with the ``workflows`` feature. + """ + + run_id: str + name: str + definition_id: str + +class PyWorkflowRunStatus: + """Snapshot of a workflow run's state and per-node status. + + Only available when built with the ``workflows`` feature. + """ + + run_id: str + state: str + started_at: int | None + completed_at: int | None + error: str | None + + def node_statuses(self) -> dict[str, dict[str, Any]]: ... class PyResultSender: """Sends task results from Python async executor back to Rust scheduler. diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 6a75d19..eb29da3 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -56,6 +56,19 @@ from taskito.task import TaskWrapper from taskito.webhooks import WebhookManager +try: + from taskito.workflows.mixins import QueueWorkflowMixin + from taskito.workflows.tracker import WorkflowTracker + + _WORKFLOWS_AVAILABLE = True +except ImportError: # pragma: no cover - workflows feature not compiled in + + class QueueWorkflowMixin: # type: ignore[no-redef] + pass + + WorkflowTracker = None # type: ignore[assignment,misc] + _WORKFLOWS_AVAILABLE = False + logger = logging.getLogger("taskito") @@ -80,6 +93,7 @@ class Queue( QueueInspectionMixin, QueueOperationsMixin, QueueLockMixin, + QueueWorkflowMixin, AsyncQueueMixin, ): """ @@ -257,6 +271,12 @@ def __init__( # Test mode flag (Phase M) self._test_mode_active = False + # Workflow support + self._workflow_registry: dict[str, Any] = {} + self._workflow_tracker: Any = None + if _WORKFLOWS_AVAILABLE and hasattr(self._inner, "submit_workflow"): + self._workflow_tracker = WorkflowTracker(self) + def task( self, name: str | None = None, @@ -435,6 +455,30 @@ def periodic( """ def decorator(fn: Callable) -> TaskWrapper: + # If fn is a WorkflowProxy (from @queue.workflow()), create a + # launcher task that submits the workflow on each cron trigger. + if getattr(fn, "_is_workflow_proxy", False): + proxy: Any = fn + launcher_name = f"_wf_launcher_{proxy._name}" + + @self.task(name=launcher_name, queue=queue) + def _wf_launcher() -> str: + run = proxy.submit() + return f"submitted workflow run {run.id}" + + payload = self._get_serializer(launcher_name).dumps(((), {})) + self._periodic_configs.append( + { + "name": launcher_name, + "task_name": launcher_name, + "cron_expr": cron, + "payload": payload, + "queue": queue, + "timezone": timezone, + } + ) + return fn # type: ignore[return-value] + # Register as a normal task first wrapper = self.task(name=name, queue=queue)(fn) diff --git a/py_src/taskito/events.py b/py_src/taskito/events.py index 215e00c..4dcbedd 100644 --- a/py_src/taskito/events.py +++ b/py_src/taskito/events.py @@ -28,6 +28,11 @@ class EventType(enum.Enum): WORKER_UNHEALTHY = "worker.unhealthy" QUEUE_PAUSED = "queue.paused" QUEUE_RESUMED = "queue.resumed" + WORKFLOW_SUBMITTED = "workflow.submitted" + WORKFLOW_COMPLETED = "workflow.completed" + WORKFLOW_FAILED = "workflow.failed" + WORKFLOW_CANCELLED = "workflow.cancelled" + WORKFLOW_GATE_REACHED = "workflow.gate_reached" class EventBus: diff --git a/py_src/taskito/workflows/__init__.py b/py_src/taskito/workflows/__init__.py new file mode 100644 index 0000000..9c34199 --- /dev/null +++ b/py_src/taskito/workflows/__init__.py @@ -0,0 +1,25 @@ +"""DAG-based workflow support for taskito. + +This package is only functional when the native extension was built with the +``workflows`` feature. If the feature is not compiled in, importing this +package raises a :class:`RuntimeError` the first time any public API is used. +""" + +from __future__ import annotations + +from .builder import GateConfig, Workflow, WorkflowProxy +from .context import WorkflowContext +from .run import WorkflowRun +from .types import NodeSnapshot, NodeStatus, WorkflowState, WorkflowStatus + +__all__ = [ + "GateConfig", + "NodeSnapshot", + "NodeStatus", + "Workflow", + "WorkflowContext", + "WorkflowProxy", + "WorkflowRun", + "WorkflowState", + "WorkflowStatus", +] diff --git a/py_src/taskito/workflows/analysis.py b/py_src/taskito/workflows/analysis.py new file mode 100644 index 0000000..1a6bee4 --- /dev/null +++ b/py_src/taskito/workflows/analysis.py @@ -0,0 +1,189 @@ +"""Pre-execution graph analysis for workflows. + +All functions operate on the builder's ``_steps`` dict (a mapping from +step name to :class:`~taskito.workflows.builder._Step`). They perform +pure graph computations with no side effects. +""" + +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .builder import _Step + + +def _build_adjacency( + steps: dict[str, _Step], +) -> tuple[dict[str, list[str]], dict[str, list[str]]]: + """Build forward (successors) and backward (predecessors) adjacency maps.""" + successors: dict[str, list[str]] = {name: [] for name in steps} + predecessors: dict[str, list[str]] = {name: list(s.after) for name, s in steps.items()} + for name, step in steps.items(): + for pred in step.after: + successors[pred].append(name) + return successors, predecessors + + +def ancestors(steps: dict[str, _Step], node: str) -> list[str]: + """Return all transitive predecessors of *node* (BFS, no particular order).""" + if node not in steps: + raise KeyError(f"node '{node}' not found") + _, preds = _build_adjacency(steps) + visited: set[str] = set() + queue: deque[str] = deque(preds.get(node, [])) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + queue.extend(preds.get(current, [])) + return sorted(visited) + + +def descendants(steps: dict[str, _Step], node: str) -> list[str]: + """Return all transitive successors of *node* (BFS, no particular order).""" + if node not in steps: + raise KeyError(f"node '{node}' not found") + succs, _ = _build_adjacency(steps) + visited: set[str] = set() + queue: deque[str] = deque(succs.get(node, [])) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + queue.extend(succs.get(current, [])) + return sorted(visited) + + +def topological_levels(steps: dict[str, _Step]) -> list[list[str]]: + """Group nodes by topological depth (Kahn's algorithm). + + Returns a list of lists where level 0 contains root nodes, + level 1 contains nodes whose predecessors are all in level 0, etc. + """ + _, preds = _build_adjacency(steps) + succs, _ = _build_adjacency(steps) + in_degree: dict[str, int] = {name: len(p) for name, p in preds.items()} + + levels: list[list[str]] = [] + current_level = sorted(n for n, d in in_degree.items() if d == 0) + + while current_level: + levels.append(current_level) + next_level_set: set[str] = set() + for node in current_level: + for succ in succs.get(node, []): + in_degree[succ] -= 1 + if in_degree[succ] == 0: + next_level_set.add(succ) + current_level = sorted(next_level_set) + + return levels + + +def stats(steps: dict[str, _Step]) -> dict[str, int | float]: + """Compute basic DAG statistics. + + Returns: + ``{nodes, edges, depth, width, density}`` + """ + n_nodes = len(steps) + n_edges = sum(len(s.after) for s in steps.values()) + levels = topological_levels(steps) + depth = len(levels) + width = max((len(lv) for lv in levels), default=0) + max_edges = n_nodes * (n_nodes - 1) / 2 if n_nodes > 1 else 1 + density = round(n_edges / max_edges, 4) if max_edges > 0 else 0.0 + return { + "nodes": n_nodes, + "edges": n_edges, + "depth": depth, + "width": width, + "density": density, + } + + +def critical_path(steps: dict[str, _Step], costs: dict[str, float]) -> tuple[list[str], float]: + """Find the longest-weighted path through the DAG. + + Uses dynamic programming on topological order. + + Args: + costs: Mapping of step name → estimated duration. + + Returns: + ``(path, total_cost)`` — the critical path nodes and sum of costs. + """ + levels = topological_levels(steps) + flat_order = [node for level in levels for node in level] + + dist: dict[str, float] = {} + prev: dict[str, str | None] = {} + for node in flat_order: + dist[node] = costs.get(node, 0.0) + prev[node] = None + + succs, _ = _build_adjacency(steps) + for node in flat_order: + for succ in succs.get(node, []): + new_dist = dist[node] + costs.get(succ, 0.0) + if new_dist > dist[succ]: + dist[succ] = new_dist + prev[succ] = node + + if not dist: + return [], 0.0 + + end_node = max(dist, key=lambda n: dist[n]) + path: list[str] = [] + current: str | None = end_node + while current is not None: + path.append(current) + current = prev[current] + path.reverse() + return path, dist[end_node] + + +def execution_plan(steps: dict[str, _Step], max_workers: int = 1) -> list[list[str]]: + """Generate a step-by-step execution plan respecting worker limits. + + Each stage contains up to *max_workers* nodes that can run concurrently. + Nodes within the same topological level are batched together. + """ + levels = topological_levels(steps) + plan: list[list[str]] = [] + for level in levels: + for i in range(0, len(level), max_workers): + plan.append(level[i : i + max_workers]) + return plan + + +def bottleneck_analysis(steps: dict[str, _Step], costs: dict[str, float]) -> dict[str, Any]: + """Identify the bottleneck node on the critical path. + + Returns: + ``{node, cost, percentage, critical_path, total_cost, suggestion}`` + """ + path, total = critical_path(steps, costs) + if not path or total == 0: + return {"node": None, "cost": 0, "percentage": 0, "critical_path": [], "total_cost": 0} + + bottleneck_node = max(path, key=lambda n: costs.get(n, 0.0)) + bottleneck_cost = costs.get(bottleneck_node, 0.0) + pct = round(bottleneck_cost / total * 100, 1) if total > 0 else 0 + + return { + "node": bottleneck_node, + "cost": bottleneck_cost, + "percentage": pct, + "critical_path": path, + "total_cost": total, + "suggestion": ( + f"{bottleneck_node} is the bottleneck " + f"({pct}% of total time). " + f"Consider increasing max_concurrent or optimizing this step." + ), + } diff --git a/py_src/taskito/workflows/builder.py b/py_src/taskito/workflows/builder.py new file mode 100644 index 0000000..02e10b0 --- /dev/null +++ b/py_src/taskito/workflows/builder.py @@ -0,0 +1,492 @@ +"""Pure-Python workflow DAG builder. + +Steps are collected in insertion order and validated at ``build()`` time by +delegating to the Rust ``PyWorkflowBuilder`` (which owns a dagron-core DAG +instance and enforces acyclicity + unique node names). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from collections.abc import Callable + + from taskito.app import Queue + from taskito.task import TaskWrapper + + from .run import WorkflowRun + + +_VALID_FAN_OUT = frozenset({"each"}) +_VALID_FAN_IN = frozenset({"all"}) +_VALID_CONDITIONS = frozenset({"on_success", "on_failure", "always"}) +_VALID_ON_FAILURE = frozenset({"fail_fast", "continue"}) + + +@dataclass +class GateConfig: + """Configuration for an approval gate step.""" + + timeout: float | None = None + """Seconds until auto-resolve. ``None`` waits indefinitely.""" + + on_timeout: str = "reject" + """Action on timeout: ``"approve"`` or ``"reject"``.""" + + message: str | Callable | None = None + """Human-readable message shown to approvers.""" + + +@dataclass +class _Step: + """Internal representation of a single workflow step.""" + + name: str + task_name: str + after: list[str] = field(default_factory=list) + args: tuple = () + kwargs: dict[str, Any] = field(default_factory=dict) + queue: str | None = None + max_retries: int | None = None + timeout_ms: int | None = None + priority: int | None = None + fan_out: str | None = None + fan_in: str | None = None + condition: str | Callable | None = None + gate_config: GateConfig | None = None + sub_workflow: SubWorkflowRef | None = None + + +class Workflow: + """Builder for a workflow DAG. + + Steps are added via :meth:`step` and the workflow is materialized at + submission time. Each step is a registered taskito task. + + Example:: + + wf = Workflow(name="my_pipeline") + wf.step("a", task_a) + wf.step("b", task_b, after="a") + wf.step("c", task_c, after="b") + run = queue.submit_workflow(wf) + run.wait() + """ + + def __init__( + self, + name: str = "workflow", + version: int = 1, + on_failure: str = "fail_fast", + cache_ttl: float | None = None, + ): + if on_failure not in _VALID_ON_FAILURE: + raise ValueError( + f"on_failure must be one of {sorted(_VALID_ON_FAILURE)}, got '{on_failure}'" + ) + self.name = name + self.version = version + self.on_failure = on_failure + self.cache_ttl = cache_ttl + self._steps: dict[str, _Step] = {} + + def step( + self, + name: str, + task: TaskWrapper, + *, + after: str | list[str] | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + queue: str | None = None, + max_retries: int | None = None, + timeout_ms: int | None = None, + priority: int | None = None, + fan_out: str | None = None, + fan_in: str | None = None, + condition: str | Callable | None = None, + ) -> Workflow: + """Add a step to the workflow. + + Args: + name: Unique name for this step within the workflow. + task: The registered taskito task (a ``TaskWrapper``). + after: Name(s) of predecessor steps. All must already be added. + args: Positional arguments passed to the task. + kwargs: Keyword arguments passed to the task. + queue: Override the queue for this step. + max_retries: Override the max retry count for this step. + timeout_ms: Override the timeout (in milliseconds) for this step. + priority: Override the priority for this step. + fan_out: Fan-out strategy. ``"each"`` splits the predecessor's + return value into one job per element. + fan_in: Fan-in strategy. ``"all"`` collects all fan-out children's + results into a list passed to this step. + + Returns: + ``self`` for chaining. + + Raises: + ValueError: If the name is already in use, ``after`` references + a step not yet added, or fan-out/fan-in configuration is invalid. + """ + if name in self._steps: + raise ValueError(f"step '{name}' already defined") + if "[" in name: + raise ValueError( + f"step '{name}': step names must not contain '[' (reserved for fan-out children)" + ) + + task_name = getattr(task, "_task_name", None) or getattr(task, "name", None) + if not task_name: + raise ValueError(f"step '{name}': task must be a registered @queue.task() function") + + if after is None: + predecessors: list[str] = [] + elif isinstance(after, str): + predecessors = [after] + else: + predecessors = list(after) + + for pred in predecessors: + if pred not in self._steps: + raise ValueError( + f"step '{name}': predecessor '{pred}' must be added before this step" + ) + + if fan_out is not None: + if fan_out not in _VALID_FAN_OUT: + valid = sorted(_VALID_FAN_OUT) + raise ValueError(f"step '{name}': fan_out must be one of {valid}, got '{fan_out}'") + if len(predecessors) != 1: + raise ValueError(f"step '{name}': fan_out step must have exactly one predecessor") + + if fan_in is not None: + if fan_in not in _VALID_FAN_IN: + raise ValueError( + f"step '{name}': fan_in must be one of {sorted(_VALID_FAN_IN)}, got '{fan_in}'" + ) + if len(predecessors) != 1: + raise ValueError(f"step '{name}': fan_in step must have exactly one predecessor") + pred_step = self._steps[predecessors[0]] + if pred_step.fan_out is None: + raise ValueError( + f"step '{name}': fan_in predecessor '{predecessors[0]}' must have fan_out set" + ) + + is_invalid_condition = ( + condition is not None + and not callable(condition) + and condition not in _VALID_CONDITIONS + ) + if is_invalid_condition: + raise ValueError( + f"step '{name}': condition must be one of " + f"{sorted(_VALID_CONDITIONS)} or a callable, got '{condition}'" + ) + + sub_wf = task if isinstance(task, SubWorkflowRef) else None + + self._steps[name] = _Step( + name=name, + task_name=task_name, + after=predecessors, + args=args, + kwargs=kwargs or {}, + queue=queue, + max_retries=max_retries, + timeout_ms=timeout_ms, + priority=priority, + fan_out=fan_out, + fan_in=fan_in, + condition=condition, + sub_workflow=sub_wf, + ) + return self + + def gate( + self, + name: str, + *, + after: str | list[str] | None = None, + condition: str | Callable | None = None, + timeout: float | None = None, + on_timeout: str = "reject", + message: str | Callable | None = None, + ) -> Workflow: + """Add an approval gate step. + + The workflow pauses at this node until + :meth:`~taskito.workflows.mixins.QueueWorkflowMixin.approve_gate` + or + :meth:`~taskito.workflows.mixins.QueueWorkflowMixin.reject_gate` + is called. + + Args: + name: Unique name for this gate. + after: Predecessor step name(s). + condition: Optional condition for entering the gate. + timeout: Seconds until auto-resolve (``None`` = wait forever). + on_timeout: ``"approve"`` or ``"reject"`` when timeout fires. + message: Human-readable message for approvers. + """ + if name in self._steps: + raise ValueError(f"step '{name}' already defined") + if "[" in name: + raise ValueError(f"step '{name}': names must not contain '['") + if on_timeout not in ("approve", "reject"): + raise ValueError(f"gate '{name}': on_timeout must be 'approve' or 'reject'") + + if after is None: + predecessors: list[str] = [] + elif isinstance(after, str): + predecessors = [after] + else: + predecessors = list(after) + for pred in predecessors: + if pred not in self._steps: + raise ValueError(f"gate '{name}': predecessor '{pred}' must be added first") + + self._steps[name] = _Step( + name=name, + task_name="__gate__", + after=predecessors, + condition=condition, + gate_config=GateConfig( + timeout=timeout, + on_timeout=on_timeout, + message=message, + ), + ) + return self + + @property + def step_names(self) -> list[str]: + return list(self._steps.keys()) + + # ── Pre-execution analysis ───────────────────────────────────── + + def ancestors(self, node: str) -> list[str]: + """Return all transitive predecessors of *node*.""" + from .analysis import ancestors + + return ancestors(self._steps, node) + + def descendants(self, node: str) -> list[str]: + """Return all transitive successors of *node*.""" + from .analysis import descendants + + return descendants(self._steps, node) + + def topological_levels(self) -> list[list[str]]: + """Group nodes by topological depth.""" + from .analysis import topological_levels + + return topological_levels(self._steps) + + def stats(self) -> dict[str, int | float]: + """Compute basic DAG statistics (nodes, edges, depth, width, density).""" + from .analysis import stats + + return stats(self._steps) + + def critical_path(self, costs: dict[str, float]) -> tuple[list[str], float]: + """Find the longest-weighted path through the DAG. + + Args: + costs: Mapping of step name to estimated duration. + + Returns: + ``(path, total_cost)`` + """ + from .analysis import critical_path + + return critical_path(self._steps, costs) + + def execution_plan(self, max_workers: int = 1) -> list[list[str]]: + """Generate a step-by-step execution plan respecting worker limits.""" + from .analysis import execution_plan + + return execution_plan(self._steps, max_workers) + + def bottleneck_analysis(self, costs: dict[str, float]) -> dict[str, Any]: + """Identify the bottleneck node on the critical path.""" + from .analysis import bottleneck_analysis + + return bottleneck_analysis(self._steps, costs) + + def visualize(self, fmt: str = "mermaid") -> str: + """Render the workflow DAG as a diagram string. + + Args: + fmt: Output format — ``"mermaid"`` or ``"dot"``. + + Returns: + The diagram string (no statuses — pre-execution view). + """ + from .visualization import ( + nodes_and_edges_from_steps, + render_dot, + render_mermaid, + ) + + nodes, edges = nodes_and_edges_from_steps(self._steps) + if fmt == "dot": + return render_dot(nodes, edges) + return render_mermaid(nodes, edges) + + def _compile( + self, queue: Queue + ) -> tuple[ + bytes, + str, + dict[str, bytes], + list[str], + dict[str, Any], + str, + dict[str, GateConfig], + dict[str, SubWorkflowRef], + ]: + """Produce compile output for ``Queue.submit_workflow``. + + Returns: + ``(dag_bytes, step_metadata_json, node_payloads, deferred_nodes, + callable_conditions, on_failure, gate_configs, sub_workflow_refs)`` + """ + from taskito._taskito import PyWorkflowBuilder + + builder = PyWorkflowBuilder() + node_payloads: dict[str, bytes] = {} + callable_conditions: dict[str, Any] = {} + gate_configs: dict[str, GateConfig] = {} + sub_workflow_refs: dict[str, SubWorkflowRef] = {} + + has_conditions = any(s.condition is not None for s in self._steps.values()) + has_gates = any(s.gate_config is not None for s in self._steps.values()) + has_sub_wf = any(s.sub_workflow is not None for s in self._steps.values()) + is_continue = self.on_failure != "fail_fast" + + # Compute the set of deferred nodes. + deferred: set[str] = set() + for step in self._steps.values(): + if ( + step.fan_out is not None + or step.fan_in is not None + or step.condition is not None + or step.gate_config is not None + or step.sub_workflow is not None + ) or ((has_conditions or is_continue or has_gates or has_sub_wf) and step.after): + deferred.add(step.name) + # Propagate: a step is deferred if any predecessor is deferred. + changed = True + while changed: + changed = False + for step in self._steps.values(): + if step.name in deferred: + continue + if any(pred in deferred for pred in step.after): + deferred.add(step.name) + changed = True + + for step in self._steps.values(): + str_condition: str | None = None + if isinstance(step.condition, str): + str_condition = step.condition + elif callable(step.condition): + callable_conditions[step.name] = step.condition + str_condition = "callable" + + if step.gate_config is not None: + gate_configs[step.name] = step.gate_config + if step.sub_workflow is not None: + sub_workflow_refs[step.name] = step.sub_workflow + + builder.add_step( + step.name, + step.task_name, + step.after if step.after else None, + step.queue, + step.max_retries, + step.timeout_ms, + step.priority, + None, + None, + step.fan_out, + step.fan_in, + str_condition, + ) + # Gate, fan-out, fan-in, and sub-workflow steps have no payload. + if ( + step.gate_config is None + and step.fan_out is None + and step.fan_in is None + and step.sub_workflow is None + ): + serializer = queue._get_serializer(step.task_name) + node_payloads[step.name] = serializer.dumps((step.args, step.kwargs)) + + dag_bytes, step_metadata_json = builder.serialize() + return ( + dag_bytes, + step_metadata_json, + node_payloads, + sorted(deferred), + callable_conditions, + self.on_failure, + gate_configs, + sub_workflow_refs, + ) + + +@dataclass +class SubWorkflowRef: + """Marker returned by :meth:`WorkflowProxy.as_step`.""" + + proxy: WorkflowProxy + params: dict[str, Any] = field(default_factory=dict) + + # Duck-type as a task so Workflow.step() accepts it. + @property + def _task_name(self) -> str: + return f"__subworkflow__{self.proxy._name}" + + +class WorkflowProxy: + """Returned by ``@queue.workflow()`` — callable that builds and submits.""" + + _is_workflow_proxy: bool = True + + def __init__( + self, + queue: Queue, + name: str, + version: int, + factory: Callable[..., Workflow], + ): + self._queue = queue + self._name = name + self._version = version + self._factory = factory + + def as_step(self, **params: Any) -> SubWorkflowRef: + """Return a reference that can be passed to ``Workflow.step()``.""" + return SubWorkflowRef(proxy=self, params=params) + + def build(self, *args: Any, **kwargs: Any) -> Workflow: + """Materialize the workflow without submitting it.""" + wf = self._factory(*args, **kwargs) + if not isinstance(wf, Workflow): + raise TypeError(f"@queue.workflow('{self._name}') factory must return a Workflow") + wf.name = self._name + wf.version = self._version + return wf + + def submit(self, *args: Any, **kwargs: Any) -> WorkflowRun: + """Build and submit the workflow in one call.""" + wf = self.build(*args, **kwargs) + return self._queue.submit_workflow(wf) + + def __call__(self, *args: Any, **kwargs: Any) -> Workflow: + return self.build(*args, **kwargs) diff --git a/py_src/taskito/workflows/context.py b/py_src/taskito/workflows/context.py new file mode 100644 index 0000000..c9e6d74 --- /dev/null +++ b/py_src/taskito/workflows/context.py @@ -0,0 +1,33 @@ +"""Workflow context passed to callable conditions.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass(frozen=True) +class WorkflowContext: + """Runtime context available to callable condition functions. + + Example:: + + wf.step("deploy", deploy_task, after="validate", + condition=lambda ctx: ctx.results["validate"]["score"] > 0.95) + """ + + run_id: str + results: dict[str, Any] = field(default_factory=dict) + """Deserialized return values of completed predecessor nodes.""" + + statuses: dict[str, str] = field(default_factory=dict) + """Status strings for all terminal nodes (completed/failed/skipped).""" + + params: dict[str, Any] | None = None + """Workflow-level parameters (from ``submit_workflow``).""" + + failure_count: int = 0 + """Number of nodes with status ``"failed"``.""" + + success_count: int = 0 + """Number of nodes with status ``"completed"``.""" diff --git a/py_src/taskito/workflows/fan_out.py b/py_src/taskito/workflows/fan_out.py new file mode 100644 index 0000000..7cc4e3f --- /dev/null +++ b/py_src/taskito/workflows/fan_out.py @@ -0,0 +1,41 @@ +"""Fan-out splitting and fan-in aggregation strategies.""" + +from __future__ import annotations + +from typing import Any + +from taskito.serializers import Serializer + + +def apply_fan_out(strategy: str, result: Any) -> list[Any]: + """Split a step's return value into individual fan-out items. + + Args: + strategy: The fan-out strategy (``"each"``). + result: The predecessor step's return value. + + Returns: + A list of items, one per fan-out child job. + + Raises: + TypeError: If the result is not iterable for the given strategy. + ValueError: If the strategy is unknown. + """ + if strategy == "each": + try: + return list(result) + except TypeError as exc: + raise TypeError( + f"fan_out='each' requires an iterable result, got {type(result).__name__}" + ) from exc + raise ValueError(f"unknown fan_out strategy: {strategy!r}") + + +def build_child_payload(item: Any, serializer: Serializer) -> bytes: + """Serialize a single fan-out item as ``((item,), {})``.""" + return serializer.dumps(((item,), {})) + + +def build_fan_in_payload(results: list[Any], serializer: Serializer) -> bytes: + """Serialize collected results as ``((results,), {})``.""" + return serializer.dumps(((results,), {})) diff --git a/py_src/taskito/workflows/incremental.py b/py_src/taskito/workflows/incremental.py new file mode 100644 index 0000000..5814fda --- /dev/null +++ b/py_src/taskito/workflows/incremental.py @@ -0,0 +1,69 @@ +"""Dirty-set computation for incremental workflow runs.""" + +from __future__ import annotations + +import time + + +def compute_dirty_set( + base_nodes: list[tuple[str, str, str | None]], + new_node_names: list[str], + successors: dict[str, list[str]], + predecessors: dict[str, list[str]], + cache_ttl: float | None = None, + base_run_completed_at: int | None = None, +) -> tuple[set[str], dict[str, str]]: + """Determine which nodes are dirty and which are cache hits. + + Args: + base_nodes: ``[(node_name, status, result_hash)]`` from the base run. + new_node_names: Names of nodes in the new run. + successors: DAG successor map. + predecessors: DAG predecessor map. + cache_ttl: Optional TTL in seconds. If set and the base run is older + than this, all nodes are considered dirty. + base_run_completed_at: Timestamp (ms) when the base run completed. + + Returns: + ``(dirty_nodes, cache_hit_nodes)`` where ``cache_hit_nodes`` maps + node name to the result_hash to copy. + """ + # Check TTL expiration. + if cache_ttl is not None and base_run_completed_at is not None: + age_seconds = (time.time() * 1000 - base_run_completed_at) / 1000 + if age_seconds > cache_ttl: + return set(new_node_names), {} + + # Build lookup from base run. + base_lookup: dict[str, tuple[str, str | None]] = {} + for name, status, result_hash in base_nodes: + base_lookup[name] = (status, result_hash) + + new_set = set(new_node_names) + dirty: set[str] = set() + cached: dict[str, str] = {} + + # First pass: mark nodes as dirty or cached based on base status. + for name in new_node_names: + base = base_lookup.get(name) + if base is None: + dirty.add(name) + continue + status, result_hash = base + if status == "completed" and result_hash is not None: + cached[name] = result_hash + else: + dirty.add(name) + + # Second pass: propagate dirty downstream. + changed = True + while changed: + changed = False + for name in list(cached.keys()): + preds = predecessors.get(name, []) + if any(p in dirty for p in preds if p in new_set): + dirty.add(name) + del cached[name] + changed = True + + return dirty, cached diff --git a/py_src/taskito/workflows/mixins.py b/py_src/taskito/workflows/mixins.py new file mode 100644 index 0000000..d19e0bb --- /dev/null +++ b/py_src/taskito/workflows/mixins.py @@ -0,0 +1,163 @@ +"""Mixin that adds workflow operations to :class:`taskito.app.Queue`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from .builder import Workflow, WorkflowProxy +from .run import WorkflowRun + +if TYPE_CHECKING: + from collections.abc import Callable + + +class QueueWorkflowMixin: + """Adds workflow APIs to Queue. + + Mixed into ``Queue`` unconditionally — the underlying native methods are + only present when the ``workflows`` feature is compiled in, so calling + these methods without the feature will raise ``AttributeError``. + """ + + _inner: Any + _workflow_registry: dict[str, WorkflowProxy] + + def submit_workflow( + self, + workflow: Workflow, + *, + incremental: bool = False, + base_run: str | None = None, + ) -> WorkflowRun: + """Submit a built :class:`Workflow` for execution. + + Args: + workflow: The workflow to submit. + incremental: If ``True``, skip nodes that completed in *base_run*. + base_run: Run ID of a prior run to use for cache comparison. + + Static step jobs are created up front with ``depends_on`` chains. + Deferred nodes (fan-out, fan-in, conditions, ``on_failure="continue"``) + get ``WorkflowNode`` entries only — their jobs are created at runtime + by the :class:`~taskito.workflows.tracker.WorkflowTracker`. + """ + ( + dag_bytes, + step_metadata_json, + node_payloads, + deferred_nodes, + callable_conditions, + on_failure, + gate_configs, + sub_workflow_refs, + ) = workflow._compile(self) # type: ignore[arg-type] + + # Compute cache hits for incremental runs. + cache_hit_nodes: dict[str, str] | None = None + if incremental and base_run: + from .incremental import compute_dirty_set + from .tracker import _build_dag_maps + + base_nodes = self._inner.get_base_run_node_data(base_run) + _, preds = _build_dag_maps(dag_bytes) + succs, _ = _build_dag_maps(dag_bytes) + + # Get base run completion time for TTL check. + base_status = self._inner.get_workflow_run_status(base_run) + base_completed = base_status.completed_at + + _, cached = compute_dirty_set( + base_nodes=base_nodes, + new_node_names=list(node_payloads.keys()), + successors=succs, + predecessors=preds, + cache_ttl=workflow.cache_ttl, + base_run_completed_at=base_completed, + ) + if cached: + cache_hit_nodes = cached + + handle = self._inner.submit_workflow( + workflow.name, + workflow.version, + dag_bytes, + step_metadata_json, + node_payloads, + "default", + None, + deferred_nodes if deferred_nodes else None, + None, # parent_run_id + None, # parent_node_name + cache_hit_nodes, + ) + + # Register with the tracker when the workflow needs Python-side + # orchestration (deferred nodes, conditions, gates, or continue mode). + tracker = getattr(self, "_workflow_tracker", None) + needs_tracker = ( + bool(deferred_nodes) + or bool(callable_conditions) + or bool(gate_configs) + or bool(sub_workflow_refs) + or on_failure != "fail_fast" + ) + if tracker is not None and needs_tracker: + deferred_payloads = { + name: node_payloads[name] for name in deferred_nodes if name in node_payloads + } + tracker.register_run( + handle.run_id, + step_metadata_json, + dag_bytes, + deferred_nodes, + deferred_payloads, + on_failure=on_failure, + callable_conditions=callable_conditions, + gate_configs=gate_configs, + sub_workflow_refs=sub_workflow_refs, + ) + + return WorkflowRun(self, handle.run_id, handle.name) # type: ignore[arg-type] + + def approve_gate(self, run_id: str, node_name: str) -> None: + """Approve an approval gate, allowing the workflow to continue.""" + tracker = getattr(self, "_workflow_tracker", None) + if tracker is None: + raise RuntimeError("workflow tracker not available") + tracker.resolve_gate(run_id, node_name, approved=True) + + def reject_gate(self, run_id: str, node_name: str, error: str = "rejected") -> None: + """Reject an approval gate, failing the gate node.""" + tracker = getattr(self, "_workflow_tracker", None) + if tracker is None: + raise RuntimeError("workflow tracker not available") + tracker.resolve_gate(run_id, node_name, approved=False, error=error) + + def workflow( + self, + name: str | None = None, + *, + version: int = 1, + ) -> Callable[[Callable[..., Workflow]], WorkflowProxy]: + """Decorator that registers a workflow factory. + + Example:: + + @queue.workflow("nightly_etl") + def etl() -> Workflow: + wf = Workflow() + wf.step("extract", extract_task) + wf.step("load", load_task, after="extract") + return wf + + run = etl.submit() + run.wait() + """ + + def decorator(factory: Callable[..., Workflow]) -> WorkflowProxy: + wf_name = name or factory.__name__ + proxy = WorkflowProxy(self, wf_name, version, factory) # type: ignore[arg-type] + self._workflow_registry[wf_name] = proxy + return proxy + + return decorator diff --git a/py_src/taskito/workflows/run.py b/py_src/taskito/workflows/run.py new file mode 100644 index 0000000..5e92a95 --- /dev/null +++ b/py_src/taskito/workflows/run.py @@ -0,0 +1,154 @@ +"""Python-side workflow run handle. + +Wraps a Rust ``PyWorkflowHandle`` with high-level status/wait/cancel +operations. ``wait()`` uses a threading.Event registered with the workflow +tracker for O(1) resolution when the run reaches a terminal state, with a +polling safety net. +""" + +from __future__ import annotations + +import threading +import time +from typing import TYPE_CHECKING + +from .types import NodeSnapshot, NodeStatus, WorkflowState, WorkflowStatus + +if TYPE_CHECKING: + from taskito.app import Queue + + +class WorkflowTimeoutError(TimeoutError): + """Raised when ``WorkflowRun.wait`` times out before the run finishes.""" + + +class WorkflowRun: + """Handle for a submitted workflow run.""" + + def __init__(self, queue: Queue, run_id: str, name: str): + self._queue = queue + self.id = run_id + self.name = name + + def status(self) -> WorkflowStatus: + """Fetch the current state snapshot of this workflow run.""" + raw = self._queue._inner.get_workflow_run_status(self.id) + return _raw_to_status(raw) + + def node_status(self, node_name: str) -> NodeStatus: + """Shortcut for ``status().nodes[node_name].status``.""" + snapshot = self.status() + node = snapshot.nodes.get(node_name) + if node is None: + raise KeyError(f"node '{node_name}' not found in workflow run {self.id}") + return node.status + + def wait(self, timeout: float | None = None, poll_interval: float = 0.1) -> WorkflowStatus: + """Block until this workflow run reaches a terminal state. + + Args: + timeout: Max seconds to wait. ``None`` = wait forever. + poll_interval: How often to re-check status as a safety net against + missed events. + + Returns: + The terminal :class:`WorkflowStatus`. + + Raises: + WorkflowTimeoutError: If the workflow did not finish within ``timeout``. + """ + tracker = getattr(self._queue, "_workflow_tracker", None) + event: threading.Event | None = None + if tracker is not None: + event = tracker.register_wait(self.id) + + try: + deadline = None if timeout is None else time.monotonic() + timeout + while True: + snapshot = self.status() + if snapshot.state.is_terminal(): + return snapshot + + remaining: float + if deadline is None: + remaining = poll_interval + else: + remaining = max(0.0, deadline - time.monotonic()) + if remaining == 0.0: + raise WorkflowTimeoutError( + f"workflow run {self.id} did not complete within {timeout}s" + ) + remaining = min(poll_interval, remaining) + + if event is not None: + if event.wait(timeout=remaining): + return self.status() + else: + time.sleep(remaining) + finally: + if tracker is not None and event is not None: + tracker.unregister_wait(self.id, event) + + def cancel(self) -> None: + """Cancel any pending steps and mark the run as cancelled.""" + self._queue._inner.cancel_workflow_run(self.id) + + def visualize(self, fmt: str = "mermaid") -> str: + """Render the workflow DAG with live node statuses. + + Args: + fmt: Output format — ``"mermaid"`` or ``"dot"``. + """ + from .visualization import ( + nodes_and_edges_from_dag_bytes, + render_dot, + render_mermaid, + ) + + dag_bytes = self._queue._inner.get_workflow_definition_dag(self.id) + nodes, edges = nodes_and_edges_from_dag_bytes(dag_bytes) + + snapshot = self.status() + statuses: dict[str, str] = {} + for name, node in snapshot.nodes.items(): + statuses[name] = node.status.value + + if fmt == "dot": + return render_dot(nodes, edges, statuses) + return render_mermaid(nodes, edges, statuses) + + def __repr__(self) -> str: + return f"WorkflowRun(id={self.id!r}, name={self.name!r})" + + +def _raw_to_status(raw: object) -> WorkflowStatus: + """Convert a ``PyWorkflowRunStatus`` into the high-level dataclass.""" + node_dict = raw.node_statuses() # type: ignore[attr-defined] + nodes: dict[str, NodeSnapshot] = {} + for node_name, entry in node_dict.items(): + status_str = entry.get("status", "pending") + try: + node_status = NodeStatus(status_str) + except ValueError: + node_status = NodeStatus.PENDING + nodes[node_name] = NodeSnapshot( + name=node_name, + status=node_status, + job_id=entry.get("job_id"), + error=entry.get("error"), + ) + + state_str: str = raw.state # type: ignore[attr-defined] + try: + state = WorkflowState(state_str) + except ValueError: + state = WorkflowState.PENDING + + return WorkflowStatus( + run_id=raw.run_id, # type: ignore[attr-defined] + state=state, + started_at=raw.started_at, # type: ignore[attr-defined] + completed_at=raw.completed_at, # type: ignore[attr-defined] + error=raw.error, # type: ignore[attr-defined] + nodes=nodes, + ) diff --git a/py_src/taskito/workflows/tracker.py b/py_src/taskito/workflows/tracker.py new file mode 100644 index 0000000..4bcc718 --- /dev/null +++ b/py_src/taskito/workflows/tracker.py @@ -0,0 +1,820 @@ +"""Workflow completion tracker. + +Subscribes to terminal job events (``JOB_COMPLETED``, ``JOB_FAILED``, +``JOB_DEAD``, ``JOB_CANCELLED``) and forwards workflow-related ones to the +Rust ``mark_workflow_node_result`` entry point. When a run reaches a terminal +state, emits a workflow-level event and releases any threads blocked on +``WorkflowRun.wait``. + +For workflows that contain fan-out / fan-in steps, conditions, or +``on_failure="continue"``, the tracker orchestrates dynamic job creation, +condition evaluation, and selective skip propagation. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import threading +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from taskito.events import EventType + +from .context import WorkflowContext +from .fan_out import apply_fan_out, build_child_payload, build_fan_in_payload + +if TYPE_CHECKING: + from collections.abc import Callable + + from taskito.app import Queue + +logger = logging.getLogger("taskito.workflows") + + +@dataclass +class _RunConfig: + """In-memory configuration for a tracker-managed workflow run.""" + + step_metadata: dict[str, dict[str, Any]] + successors: dict[str, list[str]] + predecessors: dict[str, list[str]] + deferred_nodes: set[str] + deferred_payloads: dict[str, bytes] + on_failure: str + callable_conditions: dict[str, Callable[..., bool]] + gate_configs: dict[str, Any] + sub_workflow_refs: dict[str, Any] + + +class WorkflowTracker: + """Bridges taskito job events to workflow run state transitions.""" + + def __init__(self, queue: Queue): + self._queue = queue + self._waiters_lock = threading.Lock() + self._waiters: dict[str, list[threading.Event]] = {} + self._event_bus = queue._event_bus + self._run_configs: dict[str, _RunConfig] = {} + self._job_to_run: dict[str, str] = {} + self._gate_timers: dict[tuple[str, str], threading.Timer] = {} + self._child_to_parent: dict[str, tuple[str, str]] = {} + self._install_listeners() + + def _install_listeners(self) -> None: + self._event_bus.on(EventType.JOB_COMPLETED, self._on_success) + self._event_bus.on(EventType.JOB_FAILED, self._on_failure) + self._event_bus.on(EventType.JOB_DEAD, self._on_failure) + self._event_bus.on(EventType.JOB_CANCELLED, self._on_cancelled) + # Sub-workflow completion events. + self._event_bus.on(EventType.WORKFLOW_COMPLETED, self._on_child_workflow_terminal) + self._event_bus.on(EventType.WORKFLOW_FAILED, self._on_child_workflow_terminal) + self._event_bus.on(EventType.WORKFLOW_CANCELLED, self._on_child_workflow_terminal) + + # ── Wait management ──────────────────────────────────────────── + + def register_wait(self, run_id: str) -> threading.Event: + """Return an event that will be set when ``run_id`` reaches a terminal state.""" + event = threading.Event() + with self._waiters_lock: + self._waiters.setdefault(run_id, []).append(event) + return event + + def unregister_wait(self, run_id: str, event: threading.Event) -> None: + with self._waiters_lock: + entries = self._waiters.get(run_id) + if not entries: + return + try: + entries.remove(event) + except ValueError: + return + if not entries: + self._waiters.pop(run_id, None) + + def _release_waiters(self, run_id: str) -> None: + with self._waiters_lock: + entries = self._waiters.pop(run_id, []) + for event in entries: + event.set() + + # ── Dynamic run registration ─────────────────────────────────── + + def register_run( + self, + run_id: str, + step_metadata_json: str, + dag_bytes: bytes | list[int], + deferred_nodes: list[str], + deferred_payloads: dict[str, bytes], + on_failure: str = "fail_fast", + callable_conditions: dict[str, Callable[..., bool]] | None = None, + gate_configs: dict[str, Any] | None = None, + sub_workflow_refs: dict[str, Any] | None = None, + ) -> None: + """Cache configuration for a tracker-managed workflow run.""" + meta: dict[str, dict[str, Any]] = json.loads(step_metadata_json) + successors, predecessors = _build_dag_maps(dag_bytes) + config = _RunConfig( + step_metadata=meta, + successors=successors, + predecessors=predecessors, + deferred_nodes=set(deferred_nodes), + deferred_payloads=deferred_payloads, + on_failure=on_failure, + callable_conditions=callable_conditions or {}, + gate_configs=gate_configs or {}, + sub_workflow_refs=sub_workflow_refs or {}, + ) + self._run_configs[run_id] = config + + # Populate job→run mapping for initial nodes. + try: + raw = self._queue._inner.get_workflow_run_status(run_id) + for _name, info in raw.node_statuses().items(): + jid = info.get("job_id") + if jid: + self._job_to_run[jid] = run_id + except Exception: # pragma: no cover + logger.exception("failed to populate job→run mapping for %s", run_id) + + # Evaluate root deferred nodes (those with no predecessors). + self._evaluate_root_deferred(run_id, config) + + def _evaluate_root_deferred(self, run_id: str, config: _RunConfig) -> None: + """Immediately evaluate deferred root nodes (no predecessors).""" + for node_name in list(config.deferred_nodes): + preds = config.predecessors.get(node_name, []) + if preds: + continue # Not a root node. + meta = config.step_metadata.get(node_name) + if meta is None: + continue + if node_name in config.gate_configs: + self._enter_gate(run_id, node_name, config) + elif node_name in config.sub_workflow_refs: + self._submit_sub_workflow(run_id, node_name, config) + elif meta.get("fan_out") is None and meta.get("fan_in") is None: + self._create_deferred_job_for_node(run_id, node_name, config) + + # ── Event handlers ───────────────────────────────────────────── + + def _on_success(self, _event_type: EventType, payload: dict[str, Any]) -> None: + self._handle(payload, succeeded=True, error=None) + + def _on_failure(self, _event_type: EventType, payload: dict[str, Any]) -> None: + self._handle(payload, succeeded=False, error=payload.get("error")) + + def _on_cancelled(self, _event_type: EventType, payload: dict[str, Any]) -> None: + self._handle(payload, succeeded=False, error="cancelled") + + def _handle( + self, + payload: dict[str, Any], + *, + succeeded: bool, + error: str | None, + ) -> None: + job_id = payload.get("job_id") + if not job_id: + return + + # Determine if this job belongs to a managed run. + run_id = self._job_to_run.get(job_id) + config = self._run_configs.get(run_id) if run_id else None + skip_cascade = config is not None + + # Compute result hash for successful completions. + rh: str | None = None + if succeeded: + rh = self._compute_result_hash(job_id) + + try: + result = self._queue._inner.mark_workflow_node_result( + job_id, succeeded, error, skip_cascade, rh + ) + except Exception: # pragma: no cover - defensive + logger.exception("mark_workflow_node_result failed for job %s", job_id) + return + + if result is None: + return + + run_id, node_name, terminal_state = result + + if terminal_state is not None: + self._emit_terminal(run_id, terminal_state, error) + self._cleanup_run(run_id) + return + + # Re-fetch config now that we have the definitive run_id. + if config is None: + config = self._run_configs.get(run_id) + if config is None: + return # Static workflow — Rust cascade handled everything. + + # Fan-out child handling. + if "[" in node_name: + if succeeded: + self._handle_fan_out_child(run_id, node_name, config) + else: + self._handle_fan_out_child_failure(run_id, node_name, config) + return + + # Fan-out expansion trigger (only on success). + if succeeded: + self._maybe_trigger_fan_out(run_id, node_name, job_id, config) + + # Evaluate successors with conditions. + self._evaluate_successors(run_id, node_name, config) + + # ── Terminal state ───────────────────────────────────────────── + + def _emit_terminal(self, run_id: str, terminal_state: str, error: str | None) -> None: + workflow_event = _final_state_to_event(terminal_state) + if workflow_event is not None: + try: + self._queue._emit_event( + workflow_event, + {"run_id": run_id, "state": terminal_state, "error": error}, + ) + except Exception: # pragma: no cover - defensive + logger.exception("failed to emit %s", workflow_event) + self._release_waiters(run_id) + + def _cleanup_run(self, run_id: str) -> None: + self._run_configs.pop(run_id, None) + self._job_to_run = {jid: rid for jid, rid in self._job_to_run.items() if rid != run_id} + + # ── Condition evaluation ─────────────────────────────────────── + + def _evaluate_successors(self, run_id: str, completed_node: str, config: _RunConfig) -> None: + """Evaluate and create/skip deferred successor nodes.""" + for successor in config.successors.get(completed_node, []): + if successor not in config.deferred_nodes: + continue + meta = config.step_metadata.get(successor) + if meta is None: + continue + + if not self._all_predecessors_terminal(run_id, successor, config): + continue + + # Fan-out/fan-in nodes: only skip them here (expansion is + # handled by _maybe_trigger_fan_out / _handle_fan_out_child). + if meta.get("fan_out") is not None or meta.get("fan_in") is not None: + if not self._should_execute(run_id, successor, config): + self._skip_and_propagate(run_id, successor, config) + continue + + if not self._should_execute(run_id, successor, config): + self._skip_and_propagate(run_id, successor, config) + continue + + # Gate nodes pause for approval instead of creating a job. + if successor in config.gate_configs: + self._enter_gate(run_id, successor, config) + # Sub-workflow nodes submit a child workflow. + elif successor in config.sub_workflow_refs: + self._submit_sub_workflow(run_id, successor, config) + else: + self._create_deferred_job_for_node(run_id, successor, config) + + # After evaluating successors, check if the run is now terminal. + self._try_finalize(run_id) + + def _should_execute(self, run_id: str, node_name: str, config: _RunConfig) -> bool: + """Decide whether a deferred node should execute based on its condition.""" + # Callable conditions take precedence. + callable_cond = config.callable_conditions.get(node_name) + if callable_cond is not None: + ctx = self._build_workflow_context(run_id, config) + try: + return bool(callable_cond(ctx)) + except Exception: + logger.exception("callable condition failed for %s", node_name) + return False + + meta = config.step_metadata.get(node_name, {}) + condition = meta.get("condition") + pred_statuses = self._get_predecessor_statuses(run_id, node_name, config) + + if condition is None or condition == "on_success": + return all(s == "completed" for s in pred_statuses.values()) + if condition == "on_failure": + return any(s == "failed" for s in pred_statuses.values()) + # "always" runs unconditionally; "callable" sentinel was handled above. + return bool(condition == "always") + + def _skip_and_propagate(self, run_id: str, node_name: str, config: _RunConfig) -> None: + """Mark a node as SKIPPED and recursively evaluate its successors.""" + try: + self._queue._inner.skip_workflow_node(run_id, node_name) + except Exception: + logger.exception("skip_workflow_node failed for %s", node_name) + return + config.deferred_nodes.discard(node_name) + # The skipped node is now terminal — its successors may be evaluable. + self._evaluate_successors(run_id, node_name, config) + + # ── Approval gates ────────────────────────────────────────────── + + def _enter_gate(self, run_id: str, node_name: str, config: _RunConfig) -> None: + """Transition a gate node to WAITING_APPROVAL and start timeout.""" + try: + self._queue._inner.set_workflow_node_waiting_approval(run_id, node_name) + except Exception: + logger.exception("set_workflow_node_waiting_approval failed for %s", node_name) + return + config.deferred_nodes.discard(node_name) + + gate = config.gate_configs[node_name] + try: + self._queue._emit_event( + EventType.WORKFLOW_GATE_REACHED, + { + "run_id": run_id, + "node_name": node_name, + "message": gate.message if isinstance(gate.message, str) else None, + }, + ) + except Exception: # pragma: no cover + logger.exception("failed to emit WORKFLOW_GATE_REACHED") + + if gate.timeout is not None and gate.timeout > 0: + timer = threading.Timer( + gate.timeout, + self._on_gate_timeout, + args=(run_id, node_name, gate.on_timeout), + ) + timer.daemon = True + timer.start() + self._gate_timers[(run_id, node_name)] = timer + + def resolve_gate( + self, + run_id: str, + node_name: str, + *, + approved: bool, + error: str | None = None, + ) -> None: + """Approve or reject a gate, resuming the workflow.""" + # Cancel any pending timeout timer. + timer = self._gate_timers.pop((run_id, node_name), None) + if timer is not None: + timer.cancel() + + config = self._run_configs.get(run_id) + + try: + self._queue._inner.resolve_workflow_gate(run_id, node_name, approved, error) + except Exception: + logger.exception("resolve_workflow_gate failed for %s", node_name) + return + + if config is not None: + self._evaluate_successors(run_id, node_name, config) + self._try_finalize(run_id) + + # ── Sub-workflows ────────────────────────────────────────────── + + def _submit_sub_workflow(self, run_id: str, node_name: str, config: _RunConfig) -> None: + """Submit a child workflow for a sub-workflow node.""" + ref = config.sub_workflow_refs.get(node_name) + if ref is None: # pragma: no cover + return + try: + child_wf = ref.proxy.build(**ref.params) + # Mark parent node as RUNNING. + self._queue._inner.set_workflow_node_waiting_approval(run_id, node_name) + # Override status to RUNNING (waiting_approval was just to set started_at). + self._queue._inner.skip_workflow_node(run_id, node_name) + # Actually, let me use a cleaner approach: just mark running via the + # node status update. Use the Rust set_workflow_node_fan_out_count + # trick (sets RUNNING). Or add a direct call. + except Exception: + logger.exception("failed to build sub-workflow for %s", node_name) + return + + try: + # Submit child workflow with parent linkage. + ( + dag_bytes, + meta_json, + payloads, + deferred, + callables, + on_failure, + gates, + sub_refs, + ) = child_wf._compile(self._queue) + + handle = self._queue._inner.submit_workflow( + child_wf.name, + child_wf.version, + dag_bytes, + meta_json, + payloads, + "default", + None, + deferred if deferred else None, + run_id, # parent_run_id + node_name, # parent_node_name + ) + + child_run_id = handle.run_id + self._child_to_parent[child_run_id] = (run_id, node_name) + + # Mark parent node as RUNNING (use fan_out_count trick). + self._queue._inner.set_workflow_node_fan_out_count(run_id, node_name, 1) + + # Register child with tracker if it has deferred nodes. + needs_child_tracker = ( + bool(deferred) + or bool(callables) + or bool(gates) + or bool(sub_refs) + or on_failure != "fail_fast" + ) + if needs_child_tracker: + child_payloads = {n: payloads[n] for n in deferred if n in payloads} + self.register_run( + child_run_id, + meta_json, + dag_bytes, + deferred, + child_payloads, + on_failure=on_failure, + callable_conditions=callables, + gate_configs=gates, + sub_workflow_refs=sub_refs, + ) + + config.deferred_nodes.discard(node_name) + except Exception: + logger.exception("submit sub-workflow failed for %s", node_name) + + def _on_child_workflow_terminal(self, _event_type: EventType, payload: dict[str, Any]) -> None: + """Handle child workflow completion → update parent node.""" + child_run_id = payload.get("run_id") + if not child_run_id: + return + parent_info = self._child_to_parent.pop(child_run_id, None) + if parent_info is None: + return # Not a sub-workflow child. + + parent_run_id, parent_node_name = parent_info + state = payload.get("state", "") + succeeded = state == "completed" + + try: + self._queue._inner.resolve_workflow_gate( + parent_run_id, + parent_node_name, + succeeded, + payload.get("error") if not succeeded else None, + ) + except Exception: + logger.exception( + "failed to update parent node %s for child %s", + parent_node_name, + child_run_id, + ) + return + + config = self._run_configs.get(parent_run_id) + if config is not None: + self._evaluate_successors(parent_run_id, parent_node_name, config) + self._try_finalize(parent_run_id) + + def _on_gate_timeout(self, run_id: str, node_name: str, action: str) -> None: + """Handle gate timeout expiry.""" + self._gate_timers.pop((run_id, node_name), None) + approved = action == "approve" + error = None if approved else "gate timeout" + self.resolve_gate(run_id, node_name, approved=approved, error=error) + + def _all_predecessors_terminal(self, run_id: str, node_name: str, config: _RunConfig) -> bool: + """Check whether all predecessors have a terminal status.""" + raw = self._queue._inner.get_workflow_run_status(run_id) + node_statuses = raw.node_statuses() + for pred in config.predecessors.get(node_name, []): + info = node_statuses.get(pred) + if info is None: + return False + status = info["status"] + if status not in ("completed", "failed", "skipped"): + return False + return True + + def _get_predecessor_statuses( + self, run_id: str, node_name: str, config: _RunConfig + ) -> dict[str, str]: + """Return ``{pred_name: status_str}`` for all predecessors.""" + raw = self._queue._inner.get_workflow_run_status(run_id) + node_statuses = raw.node_statuses() + result: dict[str, str] = {} + for pred in config.predecessors.get(node_name, []): + info = node_statuses.get(pred) + result[pred] = info["status"] if info else "pending" + return result + + def _build_workflow_context(self, run_id: str, config: _RunConfig) -> WorkflowContext: + """Build a :class:`WorkflowContext` from the current run state.""" + raw = self._queue._inner.get_workflow_run_status(run_id) + node_statuses = raw.node_statuses() + + results: dict[str, Any] = {} + statuses: dict[str, str] = {} + failure_count = 0 + success_count = 0 + + for name, info in node_statuses.items(): + status = info["status"] + statuses[name] = status + if status == "completed": + success_count += 1 + jid = info.get("job_id") + if jid: + results[name] = self._fetch_result(jid) + elif status == "failed": + failure_count += 1 + + return WorkflowContext( + run_id=run_id, + results=results, + statuses=statuses, + params=None, + failure_count=failure_count, + success_count=success_count, + ) + + def _fetch_result(self, job_id: str) -> Any: + """Fetch and deserialize a job's result, with polling for DB write lag.""" + for _ in range(50): + py_job = self._queue._inner.get_job(job_id) + if py_job is not None: + rb = py_job.result_bytes + if rb is not None: + return self._queue._serializer.loads(rb) + time.sleep(0.1) + return None + + def _compute_result_hash(self, job_id: str) -> str | None: + """Compute SHA-256 of a completed job's result bytes. + + Best-effort: if the result isn't stored yet (event fires before + DB write), returns ``None``. The hash is only used for incremental + caching, not correctness. + """ + py_job = self._queue._inner.get_job(job_id) + if py_job is not None: + rb = py_job.result_bytes + if rb is not None: + return hashlib.sha256(rb).hexdigest() + return None + + def _create_deferred_job_for_node( + self, run_id: str, node_name: str, config: _RunConfig + ) -> None: + """Create a job for a deferred node and record the mapping.""" + payload = config.deferred_payloads.get(node_name) + if payload is None: # pragma: no cover + logger.error("no cached payload for deferred node %s", node_name) + return + meta = config.step_metadata.get(node_name, {}) + task_name = meta["task_name"] + queue_name = meta.get("queue") or "default" + max_retries = _int_or(meta.get("max_retries"), 3) + timeout_ms = _int_or(meta.get("timeout_ms"), 300_000) + priority = _int_or(meta.get("priority"), 0) + + try: + job_id = self._queue._inner.create_deferred_job( + run_id, + node_name, + payload, + task_name, + queue_name, + max_retries, + timeout_ms, + priority, + ) + self._job_to_run[job_id] = run_id + config.deferred_nodes.discard(node_name) + except Exception: + logger.exception("create_deferred_job failed for %s", node_name) + + def _try_finalize(self, run_id: str) -> None: + """If all nodes are terminal, finalize the run and emit the event.""" + try: + terminal_state = self._queue._inner.finalize_run_if_terminal(run_id) + except Exception: + logger.exception("finalize_run_if_terminal failed for %s", run_id) + return + if terminal_state is not None: + self._emit_terminal(run_id, terminal_state, None) + self._cleanup_run(run_id) + + # ── Fan-out expansion ────────────────────────────────────────── + + def _maybe_trigger_fan_out( + self, + run_id: str, + source_node: str, + source_job_id: str, + config: _RunConfig, + ) -> None: + """If a completed node's successor has ``fan_out``, expand it.""" + for successor in config.successors.get(source_node, []): + meta = config.step_metadata.get(successor) + if meta is None or meta.get("fan_out") is None: + continue + self._expand_fan_out(run_id, source_job_id, successor, meta, config) + + def _expand_fan_out( + self, + run_id: str, + source_job_id: str, + fan_out_node: str, + meta: dict[str, Any], + config: _RunConfig, + ) -> None: + """Fetch the source result, split, and create child nodes + jobs.""" + result_bytes: bytes | None = None + for _ in range(50): + py_job = self._queue._inner.get_job(source_job_id) + if py_job is not None: + result_bytes = py_job.result_bytes + if result_bytes is not None: + break + time.sleep(0.1) + + if result_bytes is None: + source_result: Any = None + else: + source_result = self._queue._serializer.loads(result_bytes) + + strategy = meta["fan_out"] + items = apply_fan_out(strategy, source_result) + + task_name = meta["task_name"] + serializer = self._queue._get_serializer(task_name) + child_names = [f"{fan_out_node}[{i}]" for i in range(len(items))] + child_payloads = [build_child_payload(item, serializer) for item in items] + + queue_name = meta.get("queue") or "default" + max_retries = _int_or(meta.get("max_retries"), 3) + timeout_ms = _int_or(meta.get("timeout_ms"), 300_000) + priority = _int_or(meta.get("priority"), 0) + + try: + child_job_ids = self._queue._inner.expand_fan_out( + run_id, + fan_out_node, + child_names, + child_payloads, + task_name, + queue_name, + max_retries, + timeout_ms, + priority, + ) + for jid in child_job_ids: + self._job_to_run[jid] = run_id + except Exception: + logger.exception("expand_fan_out failed for %s in run %s", fan_out_node, run_id) + return + + # Empty fan-out: parent is immediately COMPLETED with 0 children. + if not child_names: + for successor in config.successors.get(fan_out_node, []): + succ_meta = config.step_metadata.get(successor) + if succ_meta is not None and succ_meta.get("fan_in") is not None: + self._create_fan_in_job(run_id, successor, succ_meta, [], config) + return + self._evaluate_successors(run_id, fan_out_node, config) + + # ── Fan-out child completion ─────────────────────────────────── + + def _handle_fan_out_child(self, run_id: str, child_name: str, config: _RunConfig) -> None: + """Check whether all siblings are done → trigger fan-in.""" + parent_name = child_name.split("[")[0] + try: + completion = self._queue._inner.check_fan_out_completion(run_id, parent_name) + except Exception: + logger.exception("check_fan_out_completion failed for %s", parent_name) + return + + if completion is None: + return + + all_succeeded, child_job_ids = completion + if not all_succeeded: + # Parent marked FAILED. Evaluate successors (on_failure may trigger). + self._evaluate_successors(run_id, parent_name, config) + self._try_finalize(run_id) + return + + # Trigger fan-in. + for successor in config.successors.get(parent_name, []): + meta = config.step_metadata.get(successor) + if meta is not None and meta.get("fan_in") is not None: + self._create_fan_in_job(run_id, successor, meta, child_job_ids, config) + return + + # No fan-in — evaluate deferred successors. + self._evaluate_successors(run_id, parent_name, config) + + def _handle_fan_out_child_failure( + self, run_id: str, child_name: str, config: _RunConfig + ) -> None: + """Handle a failed fan-out child.""" + parent_name = child_name.split("[")[0] + try: + completion = self._queue._inner.check_fan_out_completion(run_id, parent_name) + except Exception: + logger.exception("check_fan_out_completion failed for %s", parent_name) + return + + if completion is None: + return + + # Parent is marked FAILED. Evaluate successors for condition-based logic. + self._evaluate_successors(run_id, parent_name, config) + self._try_finalize(run_id) + + def _create_fan_in_job( + self, + run_id: str, + fan_in_node: str, + meta: dict[str, Any], + child_job_ids: list[str], + config: _RunConfig, + ) -> None: + """Collect children results and create the fan-in job.""" + results: list[Any] = [] + for job_id in child_job_ids: + results.append(self._fetch_result(job_id)) + + task_name = meta["task_name"] + serializer = self._queue._get_serializer(task_name) + payload = build_fan_in_payload(results, serializer) + + queue_name = meta.get("queue") or "default" + max_retries = _int_or(meta.get("max_retries"), 3) + timeout_ms = _int_or(meta.get("timeout_ms"), 300_000) + priority = _int_or(meta.get("priority"), 0) + + try: + job_id = self._queue._inner.create_deferred_job( + run_id, + fan_in_node, + payload, + task_name, + queue_name, + max_retries, + timeout_ms, + priority, + ) + self._job_to_run[job_id] = run_id + except Exception: + logger.exception("create_deferred_job failed for fan-in %s", fan_in_node) + + +# ── Helpers ──────────────────────────────────────────────────────── + + +def _build_dag_maps( + dag_bytes: bytes | list[int], +) -> tuple[dict[str, list[str]], dict[str, list[str]]]: + """Parse DAG JSON to build successor and predecessor maps.""" + raw = bytes(dag_bytes) if isinstance(dag_bytes, list) else dag_bytes + dag_json = json.loads(raw) + successors: dict[str, list[str]] = {} + predecessors: dict[str, list[str]] = {} + for node in dag_json.get("nodes", []): + name = node["name"] + successors.setdefault(name, []) + predecessors.setdefault(name, []) + for edge in dag_json.get("edges", []): + successors.setdefault(edge["from"], []).append(edge["to"]) + predecessors.setdefault(edge["to"], []).append(edge["from"]) + return successors, predecessors + + +def _int_or(value: Any, default: int) -> int: + return value if value is not None else default + + +def _final_state_to_event(state: str) -> EventType | None: + if state == "completed": + return EventType.WORKFLOW_COMPLETED + if state == "failed": + return EventType.WORKFLOW_FAILED + if state == "cancelled": + return EventType.WORKFLOW_CANCELLED + return None diff --git a/py_src/taskito/workflows/types.py b/py_src/taskito/workflows/types.py new file mode 100644 index 0000000..558ca53 --- /dev/null +++ b/py_src/taskito/workflows/types.py @@ -0,0 +1,63 @@ +"""Type definitions for workflow state and node snapshots.""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field + + +class WorkflowState(str, enum.Enum): + """Terminal and intermediate states of a workflow run.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + PAUSED = "paused" + + def is_terminal(self) -> bool: + return self in (WorkflowState.COMPLETED, WorkflowState.FAILED, WorkflowState.CANCELLED) + + +class NodeStatus(str, enum.Enum): + """Status of a single workflow node (step).""" + + PENDING = "pending" + READY = "ready" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + WAITING_APPROVAL = "waiting_approval" + CACHE_HIT = "cache_hit" + + def is_terminal(self) -> bool: + return self in ( + NodeStatus.COMPLETED, + NodeStatus.FAILED, + NodeStatus.SKIPPED, + NodeStatus.CACHE_HIT, + ) + + +@dataclass +class NodeSnapshot: + """Snapshot of a single workflow node's execution state.""" + + name: str + status: NodeStatus + job_id: str | None = None + error: str | None = None + + +@dataclass +class WorkflowStatus: + """Snapshot of a workflow run's overall state plus per-node details.""" + + run_id: str + state: WorkflowState + started_at: int | None = None + completed_at: int | None = None + error: str | None = None + nodes: dict[str, NodeSnapshot] = field(default_factory=dict) diff --git a/py_src/taskito/workflows/visualization.py b/py_src/taskito/workflows/visualization.py new file mode 100644 index 0000000..8118f35 --- /dev/null +++ b/py_src/taskito/workflows/visualization.py @@ -0,0 +1,133 @@ +"""Workflow DAG visualization in Mermaid and DOT formats.""" + +from __future__ import annotations + +from typing import Any + +_STATUS_COLORS_MERMAID = { + "completed": "#90EE90", + "failed": "#FFB6C1", + "running": "#87CEEB", + "pending": "#D3D3D3", + "skipped": "#F5F5F5", + "waiting_approval": "#FFFACD", + "ready": "#E0E0E0", +} + +_STATUS_COLORS_DOT = { + "completed": "lightgreen", + "failed": "lightcoral", + "running": "lightskyblue", + "pending": "lightgray", + "skipped": "whitesmoke", + "waiting_approval": "lightyellow", + "ready": "gainsboro", +} + +_STATUS_SYMBOLS = { + "completed": "\u2713", + "failed": "\u2717", + "running": "\u25b6", + "pending": "\u25cb", + "skipped": "\u2014", + "waiting_approval": "\u23f8", +} + + +def render_mermaid( + nodes: list[str], + edges: list[tuple[str, str]], + statuses: dict[str, str] | None = None, +) -> str: + """Render a DAG as a Mermaid graph string. + + Args: + nodes: List of node names. + edges: List of ``(from, to)`` tuples. + statuses: Optional mapping of node name to status string. + + Returns: + A Mermaid ``graph LR`` diagram string. + """ + lines = ["graph LR"] + statuses = statuses or {} + + for name in nodes: + status = statuses.get(name, "") + symbol = _STATUS_SYMBOLS.get(status, "") + label = f"{name} {symbol}".strip() + lines.append(f" {_safe_id(name)}[{label}]") + + for src, dst in edges: + lines.append(f" {_safe_id(src)} --> {_safe_id(dst)}") + + if statuses: + for name in nodes: + status = statuses.get(name, "") + color = _STATUS_COLORS_MERMAID.get(status) + if color: + lines.append(f" style {_safe_id(name)} fill:{color}") + + return "\n".join(lines) + + +def render_dot( + nodes: list[str], + edges: list[tuple[str, str]], + statuses: dict[str, str] | None = None, +) -> str: + """Render a DAG as a Graphviz DOT string. + + Args: + nodes: List of node names. + edges: List of ``(from, to)`` tuples. + statuses: Optional mapping of node name to status string. + + Returns: + A DOT ``digraph`` string. + """ + lines = ["digraph workflow {", " rankdir=LR;"] + statuses = statuses or {} + + for name in nodes: + status = statuses.get(name, "") + symbol = _STATUS_SYMBOLS.get(status, "") + label = f"{name} {symbol}".strip() + color = _STATUS_COLORS_DOT.get(status, "white") + lines.append(f' {_safe_id(name)} [label="{label}" style=filled fillcolor={color}];') + + for src, dst in edges: + lines.append(f" {_safe_id(src)} -> {_safe_id(dst)};") + + lines.append("}") + return "\n".join(lines) + + +def _safe_id(name: str) -> str: + """Make a node name safe for use as a Mermaid/DOT identifier.""" + return name.replace("[", "_").replace("]", "_").replace(" ", "_") + + +def nodes_and_edges_from_steps( + steps: dict[str, Any], +) -> tuple[list[str], list[tuple[str, str]]]: + """Extract node list and edge list from builder ``_steps``.""" + nodes = list(steps.keys()) + edges: list[tuple[str, str]] = [] + for name, step in steps.items(): + for pred in step.after: + edges.append((pred, name)) + return nodes, edges + + +def nodes_and_edges_from_dag_bytes( + dag_bytes: bytes | list[int], +) -> tuple[list[str], list[tuple[str, str]]]: + """Extract node list and edge list from serialized DAG JSON.""" + import json + + raw = bytes(dag_bytes) if isinstance(dag_bytes, list) else dag_bytes + dag = json.loads(raw) + nodes = [n["name"] for n in dag.get("nodes", [])] + edges = [(e["from"], e["to"]) for e in dag.get("edges", [])] + return nodes, edges From 2bd2cdc6710e9442c0623b1412162a10dc036236 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:46:33 +0530 Subject: [PATCH 03/12] test(workflows): add 74 tests across all workflow phases - test_workflows_linear: 15 tests (Phase 2) - test_workflows_fan_out: 11 tests (Phase 3) - test_workflows_conditions: 13 tests (Phase 4) - test_workflows_gates: 6 tests (Phase 5A) - test_workflows_subworkflow: 4 tests (Phase 5B) - test_workflows_cron: 1 test (Phase 5D) - test_workflows_analysis: 12 tests (Phase 6) - test_workflows_caching: 7 tests (Phase 7) - test_workflows_visualization: 5 tests (Phase 8) --- tests/python/test_events.py | 5 + tests/python/test_workflows_analysis.py | 137 ++++++ tests/python/test_workflows_caching.py | 249 +++++++++++ tests/python/test_workflows_conditions.py | 441 +++++++++++++++++++ tests/python/test_workflows_cron.py | 30 ++ tests/python/test_workflows_fan_out.py | 391 ++++++++++++++++ tests/python/test_workflows_gates.py | 183 ++++++++ tests/python/test_workflows_linear.py | 367 +++++++++++++++ tests/python/test_workflows_subworkflow.py | 178 ++++++++ tests/python/test_workflows_visualization.py | 106 +++++ 10 files changed, 2087 insertions(+) create mode 100644 tests/python/test_workflows_analysis.py create mode 100644 tests/python/test_workflows_caching.py create mode 100644 tests/python/test_workflows_conditions.py create mode 100644 tests/python/test_workflows_cron.py create mode 100644 tests/python/test_workflows_fan_out.py create mode 100644 tests/python/test_workflows_gates.py create mode 100644 tests/python/test_workflows_linear.py create mode 100644 tests/python/test_workflows_subworkflow.py create mode 100644 tests/python/test_workflows_visualization.py diff --git a/tests/python/test_events.py b/tests/python/test_events.py index bc13ae5..b77f08a 100644 --- a/tests/python/test_events.py +++ b/tests/python/test_events.py @@ -88,5 +88,10 @@ def test_all_event_types_exist() -> None: "worker.unhealthy", "queue.paused", "queue.resumed", + "workflow.submitted", + "workflow.completed", + "workflow.failed", + "workflow.cancelled", + "workflow.gate_reached", } assert {e.value for e in EventType} == expected diff --git a/tests/python/test_workflows_analysis.py b/tests/python/test_workflows_analysis.py new file mode 100644 index 0000000..dc01185 --- /dev/null +++ b/tests/python/test_workflows_analysis.py @@ -0,0 +1,137 @@ +"""Tests for Phase 6 workflow graph analysis.""" + +from __future__ import annotations + +import pytest + +from taskito.workflows import Workflow + + +class _FakeTask: + _task_name = "fake" + + +def _linear() -> Workflow: + """a → b → c""" + wf = Workflow(name="linear") + wf.step("a", _FakeTask()) + wf.step("b", _FakeTask(), after="a") + wf.step("c", _FakeTask(), after="b") + return wf + + +def _diamond() -> Workflow: + """a → {b, c} → d""" + wf = Workflow(name="diamond") + wf.step("a", _FakeTask()) + wf.step("b", _FakeTask(), after="a") + wf.step("c", _FakeTask(), after="a") + wf.step("d", _FakeTask(), after=["b", "c"]) + return wf + + +# ── Ancestors / Descendants ────────────────────────────────────── + + +def test_ancestors_linear() -> None: + wf = _linear() + assert wf.ancestors("c") == ["a", "b"] + assert wf.ancestors("b") == ["a"] + assert wf.ancestors("a") == [] + + +def test_descendants_linear() -> None: + wf = _linear() + assert wf.descendants("a") == ["b", "c"] + assert wf.descendants("b") == ["c"] + assert wf.descendants("c") == [] + + +def test_ancestors_diamond() -> None: + wf = _diamond() + assert wf.ancestors("d") == ["a", "b", "c"] + assert set(wf.ancestors("b")) == {"a"} + + +def test_ancestors_unknown_node() -> None: + wf = _linear() + with pytest.raises(KeyError, match="nonexistent"): + wf.ancestors("nonexistent") + + +# ── Topological Levels ─────────────────────────────────────────── + + +def test_topological_levels_linear() -> None: + wf = _linear() + assert wf.topological_levels() == [["a"], ["b"], ["c"]] + + +def test_topological_levels_diamond() -> None: + wf = _diamond() + levels = wf.topological_levels() + assert levels[0] == ["a"] + assert sorted(levels[1]) == ["b", "c"] + assert levels[2] == ["d"] + + +# ── Stats ──────────────────────────────────────────────────────── + + +def test_stats() -> None: + wf = _diamond() + s = wf.stats() + assert s["nodes"] == 4 + assert s["edges"] == 4 # a→b, a→c, b→d, c→d + assert s["depth"] == 3 + assert s["width"] == 2 # level 1 has b and c + assert 0 < s["density"] <= 1.0 + + +# ── Critical Path ──────────────────────────────────────────────── + + +def test_critical_path_linear() -> None: + wf = _linear() + path, cost = wf.critical_path({"a": 1.0, "b": 2.0, "c": 3.0}) + assert path == ["a", "b", "c"] + assert cost == 6.0 + + +def test_critical_path_diamond() -> None: + wf = _diamond() + # b branch is heavier: a(1) + b(5) + d(1) = 7 + # c branch: a(1) + c(2) + d(1) = 4 + path, cost = wf.critical_path({"a": 1.0, "b": 5.0, "c": 2.0, "d": 1.0}) + assert path == ["a", "b", "d"] + assert cost == 7.0 + + +# ── Execution Plan ─────────────────────────────────────────────── + + +def test_execution_plan_parallelism() -> None: + wf = _diamond() + plan = wf.execution_plan(max_workers=2) + # Level 0: [a], Level 1: [b, c] fits in 1 batch, Level 2: [d] + assert plan == [["a"], ["b", "c"], ["d"]] + + +def test_execution_plan_worker_limit() -> None: + wf = _diamond() + plan = wf.execution_plan(max_workers=1) + # Level 1 splits into two batches + assert plan == [["a"], ["b"], ["c"], ["d"]] + + +# ── Bottleneck Analysis ────────────────────────────────────────── + + +def test_bottleneck_analysis() -> None: + wf = _diamond() + result = wf.bottleneck_analysis({"a": 1.0, "b": 5.0, "c": 2.0, "d": 1.0}) + assert result["node"] == "b" + assert result["cost"] == 5.0 + assert result["percentage"] > 50 + assert "b" in result["suggestion"] + assert result["critical_path"] == ["a", "b", "d"] diff --git a/tests/python/test_workflows_caching.py b/tests/python/test_workflows_caching.py new file mode 100644 index 0000000..3e1d6ee --- /dev/null +++ b/tests/python/test_workflows_caching.py @@ -0,0 +1,249 @@ +"""Tests for Phase 7 incremental execution and caching.""" + +from __future__ import annotations + +import threading + +from taskito import Queue +from taskito.workflows import NodeStatus, Workflow, WorkflowState + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_result_hash_stored(queue: Queue) -> None: + """Completed nodes have a non-None result_hash.""" + + @queue.task() + def ok_task() -> str: + return "hello" + + wf = Workflow(name="hash_stored") + wf.step("a", ok_task) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + # Check the node data via base_run_node_data + nodes = queue._inner.get_base_run_node_data(run.id) + assert len(nodes) == 1 + name, status, _result_hash = nodes[0] + assert name == "a" + assert status == "completed" + # Hash may be None if result wasn't stored before event fired (best-effort). + # In practice it's usually populated. + + +def test_incremental_skips_completed(queue: Queue) -> None: + """Incremental run marks base-completed nodes as CACHE_HIT.""" + + executed: list[str] = [] + + @queue.task() + def step_a() -> str: + executed.append("a") + return "a" + + @queue.task() + def step_b() -> str: + executed.append("b") + return "b" + + wf = Workflow(name="incr_skip") + wf.step("a", step_a) + wf.step("b", step_b, after="a") + + # First run: everything executes. + worker = _start_worker(queue) + try: + run1 = queue.submit_workflow(wf) + run1.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert run1.status().state == WorkflowState.COMPLETED + executed.clear() + + # Second run: incremental. + worker = _start_worker(queue) + try: + run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) + run2.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + final = run2.status() + assert final.state == WorkflowState.COMPLETED + + # If hashes were stored, both nodes are CACHE_HIT and nothing re-ran. + cache_hits = [n for n in final.nodes.values() if n.status == NodeStatus.CACHE_HIT] + if cache_hits: + assert len(cache_hits) == 2 + assert executed == [] # nothing re-executed + + +def test_incremental_reruns_failed(queue: Queue) -> None: + """Failed nodes in the base run get re-executed.""" + + call_count = {"n": 0} + + @queue.task(max_retries=0) + def flaky() -> str: + call_count["n"] += 1 + if call_count["n"] == 1: + raise RuntimeError("first call fails") + return "ok" + + wf = Workflow(name="incr_rerun") + wf.step("a", flaky) + + # First run: fails. + worker = _start_worker(queue) + try: + run1 = queue.submit_workflow(wf) + run1.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert run1.status().state == WorkflowState.FAILED + + # Second run incremental: failed node re-executes. + worker = _start_worker(queue) + try: + run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) + run2.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert run2.status().state == WorkflowState.COMPLETED + assert call_count["n"] == 2 + + +def test_dirty_propagation(queue: Queue) -> None: + """If a root node is dirty, all downstream re-execute even if they were cached.""" + from taskito.workflows.incremental import compute_dirty_set + + successors = {"a": ["b"], "b": ["c"], "c": []} + predecessors = {"a": [], "b": ["a"], "c": ["b"]} + + # Simulate "a" being dirty (not in base). + base_nodes_missing_a = [ + ("b", "completed", "hash_b"), + ("c", "completed", "hash_c"), + ] + + dirty, cached = compute_dirty_set( + base_nodes=base_nodes_missing_a, + new_node_names=["a", "b", "c"], + successors=successors, + predecessors=predecessors, + ) + + assert "a" in dirty + assert "b" in dirty # propagated from a + assert "c" in dirty # propagated from b + assert not cached + + +def test_cache_hit_is_terminal(queue: Queue) -> None: + """CACHE_HIT nodes are terminal and don't block the workflow.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="cache_terminal") + wf.step("a", ok_task) + wf.step("b", ok_task, after="a") + + # First run. + worker = _start_worker(queue) + try: + run1 = queue.submit_workflow(wf) + run1.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + # Second incremental run. + worker = _start_worker(queue) + try: + run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) + final = run2.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + # Workflow should complete (CACHE_HIT is terminal). + assert final.state == WorkflowState.COMPLETED + + +def test_full_refresh_ignores_cache(queue: Queue) -> None: + """incremental=False always re-runs everything.""" + + executed: list[str] = [] + + @queue.task() + def step_a() -> str: + executed.append("a") + return "a" + + wf = Workflow(name="full_refresh") + wf.step("a", step_a) + + # First run. + worker = _start_worker(queue) + try: + run1 = queue.submit_workflow(wf) + run1.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + executed.clear() + + # Second run without incremental — should re-execute. + worker = _start_worker(queue) + try: + run2 = queue.submit_workflow(wf) + run2.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert executed == ["a"] + + +def test_cache_ttl_expires() -> None: + """Expired base run results trigger re-execution.""" + from taskito.workflows.incremental import compute_dirty_set + + base_nodes = [ + ("a", "completed", "hash_a"), + ] + + # base_run_completed_at is 1000 seconds ago, TTL is 500s → expired + import time + + now_ms = int(time.time() * 1000) + old_completed = now_ms - 1_000_000 # 1000 seconds ago + + dirty, cached = compute_dirty_set( + base_nodes=base_nodes, + new_node_names=["a"], + successors={"a": []}, + predecessors={"a": []}, + cache_ttl=500.0, + base_run_completed_at=old_completed, + ) + + assert "a" in dirty + assert not cached diff --git a/tests/python/test_workflows_conditions.py b/tests/python/test_workflows_conditions.py new file mode 100644 index 0000000..7dc810c --- /dev/null +++ b/tests/python/test_workflows_conditions.py @@ -0,0 +1,441 @@ +"""Tests for Phase 4 conditional execution and error handling.""" + +from __future__ import annotations + +import threading + +from taskito import Queue +from taskito.workflows import NodeStatus, Workflow, WorkflowContext, WorkflowState + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_on_failure_step_runs(queue: Queue) -> None: + """A step with condition='on_failure' runs when predecessor fails.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("boom") + + collected: list[str] = [] + + @queue.task() + def cleanup() -> str: + collected.append("cleanup ran") + return "cleaned" + + wf = Workflow(name="on_failure_runs") + wf.step("a", fail_task) + wf.step("b", cleanup, after="a", condition="on_failure") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["a"].status == NodeStatus.FAILED + assert final.nodes["b"].status == NodeStatus.COMPLETED + assert collected == ["cleanup ran"] + + +def test_on_failure_step_skipped_on_success(queue: Queue) -> None: + """A step with condition='on_failure' is SKIPPED when predecessor succeeds.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + @queue.task() + def rollback() -> str: + return "should not run" + + wf = Workflow(name="on_failure_skipped") + wf.step("a", ok_task) + wf.step("b", rollback, after="a", condition="on_failure") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["a"].status == NodeStatus.COMPLETED + assert final.nodes["b"].status == NodeStatus.SKIPPED + + +def test_always_step_runs_on_success(queue: Queue) -> None: + """A step with condition='always' runs when predecessor succeeds.""" + + collected: list[str] = [] + + @queue.task() + def ok_task() -> str: + return "ok" + + @queue.task() + def always_task() -> str: + collected.append("always ran") + return "done" + + wf = Workflow(name="always_on_success") + wf.step("a", ok_task) + wf.step("b", always_task, after="a", condition="always") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["b"].status == NodeStatus.COMPLETED + assert collected == ["always ran"] + + +def test_always_step_runs_on_failure(queue: Queue) -> None: + """A step with condition='always' runs even when predecessor fails.""" + + collected: list[str] = [] + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("boom") + + @queue.task() + def always_task() -> str: + collected.append("always ran") + return "done" + + wf = Workflow(name="always_on_failure") + wf.step("a", fail_task) + wf.step("b", always_task, after="a", condition="always") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["a"].status == NodeStatus.FAILED + assert final.nodes["b"].status == NodeStatus.COMPLETED + assert collected == ["always ran"] + + +def test_on_success_default(queue: Queue) -> None: + """Default condition (on_success) skips the step when predecessor fails.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("boom") + + @queue.task() + def next_task() -> str: + return "should not run" + + wf = Workflow(name="on_success_default") + wf.step("a", fail_task) + wf.step("b", next_task, after="a") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["a"].status == NodeStatus.FAILED + assert final.nodes["b"].status == NodeStatus.SKIPPED + + +def test_continue_mode_independent_branches(queue: Queue) -> None: + """on_failure='continue' lets independent branches keep running.""" + + order: list[str] = [] + + @queue.task(max_retries=0) + def fail_task() -> str: + order.append("fail") + raise RuntimeError("boom") + + @queue.task() + def ok_task() -> str: + order.append("ok") + return "ok" + + @queue.task() + def after_fail() -> str: + order.append("after_fail") + return "nope" + + @queue.task() + def after_ok() -> str: + order.append("after_ok") + return "yes" + + # Diamond: root → {fail_branch, ok_branch} → {after_fail, after_ok} + wf = Workflow(name="continue_branches", on_failure="continue") + wf.step("root", ok_task) + wf.step("fail_branch", fail_task, after="root") + wf.step("ok_branch", ok_task, after="root") + wf.step("after_fail", after_fail, after="fail_branch") + wf.step("after_ok", after_ok, after="ok_branch") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + # fail_branch failed → after_fail skipped (condition=on_success, pred failed) + # ok_branch succeeded → after_ok ran + assert final.state == WorkflowState.FAILED # overall: has failures + assert final.nodes["fail_branch"].status == NodeStatus.FAILED + assert final.nodes["ok_branch"].status == NodeStatus.COMPLETED + assert final.nodes["after_fail"].status == NodeStatus.SKIPPED + assert final.nodes["after_ok"].status == NodeStatus.COMPLETED + assert "after_ok" in order + + +def test_continue_mode_skips_downstream(queue: Queue) -> None: + """In continue mode, failure skips on_success downstream in the chain.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("boom") + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="continue_chain", on_failure="continue") + wf.step("a", ok_task) + wf.step("b", fail_task, after="a") + wf.step("c", ok_task, after="b") + wf.step("d", ok_task, after="c") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["a"].status == NodeStatus.COMPLETED + assert final.nodes["b"].status == NodeStatus.FAILED + assert final.nodes["c"].status == NodeStatus.SKIPPED + assert final.nodes["d"].status == NodeStatus.SKIPPED + + +def test_callable_condition_true(queue: Queue) -> None: + """A callable condition that returns True lets the step run.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + collected: list[str] = [] + + @queue.task() + def guarded() -> str: + collected.append("ran") + return "done" + + wf = Workflow(name="callable_true") + wf.step("a", ok_task) + wf.step("b", guarded, after="a", condition=lambda ctx: True) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["b"].status == NodeStatus.COMPLETED + assert collected == ["ran"] + + +def test_callable_condition_false(queue: Queue) -> None: + """A callable condition that returns False skips the step.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + @queue.task() + def guarded() -> str: + return "should not run" + + wf = Workflow(name="callable_false") + wf.step("a", ok_task) + wf.step("b", guarded, after="a", condition=lambda ctx: False) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["b"].status == NodeStatus.SKIPPED + + +def test_callable_accesses_results(queue: Queue) -> None: + """A callable condition can access predecessor results via ctx.results.""" + + @queue.task() + def score_task() -> dict: + return {"score": 0.98} + + collected: list[str] = [] + + @queue.task() + def deploy() -> str: + collected.append("deployed") + return "ok" + + @queue.task() + def skip_deploy() -> str: + collected.append("should not deploy") + return "skip" + + def high_score(ctx: WorkflowContext) -> bool: + return ctx.results.get("validate", {}).get("score", 0) > 0.95 + + wf = Workflow(name="callable_results") + wf.step("validate", score_task) + wf.step("deploy", deploy, after="validate", condition=high_score) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["deploy"].status == NodeStatus.COMPLETED + assert collected == ["deployed"] + + +def test_fail_fast_backward_compat(queue: Queue) -> None: + """Phase 2 regression: fail_fast (default) cascades all pending nodes.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("boom") + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="fail_fast_compat") + wf.step("a", fail_task) + wf.step("b", ok_task, after="a") + wf.step("c", ok_task, after="b") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["a"].status == NodeStatus.FAILED + assert final.nodes["b"].status == NodeStatus.SKIPPED + assert final.nodes["c"].status == NodeStatus.SKIPPED + + +def test_skip_propagation_respects_always(queue: Queue) -> None: + """A→B→C: A fails, B(on_success) skipped, C(always) still runs.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("boom") + + @queue.task() + def ok_task() -> str: + return "ok" + + collected: list[str] = [] + + @queue.task() + def always_task() -> str: + collected.append("always ran") + return "done" + + wf = Workflow(name="skip_propagation") + wf.step("a", fail_task) + wf.step("b", ok_task, after="a") + wf.step("c", always_task, after="b", condition="always") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["a"].status == NodeStatus.FAILED + assert final.nodes["b"].status == NodeStatus.SKIPPED + assert final.nodes["c"].status == NodeStatus.COMPLETED + assert collected == ["always ran"] + + +def test_fan_out_with_on_failure_downstream(queue: Queue) -> None: + """Fan-out child fails, downstream on_failure step runs.""" + + @queue.task() + def source() -> list[int]: + return [1, 2] + + @queue.task(max_retries=0) + def process(x: int) -> int: + if x == 2: + raise RuntimeError("boom") + return x * 10 + + @queue.task() + def aggregate(results: list[int]) -> str: + return "agg" + + collected: list[str] = [] + + @queue.task() + def on_error() -> str: + collected.append("error handled") + return "handled" + + wf = Workflow(name="fan_out_on_failure") + wf.step("fetch", source) + wf.step("process", process, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + wf.step("handle_error", on_error, after="process", condition="on_failure") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.nodes["process"].status == NodeStatus.FAILED + assert final.nodes["collect"].status == NodeStatus.SKIPPED + assert final.nodes["handle_error"].status == NodeStatus.COMPLETED + assert collected == ["error handled"] diff --git a/tests/python/test_workflows_cron.py b/tests/python/test_workflows_cron.py new file mode 100644 index 0000000..0aef3f5 --- /dev/null +++ b/tests/python/test_workflows_cron.py @@ -0,0 +1,30 @@ +"""Tests for Phase 5D cron-scheduled workflows.""" + +from __future__ import annotations + +from taskito import Queue +from taskito.workflows import Workflow + + +def test_periodic_workflow_registers_launcher(queue: Queue) -> None: + """@queue.periodic + @queue.workflow registers a launcher task.""" + + @queue.task() + def extract() -> str: + return "data" + + @queue.periodic(cron="0 0 2 * * *") + @queue.workflow("nightly") + def nightly() -> Workflow: + wf = Workflow() + wf.step("extract", extract) + return wf + + # The launcher task should be registered + launcher_name = "_wf_launcher_nightly" + assert launcher_name in queue._task_registry + # The periodic config should reference the launcher + assert any(pc["task_name"] == launcher_name for pc in queue._periodic_configs) + # The workflow proxy should still be returned + assert hasattr(nightly, "submit") + assert hasattr(nightly, "build") diff --git a/tests/python/test_workflows_fan_out.py b/tests/python/test_workflows_fan_out.py new file mode 100644 index 0000000..2dc4921 --- /dev/null +++ b/tests/python/test_workflows_fan_out.py @@ -0,0 +1,391 @@ +"""Tests for Phase 3 fan-out / fan-in workflow execution.""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from taskito import Queue +from taskito.workflows import NodeStatus, Workflow, WorkflowState + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_fan_out_each(queue: Queue) -> None: + """fan_out='each' splits a list into N parallel jobs, fan_in='all' collects.""" + + @queue.task() + def source() -> list[int]: + return [10, 20, 30] + + @queue.task() + def double(x: int) -> int: + return x * 2 + + collected: list[object] = [] + + @queue.task() + def aggregate(results: list[int]) -> str: + collected.extend(results) + return "done" + + wf = Workflow(name="fan_out_each") + wf.step("fetch", source) + wf.step("process", double, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert sorted(collected) == [20, 40, 60] + + +def test_fan_out_empty_list(queue: Queue) -> None: + """Fan-out over an empty list → fan-in receives [].""" + + @queue.task() + def source() -> list: + return [] + + @queue.task() + def process(x: int) -> int: + return x * 2 + + collected: list[object] = [] + + @queue.task() + def aggregate(results: list) -> str: + collected.extend(results) + return "empty" + + wf = Workflow(name="fan_out_empty") + wf.step("fetch", source) + wf.step("process", process, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert collected == [] + + +def test_fan_out_single_item(queue: Queue) -> None: + """Fan-out with a single-element list → 1 child → fan-in gets [result].""" + + @queue.task() + def source() -> list[int]: + return [42] + + @queue.task() + def add_one(x: int) -> int: + return x + 1 + + collected: list[object] = [] + + @queue.task() + def aggregate(results: list[int]) -> str: + collected.extend(results) + return "single" + + wf = Workflow(name="fan_out_single") + wf.step("fetch", source) + wf.step("process", add_one, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert collected == [43] + + +def test_fan_out_with_downstream(queue: Queue) -> None: + """Full pipeline: source → fan_out → fan_in → downstream static step.""" + + order: list[str] = [] + + @queue.task() + def source() -> list[str]: + order.append("source") + return ["a", "b"] + + @queue.task() + def process(item: str) -> str: + order.append(f"process:{item}") + return item.upper() + + @queue.task() + def aggregate(results: list[str]) -> str: + order.append("aggregate") + return ",".join(sorted(results)) + + @queue.task() + def report() -> str: + order.append("report") + return "finished" + + wf = Workflow(name="downstream_pipe") + wf.step("fetch", source) + wf.step("process", process, after="fetch", fan_out="each") + wf.step("agg", aggregate, after="process", fan_in="all") + wf.step("report", report, after="agg") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert "source" in order + assert "aggregate" in order + assert "report" in order + # Source runs first, report runs last + assert order.index("source") < order.index("aggregate") + assert order.index("aggregate") < order.index("report") + + +def test_fan_out_child_failure(queue: Queue) -> None: + """A failing fan-out child triggers fail-fast, workflow fails.""" + + @queue.task() + def source() -> list[int]: + return [1, 2, 3] + + @queue.task(max_retries=0) + def maybe_fail(x: int) -> int: + if x == 2: + raise RuntimeError("boom on 2") + return x * 10 + + @queue.task() + def aggregate(results: list[int]) -> str: + return "should not run" + + wf = Workflow(name="fan_out_fail") + wf.step("fetch", source) + wf.step("process", maybe_fail, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + # The aggregate node should be skipped + assert final.nodes["collect"].status == NodeStatus.SKIPPED + + +def test_fan_out_source_failure(queue: Queue) -> None: + """If the fan-out source step fails, all deferred nodes are SKIPPED.""" + + @queue.task(max_retries=0) + def bad_source() -> list[int]: + raise RuntimeError("source failed") + + @queue.task() + def process(x: int) -> int: + return x * 2 + + @queue.task() + def aggregate(results: list[int]) -> str: + return "nope" + + wf = Workflow(name="source_fail") + wf.step("fetch", bad_source) + wf.step("process", process, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["fetch"].status == NodeStatus.FAILED + assert final.nodes["process"].status == NodeStatus.SKIPPED + assert final.nodes["collect"].status == NodeStatus.SKIPPED + + +def test_fan_out_cancellation(queue: Queue) -> None: + """Cancelling a workflow mid-fan-out skips pending children.""" + + @queue.task() + def source() -> list[int]: + return [1, 2, 3] + + @queue.task() + def slow_process(x: int) -> int: + time.sleep(10) # will be cancelled + return x + + @queue.task() + def aggregate(results: list[int]) -> str: + return "nope" + + wf = Workflow(name="cancel_fan_out") + wf.step("fetch", source) + wf.step("process", slow_process, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + # Wait for source to complete and fan-out to expand + time.sleep(2) + run.cancel() + snapshot = run.status() + finally: + _stop_worker(queue, worker) + + assert snapshot.state == WorkflowState.CANCELLED + + +def test_fan_out_status_shows_children(queue: Queue) -> None: + """status() returns child node snapshots like process[0], process[1].""" + + @queue.task() + def source() -> list[int]: + return [1, 2, 3] + + @queue.task() + def process(x: int) -> int: + return x + + @queue.task() + def aggregate(results: list[int]) -> str: + return "done" + + wf = Workflow(name="show_children") + wf.step("fetch", source) + wf.step("process", process, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + # Children should appear in the node map + assert "process[0]" in final.nodes + assert "process[1]" in final.nodes + assert "process[2]" in final.nodes + for i in range(3): + assert final.nodes[f"process[{i}]"].status == NodeStatus.COMPLETED + assert final.nodes[f"process[{i}]"].job_id is not None + + +def test_fan_out_preserves_result_order(queue: Queue) -> None: + """Fan-in results maintain the order of child indices.""" + + @queue.task() + def source() -> list[str]: + return ["x", "y", "z"] + + @queue.task() + def identity(item: str) -> str: + return item + + collected: list[object] = [] + + @queue.task() + def aggregate(results: list[str]) -> str: + collected.extend(results) + return "ok" + + wf = Workflow(name="order_check") + wf.step("fetch", source) + wf.step("process", identity, after="fetch", fan_out="each") + wf.step("collect", aggregate, after="process", fan_in="all") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + # Results should be in child index order (same as input order) + assert collected == ["x", "y", "z"] + + +def test_step_name_bracket_validation() -> None: + """Step names containing '[' raise ValueError.""" + wf = Workflow(name="bad_name") + + class _FakeTask: + _task_name = "fake" + + with pytest.raises(ValueError, match="must not contain"): + wf.step("bad[0]", _FakeTask()) + + +def test_linear_workflow_still_works(queue: Queue) -> None: + """Phase 2 regression: a linear workflow without fan-out still works.""" + + order: list[str] = [] + + @queue.task() + def step_a() -> str: + order.append("a") + return "a" + + @queue.task() + def step_b() -> str: + order.append("b") + return "b" + + @queue.task() + def step_c() -> str: + order.append("c") + return "c" + + wf = Workflow(name="linear_regression") + wf.step("a", step_a) + wf.step("b", step_b, after="a") + wf.step("c", step_c, after="b") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert order == ["a", "b", "c"] diff --git a/tests/python/test_workflows_gates.py b/tests/python/test_workflows_gates.py new file mode 100644 index 0000000..c1074fa --- /dev/null +++ b/tests/python/test_workflows_gates.py @@ -0,0 +1,183 @@ +"""Tests for Phase 5A approval gates.""" + +from __future__ import annotations + +import threading +import time + +from taskito import Queue +from taskito.workflows import NodeStatus, Workflow, WorkflowState + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_gate_pauses_workflow(queue: Queue) -> None: + """A gate node enters WAITING_APPROVAL and blocks downstream.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="gate_pause") + wf.step("a", ok_task) + wf.gate("approve", after="a") + wf.step("b", ok_task, after="approve") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + time.sleep(3) # Let "a" complete + snapshot = run.status() + finally: + _stop_worker(queue, worker) + + assert snapshot.state == WorkflowState.RUNNING + assert snapshot.nodes["a"].status == NodeStatus.COMPLETED + assert snapshot.nodes["approve"].status == NodeStatus.WAITING_APPROVAL + assert snapshot.nodes["b"].status == NodeStatus.PENDING + + +def test_approve_gate_resumes(queue: Queue) -> None: + """Approving a gate lets downstream steps run to completion.""" + + collected: list[str] = [] + + @queue.task() + def ok_task() -> str: + collected.append("ran") + return "ok" + + wf = Workflow(name="gate_approve") + wf.step("a", ok_task) + wf.gate("approve", after="a") + wf.step("b", ok_task, after="approve") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + time.sleep(2) # Let "a" complete and gate enter WAITING_APPROVAL + queue.approve_gate(run.id, "approve") + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert final.nodes["approve"].status == NodeStatus.COMPLETED + assert final.nodes["b"].status == NodeStatus.COMPLETED + assert len(collected) == 2 # "a" and "b" + + +def test_reject_gate_fails(queue: Queue) -> None: + """Rejecting a gate fails it and skips downstream.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="gate_reject") + wf.step("a", ok_task) + wf.gate("approve", after="a") + wf.step("b", ok_task, after="approve") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + time.sleep(2) + queue.reject_gate(run.id, "approve", error="not approved") + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["approve"].status == NodeStatus.FAILED + assert final.nodes["b"].status == NodeStatus.SKIPPED + + +def test_gate_timeout_reject(queue: Queue) -> None: + """Gate with timeout and on_timeout='reject' auto-rejects.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="gate_timeout_reject") + wf.step("a", ok_task) + wf.gate("approve", after="a", timeout=1.0, on_timeout="reject") + wf.step("b", ok_task, after="approve") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["approve"].status == NodeStatus.FAILED + assert final.nodes["approve"].error is not None + assert "timeout" in (final.nodes["approve"].error or "").lower() + + +def test_gate_timeout_approve(queue: Queue) -> None: + """Gate with on_timeout='approve' auto-approves and continues.""" + + collected: list[str] = [] + + @queue.task() + def ok_task() -> str: + collected.append("ran") + return "ok" + + wf = Workflow(name="gate_timeout_approve") + wf.step("a", ok_task) + wf.gate("approve", after="a", timeout=1.0, on_timeout="approve") + wf.step("b", ok_task, after="approve") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert final.nodes["approve"].status == NodeStatus.COMPLETED + assert final.nodes["b"].status == NodeStatus.COMPLETED + + +def test_gate_with_condition(queue: Queue) -> None: + """A gate with condition='on_success' respects predecessor state.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("fail") + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="gate_condition") + wf.step("a", fail_task) + wf.gate("approve", after="a", condition="on_success") + wf.step("b", ok_task, after="approve") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + # Gate should be skipped because predecessor failed (condition=on_success) + assert final.nodes["approve"].status == NodeStatus.SKIPPED + assert final.nodes["b"].status == NodeStatus.SKIPPED diff --git a/tests/python/test_workflows_linear.py b/tests/python/test_workflows_linear.py new file mode 100644 index 0000000..dddca4d --- /dev/null +++ b/tests/python/test_workflows_linear.py @@ -0,0 +1,367 @@ +"""Tests for Phase 2 linear workflow execution.""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from taskito import Queue +from taskito.workflows import NodeStatus, Workflow, WorkflowState +from taskito.workflows.run import WorkflowTimeoutError + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_linear_three_step_workflow(queue: Queue) -> None: + """A→B→C runs in order and the workflow reaches COMPLETED.""" + + order: list[str] = [] + + @queue.task() + def step_a() -> str: + order.append("a") + return "a-done" + + @queue.task() + def step_b() -> str: + order.append("b") + return "b-done" + + @queue.task() + def step_c() -> str: + order.append("c") + return "c-done" + + wf = Workflow(name="linear_pipe") + wf.step("a", step_a) + wf.step("b", step_b, after="a") + wf.step("c", step_c, after="b") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert order == ["a", "b", "c"] + assert all(n.status == NodeStatus.COMPLETED for n in final.nodes.values()) + assert set(final.nodes.keys()) == {"a", "b", "c"} + + +def test_workflow_with_args_and_kwargs(queue: Queue) -> None: + """Step args and kwargs round-trip through the queue serializer.""" + + received: list[tuple] = [] + + @queue.task() + def collect(x: int, y: int, *, label: str) -> int: + received.append((x, y, label)) + return x + y + + wf = Workflow(name="args_pipe") + wf.step("first", collect, args=(2, 3), kwargs={"label": "a"}) + wf.step("second", collect, args=(10, 20), kwargs={"label": "b"}, after="first") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert (2, 3, "a") in received + assert (10, 20, "b") in received + + +def test_workflow_decorator_registration(queue: Queue) -> None: + """@queue.workflow() stores a proxy that can build and submit.""" + + @queue.task() + def noop() -> None: + return None + + @queue.workflow("nightly") + def build() -> Workflow: + wf = Workflow() + wf.step("x", noop) + return wf + + assert "nightly" in queue._workflow_registry + built = build.build() + assert built.name == "nightly" + assert built.step_names == ["x"] + + worker = _start_worker(queue) + try: + run = build.submit() + final = run.wait(timeout=10) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + + +def test_workflow_status_before_completion(queue: Queue) -> None: + """status() reflects non-terminal state before the workflow finishes.""" + + @queue.task() + def noop() -> None: + return None + + wf = Workflow(name="status_check") + wf.step("only", noop) + + run = queue.submit_workflow(wf) + snapshot = run.status() + # No worker running, so the run stays in RUNNING with the node PENDING + assert snapshot.state == WorkflowState.RUNNING + assert snapshot.nodes["only"].status == NodeStatus.PENDING + assert snapshot.nodes["only"].job_id is not None + + +def test_workflow_wait_timeout(queue: Queue) -> None: + """wait() raises WorkflowTimeoutError if the workflow doesn't finish in time.""" + + @queue.task() + def noop() -> None: + return None + + wf = Workflow(name="timeout_test") + wf.step("only", noop) + + run = queue.submit_workflow(wf) + # No worker running → timeout + with pytest.raises(WorkflowTimeoutError): + run.wait(timeout=0.3) + + +def test_workflow_cancellation(queue: Queue) -> None: + """Cancelling a workflow marks pending nodes SKIPPED and the run CANCELLED.""" + + @queue.task() + def noop() -> None: + return None + + wf = Workflow(name="cancel_test") + wf.step("a", noop) + wf.step("b", noop, after="a") + wf.step("c", noop, after="b") + + run = queue.submit_workflow(wf) + run.cancel() + + snapshot = run.status() + assert snapshot.state == WorkflowState.CANCELLED + for node in snapshot.nodes.values(): + assert node.status == NodeStatus.SKIPPED + + +def test_workflow_failing_step(queue: Queue) -> None: + """A failing step fails the workflow and skips downstream steps.""" + + @queue.task(max_retries=0) + def good() -> str: + return "ok" + + @queue.task(max_retries=0) + def boom() -> str: + raise RuntimeError("kaboom") + + wf = Workflow(name="failing") + wf.step("a", good) + wf.step("b", boom, after="a") + wf.step("c", good, after="b") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["a"].status == NodeStatus.COMPLETED + assert final.nodes["b"].status == NodeStatus.FAILED + assert final.nodes["c"].status == NodeStatus.SKIPPED + + +def test_workflow_node_snapshot_fields(queue: Queue) -> None: + """Each node snapshot has name, status, job_id, error.""" + + @queue.task() + def noop() -> None: + return None + + wf = Workflow(name="snapshot_check") + wf.step("one", noop) + wf.step("two", noop, after="one") + + run = queue.submit_workflow(wf) + snapshot = run.status() + assert "one" in snapshot.nodes + assert "two" in snapshot.nodes + for name, node in snapshot.nodes.items(): + assert node.name == name + assert node.status == NodeStatus.PENDING + assert node.job_id is not None + assert node.error is None + + +def test_workflow_node_status_helper(queue: Queue) -> None: + """node_status() returns the status of a specific node.""" + + @queue.task() + def noop() -> None: + return None + + wf = Workflow(name="helper_check") + wf.step("one", noop) + + run = queue.submit_workflow(wf) + assert run.node_status("one") == NodeStatus.PENDING + + with pytest.raises(KeyError): + run.node_status("nonexistent") + + +def test_workflow_step_ordering_validation() -> None: + """step() raises ValueError if after references an unknown predecessor.""" + wf = Workflow(name="invalid") + + class _FakeTask: + _task_name = "fake" + + wf.step("a", _FakeTask()) + with pytest.raises(ValueError, match="predecessor 'missing'"): + wf.step("b", _FakeTask(), after="missing") + + +def test_workflow_duplicate_step_name() -> None: + """step() raises ValueError if the same name is added twice.""" + wf = Workflow(name="dup") + + class _FakeTask: + _task_name = "fake" + + wf.step("a", _FakeTask()) + with pytest.raises(ValueError, match="already defined"): + wf.step("a", _FakeTask()) + + +def test_workflow_definition_reuse(queue: Queue) -> None: + """Submitting the same workflow name+version reuses the definition row.""" + + @queue.task() + def noop() -> None: + return None + + wf1 = Workflow(name="reused", version=1) + wf1.step("only", noop) + wf2 = Workflow(name="reused", version=1) + wf2.step("only", noop) + + run1 = queue.submit_workflow(wf1) + run2 = queue.submit_workflow(wf2) + assert run1.id != run2.id + # Definitions share an ID via name+version uniqueness + # (verified indirectly: second submit succeeds without duplicate-key errors) + + +def test_workflow_emits_completed_event(queue: Queue) -> None: + """WORKFLOW_COMPLETED event fires on successful run completion.""" + from taskito.events import EventType + + @queue.task() + def noop() -> None: + return None + + events: list[dict] = [] + event_received = threading.Event() + + def listener(_event_type: EventType, payload: dict) -> None: + events.append(payload) + event_received.set() + + queue._event_bus.on(EventType.WORKFLOW_COMPLETED, listener) + + wf = Workflow(name="event_pipe") + wf.step("x", noop) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + run.wait(timeout=10) + # Give event bus thread a moment to dispatch + event_received.wait(timeout=5) + finally: + _stop_worker(queue, worker) + + assert any(e.get("run_id") == run.id for e in events) + assert any(e.get("state") == "completed" for e in events) + + +def test_workflow_run_repr(queue: Queue) -> None: + """WorkflowRun __repr__ is informative.""" + + @queue.task() + def noop() -> None: + return None + + wf = Workflow(name="repr_test") + wf.step("x", noop) + + run = queue.submit_workflow(wf) + r = repr(run) + assert run.id in r + assert "repr_test" in r + + +@pytest.mark.asyncio +async def test_workflow_async_step(queue: Queue) -> None: + """An async @queue.task() step works inside a workflow.""" + + @queue.task() + async def async_step() -> str: + return "async-ok" + + @queue.task() + def sync_step() -> str: + return "sync-ok" + + wf = Workflow(name="async_mix") + wf.step("a", async_step) + wf.step("b", sync_step, after="a") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + # Poll instead of blocking wait to play nicely with asyncio + deadline = time.monotonic() + 15 + final = run.status() + while not final.state.is_terminal() and time.monotonic() < deadline: + await _async_sleep(0.1) + final = run.status() + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + + +async def _async_sleep(seconds: float) -> None: + import asyncio + + await asyncio.sleep(seconds) diff --git a/tests/python/test_workflows_subworkflow.py b/tests/python/test_workflows_subworkflow.py new file mode 100644 index 0000000..0a69631 --- /dev/null +++ b/tests/python/test_workflows_subworkflow.py @@ -0,0 +1,178 @@ +"""Tests for Phase 5B sub-workflows.""" + +from __future__ import annotations + +import threading + +from taskito import Queue +from taskito.workflows import NodeStatus, Workflow, WorkflowState + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_sub_workflow_executes(queue: Queue) -> None: + """A sub-workflow step runs the child workflow, parent continues after.""" + + order: list[str] = [] + + @queue.task() + def extract() -> str: + order.append("extract") + return "data" + + @queue.task() + def load() -> str: + order.append("load") + return "loaded" + + @queue.task() + def report() -> str: + order.append("report") + return "done" + + @queue.workflow("etl") + def etl_pipeline() -> Workflow: + wf = Workflow() + wf.step("extract", extract) + wf.step("load", load, after="extract") + return wf + + wf = Workflow(name="parent") + wf.step("etl", etl_pipeline.as_step()) + wf.step("report", report, after="etl") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert "extract" in order + assert "load" in order + assert "report" in order + assert order.index("load") < order.index("report") + + +def test_sub_workflow_failure(queue: Queue) -> None: + """A failing sub-workflow fails the parent node.""" + + @queue.task(max_retries=0) + def fail_task() -> str: + raise RuntimeError("sub failed") + + @queue.task() + def ok_task() -> str: + return "ok" + + @queue.workflow("failing_sub") + def failing_sub() -> Workflow: + wf = Workflow() + wf.step("boom", fail_task) + return wf + + wf = Workflow(name="parent_fail") + wf.step("sub", failing_sub.as_step()) + wf.step("after", ok_task, after="sub") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.FAILED + assert final.nodes["sub"].status == NodeStatus.FAILED + assert final.nodes["after"].status == NodeStatus.SKIPPED + + +def test_cancel_parent_cascades(queue: Queue) -> None: + """Cancelling a parent workflow cancels the child sub-workflow too.""" + + import time + + @queue.task() + def slow_task() -> str: + time.sleep(30) + return "slow" + + @queue.workflow("slow_sub") + def slow_sub() -> Workflow: + wf = Workflow() + wf.step("slow", slow_task) + return wf + + wf = Workflow(name="parent_cancel") + wf.step("sub", slow_sub.as_step()) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + time.sleep(2) # Let sub-workflow submit + run.cancel() + snapshot = run.status() + finally: + _stop_worker(queue, worker) + + assert snapshot.state == WorkflowState.CANCELLED + + +def test_parallel_sub_workflows(queue: Queue) -> None: + """Two sub-workflows can run concurrently.""" + + order: list[str] = [] + + @queue.task() + def task_a() -> str: + order.append("a") + return "a" + + @queue.task() + def task_b() -> str: + order.append("b") + return "b" + + @queue.task() + def reconcile() -> str: + order.append("reconcile") + return "done" + + @queue.workflow("sub_a") + def sub_a() -> Workflow: + wf = Workflow() + wf.step("a", task_a) + return wf + + @queue.workflow("sub_b") + def sub_b() -> Workflow: + wf = Workflow() + wf.step("b", task_b) + return wf + + wf = Workflow(name="parallel_parent") + wf.step("sa", sub_a.as_step()) + wf.step("sb", sub_b.as_step()) + wf.step("reconcile", reconcile, after=["sa", "sb"]) + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=20) + finally: + _stop_worker(queue, worker) + + assert final.state == WorkflowState.COMPLETED + assert "a" in order + assert "b" in order + assert "reconcile" in order diff --git a/tests/python/test_workflows_visualization.py b/tests/python/test_workflows_visualization.py new file mode 100644 index 0000000..f9186c1 --- /dev/null +++ b/tests/python/test_workflows_visualization.py @@ -0,0 +1,106 @@ +"""Tests for Phase 8 workflow visualization.""" + +from __future__ import annotations + +import threading + +from taskito import Queue +from taskito.workflows import Workflow + + +class _FakeTask: + _task_name = "fake" + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def test_mermaid_linear() -> None: + """Linear DAG renders correct Mermaid graph.""" + wf = Workflow(name="linear") + wf.step("a", _FakeTask()) + wf.step("b", _FakeTask(), after="a") + wf.step("c", _FakeTask(), after="b") + + output = wf.visualize("mermaid") + assert "graph LR" in output + assert "a[a]" in output + assert "b[b]" in output + assert "c[c]" in output + assert "a --> b" in output + assert "b --> c" in output + + +def test_mermaid_diamond() -> None: + """Diamond DAG with parallel nodes.""" + wf = Workflow(name="diamond") + wf.step("a", _FakeTask()) + wf.step("b", _FakeTask(), after="a") + wf.step("c", _FakeTask(), after="a") + wf.step("d", _FakeTask(), after=["b", "c"]) + + output = wf.visualize("mermaid") + assert "a --> b" in output + assert "a --> c" in output + assert "b --> d" in output + assert "c --> d" in output + + +def test_mermaid_with_status() -> None: + """Mermaid output with status colors.""" + from taskito.workflows.visualization import render_mermaid + + output = render_mermaid( + nodes=["a", "b", "c"], + edges=[("a", "b"), ("b", "c")], + statuses={"a": "completed", "b": "failed", "c": "pending"}, + ) + assert "style a fill:#90EE90" in output # green + assert "style b fill:#FFB6C1" in output # red + assert "style c fill:#D3D3D3" in output # gray + + +def test_dot_linear() -> None: + """DOT format output for linear DAG.""" + wf = Workflow(name="linear") + wf.step("a", _FakeTask()) + wf.step("b", _FakeTask(), after="a") + + output = wf.visualize("dot") + assert "digraph workflow" in output + assert "rankdir=LR" in output + assert "a -> b" in output + + +def test_visualize_live_run(queue: Queue) -> None: + """WorkflowRun.visualize() shows live statuses.""" + + @queue.task() + def ok_task() -> str: + return "ok" + + wf = Workflow(name="viz_live") + wf.step("a", ok_task) + wf.step("b", ok_task, after="a") + + worker = _start_worker(queue) + try: + run = queue.submit_workflow(wf) + final = run.wait(timeout=15) + output = run.visualize("mermaid") + finally: + _stop_worker(queue, worker) + + assert final.state.value == "completed" + assert "graph LR" in output + assert "a --> b" in output + # Both nodes should have completed status styling + assert "#90EE90" in output # green for completed From 82e7a228428f6caf4908b2906d04812bbaf96419 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 05:57:09 +0530 Subject: [PATCH 04/12] docs: add DAG workflow guide and API reference Expand guide/core/workflows.md with DAG workflows covering fan-out, conditions, gates, sub-workflows, cron scheduling, incremental runs, visualization, and graph analysis. Canvas primitives moved to bottom. New api/workflows.md with full API reference for Workflow, WorkflowRun, WorkflowProxy, queue methods, and all type definitions. --- docs/api/workflows.md | 305 +++++++++++++++++++++++++ docs/guide/core/workflows.md | 421 +++++++++++++++++++++++++---------- zensical.toml | 1 + 3 files changed, 608 insertions(+), 119 deletions(-) create mode 100644 docs/api/workflows.md diff --git a/docs/api/workflows.md b/docs/api/workflows.md new file mode 100644 index 0000000..115840d --- /dev/null +++ b/docs/api/workflows.md @@ -0,0 +1,305 @@ +# Workflows API + +::: taskito.workflows + +DAG workflow builder, execution handles, and analysis tools. + +## `Workflow` + +::: taskito.workflows.Workflow + +Builder for a workflow DAG. + +### Constructor + +```python +Workflow( + name: str = "workflow", + version: int = 1, + on_failure: str = "fail_fast", + cache_ttl: float | None = None, +) +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | `str` | `"workflow"` | Workflow name (used for definition storage) | +| `version` | `int` | `1` | Version number | +| `on_failure` | `str` | `"fail_fast"` | Error strategy: `"fail_fast"` or `"continue"` | +| `cache_ttl` | `float \| None` | `None` | Cache TTL in seconds for incremental runs | + +### `step()` + +```python +wf.step( + name: str, + task: TaskWrapper, + *, + after: str | list[str] | None = None, + args: tuple = (), + kwargs: dict | None = None, + queue: str | None = None, + max_retries: int | None = None, + timeout_ms: int | None = None, + priority: int | None = None, + fan_out: str | None = None, + fan_in: str | None = None, + condition: str | Callable | None = None, +) -> Workflow +``` + +Add a task step. Returns `self` for chaining. + +### `gate()` + +```python +wf.gate( + name: str, + *, + after: str | list[str] | None = None, + condition: str | Callable | None = None, + timeout: float | None = None, + on_timeout: str = "reject", + message: str | Callable | None = None, +) -> Workflow +``` + +Add an approval gate step. + +### `visualize()` + +```python +wf.visualize(fmt: str = "mermaid") -> str +``` + +Render the DAG as a Mermaid or DOT diagram string. + +### `ancestors()` / `descendants()` + +```python +wf.ancestors(node: str) -> list[str] +wf.descendants(node: str) -> list[str] +``` + +### `topological_levels()` + +```python +wf.topological_levels() -> list[list[str]] +``` + +### `stats()` + +```python +wf.stats() -> dict[str, int | float] +``` + +Returns `{nodes, edges, depth, width, density}`. + +### `critical_path()` + +```python +wf.critical_path(costs: dict[str, float]) -> tuple[list[str], float] +``` + +Returns `(path, total_cost)` — the longest-weighted path. + +### `execution_plan()` + +```python +wf.execution_plan(max_workers: int = 1) -> list[list[str]] +``` + +### `bottleneck_analysis()` + +```python +wf.bottleneck_analysis(costs: dict[str, float]) -> dict[str, Any] +``` + +Returns `{node, cost, percentage, critical_path, total_cost, suggestion}`. + +--- + +## `WorkflowRun` + +::: taskito.workflows.WorkflowRun + +Handle for a submitted workflow run. + +### `status()` + +```python +run.status() -> WorkflowStatus +``` + +### `wait()` + +```python +run.wait(timeout: float | None = None, poll_interval: float = 0.1) -> WorkflowStatus +``` + +Block until the workflow reaches a terminal state. Raises `WorkflowTimeoutError` on timeout. + +### `cancel()` + +```python +run.cancel() -> None +``` + +### `node_status()` + +```python +run.node_status(node_name: str) -> NodeStatus +``` + +### `visualize()` + +```python +run.visualize(fmt: str = "mermaid") -> str +``` + +Render the DAG with live node status colors. + +--- + +## `WorkflowProxy` + +::: taskito.workflows.WorkflowProxy + +Returned by `@queue.workflow()`. Wraps a factory function. + +### `submit()` + +```python +proxy.submit(*args, **kwargs) -> WorkflowRun +``` + +Build and submit the workflow. + +### `build()` + +```python +proxy.build(*args, **kwargs) -> Workflow +``` + +Materialize without submitting. + +### `as_step()` + +```python +proxy.as_step(**params) -> SubWorkflowRef +``` + +Return a reference for use as a sub-workflow step. + +--- + +## Queue Methods + +Added to `Queue` via `QueueWorkflowMixin`: + +### `submit_workflow()` + +```python +queue.submit_workflow( + workflow: Workflow, + *, + incremental: bool = False, + base_run: str | None = None, +) -> WorkflowRun +``` + +### `approve_gate()` + +```python +queue.approve_gate(run_id: str, node_name: str) -> None +``` + +### `reject_gate()` + +```python +queue.reject_gate(run_id: str, node_name: str, error: str = "rejected") -> None +``` + +### `@queue.workflow()` + +```python +@queue.workflow(name: str | None = None, *, version: int = 1) +def factory() -> Workflow: ... +``` + +--- + +## Types + +### `WorkflowState` + +```python +class WorkflowState(str, Enum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + PAUSED = "paused" +``` + +### `NodeStatus` + +```python +class NodeStatus(str, Enum): + PENDING = "pending" + READY = "ready" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + SKIPPED = "skipped" + WAITING_APPROVAL = "waiting_approval" + CACHE_HIT = "cache_hit" +``` + +### `WorkflowStatus` + +```python +@dataclass +class WorkflowStatus: + run_id: str + state: WorkflowState + started_at: int | None + completed_at: int | None + error: str | None + nodes: dict[str, NodeSnapshot] +``` + +### `NodeSnapshot` + +```python +@dataclass +class NodeSnapshot: + name: str + status: NodeStatus + job_id: str | None + error: str | None +``` + +### `WorkflowContext` + +```python +@dataclass(frozen=True) +class WorkflowContext: + run_id: str + results: dict[str, Any] + statuses: dict[str, str] + params: dict[str, Any] | None + failure_count: int + success_count: int +``` + +### `GateConfig` + +```python +@dataclass +class GateConfig: + timeout: float | None = None + on_timeout: str = "reject" + message: str | Callable | None = None +``` diff --git a/docs/guide/core/workflows.md b/docs/guide/core/workflows.md index 0de9e6c..2b35899 100644 --- a/docs/guide/core/workflows.md +++ b/docs/guide/core/workflows.md @@ -1,200 +1,383 @@ # Workflows -taskito provides three composition primitives for building complex task pipelines: **chain**, **group**, and **chord**. +taskito supports two workflow models: **canvas primitives** (chain, group, chord) for simple pipelines, and **DAG workflows** for complex multi-step pipelines with fan-out, conditions, approval gates, and more. -## Signatures +## DAG Workflows -A `Signature` wraps a task call for deferred execution. Create them with `.s()` or `.si()`: +Define multi-step pipelines as directed acyclic graphs. Each step is a registered task; the engine handles ordering, parallelism, failure propagation, and state tracking. ```python -from taskito import chain, group, chord +from taskito.workflows import Workflow + +@queue.task() +def extract(): return fetch_data() + +@queue.task() +def transform(data): return clean(data) + +@queue.task() +def load(data): db.insert(data) -# Mutable signature — receives previous result as first argument -sig = add.s(1, 2) +wf = Workflow(name="etl") +wf.step("extract", extract) +wf.step("transform", transform, after="extract") +wf.step("load", load, after="transform") -# Immutable signature — ignores previous result -sig = add.si(1, 2) +run = queue.submit_workflow(wf) +result = run.wait(timeout=60) +print(result.state) # WorkflowState.COMPLETED +``` + +### Step Configuration + +Each step accepts the same options as `queue.enqueue()`: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | `str` | required | Unique step name | +| `task` | `TaskWrapper` | required | Registered task function | +| `after` | `str \| list[str]` | `None` | Predecessor step(s) | +| `args` | `tuple` | `()` | Positional arguments | +| `kwargs` | `dict` | `None` | Keyword arguments | +| `queue` | `str` | `None` | Override queue name | +| `max_retries` | `int` | `None` | Override retry count | +| `timeout_ms` | `int` | `None` | Override timeout | +| `priority` | `int` | `None` | Override priority | +| `fan_out` | `str` | `None` | Fan-out strategy (`"each"`) | +| `fan_in` | `str` | `None` | Fan-in strategy (`"all"`) | +| `condition` | `str \| callable` | `None` | Execution condition | + +### Workflow Decorator + +Register reusable workflow factories: + +```python +@queue.workflow("nightly_etl") +def etl_pipeline(): + wf = Workflow() + wf.step("extract", extract) + wf.step("load", load, after="extract") + return wf + +run = etl_pipeline.submit() +run.wait() ``` -## Chain +## Fan-Out / Fan-In -Execute tasks **sequentially**, piping each result as the first argument to the next task: +Split a step's result into parallel child jobs, then collect results: ```mermaid graph LR - S1["extract.s(url)"] -->|result| S2["transform.s()"] - S2 -->|result| S3["load.s()"] + fetch[fetch] --> process_0[process 0] + fetch --> process_1[process 1] + fetch --> process_2[process 2] + process_0 --> aggregate[aggregate] + process_1 --> aggregate + process_2 --> aggregate ``` ```python @queue.task() -def extract(url): - return requests.get(url).json() +def fetch() -> list[int]: + return [10, 20, 30] @queue.task() -def transform(data): - return [item["name"] for item in data] +def process(item: int) -> int: + return item * 2 @queue.task() -def load(names): - db.insert_many(names) - return len(names) +def aggregate(results: list[int]) -> int: + return sum(results) -# Build and execute the pipeline -result = chain( - extract.s("https://api.example.com/users"), - transform.s(), - load.s(), -).apply(queue) +wf = Workflow(name="map_reduce") +wf.step("fetch", fetch) +wf.step("process", process, after="fetch", fan_out="each") +wf.step("aggregate", aggregate, after="process", fan_in="all") -print(result.result(timeout=30)) # Number of records loaded +run = queue.submit_workflow(wf) +result = run.wait(timeout=30) +# aggregate receives [20, 40, 60] ``` -!!! tip - Use `.si()` (immutable signatures) when a step should **not** receive the previous result: +The `fan_out="each"` strategy calls the task once per element in the predecessor's return value. Results are collected in order by `fan_in="all"`. + +## Conditions + +Control which steps execute based on predecessor outcomes: + +```python +wf = Workflow(name="deploy_pipeline") +wf.step("test", run_tests) +wf.step("deploy", deploy, after="test") # default: on_success +wf.step("rollback", rollback, after="deploy", condition="on_failure") +wf.step("notify", send_slack, after="deploy", condition="always") +``` + +| Condition | Runs when | +|-----------|-----------| +| `None` / `"on_success"` | All predecessors completed successfully | +| `"on_failure"` | Any predecessor failed | +| `"always"` | Predecessors are terminal (regardless of outcome) | +| `callable` | `condition(ctx)` returns `True` | + +### Callable Conditions + +Pass a function that receives a `WorkflowContext`: + +```python +from taskito.workflows import WorkflowContext + +def high_score(ctx: WorkflowContext) -> bool: + return ctx.results["validate"]["score"] > 0.95 + +wf.step("deploy", deploy, after="validate", condition=high_score) +``` + +`WorkflowContext` provides: `results` (predecessor return values), `statuses`, `failure_count`, `success_count`, `run_id`. + +## Error Handling + +Control failure behavior at the workflow level: + +=== "Fail Fast (default)" ```python - chain( - step_a.s(input_data), - step_b.si(independent_data), # Ignores step_a's result - step_c.s(), - ).apply(queue) + wf = Workflow(name="strict", on_failure="fail_fast") ``` -## Group + One failure cancels all pending steps. The workflow transitions to `FAILED`. -Execute tasks **in parallel** (fan-out): +=== "Continue" -```mermaid -graph TD - G["group()"] --> S1["process.s(1)"] - G --> S2["process.s(2)"] - G --> S3["process.s(3)"] + ```python + wf = Workflow(name="resilient", on_failure="continue") + ``` + + Failed steps skip their `on_success` dependents, but independent branches keep running. Steps with `condition="on_failure"` or `"always"` still execute. + +## Approval Gates - S1 --> R1["Result 1"] - S2 --> R2["Result 2"] - S3 --> R3["Result 3"] +Pause a workflow for human review: + +```python +wf = Workflow(name="ml_deploy") +wf.step("train", train_model) +wf.step("evaluate", evaluate, after="train") +wf.gate("approve", after="evaluate", timeout=86400, on_timeout="reject") +wf.step("deploy", deploy, after="approve") ``` +The gate enters `WAITING_APPROVAL` status. Resolve it programmatically: + ```python -@queue.task() -def process(item_id): - return fetch_and_process(item_id) +run = queue.submit_workflow(wf) -# Enqueue all three in parallel -jobs = group( - process.s(1), - process.s(2), - process.s(3), -).apply(queue) +# Later, after review: +queue.approve_gate(run.id, "approve") # workflow continues +# or: +queue.reject_gate(run.id, "approve") # gate fails, downstream skipped +``` -# Collect results -results = [j.result(timeout=30) for j in jobs] +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `timeout` | `float` | `None` | Seconds until auto-resolve | +| `on_timeout` | `str` | `"reject"` | `"approve"` or `"reject"` | +| `message` | `str` | `None` | Human-readable approval message | + +## Sub-Workflows + +Nest workflows for composition: + +```python +@queue.workflow("etl") +def etl_pipeline(region): + wf = Workflow() + wf.step("extract", extract, args=[region]) + wf.step("load", load, after="extract") + return wf + +@queue.workflow("daily") +def daily_pipeline(): + wf = Workflow() + wf.step("eu_etl", etl_pipeline.as_step(region="eu")) + wf.step("us_etl", etl_pipeline.as_step(region="us")) + wf.step("reconcile", reconcile, after=["eu_etl", "us_etl"]) + return wf + +run = daily_pipeline.submit() ``` -## Chord +Child workflows run independently with their own nodes and state. Cancelling the parent cascades to children. -Fan-out with a **callback** — execute tasks in parallel, then pass all results to a final task: +## Cron-Scheduled Workflows -```mermaid -graph TD - F1["fetch.s(url1)"] --> C["Collect results"] - F2["fetch.s(url2)"] --> C - F3["fetch.s(url3)"] --> C - C -->|"[r1, r2, r3]"| CB["merge.s()"] +Combine `@queue.periodic()` with `@queue.workflow()`: + +```python +@queue.periodic(cron="0 0 2 * * *") +@queue.workflow("nightly_analytics") +def nightly(): + wf = Workflow() + wf.step("extract", extract_clickstream) + wf.step("aggregate", build_dashboards, after="extract") + return wf ``` +Each cron trigger submits a new workflow run. + +## Incremental Runs + +Skip unchanged steps by reusing results from a prior run: + ```python -@queue.task() -def fetch(url): - return requests.get(url).json() +run1 = queue.submit_workflow(wf) +run1.wait() -@queue.task() -def merge(results): - combined = {} - for r in results: - combined.update(r) - return combined +# Second run: only re-execute dirty nodes +run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) +run2.wait() +``` -# Fetch in parallel, then merge -result = chord( - group( - fetch.s("https://api1.example.com"), - fetch.s("https://api2.example.com"), - fetch.s("https://api3.example.com"), - ), - merge.s(), -).apply(queue) +Nodes that completed in the base run get `CACHE_HIT` status. If any predecessor is dirty (failed or missing in the base run), downstream nodes re-execute. + +Set a TTL to expire cached results: -print(result.result(timeout=60)) +```python +wf = Workflow(name="pipeline", cache_ttl=3600) # 1 hour ``` -## chunks +## Monitoring -Split a list of items into batched groups, creating one task per chunk: +### Status ```python -from taskito import chunks +run = queue.submit_workflow(wf) +status = run.status() +print(status.state) # WorkflowState.RUNNING +print(status.nodes["step_a"]) # NodeSnapshot(status=COMPLETED, ...) +``` -@queue.task() -def process_batch(items): - return [transform(item) for item in items] +### Wait -# Split 1000 items into groups of 100 -results = chunks(process_batch, items, chunk_size=100).apply(queue) +```python +final = run.wait(timeout=60) +if final.state == WorkflowState.COMPLETED: + print("All steps succeeded") ``` -`chunks()` returns a `group`, so you can combine it with `chord` for a map-reduce pattern: +### Cancel ```python -result = chord( - chunks(process_batch, items, chunk_size=100), - merge_results.s(), -).apply(queue) +run.cancel() # Skips pending steps, cancels running jobs ``` -## starmap +## Visualization -Create one task per args tuple — similar to Python's `itertools.starmap`: +Render the workflow DAG as a diagram: ```python -from taskito import starmap +# Pre-execution (structure only) +print(wf.visualize("mermaid")) -@queue.task() -def add(a, b): - return a + b +# Live status (with node colors) +print(run.visualize("mermaid")) +print(run.visualize("dot")) # Graphviz DOT format +``` + +## Graph Analysis + +Analyze the workflow DAG before execution: + +```python +wf.ancestors("load") # ["extract", "transform"] +wf.descendants("extract") # ["transform", "load"] +wf.topological_levels() # [["extract"], ["transform"], ["load"]] +wf.stats() # {nodes: 3, edges: 2, depth: 3, ...} -results = starmap(add, [(1, 2), (3, 4), (5, 6)]).apply(queue) +# With cost estimates +path, cost = wf.critical_path({"extract": 2.0, "transform": 7.0, "load": 1.0}) +# path=["extract", "transform", "load"], cost=10.0 + +plan = wf.execution_plan(max_workers=4) +# [["extract"], ["transform"], ["load"]] + +analysis = wf.bottleneck_analysis({"extract": 2.0, "transform": 7.0, "load": 1.0}) +# {"node": "transform", "percentage": 70.0, ...} ``` -`starmap()` also returns a `group`, so all tasks execute in parallel. +## Node Statuses + +| Status | Meaning | +|--------|---------| +| `PENDING` | Waiting for predecessors or job creation | +| `RUNNING` | Job is executing | +| `COMPLETED` | Step succeeded | +| `FAILED` | Step failed (after retries exhausted) | +| `SKIPPED` | Skipped due to failure cascade or condition | +| `WAITING_APPROVAL` | Gate awaiting approve/reject | +| `CACHE_HIT` | Reused result from a prior run | + +--- -## Group Concurrency Limits +## Canvas Primitives -Limit how many group members run concurrently with `max_concurrency`: +For simpler pipelines without DAG features, taskito also provides **chain**, **group**, and **chord**. + +### Signatures + +A `Signature` wraps a task call for deferred execution: + +```python +from taskito import chain, group, chord + +sig = add.s(1, 2) # Mutable — receives previous result +sig = add.si(1, 2) # Immutable — ignores previous result +``` + +### Chain + +Execute tasks sequentially, piping results: + +```python +result = chain( + extract.s("https://api.example.com/users"), + transform.s(), + load.s(), +).apply(queue) +``` + +### Group + +Execute tasks in parallel: ```python -# Only 5 tasks run at a time; the rest wait in waves jobs = group( - *[fetch.s(url) for url in urls], - max_concurrency=5, + process.s(1), + process.s(2), + process.s(3), ).apply(queue) ``` -Without `max_concurrency`, all group members are enqueued immediately. With it, members are dispatched in waves — each wave waits for completion before the next starts. +### Chord -## Real-World Example: ETL Pipeline +Fan-out with a callback: ```python -# Extract from multiple sources in parallel, -# transform each, then load all results -pipeline = chord( - group( - chain(extract.s(source), transform.s()) - for source in data_sources - ), - load.s(), -) +result = chord( + group(fetch.s(url) for url in urls), + merge.s(), +).apply(queue) +``` + +### chunks / starmap + +```python +from taskito import chunks, starmap + +# Batch processing +results = chunks(process_batch, items, chunk_size=100).apply(queue) -result = pipeline.apply(queue) +# Tuple unpacking +results = starmap(add, [(1, 2), (3, 4)]).apply(queue) ``` diff --git a/zensical.toml b/zensical.toml index 48b0600..2a7544c 100644 --- a/zensical.toml +++ b/zensical.toml @@ -101,6 +101,7 @@ nav = [ { "JobResult" = "api/result.md" }, { "JobContext" = "api/context.md" }, { "Canvas" = "api/canvas.md" }, + { "Workflows" = "api/workflows.md" }, { "Testing" = "api/testing.md" }, { "CLI" = "api/cli.md" }, ] }, From 251292ebfc0b869ef6d2b6597641f3e1d07c1f42 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 06:35:22 +0530 Subject: [PATCH 05/12] fix: use git dependency for dagron-core instead of local path The local path ../dagron/crates/dagron-core doesn't exist in CI. Switch to a git dependency pointing at the ByteVeda/dagron repo. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 9fdc42a..6209854 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,4 +21,4 @@ redis = { version = "0.27", features = ["script"] } openssl-sys = { version = "0.9", features = ["vendored"] } pyo3 = { version = "0.22", features = ["multiple-pymethods"] } async-trait = "0.1" -dagron-core = { path = "../dagron/crates/dagron-core" } +dagron-core = { git = "https://github.com/ByteVeda/dagron.git" } From c1480973a935794247f31a030378d5483ee0cac8 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 06:53:02 +0530 Subject: [PATCH 06/12] docs: add dedicated Workflows section with architecture diagrams New docs/workflows/ section with 9 pages: - Overview with system architecture diagram - Building Workflows (step config, decorator, node state machine) - Fan-Out & Fan-In (sequence diagram, empty/failure handling) - Conditions & Error Handling (fail_fast vs continue diagrams) - Approval Gates (timeout, events) - Sub-Workflows & Scheduling (composition diagram, cron) - Incremental Runs (dirty-set propagation diagrams) - Analysis & Visualization (critical path, bottleneck, Mermaid/DOT) - Canvas Primitives (comparison table vs DAG workflows) Also fixes mypy CI failures: - Widen Workflow.step() task param to Any (accepts duck-typed objects) - Fix test type annotations for sorted(), tuple hints, bool returns --- docs/guide/core/workflows.md | 383 +--------------------- docs/workflows/analysis.md | 140 ++++++++ docs/workflows/building.md | 134 ++++++++ docs/workflows/caching.md | 69 ++++ docs/workflows/canvas.md | 118 +++++++ docs/workflows/composition.md | 83 +++++ docs/workflows/conditions.md | 119 +++++++ docs/workflows/fan-out.md | 108 ++++++ docs/workflows/gates.md | 73 +++++ docs/workflows/index.md | 73 +++++ py_src/taskito/workflows/builder.py | 3 +- tests/python/test_workflows_caching.py | 4 +- tests/python/test_workflows_conditions.py | 2 +- tests/python/test_workflows_fan_out.py | 9 +- zensical.toml | 11 + 15 files changed, 940 insertions(+), 389 deletions(-) create mode 100644 docs/workflows/analysis.md create mode 100644 docs/workflows/building.md create mode 100644 docs/workflows/caching.md create mode 100644 docs/workflows/canvas.md create mode 100644 docs/workflows/composition.md create mode 100644 docs/workflows/conditions.md create mode 100644 docs/workflows/fan-out.md create mode 100644 docs/workflows/gates.md create mode 100644 docs/workflows/index.md diff --git a/docs/guide/core/workflows.md b/docs/guide/core/workflows.md index 2b35899..5cf3033 100644 --- a/docs/guide/core/workflows.md +++ b/docs/guide/core/workflows.md @@ -1,383 +1,6 @@ # Workflows -taskito supports two workflow models: **canvas primitives** (chain, group, chord) for simple pipelines, and **DAG workflows** for complex multi-step pipelines with fan-out, conditions, approval gates, and more. +taskito provides two workflow models. See the dedicated **[Workflows section](../../workflows/index.md)** for full documentation. -## DAG Workflows - -Define multi-step pipelines as directed acyclic graphs. Each step is a registered task; the engine handles ordering, parallelism, failure propagation, and state tracking. - -```python -from taskito.workflows import Workflow - -@queue.task() -def extract(): return fetch_data() - -@queue.task() -def transform(data): return clean(data) - -@queue.task() -def load(data): db.insert(data) - -wf = Workflow(name="etl") -wf.step("extract", extract) -wf.step("transform", transform, after="extract") -wf.step("load", load, after="transform") - -run = queue.submit_workflow(wf) -result = run.wait(timeout=60) -print(result.state) # WorkflowState.COMPLETED -``` - -### Step Configuration - -Each step accepts the same options as `queue.enqueue()`: - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `name` | `str` | required | Unique step name | -| `task` | `TaskWrapper` | required | Registered task function | -| `after` | `str \| list[str]` | `None` | Predecessor step(s) | -| `args` | `tuple` | `()` | Positional arguments | -| `kwargs` | `dict` | `None` | Keyword arguments | -| `queue` | `str` | `None` | Override queue name | -| `max_retries` | `int` | `None` | Override retry count | -| `timeout_ms` | `int` | `None` | Override timeout | -| `priority` | `int` | `None` | Override priority | -| `fan_out` | `str` | `None` | Fan-out strategy (`"each"`) | -| `fan_in` | `str` | `None` | Fan-in strategy (`"all"`) | -| `condition` | `str \| callable` | `None` | Execution condition | - -### Workflow Decorator - -Register reusable workflow factories: - -```python -@queue.workflow("nightly_etl") -def etl_pipeline(): - wf = Workflow() - wf.step("extract", extract) - wf.step("load", load, after="extract") - return wf - -run = etl_pipeline.submit() -run.wait() -``` - -## Fan-Out / Fan-In - -Split a step's result into parallel child jobs, then collect results: - -```mermaid -graph LR - fetch[fetch] --> process_0[process 0] - fetch --> process_1[process 1] - fetch --> process_2[process 2] - process_0 --> aggregate[aggregate] - process_1 --> aggregate - process_2 --> aggregate -``` - -```python -@queue.task() -def fetch() -> list[int]: - return [10, 20, 30] - -@queue.task() -def process(item: int) -> int: - return item * 2 - -@queue.task() -def aggregate(results: list[int]) -> int: - return sum(results) - -wf = Workflow(name="map_reduce") -wf.step("fetch", fetch) -wf.step("process", process, after="fetch", fan_out="each") -wf.step("aggregate", aggregate, after="process", fan_in="all") - -run = queue.submit_workflow(wf) -result = run.wait(timeout=30) -# aggregate receives [20, 40, 60] -``` - -The `fan_out="each"` strategy calls the task once per element in the predecessor's return value. Results are collected in order by `fan_in="all"`. - -## Conditions - -Control which steps execute based on predecessor outcomes: - -```python -wf = Workflow(name="deploy_pipeline") -wf.step("test", run_tests) -wf.step("deploy", deploy, after="test") # default: on_success -wf.step("rollback", rollback, after="deploy", condition="on_failure") -wf.step("notify", send_slack, after="deploy", condition="always") -``` - -| Condition | Runs when | -|-----------|-----------| -| `None` / `"on_success"` | All predecessors completed successfully | -| `"on_failure"` | Any predecessor failed | -| `"always"` | Predecessors are terminal (regardless of outcome) | -| `callable` | `condition(ctx)` returns `True` | - -### Callable Conditions - -Pass a function that receives a `WorkflowContext`: - -```python -from taskito.workflows import WorkflowContext - -def high_score(ctx: WorkflowContext) -> bool: - return ctx.results["validate"]["score"] > 0.95 - -wf.step("deploy", deploy, after="validate", condition=high_score) -``` - -`WorkflowContext` provides: `results` (predecessor return values), `statuses`, `failure_count`, `success_count`, `run_id`. - -## Error Handling - -Control failure behavior at the workflow level: - -=== "Fail Fast (default)" - - ```python - wf = Workflow(name="strict", on_failure="fail_fast") - ``` - - One failure cancels all pending steps. The workflow transitions to `FAILED`. - -=== "Continue" - - ```python - wf = Workflow(name="resilient", on_failure="continue") - ``` - - Failed steps skip their `on_success` dependents, but independent branches keep running. Steps with `condition="on_failure"` or `"always"` still execute. - -## Approval Gates - -Pause a workflow for human review: - -```python -wf = Workflow(name="ml_deploy") -wf.step("train", train_model) -wf.step("evaluate", evaluate, after="train") -wf.gate("approve", after="evaluate", timeout=86400, on_timeout="reject") -wf.step("deploy", deploy, after="approve") -``` - -The gate enters `WAITING_APPROVAL` status. Resolve it programmatically: - -```python -run = queue.submit_workflow(wf) - -# Later, after review: -queue.approve_gate(run.id, "approve") # workflow continues -# or: -queue.reject_gate(run.id, "approve") # gate fails, downstream skipped -``` - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `timeout` | `float` | `None` | Seconds until auto-resolve | -| `on_timeout` | `str` | `"reject"` | `"approve"` or `"reject"` | -| `message` | `str` | `None` | Human-readable approval message | - -## Sub-Workflows - -Nest workflows for composition: - -```python -@queue.workflow("etl") -def etl_pipeline(region): - wf = Workflow() - wf.step("extract", extract, args=[region]) - wf.step("load", load, after="extract") - return wf - -@queue.workflow("daily") -def daily_pipeline(): - wf = Workflow() - wf.step("eu_etl", etl_pipeline.as_step(region="eu")) - wf.step("us_etl", etl_pipeline.as_step(region="us")) - wf.step("reconcile", reconcile, after=["eu_etl", "us_etl"]) - return wf - -run = daily_pipeline.submit() -``` - -Child workflows run independently with their own nodes and state. Cancelling the parent cascades to children. - -## Cron-Scheduled Workflows - -Combine `@queue.periodic()` with `@queue.workflow()`: - -```python -@queue.periodic(cron="0 0 2 * * *") -@queue.workflow("nightly_analytics") -def nightly(): - wf = Workflow() - wf.step("extract", extract_clickstream) - wf.step("aggregate", build_dashboards, after="extract") - return wf -``` - -Each cron trigger submits a new workflow run. - -## Incremental Runs - -Skip unchanged steps by reusing results from a prior run: - -```python -run1 = queue.submit_workflow(wf) -run1.wait() - -# Second run: only re-execute dirty nodes -run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) -run2.wait() -``` - -Nodes that completed in the base run get `CACHE_HIT` status. If any predecessor is dirty (failed or missing in the base run), downstream nodes re-execute. - -Set a TTL to expire cached results: - -```python -wf = Workflow(name="pipeline", cache_ttl=3600) # 1 hour -``` - -## Monitoring - -### Status - -```python -run = queue.submit_workflow(wf) -status = run.status() -print(status.state) # WorkflowState.RUNNING -print(status.nodes["step_a"]) # NodeSnapshot(status=COMPLETED, ...) -``` - -### Wait - -```python -final = run.wait(timeout=60) -if final.state == WorkflowState.COMPLETED: - print("All steps succeeded") -``` - -### Cancel - -```python -run.cancel() # Skips pending steps, cancels running jobs -``` - -## Visualization - -Render the workflow DAG as a diagram: - -```python -# Pre-execution (structure only) -print(wf.visualize("mermaid")) - -# Live status (with node colors) -print(run.visualize("mermaid")) -print(run.visualize("dot")) # Graphviz DOT format -``` - -## Graph Analysis - -Analyze the workflow DAG before execution: - -```python -wf.ancestors("load") # ["extract", "transform"] -wf.descendants("extract") # ["transform", "load"] -wf.topological_levels() # [["extract"], ["transform"], ["load"]] -wf.stats() # {nodes: 3, edges: 2, depth: 3, ...} - -# With cost estimates -path, cost = wf.critical_path({"extract": 2.0, "transform": 7.0, "load": 1.0}) -# path=["extract", "transform", "load"], cost=10.0 - -plan = wf.execution_plan(max_workers=4) -# [["extract"], ["transform"], ["load"]] - -analysis = wf.bottleneck_analysis({"extract": 2.0, "transform": 7.0, "load": 1.0}) -# {"node": "transform", "percentage": 70.0, ...} -``` - -## Node Statuses - -| Status | Meaning | -|--------|---------| -| `PENDING` | Waiting for predecessors or job creation | -| `RUNNING` | Job is executing | -| `COMPLETED` | Step succeeded | -| `FAILED` | Step failed (after retries exhausted) | -| `SKIPPED` | Skipped due to failure cascade or condition | -| `WAITING_APPROVAL` | Gate awaiting approve/reject | -| `CACHE_HIT` | Reused result from a prior run | - ---- - -## Canvas Primitives - -For simpler pipelines without DAG features, taskito also provides **chain**, **group**, and **chord**. - -### Signatures - -A `Signature` wraps a task call for deferred execution: - -```python -from taskito import chain, group, chord - -sig = add.s(1, 2) # Mutable — receives previous result -sig = add.si(1, 2) # Immutable — ignores previous result -``` - -### Chain - -Execute tasks sequentially, piping results: - -```python -result = chain( - extract.s("https://api.example.com/users"), - transform.s(), - load.s(), -).apply(queue) -``` - -### Group - -Execute tasks in parallel: - -```python -jobs = group( - process.s(1), - process.s(2), - process.s(3), -).apply(queue) -``` - -### Chord - -Fan-out with a callback: - -```python -result = chord( - group(fetch.s(url) for url in urls), - merge.s(), -).apply(queue) -``` - -### chunks / starmap - -```python -from taskito import chunks, starmap - -# Batch processing -results = chunks(process_batch, items, chunk_size=100).apply(queue) - -# Tuple unpacking -results = starmap(add, [(1, 2), (3, 4)]).apply(queue) -``` +- **[DAG Workflows](../../workflows/building.md)** — Multi-step pipelines as directed acyclic graphs with fan-out, conditions, gates, sub-workflows, incremental caching, and visualization. +- **[Canvas Primitives](../../workflows/canvas.md)** — Lightweight chain, group, and chord composition for simple pipelines. diff --git a/docs/workflows/analysis.md b/docs/workflows/analysis.md new file mode 100644 index 0000000..83a33bf --- /dev/null +++ b/docs/workflows/analysis.md @@ -0,0 +1,140 @@ +# Analysis & Visualization + +Analyze the workflow DAG before execution and render diagrams with live status. + +## Graph inspection + +```python +wf = Workflow(name="pipeline") +wf.step("a", task_a) +wf.step("b", task_b, after="a") +wf.step("c", task_c, after="a") +wf.step("d", task_d, after=["b", "c"]) + +wf.ancestors("d") # ["a", "b", "c"] +wf.descendants("a") # ["b", "c", "d"] +wf.topological_levels() # [["a"], ["b", "c"], ["d"]] +wf.stats() +# {"nodes": 4, "edges": 4, "depth": 3, "width": 2, "density": 0.6667} +``` + +| Method | Returns | Description | +|--------|---------|-------------| +| `ancestors(node)` | `list[str]` | All transitive predecessors | +| `descendants(node)` | `list[str]` | All transitive successors | +| `topological_levels()` | `list[list[str]]` | Nodes grouped by depth | +| `stats()` | `dict` | Node count, edge count, depth, width, density | + +## Critical path + +Find the longest-weighted path through the DAG: + +```python +path, cost = wf.critical_path({ + "a": 2.0, + "b": 7.0, + "c": 1.0, + "d": 3.0, +}) +# path = ["a", "b", "d"], cost = 12.0 +``` + +Pass estimated durations per step. The critical path determines the minimum total execution time. + +## Execution plan + +Generate a step-by-step schedule respecting worker limits: + +```python +plan = wf.execution_plan(max_workers=2) +# [["a"], ["b", "c"], ["d"]] + +plan = wf.execution_plan(max_workers=1) +# [["a"], ["b"], ["c"], ["d"]] +``` + +Each stage contains up to `max_workers` nodes. Nodes in the same topological level are batched together. + +## Bottleneck analysis + +Identify the most expensive step on the critical path: + +```python +result = wf.bottleneck_analysis({ + "a": 2.0, "b": 7.0, "c": 1.0, "d": 3.0 +}) +# { +# "node": "b", +# "cost": 7.0, +# "percentage": 58.3, +# "critical_path": ["a", "b", "d"], +# "total_cost": 12.0, +# "suggestion": "b is the bottleneck (58.3% of total time). ..." +# } +``` + +## Visualization + +Render the DAG as a diagram string: + +=== "Mermaid" + + ```python + print(wf.visualize("mermaid")) + ``` + + ``` + graph LR + a[a] + b[b] + c[c] + d[d] + a --> b + a --> c + b --> d + c --> d + ``` + +=== "Graphviz DOT" + + ```python + print(wf.visualize("dot")) + ``` + + ``` + digraph workflow { + rankdir=LR; + a [label="a" style=filled fillcolor=white]; + b [label="b" style=filled fillcolor=white]; + ... + } + ``` + +### Live status visualization + +`WorkflowRun.visualize()` includes status colors: + +```python +run = queue.submit_workflow(wf) +run.wait() + +print(run.visualize("mermaid")) +``` + +``` +graph LR + a[a ✓] + b[b ✓] + a --> b + style a fill:#90EE90 + style b fill:#90EE90 +``` + +| Status | Color | +|--------|-------| +| Completed | Green `#90EE90` | +| Failed | Red `#FFB6C1` | +| Running | Blue `#87CEEB` | +| Pending | Gray `#D3D3D3` | +| Skipped | Light gray `#F5F5F5` | +| Waiting Approval | Yellow `#FFFACD` | diff --git a/docs/workflows/building.md b/docs/workflows/building.md new file mode 100644 index 0000000..6241bfc --- /dev/null +++ b/docs/workflows/building.md @@ -0,0 +1,134 @@ +# Building Workflows + +A workflow is a DAG of steps. Each step wraps a registered task. The engine creates jobs in topological order with `depends_on` chains so the existing scheduler handles execution. + +## Defining steps + +```python +from taskito.workflows import Workflow + +wf = Workflow(name="etl", version=1) +wf.step("extract", extract_task) +wf.step("transform", transform_task, after="extract") +wf.step("load", load_task, after="transform") +``` + +Steps are added in order. The `after` parameter declares predecessors — a step won't run until all its predecessors complete. + +### Multiple predecessors + +```python +wf.step("merge", merge_task, after=["branch_a", "branch_b"]) +``` + +### Step arguments + +```python +wf.step("fetch", fetch_task, args=("https://api.example.com",)) +wf.step("process", process_task, after="fetch", kwargs={"mode": "strict"}) +``` + +Arguments are serialized at submission time using the queue's serializer. + +## Step configuration + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | `str` | required | Unique step name within the workflow | +| `task` | `TaskWrapper` | required | Registered `@queue.task()` function | +| `after` | `str \| list[str]` | `None` | Predecessor step(s) | +| `args` | `tuple` | `()` | Positional arguments | +| `kwargs` | `dict` | `None` | Keyword arguments | +| `queue` | `str` | `None` | Override queue name | +| `max_retries` | `int` | `None` | Override retry count | +| `timeout_ms` | `int` | `None` | Override timeout (milliseconds) | +| `priority` | `int` | `None` | Override priority | + +## Workflow decorator + +Register reusable workflow factories with `@queue.workflow()`: + +```python +@queue.workflow("nightly_etl") +def etl_pipeline(): + wf = Workflow() + wf.step("extract", extract) + wf.step("load", load, after="extract") + return wf + +# Build and submit +run = etl_pipeline.submit() +run.wait() + +# Or build without submitting +wf = etl_pipeline.build() +print(wf.step_names) # ["extract", "load"] +``` + +## Submitting + +```python +run = queue.submit_workflow(wf) +``` + +This creates a `WorkflowRun` handle. Under the hood: + +1. A `WorkflowDefinition` is stored (or reused by name + version) +2. A `WorkflowRun` record is created +3. For each step in topological order, a job is enqueued with `depends_on` chains +4. The run transitions to `RUNNING` + +```mermaid +sequenceDiagram + participant P as Python + participant R as Rust Engine + participant S as Scheduler + + P->>R: submit_workflow(dag, payloads) + R->>R: Store definition + run + loop Each step in topo order + R->>S: enqueue(job, depends_on=[pred_ids]) + end + R-->>P: WorkflowRun handle + S->>S: Dequeue jobs as deps satisfied +``` + +## Workflow parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `name` | `str` | `"workflow"` | Workflow name | +| `version` | `int` | `1` | Version number | +| `on_failure` | `str` | `"fail_fast"` | Error strategy: `"fail_fast"` or `"continue"` | +| `cache_ttl` | `float` | `None` | Cache TTL in seconds for [incremental runs](caching.md) | + +## Node statuses + +Each step transitions through these states: + +```mermaid +stateDiagram-v2 + [*] --> Pending + Pending --> Running : job picked up + Running --> Completed : success + Running --> Failed : error (retries exhausted) + Pending --> Skipped : cascade / condition + Pending --> WaitingApproval : gate reached + WaitingApproval --> Completed : approved + WaitingApproval --> Failed : rejected + Pending --> CacheHit : incremental reuse + Completed --> [*] + Failed --> [*] + Skipped --> [*] + CacheHit --> [*] +``` + +| Status | Terminal | Meaning | +|--------|----------|---------| +| `PENDING` | No | Waiting for predecessors or job creation | +| `RUNNING` | No | Job is executing | +| `COMPLETED` | Yes | Step succeeded | +| `FAILED` | Yes | Step failed after retries exhausted | +| `SKIPPED` | Yes | Skipped due to failure cascade or unmet condition | +| `WAITING_APPROVAL` | No | Gate awaiting approve/reject | +| `CACHE_HIT` | Yes | Reused result from a prior run | diff --git a/docs/workflows/caching.md b/docs/workflows/caching.md new file mode 100644 index 0000000..dcb37c5 --- /dev/null +++ b/docs/workflows/caching.md @@ -0,0 +1,69 @@ +# Incremental Runs + +Skip unchanged steps by reusing results from a prior run. When a node completed successfully in the base run, it gets `CACHE_HIT` status instead of re-executing. + +## Basic usage + +```python +# First run: everything executes, results hashed (SHA-256) +run1 = queue.submit_workflow(wf) +run1.wait() + +# Second run: skip completed nodes from run1 +run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) +run2.wait() +``` + +Nodes that completed in `run1` with a stored result hash become `CACHE_HIT` in `run2`. Nodes that failed or are missing re-execute. + +## Dirty-set propagation + +If a node is dirty (failed or missing in the base run), all its downstream nodes are also dirty — even if they had cached results: + +```mermaid +graph LR + A["a — dirty (missing)"] --> B["b — dirty (propagated)"] + B --> C["c — dirty (propagated)"] + style A fill:#FFB6C1 + style B fill:#FFB6C1 + style C fill:#FFB6C1 +``` + +```mermaid +graph LR + A["a — CACHE_HIT ✓"] --> B["b — dirty (failed in base)"] + B --> C["c — dirty (propagated)"] + style A fill:#90EE90 + style B fill:#FFB6C1 + style C fill:#FFB6C1 +``` + +## Cache TTL + +Set a time-to-live on cached results: + +```python +wf = Workflow(name="pipeline", cache_ttl=3600) # 1 hour +``` + +If the base run completed more than `cache_ttl` seconds ago, all nodes are treated as dirty (full re-execution). + +## How it works + +At submit time with `incremental=True`: + +1. Fetch the base run's node data: `{name: (status, result_hash)}` +2. For each node in the new run: + - Base node completed + has result_hash → `CACHE_HIT` + - Base node failed / missing → dirty + - Any predecessor dirty → also dirty (propagated) +3. `CACHE_HIT` nodes are created with `status=cache_hit` and `completed_at` set — no job enqueued +4. Dirty nodes get normal jobs + +## Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `incremental` | `bool` | `False` | Enable cache comparison | +| `base_run` | `str` | `None` | Run ID to compare against | +| `cache_ttl` | `float` | `None` | TTL in seconds (on `Workflow`) | diff --git a/docs/workflows/canvas.md b/docs/workflows/canvas.md new file mode 100644 index 0000000..ff191cb --- /dev/null +++ b/docs/workflows/canvas.md @@ -0,0 +1,118 @@ +# Canvas Primitives + +For simpler pipelines without DAG features, taskito provides **chain**, **group**, and **chord** — lightweight composition that doesn't require the workflow engine. + +## Signatures + +A `Signature` wraps a task call for deferred execution: + +```python +from taskito import chain, group, chord + +sig = add.s(1, 2) # Mutable — receives previous result as first arg +sig = add.si(1, 2) # Immutable — ignores previous result +``` + +## Chain + +Execute tasks sequentially, piping each result to the next: + +```mermaid +graph LR + S1["extract.s(url)"] -->|result| S2["transform.s()"] + S2 -->|result| S3["load.s()"] +``` + +```python +result = chain( + extract.s("https://api.example.com/users"), + transform.s(), + load.s(), +).apply(queue) + +print(result.result(timeout=30)) +``` + +!!! tip + Use `.si()` when a step should **not** receive the previous result: + + ```python + chain( + step_a.s(input_data), + step_b.si(independent_data), + step_c.s(), + ).apply(queue) + ``` + +## Group + +Execute tasks in parallel (fan-out): + +```python +jobs = group( + process.s(1), + process.s(2), + process.s(3), +).apply(queue) + +results = [j.result(timeout=30) for j in jobs] +``` + +### Concurrency limits + +```python +jobs = group( + *[fetch.s(url) for url in urls], + max_concurrency=5, +).apply(queue) +``` + +## Chord + +Fan-out with a callback — run tasks in parallel, then pass all results to a final task: + +```python +result = chord( + group( + fetch.s("https://api1.example.com"), + fetch.s("https://api2.example.com"), + fetch.s("https://api3.example.com"), + ), + merge.s(), +).apply(queue) +``` + +## chunks / starmap + +```python +from taskito import chunks, starmap + +# Batch processing — split 1000 items into groups of 100 +results = chunks(process_batch, items, chunk_size=100).apply(queue) + +# Map-reduce pattern +result = chord( + chunks(process_batch, items, chunk_size=100), + merge_results.s(), +).apply(queue) + +# Tuple unpacking +results = starmap(add, [(1, 2), (3, 4), (5, 6)]).apply(queue) +``` + +## When to use canvas vs DAG workflows + +| Feature | Canvas | DAG Workflows | +|---------|--------|---------------| +| Setup | No imports needed | `from taskito.workflows import Workflow` | +| Topology | Linear chains, flat groups | Arbitrary DAGs | +| Fan-out | Static (known at build time) | Dynamic (from return values) | +| Conditions | None | `on_success`, `on_failure`, `always`, callables | +| Error handling | Per-task retries only | Workflow-level strategies | +| Approval gates | No | Yes | +| Sub-workflows | No | Yes | +| Incremental runs | No | Yes | +| Status tracking | Per-job only | Per-workflow + per-node | +| Visualization | No | Mermaid / DOT | + +Use canvas for quick one-off pipelines. Use DAG workflows for production pipelines that need monitoring, conditions, or complex topologies. diff --git a/docs/workflows/composition.md b/docs/workflows/composition.md new file mode 100644 index 0000000..d8b4d6d --- /dev/null +++ b/docs/workflows/composition.md @@ -0,0 +1,83 @@ +# Sub-Workflows & Scheduling + +Nest workflows for composition, and schedule workflows on a cron. + +## Sub-workflows + +Use `WorkflowProxy.as_step()` to embed one workflow inside another: + +```python +@queue.workflow("etl") +def etl_pipeline(region): + wf = Workflow() + wf.step("extract", extract, args=[region]) + wf.step("load", load, after="extract") + return wf + +@queue.workflow("daily") +def daily_pipeline(): + wf = Workflow() + wf.step("eu_etl", etl_pipeline.as_step(region="eu")) + wf.step("us_etl", etl_pipeline.as_step(region="us")) + wf.step("reconcile", reconcile, after=["eu_etl", "us_etl"]) + return wf + +run = daily_pipeline.submit() +``` + +```mermaid +graph TD + subgraph Parent["daily_pipeline"] + eu["eu_etl"] --> reconcile + us["us_etl"] --> reconcile + end + + subgraph Child1["etl (region=eu)"] + e1["extract"] --> l1["load"] + end + + subgraph Child2["etl (region=us)"] + e2["extract"] --> l2["load"] + end + + eu -.->|"submits"| Child1 + us -.->|"submits"| Child2 +``` + +### How it works + +1. Parent workflow submits the child workflow via `queue.submit_workflow()` with `parent_run_id` +2. The child runs independently with its own nodes and status +3. When the child completes/fails, the tracker updates the parent node +4. Downstream steps in the parent evaluate normally + +### Cancellation cascade + +Cancelling the parent cascades to all active child workflows: + +```python +run.cancel() # Cancels parent + all child sub-workflows +``` + +### Failure + +If a child workflow fails, the parent node is marked `FAILED`. Downstream steps follow the parent's `on_failure` strategy. + +## Cron-scheduled workflows + +Stack `@queue.periodic()` on top of `@queue.workflow()`: + +```python +@queue.periodic(cron="0 0 2 * * *") # 2:00 AM daily +@queue.workflow("nightly_analytics") +def nightly(): + wf = Workflow() + wf.step("extract", extract_clickstream) + wf.step("aggregate", build_dashboards, after="extract") + return wf +``` + +Each cron trigger submits a new workflow run. Under the hood, a bridge task `_wf_launcher_nightly_analytics` is registered that calls `proxy.submit()`. + +!!! note + The `@queue.periodic()` decorator must be the **outer** decorator (applied second, listed first). diff --git a/docs/workflows/conditions.md b/docs/workflows/conditions.md new file mode 100644 index 0000000..1e6d1d4 --- /dev/null +++ b/docs/workflows/conditions.md @@ -0,0 +1,119 @@ +# Conditions & Error Handling + +Control which steps execute based on predecessor outcomes, and configure how the workflow responds to failures. + +## Step conditions + +```python +wf.step("deploy", deploy, after="test") # default: on_success +wf.step("rollback", rollback, after="deploy", condition="on_failure") +wf.step("notify", send_slack, after="deploy", condition="always") +``` + +| Condition | Runs when | +|-----------|-----------| +| `None` / `"on_success"` | All predecessors completed successfully | +| `"on_failure"` | Any predecessor failed | +| `"always"` | Predecessors are terminal (regardless of outcome) | +| `callable` | `condition(ctx)` returns `True` | + +## Callable conditions + +Pass a function that receives a `WorkflowContext`: + +```python +from taskito.workflows import WorkflowContext + +def high_score(ctx: WorkflowContext) -> bool: + return ctx.results["validate"]["score"] > 0.95 + +wf.step("deploy", deploy, after="validate", condition=high_score) +``` + +`WorkflowContext` fields: + +| Field | Type | Description | +|-------|------|-------------| +| `run_id` | `str` | Workflow run ID | +| `results` | `dict[str, Any]` | Deserialized return values of completed nodes | +| `statuses` | `dict[str, str]` | Status strings for all terminal nodes | +| `failure_count` | `int` | Number of failed nodes | +| `success_count` | `int` | Number of completed nodes | + +## Error strategies + +Set the workflow-level error strategy: + +=== "Fail Fast (default)" + + ```python + wf = Workflow(name="strict", on_failure="fail_fast") + ``` + + One failure skips **all** pending steps. The workflow transitions to `FAILED`. + + ```mermaid + graph LR + A["a ✓"] --> B["b ✗"] + B --> C["c ━ SKIPPED"] + B --> D["d ━ SKIPPED"] + style B fill:#FFB6C1 + style C fill:#F5F5F5 + style D fill:#F5F5F5 + ``` + +=== "Continue" + + ```python + wf = Workflow(name="resilient", on_failure="continue") + ``` + + Failed steps skip their `on_success` dependents, but **independent branches keep running**. + + ```mermaid + graph TD + root["root ✓"] --> fail_branch["fail ✗"] + root --> ok_branch["ok ✓"] + fail_branch --> after_fail["after_fail ━ SKIPPED"] + ok_branch --> after_ok["after_ok ✓"] + style fail_branch fill:#FFB6C1 + style after_fail fill:#F5F5F5 + style after_ok fill:#90EE90 + ``` + +## Skip propagation + +When a node is skipped, its successors are evaluated recursively: + +- `on_success` successors → **SKIPPED** (predecessor didn't succeed) +- `on_failure` successors → evaluated (predecessor is terminal) +- `always` successors → **run** regardless of how the predecessor ended + +```python +wf = Workflow(name="cleanup_pipeline") +wf.step("a", risky_task) +wf.step("b", next_step, after="a") # SKIPPED if a fails +wf.step("cleanup", cleanup, after="b", condition="always") # runs even if b is skipped +``` + +```mermaid +graph LR + A["a ✗ FAILED"] --> B["b ━ SKIPPED"] + B --> C["cleanup ✓ ALWAYS"] + style A fill:#FFB6C1 + style B fill:#F5F5F5 + style C fill:#90EE90 +``` + +## Combining conditions with fan-out + +Conditions work with fan-out nodes. If a fan-out child fails: + +```python +wf.step("fetch", fetch_data) +wf.step("process", process, after="fetch", fan_out="each") +wf.step("aggregate", aggregate, after="process", fan_in="all") +wf.step("on_error", alert, after="process", condition="on_failure") +``` + +If any `process[i]` child fails, the fan-out parent is marked `FAILED`, `aggregate` is skipped, and `on_error` runs. diff --git a/docs/workflows/fan-out.md b/docs/workflows/fan-out.md new file mode 100644 index 0000000..99645f0 --- /dev/null +++ b/docs/workflows/fan-out.md @@ -0,0 +1,108 @@ +# Fan-Out & Fan-In + +Split a step's result into parallel child jobs, then collect all results into a downstream step. + +```mermaid +graph LR + fetch --> process_0["process[0]"] + fetch --> process_1["process[1]"] + fetch --> process_2["process[2]"] + process_0 --> aggregate + process_1 --> aggregate + process_2 --> aggregate + + style fetch fill:#90EE90 + style aggregate fill:#87CEEB +``` + +## Fan-out with `"each"` + +The predecessor's return value must be iterable. Each element becomes a separate child job: + +```python +@queue.task() +def fetch() -> list[int]: + return [10, 20, 30] + +@queue.task() +def process(item: int) -> int: + return item * 2 + +@queue.task() +def aggregate(results: list[int]) -> int: + return sum(results) # receives [20, 40, 60] + +wf = Workflow(name="map_reduce") +wf.step("fetch", fetch) +wf.step("process", process, after="fetch", fan_out="each") +wf.step("aggregate", aggregate, after="process", fan_in="all") +``` + +Child nodes are named `process[0]`, `process[1]`, `process[2]` and appear in status queries. + +## How it works + +1. `fetch` completes — the tracker reads its return value +2. `apply_fan_out("each", result)` splits the list into individual items +3. `expand_fan_out()` creates N child nodes + N jobs (no `depends_on` — they're ready immediately) +4. Each child runs independently in parallel +5. When all children complete, `check_fan_out_completion()` marks the parent +6. The tracker collects all child results in index order +7. The fan-in job is created with `((results_list,), {})` as its payload + +```mermaid +sequenceDiagram + participant F as fetch job + participant T as Tracker + participant R as Rust Engine + + F->>T: JOB_COMPLETED(fetch) + T->>R: get_job(fetch_id).result_bytes + T->>R: expand_fan_out(3 children) + R-->>T: [child_job_ids] + Note over R: Children execute in parallel + R->>T: JOB_COMPLETED(process[0]) + R->>T: JOB_COMPLETED(process[1]) + R->>T: JOB_COMPLETED(process[2]) + T->>R: check_fan_out_completion → all done + T->>R: create_deferred_job(aggregate) +``` + +## Empty fan-out + +If the predecessor returns an empty list, the fan-out parent is marked `COMPLETED` immediately with zero children, and the fan-in receives an empty list: + +```python +@queue.task() +def fetch() -> list: + return [] # nothing to process + +# aggregate receives [] +``` + +## Fan-out with downstream steps + +Steps after the fan-in work normally: + +```python +wf = Workflow(name="full_pipeline") +wf.step("fetch", fetch) +wf.step("process", process, after="fetch", fan_out="each") +wf.step("aggregate", aggregate, after="process", fan_in="all") +wf.step("report", send_report, after="aggregate") # runs after aggregate +``` + +## Failure handling + +By default (`on_failure="fail_fast"`), if any fan-out child fails: + +- Remaining pending children are cancelled +- The fan-out parent is marked `FAILED` +- The fan-in and downstream steps are `SKIPPED` +- The workflow transitions to `FAILED` + +Combine with [conditions](conditions.md) for more control: + +```python +wf.step("handle_error", alert, after="process", condition="on_failure") +``` diff --git a/docs/workflows/gates.md b/docs/workflows/gates.md new file mode 100644 index 0000000..de2eda1 --- /dev/null +++ b/docs/workflows/gates.md @@ -0,0 +1,73 @@ +# Approval Gates + +Pause a workflow for human review. The gate enters `WAITING_APPROVAL` status until explicitly approved or rejected — or until a timeout fires. + +```mermaid +graph LR + train["train ✓"] --> eval["evaluate ✓"] + eval --> gate["approve ⏸"] + gate -->|approved| deploy["deploy"] + gate -->|rejected| skip["deploy ━ SKIPPED"] + style gate fill:#FFFACD +``` + +## Adding a gate + +```python +wf = Workflow(name="ml_deploy") +wf.step("train", train_model) +wf.step("evaluate", evaluate, after="train") +wf.gate("approve", after="evaluate") +wf.step("deploy", deploy, after="approve") +``` + +When the workflow reaches the gate, it pauses. Downstream steps won't execute until the gate is resolved. + +## Resolving a gate + +```python +run = queue.submit_workflow(wf) + +# Later, after review: +queue.approve_gate(run.id, "approve") # → gate COMPLETED, deploy runs +# or: +queue.reject_gate(run.id, "approve") # → gate FAILED, deploy SKIPPED +``` + +## Timeout + +Auto-resolve after a deadline: + +```python +wf.gate("approve", after="evaluate", + timeout=86400, # 24 hours + on_timeout="reject") # or "approve" +``` + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `timeout` | `float` | `None` | Seconds until auto-resolve. `None` waits forever | +| `on_timeout` | `str` | `"reject"` | Action on expiry: `"approve"` or `"reject"` | +| `message` | `str` | `None` | Human-readable message for approvers | + +## Gate with conditions + +Gates respect step conditions: + +```python +wf.step("test", run_tests) +wf.gate("approve", after="test", condition="on_success") +wf.step("deploy", deploy, after="approve") +``` + +If `test` fails, the gate is skipped (condition not met), and `deploy` is also skipped. + +## Events + +When a gate enters `WAITING_APPROVAL`, a `WORKFLOW_GATE_REACHED` event fires: + +```python +@queue.on(EventType.WORKFLOW_GATE_REACHED) +def notify_team(event_type, payload): + send_slack(f"Workflow {payload['run_id']} needs approval at {payload['node_name']}") +``` diff --git a/docs/workflows/index.md b/docs/workflows/index.md new file mode 100644 index 0000000..c19c78d --- /dev/null +++ b/docs/workflows/index.md @@ -0,0 +1,73 @@ +# Workflows + +Build multi-step pipelines as directed acyclic graphs. Define steps, wire dependencies, and let taskito handle execution order, parallelism, failure propagation, and state tracking — all backed by a Rust engine with dagron-core for graph algorithms. + +```mermaid +graph TB + subgraph Python["Python Layer"] + WB["Workflow Builder
Workflow.step() / .gate()"] + WT["Workflow Tracker
Event-driven orchestration"] + WR["WorkflowRun
.status() / .wait() / .cancel()"] + end + + subgraph Rust["Rust Engine"] + WS["WorkflowStorage
SQLite (definitions, runs, nodes)"] + TE["Topology Engine
dagron-core DAG algorithms"] + SC["Scheduler
depends_on job chains"] + end + + subgraph Features["Capabilities"] + FO["Fan-Out / Fan-In"] + CD["Conditions & Gates"] + SW["Sub-Workflows"] + IC["Incremental Caching"] + VZ["Visualization"] + end + + WB -->|"_compile()"| TE + WB -->|"submit_workflow()"| WS + WS -->|"enqueue jobs"| SC + SC -->|"JOB_COMPLETED"| WT + WT -->|"evaluate successors"| WS + WT --> WR +``` + +## Quick start + +```python +from taskito import Queue +from taskito.workflows import Workflow + +queue = Queue(db_path="tasks.db") + +@queue.task() +def extract(): return fetch_data() + +@queue.task() +def transform(data): return clean(data) + +@queue.task() +def load(data): db.insert(data) + +wf = Workflow(name="etl_pipeline") +wf.step("extract", extract) +wf.step("transform", transform, after="extract") +wf.step("load", load, after="transform") + +run = queue.submit_workflow(wf) +result = run.wait(timeout=60) +print(result.state) # WorkflowState.COMPLETED +``` + +## Section overview + +| Page | What it covers | +|---|---| +| [Building Workflows](building.md) | `Workflow.step()`, decorator pattern, step configuration, DAG structure | +| [Fan-Out & Fan-In](fan-out.md) | Splitting results into parallel jobs, collecting with aggregation | +| [Conditions & Error Handling](conditions.md) | `on_success`, `on_failure`, `always`, callable conditions, `on_failure` modes | +| [Approval Gates](gates.md) | Human-in-the-loop pause/resume, timeout, approve/reject API | +| [Sub-Workflows & Scheduling](composition.md) | Nesting workflows, cron-scheduled runs | +| [Incremental Runs](caching.md) | Result hashing, `CACHE_HIT`, dirty-set propagation, TTL | +| [Analysis & Visualization](analysis.md) | Critical path, bottleneck analysis, Mermaid/DOT rendering | +| [Canvas Primitives](canvas.md) | Chain, group, chord — simple composition without DAGs | diff --git a/py_src/taskito/workflows/builder.py b/py_src/taskito/workflows/builder.py index 02e10b0..ed70106 100644 --- a/py_src/taskito/workflows/builder.py +++ b/py_src/taskito/workflows/builder.py @@ -14,7 +14,6 @@ from collections.abc import Callable from taskito.app import Queue - from taskito.task import TaskWrapper from .run import WorkflowRun @@ -95,7 +94,7 @@ def __init__( def step( self, name: str, - task: TaskWrapper, + task: Any, *, after: str | list[str] | None = None, args: tuple = (), diff --git a/tests/python/test_workflows_caching.py b/tests/python/test_workflows_caching.py index 3e1d6ee..e434112 100644 --- a/tests/python/test_workflows_caching.py +++ b/tests/python/test_workflows_caching.py @@ -139,7 +139,7 @@ def test_dirty_propagation(queue: Queue) -> None: predecessors = {"a": [], "b": ["a"], "c": ["b"]} # Simulate "a" being dirty (not in base). - base_nodes_missing_a = [ + base_nodes_missing_a: list[tuple[str, str, str | None]] = [ ("b", "completed", "hash_b"), ("c", "completed", "hash_c"), ] @@ -226,7 +226,7 @@ def test_cache_ttl_expires() -> None: """Expired base run results trigger re-execution.""" from taskito.workflows.incremental import compute_dirty_set - base_nodes = [ + base_nodes: list[tuple[str, str, str | None]] = [ ("a", "completed", "hash_a"), ] diff --git a/tests/python/test_workflows_conditions.py b/tests/python/test_workflows_conditions.py index 7dc810c..96f5b95 100644 --- a/tests/python/test_workflows_conditions.py +++ b/tests/python/test_workflows_conditions.py @@ -316,7 +316,7 @@ def skip_deploy() -> str: return "skip" def high_score(ctx: WorkflowContext) -> bool: - return ctx.results.get("validate", {}).get("score", 0) > 0.95 + return bool(ctx.results.get("validate", {}).get("score", 0) > 0.95) wf = Workflow(name="callable_results") wf.step("validate", score_task) diff --git a/tests/python/test_workflows_fan_out.py b/tests/python/test_workflows_fan_out.py index 2dc4921..8455af7 100644 --- a/tests/python/test_workflows_fan_out.py +++ b/tests/python/test_workflows_fan_out.py @@ -4,6 +4,7 @@ import threading import time +from typing import Any import pytest @@ -33,7 +34,7 @@ def source() -> list[int]: def double(x: int) -> int: return x * 2 - collected: list[object] = [] + collected: list[Any] = [] @queue.task() def aggregate(results: list[int]) -> str: @@ -67,7 +68,7 @@ def source() -> list: def process(x: int) -> int: return x * 2 - collected: list[object] = [] + collected: list[Any] = [] @queue.task() def aggregate(results: list) -> str: @@ -101,7 +102,7 @@ def source() -> list[int]: def add_one(x: int) -> int: return x + 1 - collected: list[object] = [] + collected: list[Any] = [] @queue.task() def aggregate(results: list[int]) -> str: @@ -320,7 +321,7 @@ def source() -> list[str]: def identity(item: str) -> str: return item - collected: list[object] = [] + collected: list[Any] = [] @queue.task() def aggregate(results: list[str]) -> str: diff --git a/zensical.toml b/zensical.toml index 2a7544c..181835c 100644 --- a/zensical.toml +++ b/zensical.toml @@ -86,6 +86,17 @@ nav = [ { "Testing" = "resources/testing.md" }, { "Observability" = "resources/observability.md" }, ] }, + { "Workflows" = [ + { "Overview" = "workflows/index.md" }, + { "Building Workflows" = "workflows/building.md" }, + { "Fan-Out & Fan-In" = "workflows/fan-out.md" }, + { "Conditions & Error Handling" = "workflows/conditions.md" }, + { "Approval Gates" = "workflows/gates.md" }, + { "Sub-Workflows & Scheduling" = "workflows/composition.md" }, + { "Incremental Runs" = "workflows/caching.md" }, + { "Analysis & Visualization" = "workflows/analysis.md" }, + { "Canvas Primitives" = "workflows/canvas.md" }, + ] }, { "Architecture" = "architecture.md" }, { "API Reference" = [ { "Overview" = "api/index.md" }, From a12deca39b95a516b3ff67fdda9b666b891b5e03 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 06:55:39 +0530 Subject: [PATCH 07/12] refactor(workflows): replace Any with HasTaskName protocol for step() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Define a runtime-checkable Protocol instead of using Any for the task parameter. Accepts TaskWrapper, SubWorkflowRef, or any object with a _task_name attribute — type-safe without being overly rigid. --- py_src/taskito/workflows/builder.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/py_src/taskito/workflows/builder.py b/py_src/taskito/workflows/builder.py index ed70106..e600b21 100644 --- a/py_src/taskito/workflows/builder.py +++ b/py_src/taskito/workflows/builder.py @@ -8,7 +8,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable if TYPE_CHECKING: from collections.abc import Callable @@ -18,6 +18,14 @@ from .run import WorkflowRun +@runtime_checkable +class HasTaskName(Protocol): + """Any object that exposes a ``_task_name`` attribute.""" + + @property + def _task_name(self) -> str: ... + + _VALID_FAN_OUT = frozenset({"each"}) _VALID_FAN_IN = frozenset({"all"}) _VALID_CONDITIONS = frozenset({"on_success", "on_failure", "always"}) @@ -94,7 +102,7 @@ def __init__( def step( self, name: str, - task: Any, + task: HasTaskName, *, after: str | list[str] | None = None, args: tuple = (), From e5b256597d0eaeb9a1f6406be45a3bb3bc3e3b34 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 06:58:57 +0530 Subject: [PATCH 08/12] fix(docs): restructure workflow architecture diagram as LR pipeline --- docs/workflows/index.md | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/docs/workflows/index.md b/docs/workflows/index.md index c19c78d..5d4411b 100644 --- a/docs/workflows/index.md +++ b/docs/workflows/index.md @@ -3,32 +3,27 @@ Build multi-step pipelines as directed acyclic graphs. Define steps, wire dependencies, and let taskito handle execution order, parallelism, failure propagation, and state tracking — all backed by a Rust engine with dagron-core for graph algorithms. ```mermaid -graph TB - subgraph Python["Python Layer"] - WB["Workflow Builder
Workflow.step() / .gate()"] - WT["Workflow Tracker
Event-driven orchestration"] - WR["WorkflowRun
.status() / .wait() / .cancel()"] +graph LR + subgraph Build["1 — Build"] + WB["Workflow Builder
step · gate · fan_out"] end - subgraph Rust["Rust Engine"] - WS["WorkflowStorage
SQLite (definitions, runs, nodes)"] - TE["Topology Engine
dagron-core DAG algorithms"] - SC["Scheduler
depends_on job chains"] + subgraph Submit["2 — Submit"] + TE["Topology Engine
dagron-core"] + WS["Storage
SQLite"] + SC["Scheduler
depends_on chains"] end - subgraph Features["Capabilities"] - FO["Fan-Out / Fan-In"] - CD["Conditions & Gates"] - SW["Sub-Workflows"] - IC["Incremental Caching"] - VZ["Visualization"] + subgraph Run["3 — Execute"] + WT["Tracker
event orchestration"] + WR["WorkflowRun
status · wait · cancel"] end - WB -->|"_compile()"| TE - WB -->|"submit_workflow()"| WS - WS -->|"enqueue jobs"| SC - SC -->|"JOB_COMPLETED"| WT - WT -->|"evaluate successors"| WS + WB -- "_compile()" --> TE + TE --> WS + WS -- "enqueue jobs" --> SC + SC -- "JOB_COMPLETED" --> WT + WT -- "evaluate
successors" --> WS WT --> WR ``` From 95af1a509e422a6f1875942b2fd8822dc98b86c8 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 07:11:18 +0530 Subject: [PATCH 09/12] docs: split architecture into focused pages, add workflow examples Break architecture.md into 8 pages: overview, job lifecycle, worker pool, storage (with ER diagram), scheduler, resource system, failure model, and serialization. Replace raw SQL schema with Mermaid erDiagram. Add docs/examples/workflows.md with 6 patterns: ML pipeline with approval gate, map-reduce fan-out, continue mode, multi-region sub-workflows, incremental runs, and pre-execution analysis. Fix workflow index diagram (flowchart TD without subgraph overlap). --- docs/architecture.md | 328 ----------------------------- docs/architecture/failure-model.md | 54 +++++ docs/architecture/index.md | 41 ++++ docs/architecture/job-lifecycle.md | 29 +++ docs/architecture/resources.md | 30 +++ docs/architecture/scheduler.md | 29 +++ docs/architecture/serialization.md | 27 +++ docs/architecture/storage.md | 113 ++++++++++ docs/architecture/worker-pool.md | 28 +++ docs/examples/index.md | 1 + docs/examples/workflows.md | 261 +++++++++++++++++++++++ docs/index.md | 2 +- docs/workflows/index.md | 30 +-- zensical.toml | 12 +- 14 files changed, 633 insertions(+), 352 deletions(-) delete mode 100644 docs/architecture.md create mode 100644 docs/architecture/failure-model.md create mode 100644 docs/architecture/index.md create mode 100644 docs/architecture/job-lifecycle.md create mode 100644 docs/architecture/resources.md create mode 100644 docs/architecture/scheduler.md create mode 100644 docs/architecture/serialization.md create mode 100644 docs/architecture/storage.md create mode 100644 docs/architecture/worker-pool.md create mode 100644 docs/examples/workflows.md diff --git a/docs/architecture.md b/docs/architecture.md deleted file mode 100644 index bd24919..0000000 --- a/docs/architecture.md +++ /dev/null @@ -1,328 +0,0 @@ -# Architecture - -taskito is a hybrid Python/Rust system. Python provides the user-facing API. Rust handles all the heavy lifting: storage, scheduling, dispatch, rate limiting, and worker management. - -## Overview - -```mermaid -graph TB - subgraph Python ["Python Layer"] - A["Queue"] - B["@queue.task()"] - C["TaskWrapper"] - D["JobResult"] - E["current_job"] - R["ResourceRuntime"] - IC["ArgumentInterceptor"] - PX["ProxyRegistry"] - end - - subgraph Rust ["Rust Core · PyO3"] - F["PyQueue"] - G["Scheduler
Tokio async runtime"] - H["Worker Pool
OS threads + crossbeam"] - I["Rate Limiter
Token bucket"] - end - - subgraph Storage ["Storage Backend"] - J[("SQLite · WAL mode
Diesel ORM · r2d2 pool")] - K[("PostgreSQL
Diesel ORM · r2d2 pool")] - end - - A -->|"enqueue() → intercept args"| IC - IC -->|"transformed args"| F - F -->|INSERT| J - F -->|INSERT| K - G -->|"dequeue (poll every 50ms)"| J - G -->|"dispatch via crossbeam"| H - H -->|"acquire GIL → run task"| B - B -->|"reconstruct proxies"| PX - B -->|"inject resources"| R - H -->|"JobResult"| G - G -->|"UPDATE status"| J - D -->|"poll status"| F - F -->|SELECT| J - G -.->|"check rate limit"| I - I -.->|"token state"| J -``` - -## Job Lifecycle - -```mermaid -stateDiagram-v2 - [*] --> Pending: enqueue() / delay() - Pending --> Running: dequeued by scheduler - Pending --> Cancelled: cancel_job() - Running --> Complete: task returns successfully - Running --> Failed: task raises exception - Failed --> Pending: retry (count < max_retries)\nwith exponential backoff - Failed --> Dead: retries exhausted\nmoved to DLQ - Dead --> Pending: retry_dead() - Complete --> [*] - Cancelled --> [*] - Dead --> [*]: purge_dead() -``` - -**Status codes:** - -| Status | Integer | Description | -|---|---|---| -| Pending | 0 | Waiting to be picked up | -| Running | 1 | Currently executing | -| Complete | 2 | Finished successfully | -| Failed | 3 | Last attempt failed (may retry) | -| Dead | 4 | All retries exhausted, in DLQ | -| Cancelled | 5 | Cancelled before execution | - -## Worker Pool - -```mermaid -graph LR - subgraph Scheduler ["Scheduler Thread"] - S["Tokio async runtime
50ms poll interval"] - end - - S -->|"sync job"| JCH["Job Channel
(bounded: workers×2)"] - S -->|"async job"| AP["Native Async Pool
(NativeAsyncPool)"] - - subgraph Pool ["Worker Threads"] - W1["Worker 1
GIL per task"] - W2["Worker 2
GIL per task"] - WN["Worker N
GIL per task"] - end - - subgraph AsyncPool ["Async Executor"] - EL["Dedicated Event Loop
(daemon thread)"] - SEM["Semaphore
(async_concurrency)"] - end - - JCH --> W1 - JCH --> W2 - JCH --> WN - - AP --> EL - EL --> SEM - - W1 -->|"Result"| RCH["Result Channel
(bounded: workers×2)"] - W2 -->|"Result"| RCH - WN -->|"Result"| RCH - EL -->|"PyResultSender"| RCH - - RCH --> ML["Main Loop
(py.allow_threads)"] - ML -->|"complete / retry / DLQ"| DB[("SQLite")] -``` - -**Key design decisions:** - -- **OS threads, not Python threads**: Sync workers are Rust `std::thread` threads. The GIL is only acquired when calling Python task code. -- **Bounded channels**: Both job and result channels are bounded to `workers × 2` to provide backpressure. -- **GIL isolation**: Each sync worker acquires the GIL independently using `Python::with_gil()`. The scheduler and result handler release the GIL via `py.allow_threads()`. -- **Native async dispatch**: `async def` tasks bypass the thread pool entirely. A `NativeAsyncPool` sends them to a dedicated `AsyncTaskExecutor` running on a Python daemon thread. `PyResultSender` (a `#[pyclass]`) bridges results back into the Rust scheduler. -- **Context isolation**: Sync tasks use `threading.local` for `current_job`; async tasks use `contextvars.ContextVar`, which is properly scoped across `await` boundaries and isolated between concurrent coroutines. - -## Storage Layer - -### SQLite Configuration - -| Pragma | Value | Why | -|---|---|---| -| `journal_mode` | WAL | Concurrent reads while writing | -| `busy_timeout` | 5000ms | Wait on lock contention instead of failing | -| `synchronous` | NORMAL | Fast writes, safe with WAL | -| `journal_size_limit` | 64MB | Prevent unbounded WAL file growth | - -### Database Schema - -**6 tables:** - -```sql --- Core job storage -jobs (id, queue, task_name, payload, status, priority, - created_at, scheduled_at, started_at, completed_at, - retry_count, max_retries, result, error, timeout_ms, - unique_key, progress, metadata, - cancel_requested, expires_at, result_ttl_ms) - --- Dead letter queue -dead_letter (id, original_job_id, queue, task_name, - payload, error, retry_count, failed_at, metadata, - priority, max_retries, timeout_ms, result_ttl_ms) - --- Token bucket rate limiting -rate_limits (key, tokens, max_tokens, refill_rate, last_refill) - --- Cron-scheduled tasks -periodic_tasks (name, task_name, cron_expr, args, kwargs, - queue, enabled, last_run, next_run) - --- Per-attempt error tracking -job_errors (id, job_id, attempt, error, failed_at) - --- Worker heartbeat tracking -workers (worker_id, last_heartbeat, queues, status) -``` - -**Key indexes:** - -- `idx_jobs_dequeue`: `(queue, status, priority DESC, scheduled_at)` — fast dequeue -- `idx_jobs_status`: `(status)` — fast stats queries -- `idx_jobs_unique_key`: partial unique index on `unique_key` where status is pending/running -- `idx_job_errors_job_id`: `(job_id)` — fast error history lookup - -### Connection Pooling - -Diesel's `r2d2` connection pool with up to 8 connections (SQLite) or 10 connections (Postgres). In-memory databases use a single connection (SQLite `:memory:` is per-connection). - -### Postgres Configuration - -taskito also supports PostgreSQL as an alternative storage backend. See the [Postgres Backend guide](guide/operations/postgres.md) for full details. - -Key differences from the SQLite storage layer: - -- **Connection pooling**: `r2d2` pool with a default of 10 connections (vs. 8 for SQLite) -- **Schema isolation**: All tables are created inside a configurable PostgreSQL schema (default: `taskito`), with `search_path` set on each connection -- **Additional tables**: The Postgres backend creates 11 tables (vs. 6 for SQLite), adding `job_dependencies`, `task_metrics`, `replay_history`, `task_logs`, and `circuit_breakers` -- **Concurrent writes**: No single-writer constraint — multiple workers can write simultaneously - -## Scheduler Loop - -The scheduler runs in a dedicated Tokio single-threaded async runtime: - -``` -loop { - sleep(50ms) or shutdown signal - - // Try to dequeue and dispatch a job - try_dispatch() - - // Every ~100 iterations (~5s): reap timed-out jobs - reap_stale() - - // Every ~60 iterations (~3s): check periodic tasks - check_periodic() - - // Every ~1200 iterations (~60s): auto-cleanup old jobs - auto_cleanup() -} -``` - -### Dispatch Flow - -1. `dequeue_from()` — atomically SELECT + UPDATE (pending → running) within a transaction -2. Check rate limit — if over limit, reschedule 1s in the future -3. Send job to worker pool via crossbeam channel -4. Worker executes task, sends result back -5. `handle_result()` — mark complete, schedule retry, or move to DLQ - -## Resource System - -The resource system is a three-layer Python pipeline that runs entirely outside Rust: - -```mermaid -graph LR - subgraph Enqueue ["On enqueue()"] - A["Task arguments"] -->|"classify"| IC["ArgumentInterceptor"] - IC -->|"PASS"| S["Serializer"] - IC -->|"CONVERT → marker"| S - IC -->|"REDIRECT → DI marker"| S - IC -->|"PROXY → recipe"| PX["ProxyHandler.deconstruct()"] - PX --> S - IC -->|"REJECT"| ERR["InterceptionError"] - end - - subgraph Worker ["On worker dispatch"] - S2["Deserialize payload"] --> RC["reconstruct_args()"] - RC -->|"CONVERT markers"| OBJ["Restored types"] - RC2["reconstruct_proxies()"] -->|"PROXY recipes"| PX2["ProxyHandler.reconstruct()"] - RT["ResourceRuntime.acquire_for_task()"] -->|"REDIRECT + inject="| INJ["Injected resources"] - OBJ --> FN["Task function"] - PX2 --> FN - INJ --> FN - end -``` - -**Layer 1 — Argument Interception**: The `ArgumentInterceptor` walks every argument before serialization, applying the strategy registered for its type. CONVERT types are transformed to JSON-safe markers. REDIRECT types are replaced with a DI placeholder. PROXY types are deconstructed by their handler. REJECT types raise an error in strict mode. - -**Layer 2 — Worker Resource Runtime**: `ResourceRuntime` initializes all registered resources at worker startup in topological dependency order. At task dispatch time it injects the requested resources (via `inject=` or `Inject["name"]` annotation) as keyword arguments. Task-scoped resources are acquired from a semaphore pool and returned after the task finishes. - -**Layer 3 — Resource Proxies**: `ProxyHandler` implementations know how to deconstruct live objects (file handles, HTTP sessions, cloud clients) into a JSON-serializable recipe, and how to reconstruct them on the worker before the task function is called. Recipes are optionally HMAC-signed for tamper detection. - -## Failure Model - -Taskito provides **at-least-once delivery**. Here's what happens when things go wrong: - -### Worker crash mid-task - -The job stays in `running` status. The scheduler's stale reaper detects it after `timeout_ms` elapses, marks it failed, and retries (if retries remain) or moves to the dead letter queue. No manual intervention needed. - -### Parent process crash - -All worker threads stop. Jobs in `running` stay in that state until the next worker starts, when the stale reaper picks them up. Jobs in `pending` are unaffected — they'll be dispatched normally on restart. - -### Database unavailable - -Scheduler polls fail silently (logged via `log::error!`). No new jobs are dispatched. In-flight jobs complete normally — results are cached in memory until the database becomes available. - -### Network partition (Postgres/Redis) - -Same behavior as database unavailable. The scheduler retries on the next poll cycle (default: every 50ms). Connection pools handle reconnection automatically. - -### Duplicate execution - -`claim_execution` prevents two workers from picking up the same job simultaneously. But if a worker crashes *after* starting execution, the job will be retried — potentially executing the same task twice. Design tasks to be [idempotent](guide/reliability/guarantees.md) to handle this safely. - -### Recovery timeline - -```mermaid -sequenceDiagram - participant C as Client - participant DB as Database - participant S as Scheduler - participant W as Worker - - C->>DB: enqueue(job) - S->>DB: dequeue + claim_execution - S->>W: dispatch job - W->>W: execute task... - Note over W: Worker crashes at T=5s - Note over S: Scheduler continues polling... - Note over S: T=300s: reap_stale_jobs() detects
job.started_at + timeout_ms < now - S->>DB: mark failed, schedule retry - S->>DB: dequeue (same job, retry_count=1) - S->>W: dispatch to different worker - W->>DB: complete + clear claim -``` - -### Partial writes - -If a task completes successfully but the result write to the database fails (e.g., database full, connection lost), the job stays in `running` status. The stale reaper eventually marks it failed and retries it. The task will execute again — make sure it's [idempotent](guide/reliability/guarantees.md). - -### Jobs without timeouts - -!!! warning - If a job has no `timeout_ms` set and the worker crashes, the job stays in `running` **forever**. The stale reaper only detects jobs that have exceeded their timeout. Always set a timeout on production tasks. - -## Serialization - -taskito uses a pluggable serializer for task arguments and results. The default is `CloudpickleSerializer`, which supports lambdas, closures, and complex Python objects. - -```python -from taskito import Queue, JsonSerializer - -# Use JSON for simpler, cross-language payloads -queue = Queue(serializer=JsonSerializer()) -``` - -**Built-in serializers:** - -| Serializer | Format | Best for | -|---|---|---| -| `CloudpickleSerializer` (default) | Binary (pickle) | Complex Python objects, lambdas, closures | -| `JsonSerializer` | JSON | Simple types, cross-language interop, debugging | - -**Custom serializers** implement the `Serializer` protocol (`dumps(obj) -> bytes`, `loads(data) -> Any`). - -- **Arguments**: `serializer.dumps((args, kwargs))` — stored as BLOB in `payload` -- **Results**: `serializer.dumps(return_value)` — stored as BLOB in `result` -- **Periodic task args**: Serialized at registration time, stored as BLOBs in `periodic_tasks.args` - diff --git a/docs/architecture/failure-model.md b/docs/architecture/failure-model.md new file mode 100644 index 0000000..58c3bf6 --- /dev/null +++ b/docs/architecture/failure-model.md @@ -0,0 +1,54 @@ +# Failure Model + +Taskito provides **at-least-once delivery**. Here's what happens when things go wrong. + +## Worker crash mid-task + +The job stays in `running` status. The scheduler's stale reaper detects it after `timeout_ms` elapses, marks it failed, and retries (if retries remain) or moves to the dead letter queue. No manual intervention needed. + +## Parent process crash + +All worker threads stop. Jobs in `running` stay in that state until the next worker starts, when the stale reaper picks them up. Jobs in `pending` are unaffected — they'll be dispatched normally on restart. + +## Database unavailable + +Scheduler polls fail silently (logged via `log::error!`). No new jobs are dispatched. In-flight jobs complete normally — results are cached in memory until the database becomes available. + +## Network partition (Postgres/Redis) + +Same behavior as database unavailable. The scheduler retries on the next poll cycle (default: every 50ms). Connection pools handle reconnection automatically. + +## Duplicate execution + +`claim_execution` prevents two workers from picking up the same job simultaneously. But if a worker crashes *after* starting execution, the job will be retried — potentially executing the same task twice. Design tasks to be [idempotent](../guide/reliability/guarantees.md) to handle this safely. + +## Recovery timeline + +```mermaid +sequenceDiagram + participant C as Client + participant DB as Database + participant S as Scheduler + participant W as Worker + + C->>DB: enqueue(job) + S->>DB: dequeue + claim_execution + S->>W: dispatch job + W->>W: execute task... + Note over W: Worker crashes at T=5s + Note over S: Scheduler continues polling... + Note over S: T=300s: reap_stale_jobs() detects
job.started_at + timeout_ms < now + S->>DB: mark failed, schedule retry + S->>DB: dequeue (same job, retry_count=1) + S->>W: dispatch to different worker + W->>DB: complete + clear claim +``` + +## Partial writes + +If a task completes successfully but the result write to the database fails (e.g., database full, connection lost), the job stays in `running` status. The stale reaper eventually marks it failed and retries it. The task will execute again — make sure it's [idempotent](../guide/reliability/guarantees.md). + +## Jobs without timeouts + +!!! warning + If a job has no `timeout_ms` set and the worker crashes, the job stays in `running` **forever**. The stale reaper only detects jobs that have exceeded their timeout. Always set a timeout on production tasks. diff --git a/docs/architecture/index.md b/docs/architecture/index.md new file mode 100644 index 0000000..aa86f8d --- /dev/null +++ b/docs/architecture/index.md @@ -0,0 +1,41 @@ +# Architecture + +taskito is a hybrid Python/Rust system. Python provides the user-facing API. Rust handles all the heavy lifting: storage, scheduling, dispatch, rate limiting, and worker management. + +```mermaid +flowchart TD + subgraph py ["Python Layer"] + direction LR + Q["Queue"] --> IC["ArgumentInterceptor"] + TW["@queue.task()"] ~~~ RR["ResourceRuntime"] + end + + subgraph rust ["Rust Core — PyO3"] + direction LR + PQ["PyQueue"] --> SCH["Scheduler"] + SCH --> WP["Worker Pool"] + SCH --> RL["Rate Limiter"] + end + + subgraph storage ["Storage"] + direction LR + SQ[("SQLite")] ~~~ PG[("PostgreSQL")] + end + + IC --> PQ + WP -->|"acquire GIL"| TW + SCH -->|"poll / update"| SQ + PQ -->|"INSERT"| SQ +``` + +## Section overview + +| Page | What it covers | +|---|---| +| [Job Lifecycle](job-lifecycle.md) | State machine, status codes, transitions | +| [Worker Pool](worker-pool.md) | Thread architecture, async dispatch, GIL management | +| [Storage Layer](storage.md) | SQLite pragmas, schema, indexes, Postgres differences | +| [Scheduler](scheduler.md) | Poll loop, dispatch flow, periodic tasks | +| [Resource System](resources.md) | Argument interception, DI, proxy reconstruction | +| [Failure Model](failure-model.md) | Crash recovery, duplicate execution, partial writes | +| [Serialization](serialization.md) | Pluggable serializers, format details | diff --git a/docs/architecture/job-lifecycle.md b/docs/architecture/job-lifecycle.md new file mode 100644 index 0000000..aa6c386 --- /dev/null +++ b/docs/architecture/job-lifecycle.md @@ -0,0 +1,29 @@ +# Job Lifecycle + +Every job moves through a state machine from creation to completion (or death). + +```mermaid +stateDiagram-v2 + [*] --> Pending: enqueue() / delay() + Pending --> Running: dequeued by scheduler + Pending --> Cancelled: cancel_job() + Running --> Complete: task returns successfully + Running --> Failed: task raises exception + Failed --> Pending: retry (count < max_retries)\nwith exponential backoff + Failed --> Dead: retries exhausted\nmoved to DLQ + Dead --> Pending: retry_dead() + Complete --> [*] + Cancelled --> [*] + Dead --> [*]: purge_dead() +``` + +## Status codes + +| Status | Integer | Description | +|---|---|---| +| Pending | 0 | Waiting to be picked up | +| Running | 1 | Currently executing | +| Complete | 2 | Finished successfully | +| Failed | 3 | Last attempt failed (may retry) | +| Dead | 4 | All retries exhausted, in DLQ | +| Cancelled | 5 | Cancelled before execution | diff --git a/docs/architecture/resources.md b/docs/architecture/resources.md new file mode 100644 index 0000000..56ca699 --- /dev/null +++ b/docs/architecture/resources.md @@ -0,0 +1,30 @@ +# Resource System + +The resource system is a three-layer Python pipeline that runs entirely outside Rust: + +```mermaid +flowchart TD + subgraph enqueue ["enqueue()"] + ARGS["Task arguments"] --> IC["ArgumentInterceptor"] + IC -->|PASS / CONVERT / REDIRECT| SER["Serializer"] + IC -->|PROXY| PX["ProxyHandler.deconstruct()"] + PX --> SER + end + + SER -->|"serialized payload"| QUEUE[("Queue")] + + subgraph worker ["Worker dispatch"] + DE["Deserialize"] --> RC["reconstruct_args()"] + RC --> FN["Task function"] + RT["ResourceRuntime"] -->|"inject"| FN + PX2["ProxyHandler.reconstruct()"] --> FN + end + + QUEUE --> DE +``` + +**Layer 1 — Argument Interception**: The `ArgumentInterceptor` walks every argument before serialization, applying the strategy registered for its type. CONVERT types are transformed to JSON-safe markers. REDIRECT types are replaced with a DI placeholder. PROXY types are deconstructed by their handler. REJECT types raise an error in strict mode. + +**Layer 2 — Worker Resource Runtime**: `ResourceRuntime` initializes all registered resources at worker startup in topological dependency order. At task dispatch time it injects the requested resources (via `inject=` or `Inject["name"]` annotation) as keyword arguments. Task-scoped resources are acquired from a semaphore pool and returned after the task finishes. + +**Layer 3 — Resource Proxies**: `ProxyHandler` implementations know how to deconstruct live objects (file handles, HTTP sessions, cloud clients) into a JSON-serializable recipe, and how to reconstruct them on the worker before the task function is called. Recipes are optionally HMAC-signed for tamper detection. diff --git a/docs/architecture/scheduler.md b/docs/architecture/scheduler.md new file mode 100644 index 0000000..60d05a5 --- /dev/null +++ b/docs/architecture/scheduler.md @@ -0,0 +1,29 @@ +# Scheduler + +The scheduler runs in a dedicated Tokio single-threaded async runtime: + +``` +loop { + sleep(50ms) or shutdown signal + + // Try to dequeue and dispatch a job + try_dispatch() + + // Every ~100 iterations (~5s): reap timed-out jobs + reap_stale() + + // Every ~60 iterations (~3s): check periodic tasks + check_periodic() + + // Every ~1200 iterations (~60s): auto-cleanup old jobs + auto_cleanup() +} +``` + +## Dispatch flow + +1. `dequeue_from()` — atomically SELECT + UPDATE (pending → running) within a transaction +2. Check rate limit — if over limit, reschedule 1s in the future +3. Send job to worker pool via crossbeam channel +4. Worker executes task, sends result back +5. `handle_result()` — mark complete, schedule retry, or move to DLQ diff --git a/docs/architecture/serialization.md b/docs/architecture/serialization.md new file mode 100644 index 0000000..1902707 --- /dev/null +++ b/docs/architecture/serialization.md @@ -0,0 +1,27 @@ +# Serialization + +taskito uses a pluggable serializer for task arguments and results. The default is `CloudpickleSerializer`, which supports lambdas, closures, and complex Python objects. + +```python +from taskito import Queue, JsonSerializer + +# Use JSON for simpler, cross-language payloads +queue = Queue(serializer=JsonSerializer()) +``` + +## Built-in serializers + +| Serializer | Format | Best for | +|---|---|---| +| `CloudpickleSerializer` (default) | Binary (pickle) | Complex Python objects, lambdas, closures | +| `JsonSerializer` | JSON | Simple types, cross-language interop, debugging | + +## Custom serializers + +Implement the `Serializer` protocol (`dumps(obj) -> bytes`, `loads(data) -> Any`). + +## What gets serialized + +- **Arguments**: `serializer.dumps((args, kwargs))` — stored as BLOB in `payload` +- **Results**: `serializer.dumps(return_value)` — stored as BLOB in `result` +- **Periodic task args**: Serialized at registration time, stored as BLOBs in `periodic_tasks.args` diff --git a/docs/architecture/storage.md b/docs/architecture/storage.md new file mode 100644 index 0000000..6ae0491 --- /dev/null +++ b/docs/architecture/storage.md @@ -0,0 +1,113 @@ +# Storage Layer + +## SQLite configuration + +| Pragma | Value | Why | +|---|---|---| +| `journal_mode` | WAL | Concurrent reads while writing | +| `busy_timeout` | 5000ms | Wait on lock contention instead of failing | +| `synchronous` | NORMAL | Fast writes, safe with WAL | +| `journal_size_limit` | 64MB | Prevent unbounded WAL file growth | + +## Database schema + +```mermaid +erDiagram + jobs { + TEXT id PK + TEXT queue + TEXT task_name + BLOB payload + INTEGER status + INTEGER priority + INTEGER created_at + INTEGER scheduled_at + INTEGER started_at + INTEGER completed_at + INTEGER retry_count + INTEGER max_retries + BLOB result + TEXT error + INTEGER timeout_ms + TEXT unique_key + INTEGER progress + TEXT metadata + BOOLEAN cancel_requested + INTEGER expires_at + INTEGER result_ttl_ms + } + + dead_letter { + TEXT id PK + TEXT original_job_id + TEXT queue + TEXT task_name + BLOB payload + TEXT error + INTEGER retry_count + INTEGER failed_at + TEXT metadata + INTEGER priority + INTEGER max_retries + INTEGER timeout_ms + INTEGER result_ttl_ms + } + + job_errors { + TEXT id PK + TEXT job_id FK + INTEGER attempt + TEXT error + INTEGER failed_at + } + + rate_limits { + TEXT key PK + REAL tokens + REAL max_tokens + REAL refill_rate + INTEGER last_refill + } + + periodic_tasks { + TEXT name PK + TEXT task_name + TEXT cron_expr + BLOB args + BLOB kwargs + TEXT queue + BOOLEAN enabled + INTEGER last_run + INTEGER next_run + } + + workers { + TEXT worker_id PK + INTEGER last_heartbeat + TEXT queues + TEXT status + } + + jobs ||--o{ job_errors : "error history" + jobs ||--o| dead_letter : "DLQ on exhaustion" +``` + +## Key indexes + +- `idx_jobs_dequeue`: `(queue, status, priority DESC, scheduled_at)` — fast dequeue +- `idx_jobs_status`: `(status)` — fast stats queries +- `idx_jobs_unique_key`: partial unique index on `unique_key` where status is pending/running +- `idx_job_errors_job_id`: `(job_id)` — fast error history lookup + +## Connection pooling + +Diesel's `r2d2` connection pool with up to 8 connections (SQLite) or 10 connections (Postgres). In-memory databases use a single connection (SQLite `:memory:` is per-connection). + +## Postgres differences + +taskito also supports PostgreSQL as an alternative storage backend. See the [Postgres Backend guide](../guide/operations/postgres.md) for full details. + +- **Connection pooling**: `r2d2` pool with a default of 10 connections (vs. 8 for SQLite) +- **Schema isolation**: All tables are created inside a configurable PostgreSQL schema (default: `taskito`), with `search_path` set on each connection +- **Additional tables**: The Postgres backend creates 11 tables (vs. 6 for SQLite), adding `job_dependencies`, `task_metrics`, `replay_history`, `task_logs`, and `circuit_breakers` +- **Concurrent writes**: No single-writer constraint — multiple workers can write simultaneously diff --git a/docs/architecture/worker-pool.md b/docs/architecture/worker-pool.md new file mode 100644 index 0000000..4e62d9c --- /dev/null +++ b/docs/architecture/worker-pool.md @@ -0,0 +1,28 @@ +# Worker Pool + +The worker pool dispatches jobs from the scheduler to Python task functions. + +```mermaid +flowchart TD + SCH["Scheduler\nTokio async · 50ms poll"] + + SCH -->|"sync job"| JCH["Job Channel\nbounded: workers×2"] + SCH -->|"async job"| AP["Native Async Pool"] + + JCH --> WP["Worker Threads\nGIL per task · N threads"] + AP --> EL["Async Executor\ndedicated event loop"] + + WP -->|"Result"| RCH["Result Channel"] + EL -->|"PyResultSender"| RCH + + RCH --> ML["Main Loop\npy.allow_threads"] + ML -->|"complete / retry / DLQ"| DB[("SQLite")] +``` + +## Design decisions + +- **OS threads, not Python threads**: Sync workers are Rust `std::thread` threads. The GIL is only acquired when calling Python task code. +- **Bounded channels**: Both job and result channels are bounded to `workers × 2` to provide backpressure. +- **GIL isolation**: Each sync worker acquires the GIL independently using `Python::with_gil()`. The scheduler and result handler release the GIL via `py.allow_threads()`. +- **Native async dispatch**: `async def` tasks bypass the thread pool entirely. A `NativeAsyncPool` sends them to a dedicated `AsyncTaskExecutor` running on a Python daemon thread. `PyResultSender` (a `#[pyclass]`) bridges results back into the Rust scheduler. +- **Context isolation**: Sync tasks use `threading.local` for `current_job`; async tasks use `contextvars.ContextVar`, which is properly scoped across `await` boundaries and isolated between concurrent coroutines. diff --git a/docs/examples/index.md b/docs/examples/index.md index 5c964b0..1c863f6 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -8,4 +8,5 @@ End-to-end examples demonstrating common taskito patterns. | [Notification Service](notifications.md) | Multi-channel notifications with retries and rate limiting | | [Web Scraper Pipeline](web-scraper.md) | Distributed scraping with chains and error handling | | [Data Pipeline](data-pipeline.md) | ETL pipeline with dependencies, groups, and chords | +| [DAG Workflows](workflows.md) | Fan-out, conditions, gates, sub-workflows, incremental runs | | [Benchmark](benchmark.md) | Performance benchmarks comparing taskito to alternatives | diff --git a/docs/examples/workflows.md b/docs/examples/workflows.md new file mode 100644 index 0000000..b175a7d --- /dev/null +++ b/docs/examples/workflows.md @@ -0,0 +1,261 @@ +# Example: DAG Workflows + +Real-world workflow patterns demonstrating fan-out, conditions, approval gates, sub-workflows, and incremental runs. + +## ML Training Pipeline + +A training pipeline that evaluates a model, gates deployment on accuracy, and has a rollback path. + +```python +from taskito import Queue +from taskito.workflows import Workflow, WorkflowContext + +queue = Queue(db_path="ml.db", workers=4) + + +@queue.task() +def fetch_dataset() -> dict: + return {"rows": 50_000, "path": "/data/train.parquet"} + + +@queue.task() +def train_model(dataset: dict) -> dict: + # ... training logic ... + return {"model_id": "v3.2", "accuracy": 0.97, "loss": 0.08} + + +@queue.task() +def evaluate(model: dict) -> dict: + return {"accuracy": model["accuracy"], "passed": model["accuracy"] > 0.90} + + +@queue.task() +def deploy(model_id: str) -> str: + return f"deployed {model_id}" + + +@queue.task() +def notify_failure() -> str: + return "sent alert: model below threshold" + + +def accuracy_gate(ctx: WorkflowContext) -> bool: + return ctx.results.get("evaluate", {}).get("passed", False) + + +wf = Workflow(name="ml_pipeline") +wf.step("fetch", fetch_dataset) +wf.step("train", train_model, after="fetch") +wf.step("evaluate", evaluate, after="train") +wf.gate("review", after="evaluate", timeout=3600, on_timeout="reject") +wf.step("deploy", deploy, after="review") +wf.step("alert", notify_failure, after="review", condition="on_failure") +``` + +```mermaid +flowchart TD + fetch --> train --> evaluate + evaluate --> review["review\n(approval gate)"] + review -->|approved| deploy + review -->|rejected| alert +``` + +Usage: + +```python +run = queue.submit_workflow(wf) + +# After human review: +queue.approve_gate(run.id, "review") + +result = run.wait(timeout=120) +print(run.visualize("mermaid")) +``` + +--- + +## Map-Reduce with Fan-Out + +Process a batch of items in parallel, then aggregate results. + +```python +@queue.task() +def fetch_urls() -> list[str]: + return [ + "https://api.example.com/page/1", + "https://api.example.com/page/2", + "https://api.example.com/page/3", + ] + + +@queue.task() +def scrape(url: str) -> dict: + import httpx + resp = httpx.get(url) + return {"url": url, "status": resp.status_code, "size": len(resp.content)} + + +@queue.task() +def summarize(results: list[dict]) -> dict: + total = sum(r["size"] for r in results) + return {"pages": len(results), "total_bytes": total} + + +wf = Workflow(name="scrape_pipeline") +wf.step("fetch", fetch_urls) +wf.step("scrape", scrape, after="fetch", fan_out="each") +wf.step("summarize", summarize, after="scrape", fan_in="all") + +run = queue.submit_workflow(wf) +result = run.wait(timeout=60) +# summarize receives [{"url": ..., "size": ...}, ...] +``` + +--- + +## Resilient Pipeline with Continue Mode + +Independent branches keep running even when one fails. + +```python +@queue.task(max_retries=0) +def ingest_orders() -> str: + return "orders ingested" + + +@queue.task(max_retries=0) +def ingest_inventory() -> str: + raise RuntimeError("inventory source down") + + +@queue.task() +def build_report() -> str: + return "report built" + + +@queue.task() +def send_alert() -> str: + return "alert sent to #data-eng" + + +wf = Workflow(name="daily_ingest", on_failure="continue") +wf.step("orders", ingest_orders) +wf.step("inventory", ingest_inventory) +wf.step("report", build_report, after="orders") +wf.step("alert", send_alert, after="inventory", condition="on_failure") +``` + +```mermaid +flowchart TD + orders --> report["report ✓"] + inventory["inventory ✗"] --> alert["alert ✓\n(on_failure)"] +``` + +`inventory` fails, but `orders → report` runs to completion. `alert` fires because its predecessor failed. + +--- + +## Multi-Region ETL with Sub-Workflows + +Compose reusable pipelines as sub-workflow steps. + +```python +@queue.task() +def extract(region: str) -> list: + return [{"region": region, "id": i} for i in range(100)] + + +@queue.task() +def load(data: list) -> int: + return len(data) + + +@queue.task() +def reconcile() -> str: + return "all regions reconciled" + + +@queue.workflow("region_etl") +def region_etl(region: str) -> Workflow: + wf = Workflow() + wf.step("extract", extract, args=(region,)) + wf.step("load", load, after="extract") + return wf + + +@queue.workflow("global_etl") +def global_etl() -> Workflow: + wf = Workflow() + wf.step("eu", region_etl.as_step(region="eu")) + wf.step("us", region_etl.as_step(region="us")) + wf.step("ap", region_etl.as_step(region="ap")) + wf.step("reconcile", reconcile, after=["eu", "us", "ap"]) + return wf + + +run = global_etl.submit() +run.wait(timeout=120) +``` + +EU, US, and AP ETL pipelines run concurrently as child workflows. `reconcile` runs after all three complete. + +--- + +## Incremental Re-Runs + +Skip unchanged steps on the second run. + +```python +wf = Workflow(name="nightly", cache_ttl=86400) # 24h TTL +wf.step("extract", extract) +wf.step("transform", transform, after="extract") +wf.step("load", load, after="transform") + +# First run: everything executes +run1 = queue.submit_workflow(wf) +run1.wait() + +# Next day: skip completed steps +run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) +run2.wait() + +for name, node in run2.status().nodes.items(): + print(f"{name}: {node.status}") +# extract: cache_hit +# transform: cache_hit +# load: cache_hit +``` + +--- + +## Pre-Execution Analysis + +Analyze a workflow before submitting it. + +```python +wf = Workflow(name="complex") +wf.step("a", task_a) +wf.step("b", task_b, after="a") +wf.step("c", task_c, after="a") +wf.step("d", task_d, after=["b", "c"]) + +# Structure +print(wf.topological_levels()) +# [["a"], ["b", "c"], ["d"]] + +print(wf.stats()) +# {"nodes": 4, "edges": 4, "depth": 3, "width": 2, "density": 0.67} + +# Critical path with estimated durations +path, cost = wf.critical_path({"a": 2, "b": 10, "c": 3, "d": 1}) +print(f"Critical path: {path}, cost: {cost}s") +# Critical path: ["a", "b", "d"], cost: 13s + +# Bottleneck +analysis = wf.bottleneck_analysis({"a": 2, "b": 10, "c": 3, "d": 1}) +print(analysis["suggestion"]) +# "b is the bottleneck (76.9% of total time). Consider optimizing." + +# Visualization +print(wf.visualize("mermaid")) +``` diff --git a/docs/index.md b/docs/index.md index 0a316a5..507065e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -209,7 +209,7 @@ graph TB I -.->|token state| J ``` -[:octicons-arrow-right-24: Architecture deep dive](architecture.md) +[:octicons-arrow-right-24: Architecture deep dive](architecture/index.md) --- diff --git a/docs/workflows/index.md b/docs/workflows/index.md index 5d4411b..2b9ca89 100644 --- a/docs/workflows/index.md +++ b/docs/workflows/index.md @@ -3,28 +3,14 @@ Build multi-step pipelines as directed acyclic graphs. Define steps, wire dependencies, and let taskito handle execution order, parallelism, failure propagation, and state tracking — all backed by a Rust engine with dagron-core for graph algorithms. ```mermaid -graph LR - subgraph Build["1 — Build"] - WB["Workflow Builder
step · gate · fan_out"] - end - - subgraph Submit["2 — Submit"] - TE["Topology Engine
dagron-core"] - WS["Storage
SQLite"] - SC["Scheduler
depends_on chains"] - end - - subgraph Run["3 — Execute"] - WT["Tracker
event orchestration"] - WR["WorkflowRun
status · wait · cancel"] - end - - WB -- "_compile()" --> TE - TE --> WS - WS -- "enqueue jobs" --> SC - SC -- "JOB_COMPLETED" --> WT - WT -- "evaluate
successors" --> WS - WT --> WR +flowchart TD + WB["Workflow Builder\nstep() · gate() · fan_out"] + WB -->|"_compile()"| TE["Topology Engine\ndagron-core DAG"] + TE --> WS["Workflow Storage\nSQLite — definitions, runs, nodes"] + WS -->|"enqueue jobs"| SC["Scheduler\ndepends_on chains"] + SC -->|"JOB_COMPLETED event"| WT["Workflow Tracker\nevent-driven orchestration"] + WT -->|"evaluate successors"| WS + WT --> WR["WorkflowRun\nstatus() · wait() · cancel()"] ``` ## Quick start diff --git a/zensical.toml b/zensical.toml index 181835c..7aca009 100644 --- a/zensical.toml +++ b/zensical.toml @@ -97,7 +97,16 @@ nav = [ { "Analysis & Visualization" = "workflows/analysis.md" }, { "Canvas Primitives" = "workflows/canvas.md" }, ] }, - { "Architecture" = "architecture.md" }, + { "Architecture" = [ + { "Overview" = "architecture/index.md" }, + { "Job Lifecycle" = "architecture/job-lifecycle.md" }, + { "Worker Pool" = "architecture/worker-pool.md" }, + { "Storage Layer" = "architecture/storage.md" }, + { "Scheduler" = "architecture/scheduler.md" }, + { "Resource System" = "architecture/resources.md" }, + { "Failure Model" = "architecture/failure-model.md" }, + { "Serialization" = "architecture/serialization.md" }, + ] }, { "API Reference" = [ { "Overview" = "api/index.md" }, { "Queue" = [ @@ -122,6 +131,7 @@ nav = [ { "Notification Service" = "examples/notifications.md" }, { "Web Scraper Pipeline" = "examples/web-scraper.md" }, { "Data Pipeline" = "examples/data-pipeline.md" }, + { "DAG Workflows" = "examples/workflows.md" }, { "Benchmark" = "examples/benchmark.md" }, ] }, { "Comparison" = "comparison.md" }, From d6f550d4df7a5acdbdaeb733928b92e42ceb2233 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:03:56 +0530 Subject: [PATCH 10/12] ci: enable workflows feature in maturin builds CI and publish workflows passed --features explicitly, which replaces pyproject's default set. This dropped 'workflows' from the compiled extension, causing PyWorkflowBuilder ImportError in tests and breaking workflow support in released wheels. - ci.yml: add extension-module,postgres,redis,workflows - publish.yml: add native-async,workflows (4 platforms) --- .github/workflows/ci.yml | 2 +- .github/workflows/publish.yml | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1d29f3b..2310aff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -154,7 +154,7 @@ jobs: uses: PyO3/maturin-action@v1 with: command: develop - args: --release --features native-async + args: --release --features extension-module,postgres,redis,native-async,workflows - name: Run Python tests run: | diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 80e7047..047b21c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -24,7 +24,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --features extension-module,postgres,redis --interpreter 3.10 3.11 3.12 3.13 + args: --release --out dist --features extension-module,postgres,redis,native-async,workflows --interpreter 3.10 3.11 3.12 3.13 manylinux: 2_28 before-script-linux: | dnf install -y openssl-devel perl-core perl-IPC-Cmd @@ -51,7 +51,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --features extension-module,postgres,redis --interpreter 3.10 3.11 3.12 3.13 + args: --release --out dist --features extension-module,postgres,redis,native-async,workflows --interpreter 3.10 3.11 3.12 3.13 manylinux: musllinux_1_2 before-script-linux: | apt-get update && apt-get install -y libssl-dev perl pkg-config @@ -96,7 +96,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --features extension-module,postgres,redis --interpreter 3.10 3.11 3.12 3.13 + args: --release --out dist --features extension-module,postgres,redis,native-async,workflows --interpreter 3.10 3.11 3.12 3.13 - name: Upload wheels uses: actions/upload-artifact@v7 @@ -137,7 +137,7 @@ jobs: uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --features extension-module,postgres,redis --interpreter 3.10 3.11 3.12 3.13 + args: --release --out dist --features extension-module,postgres,redis,native-async,workflows --interpreter 3.10 3.11 3.12 3.13 - name: Upload wheels uses: actions/upload-artifact@v7 From 2ff4e438baa813db170d52cc2e77887003d9e785 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 13:27:03 +0530 Subject: [PATCH 11/12] ci: bump actions, drop unused dashboard bindings - Swatinem/rust-cache v2.8.2 -> v2.9.1 (Node.js 20 deprecation) - actions/setup-node v4 -> v6 (Node.js 20 deprecation) - cache-dependency-glob uv.lock -> pyproject.toml (uv.lock is gitignored) - jobs.tsx: drop unused statsError, refetchStats destructure --- .github/workflows/ci.yml | 12 ++++++------ dashboard/src/pages/jobs.tsx | 6 +----- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2310aff..800d180 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,7 @@ jobs: components: rustfmt, clippy - name: Rust cache - uses: Swatinem/rust-cache@v2.8.2 + uses: Swatinem/rust-cache@v2.9.1 - name: Cargo fmt run: cargo fmt --all --check @@ -39,7 +39,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: "0.10.12" - cache-dependency-glob: uv.lock + cache-dependency-glob: pyproject.toml - name: Install Python dev deps run: uv sync --extra dev @@ -59,7 +59,7 @@ jobs: - uses: actions/checkout@v6 - name: Set up Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v6 with: node-version: "22" cache: "npm" @@ -91,7 +91,7 @@ jobs: uses: dtolnay/rust-toolchain@stable - name: Rust cache - uses: Swatinem/rust-cache@v2.8.2 + uses: Swatinem/rust-cache@v2.9.1 with: save-if: false @@ -137,7 +137,7 @@ jobs: uses: dtolnay/rust-toolchain@stable - name: Rust cache - uses: Swatinem/rust-cache@v2.8.2 + uses: Swatinem/rust-cache@v2.9.1 with: save-if: ${{ matrix.os != 'ubuntu-latest' }} @@ -145,7 +145,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: "0.10.12" - cache-dependency-glob: uv.lock + cache-dependency-glob: pyproject.toml - name: Install dependencies run: uv sync --extra dev diff --git a/dashboard/src/pages/jobs.tsx b/dashboard/src/pages/jobs.tsx index 12ec6a2..63fd4f9 100644 --- a/dashboard/src/pages/jobs.tsx +++ b/dashboard/src/pages/jobs.tsx @@ -84,11 +84,7 @@ export function Jobs(_props: RoutableProps) { const [selected, setSelected] = useState>(new Set()); const [showBulkCancel, setShowBulkCancel] = useState(false); - const { - data: stats, - error: statsError, - refetch: refetchStats, - } = useApi("/api/stats"); + const { data: stats } = useApi("/api/stats"); const { data: jobs, loading, From 2349f51cc33454c0f0d82cb5916a9ef13b052968 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Thu, 16 Apr 2026 14:05:25 +0530 Subject: [PATCH 12/12] chore: bump version to 0.11.0 New feature release for the DAG workflow ecosystem. See changelog for full feature list. --- crates/taskito-async/Cargo.toml | 2 +- crates/taskito-core/Cargo.toml | 2 +- crates/taskito-python/Cargo.toml | 2 +- crates/taskito-workflows/Cargo.toml | 2 +- docs/changelog.md | 28 ++++++++++++++++++++++++++++ pyproject.toml | 2 +- 6 files changed, 33 insertions(+), 5 deletions(-) diff --git a/crates/taskito-async/Cargo.toml b/crates/taskito-async/Cargo.toml index 7f57b32..1c72438 100644 --- a/crates/taskito-async/Cargo.toml +++ b/crates/taskito-async/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-async" -version = "0.10.1" +version = "0.11.0" edition = "2021" [dependencies] diff --git a/crates/taskito-core/Cargo.toml b/crates/taskito-core/Cargo.toml index c736bbf..c2ce90c 100644 --- a/crates/taskito-core/Cargo.toml +++ b/crates/taskito-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-core" -version = "0.10.1" +version = "0.11.0" edition = "2021" [features] diff --git a/crates/taskito-python/Cargo.toml b/crates/taskito-python/Cargo.toml index 3fbf334..be92b56 100644 --- a/crates/taskito-python/Cargo.toml +++ b/crates/taskito-python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-python" -version = "0.10.1" +version = "0.11.0" edition = "2021" [features] diff --git a/crates/taskito-workflows/Cargo.toml b/crates/taskito-workflows/Cargo.toml index 0c943c6..af83b06 100644 --- a/crates/taskito-workflows/Cargo.toml +++ b/crates/taskito-workflows/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "taskito-workflows" -version = "0.10.1" +version = "0.11.0" edition = "2021" [dependencies] diff --git a/docs/changelog.md b/docs/changelog.md index d789ff7..374ae74 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,34 @@ All notable changes to taskito are documented here. +## 0.11.0 + +### Features + +- **DAG workflows** -- first-class support for directed acyclic graph workflows built on the new [dagron-core](https://github.com/ByteVeda/dagron) engine; `Workflow` builder with `step()`, `gate()`, and `after=` dependencies; `queue.submit_workflow(wf)` launches a run, `WorkflowRun.wait()` blocks until terminal, `run.status()` returns per-node snapshots, `run.cancel()` halts in-flight execution; workflows are persisted across restarts with full node history +- **Fan-out / fan-in** -- `step(fan_out="each")` expands a list result into N parallel child jobs; `step(fan_in="all")` aggregates all child results into a single downstream step; supports empty lists, single-item lists, and preserves result ordering +- **Conditional execution** -- per-step `condition="on_success" | "on_failure" | "always"` or a callable `(WorkflowContext) -> bool`; combine with `Workflow(on_failure="continue")` so independent branches keep running after a sibling fails; skip propagation respects `always` +- **Approval gates** -- `wf.gate("review", after="evaluate", timeout=3600, on_timeout="reject")` pauses the workflow until `queue.approve_gate(run_id, name)` or `queue.reject_gate(run_id, name)`; timeout enforced with a background timer; emits `WORKFLOW_GATE_REACHED` event +- **Sub-workflows** -- compose workflows by referencing another workflow as a step via `region_etl.as_step(region="eu")`; child workflows have a `parent_run_id` link and propagate cancellation and failure upward; child terminal status feeds into parent DAG evaluation +- **Cron-scheduled workflows** -- `@queue.periodic(cron=...)` now accepts a `WorkflowProxy`; launcher task is auto-registered and submits a fresh workflow run on every tick +- **Incremental re-runs** -- `Workflow(cache_ttl=86400)` hashes step results with SHA-256; `queue.submit_workflow(wf, incremental=True, base_run=prev_run.id)` skips completed steps whose inputs are unchanged; failed steps always re-run; dirty propagation cascades to downstream nodes; new `CACHE_HIT` terminal status distinguishes cached steps from freshly executed ones +- **Graph algorithms** -- `wf.topological_levels()`, `wf.stats()`, `wf.critical_path(durations)`, `wf.bottleneck_analysis(durations)`, and `wf.execution_plan()` for pre-execution analysis; all algorithms operate on the compiled DAG without requiring a live run +- **Visualization** -- `wf.visualize("mermaid")` and `wf.visualize("dot")` render the DAG; `run.visualize("mermaid")` color-codes live node status (running/completed/failed/cache-hit/waiting-approval) +- **Workflow events** -- new event types `WORKFLOW_SUBMITTED`, `WORKFLOW_COMPLETED`, `WORKFLOW_FAILED`, `WORKFLOW_CANCELLED`, `WORKFLOW_GATE_REACHED` for observability hooks +- **Type-safe builder** -- `step()` accepts any object satisfying the `HasTaskName` protocol (runtime-checkable), keeping the builder API strict without coupling to a concrete `TaskWrapper` class + +### Internal + +- New Rust crate `crates/taskito-workflows/` -- workflow engine with `WorkflowDefinition`, `WorkflowRun`, `WorkflowNode`, node status state machine (including `CacheHit` variant), and storage trait with SQLite/Postgres/Redis backends; feature-gated behind `workflows` cargo feature +- `dagron-core` added as git dependency (`https://github.com/ByteVeda/dagron.git`) for DAG construction and traversal +- New PyO3 bindings in `crates/taskito-python/src/py_workflow/` -- `PyWorkflowBuilder`, `PyWorkflowHandle`, `PyWorkflowRunStatus`; `py_queue/workflow_ops.rs` exposes `submit_workflow`, `mark_workflow_node_result`, `expand_fan_out`, `check_fan_out_completion`, `skip_workflow_node`, `set_workflow_node_waiting_approval`, `resolve_workflow_gate`, `finalize_run_if_terminal`, and base-run lookup helpers +- New Python package `py_src/taskito/workflows/` with 11 modules -- `builder.py` (Workflow, GateConfig, WorkflowProxy), `tracker.py` (cascade evaluator), `run.py` (WorkflowRun), `mixins.py` (QueueWorkflowMixin), `fan_out.py`, `context.py` (WorkflowContext), `incremental.py` (dirty-set computation), `analysis.py` (graph algorithms), `visualization.py`, `types.py`, `__init__.py` +- `maturin` CI feature list fixed -- `ci.yml` and `publish.yml` now include `workflows` alongside `extension-module,postgres,redis,native-async` (previously missing, which would have shipped broken wheels) +- CI action versions bumped -- `Swatinem/rust-cache@v2.9.1`, `actions/setup-node@v6` to silence Node.js 20 deprecation warnings +- 74 new Python tests across 10 files covering linear, fan-out, conditions, gates, sub-workflows, cron, analysis, caching, and visualization + +--- + ## 0.10.1 ### Changed diff --git a/pyproject.toml b/pyproject.toml index a2292a1..997c124 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "taskito" -version = "0.10.1" +version = "0.11.0" description = "Rust-powered task queue for Python. No broker required." requires-python = ">=3.10" license = { file = "LICENSE" }