Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2763,6 +2763,34 @@ def sync_node_execution(
logger.info("Skipping gate node execution for now - gate nodes don't have inputs and outputs filled in")
return execution

# Handle the case where it's a branch node
elif execution._node.branch_node is not None:
# We'll need to query child node executions regardless since this is a parent node
child_node_executions = iterate_node_executions(
self.client,
workflow_execution_identifier=execution.id.execution_id,
unique_parent_id=execution.id.node_id,
)
child_node_executions = [x for x in child_node_executions]

sub_flyte_workflow = typing.cast(FlyteBranchNode, execution._node.flyte_entity)
sub_node_mapping = {}
if sub_flyte_workflow.if_else.case.then_node:
then_node = sub_flyte_workflow.if_else.case.then_node
sub_node_mapping[then_node.id] = then_node
if sub_flyte_workflow.if_else.other:
for case in sub_flyte_workflow.if_else.other:
then_node = case.then_node
sub_node_mapping[then_node.id] = then_node
if sub_flyte_workflow.if_else.else_node:
else_node = sub_flyte_workflow.if_else.else_node
sub_node_mapping[else_node.id] = else_node

execution._underlying_node_executions = [
self.sync_node_execution(FlyteNodeExecution.promote_from_model(cne), sub_node_mapping)
for cne in child_node_executions
]

# This is the plain ol' task execution case
else:
execution._task_executions = [
Expand Down
7 changes: 7 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,3 +1358,10 @@ def test_run_wf_with_resource_requests_override(register):
],
limits=[],
)

def test_conditional_workflow():
execution_id = run("conditional_workflow.py", "wf")
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import flytekit as fl
from flytekit import conditional
from flytekit.core.task import Echo

echo_radius = Echo(name="noop", inputs={"radius": float})


@fl.task
def calculate_circle_circumference(radius: float) -> float:
return 2 * 3.14 * radius # Task to calculate the circumference of a circle


@fl.task
def calculate_circle_area(radius: float) -> float:
return 3.14 * radius * radius # Task to calculate the area of a circle


@fl.task
def nop(radius: float) -> float:
return radius # Task that does nothing, effectively a no-op


@fl.workflow
def wf(radius: float = 0.5, get_area: bool = False, get_circumference: bool = True):
echoed_radius = nop(radius=radius)
(
conditional("if_area")
.if_(get_area.is_true())
.then(calculate_circle_area(radius=radius))
.else_()
.then(echo_radius(echoed_radius))
)
(
conditional("if_circumference")
.if_(get_circumference.is_true())
.then(calculate_circle_circumference(radius=echoed_radius))
.else_()
.then(echo_radius(echoed_radius))
)
Loading