@@ -248,6 +248,246 @@ def pg_provider():
248248 return PostgresProvider .connect_with_schema (db_url , "duroxide_python_tags" )
249249
250250
251+ def test_heterogeneous_workers_gpu_cpu_untagged ():
252+ """Fan-out: GPU render → CPU encode → untagged upload via DefaultAnd filter."""
253+ provider = SqliteProvider .in_memory ()
254+ client = Client (provider )
255+ runtime = Runtime (
256+ provider ,
257+ PyRuntimeOptions (
258+ dispatcher_poll_interval_ms = 50 ,
259+ worker_tag_filter = TagFilter .default_and (["gpu" , "cpu" ]),
260+ ),
261+ )
262+
263+ @runtime .register_activity ("Render" )
264+ def render (ctx , input ):
265+ return f"rendered:{ input } "
266+
267+ @runtime .register_activity ("Encode" )
268+ def encode (ctx , input ):
269+ return f"encoded:{ input } "
270+
271+ @runtime .register_activity ("Upload" )
272+ def upload (ctx , input ):
273+ return f"uploaded:{ input } "
274+
275+ @runtime .register_orchestration ("VideoPipeline" )
276+ def video_pipeline (ctx , input ):
277+ rendered = yield ctx .schedule_activity ("Render" , "frame42" ).with_tag ("gpu" )
278+ encoded = yield ctx .schedule_activity ("Encode" , rendered ).with_tag ("cpu" )
279+ uploaded = yield ctx .schedule_activity ("Upload" , encoded )
280+ return uploaded
281+
282+ runtime .start ()
283+ try :
284+ client .start_orchestration ("video-1" , "VideoPipeline" , "" )
285+ result = client .wait_for_orchestration ("video-1" , 10_000 )
286+ assert result .status == "Completed"
287+ assert result .output == "uploaded:encoded:rendered:frame42"
288+ finally :
289+ runtime .shutdown (100 )
290+
291+
292+ def test_starvation_safe_tagged_activity_timeout_fallback ():
293+ """Race tagged activity vs timer; timer wins when no GPU worker exists."""
294+ provider = SqliteProvider .in_memory ()
295+ client = Client (provider )
296+ runtime = Runtime (
297+ provider ,
298+ PyRuntimeOptions (
299+ dispatcher_poll_interval_ms = 50 ,
300+ worker_tag_filter = TagFilter .DEFAULT_ONLY ,
301+ ),
302+ )
303+
304+ @runtime .register_activity ("GpuInference" )
305+ def gpu_inference (ctx , input ):
306+ return f"inference:{ input } "
307+
308+ @runtime .register_activity ("CpuFallback" )
309+ def cpu_fallback (ctx , input ):
310+ return f"cpu_fallback:{ input } "
311+
312+ @runtime .register_orchestration ("InferenceWithFallback" )
313+ def inference_with_fallback (ctx , input ):
314+ gpu_task = ctx .schedule_activity ("GpuInference" , input ).with_tag ("gpu" )
315+ timeout = ctx .schedule_timer (500 )
316+ winner = yield ctx .race (gpu_task , timeout )
317+ if winner ["index" ] == 0 :
318+ return winner ["value" ]
319+ else :
320+ result = yield ctx .schedule_activity ("CpuFallback" , input )
321+ return result
322+
323+ runtime .start ()
324+ try :
325+ client .start_orchestration ("infer-1" , "InferenceWithFallback" , "model-v3" )
326+ result = client .wait_for_orchestration ("infer-1" , 10_000 )
327+ assert result .status == "Completed"
328+ assert result .output == "cpu_fallback:model-v3"
329+ finally :
330+ runtime .shutdown (100 )
331+
332+
333+ def test_dual_runtime_orchestrator_plus_gpu_worker ():
334+ """Two runtimes on same store: RT-A dispatches + CPU, RT-B handles GPU tags."""
335+ provider = SqliteProvider .in_memory ()
336+ client = Client (provider )
337+
338+ # Runtime A: orchestrator + default (CPU) worker
339+ rt_a = Runtime (
340+ provider ,
341+ PyRuntimeOptions (
342+ dispatcher_poll_interval_ms = 50 ,
343+ worker_tag_filter = TagFilter .DEFAULT_ONLY ,
344+ ),
345+ )
346+
347+ @rt_a .register_activity ("PreProcess" )
348+ def preprocess_a (ctx , input ):
349+ return f"preprocessed:{ input } "
350+
351+ @rt_a .register_activity ("GpuTrain" )
352+ def gpu_train_a (ctx , input ):
353+ return f"trained:{ input } "
354+
355+ @rt_a .register_activity ("SaveModel" )
356+ def save_model_a (ctx , input ):
357+ return f"saved:{ input } "
358+
359+ @rt_a .register_orchestration ("MLPipeline" )
360+ def ml_pipeline_a (ctx , input ):
361+ preprocessed = yield ctx .schedule_activity ("PreProcess" , input )
362+ model = yield ctx .schedule_activity ("GpuTrain" , preprocessed ).with_tag ("gpu" )
363+ saved = yield ctx .schedule_activity ("SaveModel" , model )
364+ return saved
365+
366+ # Runtime B: GPU worker only (no orchestration dispatcher)
367+ rt_b = Runtime (
368+ provider ,
369+ PyRuntimeOptions (
370+ dispatcher_poll_interval_ms = 50 ,
371+ orchestration_concurrency = 0 ,
372+ worker_tag_filter = TagFilter .tags (["gpu" ]),
373+ ),
374+ )
375+
376+ @rt_b .register_activity ("PreProcess" )
377+ def preprocess_b (ctx , input ):
378+ return f"preprocessed:{ input } "
379+
380+ @rt_b .register_activity ("GpuTrain" )
381+ def gpu_train_b (ctx , input ):
382+ return f"trained:{ input } "
383+
384+ @rt_b .register_activity ("SaveModel" )
385+ def save_model_b (ctx , input ):
386+ return f"saved:{ input } "
387+
388+ @rt_b .register_orchestration ("MLPipeline" )
389+ def ml_pipeline_b (ctx , input ):
390+ preprocessed = yield ctx .schedule_activity ("PreProcess" , input )
391+ model = yield ctx .schedule_activity ("GpuTrain" , preprocessed ).with_tag ("gpu" )
392+ saved = yield ctx .schedule_activity ("SaveModel" , model )
393+ return saved
394+
395+ rt_a .start ()
396+ rt_b .start ()
397+ try :
398+ client .start_orchestration ("ml-1" , "MLPipeline" , "dataset-v5" )
399+ result = client .wait_for_orchestration ("ml-1" , 10_000 )
400+ assert result .status == "Completed"
401+ assert result .output == "saved:trained:preprocessed:dataset-v5"
402+ finally :
403+ rt_b .shutdown (100 )
404+ rt_a .shutdown (100 )
405+
406+
407+ def test_nested_error_handling_propagation ():
408+ """Activity error propagates through orchestration via yield."""
409+ provider = SqliteProvider .in_memory ()
410+ client = Client (provider )
411+ runtime = Runtime (provider , PyRuntimeOptions (dispatcher_poll_interval_ms = 50 ))
412+
413+ @runtime .register_activity ("ProcessData" )
414+ def process_data (ctx , input ):
415+ if "error" in input :
416+ raise Exception ("Processing failed" )
417+ return f"Processed: { input } "
418+
419+ @runtime .register_activity ("FormatOutput" )
420+ def format_output (ctx , input ):
421+ return f"Final: { input } "
422+
423+ @runtime .register_orchestration ("NestedErrorHandling" )
424+ def nested_error_handling (ctx , input ):
425+ processed = yield ctx .schedule_activity ("ProcessData" , input )
426+ formatted = yield ctx .schedule_activity ("FormatOutput" , processed )
427+ return formatted
428+
429+ runtime .start ()
430+ try :
431+ # Success case
432+ client .start_orchestration ("nested-ok" , "NestedErrorHandling" , "test" )
433+ ok = client .wait_for_orchestration ("nested-ok" , 5_000 )
434+ assert ok .status == "Completed"
435+ assert ok .output == "Final: Processed: test"
436+
437+ # Error case
438+ client .start_orchestration ("nested-err" , "NestedErrorHandling" , "error" )
439+ err = client .wait_for_orchestration ("nested-err" , 5_000 )
440+ assert err .status == "Failed"
441+ assert "Processing failed" in err .error
442+ finally :
443+ runtime .shutdown (100 )
444+
445+
446+ def test_error_recovery_with_logging ():
447+ """Activity error caught, logged via another activity, then re-raised."""
448+ provider = SqliteProvider .in_memory ()
449+ client = Client (provider )
450+ runtime = Runtime (provider , PyRuntimeOptions (dispatcher_poll_interval_ms = 50 ))
451+
452+ @runtime .register_activity ("ProcessData" )
453+ def process_data (ctx , input ):
454+ if "error" in input :
455+ raise Exception ("Processing failed" )
456+ return f"Processed: { input } "
457+
458+ @runtime .register_activity ("LogError" )
459+ def log_error (ctx , error ):
460+ return f"Logged: { error } "
461+
462+ @runtime .register_orchestration ("ErrorRecovery" )
463+ def error_recovery (ctx , input ):
464+ try :
465+ result = yield ctx .schedule_activity ("ProcessData" , input )
466+ return result
467+ except Exception as e :
468+ yield ctx .schedule_activity ("LogError" , str (e ))
469+ raise Exception (f"Failed to process '{ input } ': { e } " )
470+
471+ runtime .start ()
472+ try :
473+ # Success case
474+ client .start_orchestration ("recovery-ok" , "ErrorRecovery" , "test" )
475+ ok = client .wait_for_orchestration ("recovery-ok" , 5_000 )
476+ assert ok .status == "Completed"
477+ assert ok .output == "Processed: test"
478+
479+ # Error recovery case
480+ client .start_orchestration ("recovery-err" , "ErrorRecovery" , "error" )
481+ err = client .wait_for_orchestration ("recovery-err" , 5_000 )
482+ assert err .status == "Failed"
483+ assert "Failed to process 'error'" in err .error
484+ finally :
485+ runtime .shutdown (100 )
486+
487+
488+ # ─── PostgreSQL Tests ──────────────────────────────────────────────
489+
490+
251491def test_pg_tagged_activity (pg_provider ):
252492 """Full PostgreSQL test for tagged activity routing."""
253493 client = Client (pg_provider )
0 commit comments