Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spatialstencil/lowering/spatial_ir_to_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def generate_rectangle(kernel: spir.Kernel,
if len(tasks) != len_for_reporting:
print(f'P{rect.x_range[0]},{rect.y_range[0]}: Reduced from {len_for_reporting} to {len(tasks)} tasks.')

task_bindings = task_recycling.plan_task_bindings(tasks, task_creation_behavior)
task_bindings = task_recycling.plan_task_bindings(tasks, task_creation_behavior, set(color_map.values()))

place_block_bytes = _place_block_storage_bytes(rect.metadata.place)

print(f'Stats P{rect.x_range[0]},{rect.y_range[0]}: {place_block_bytes} bytes/PE, '
Expand Down
28 changes: 16 additions & 12 deletions spatialstencil/syntax/csl/task_recycling.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def emit_local_transition_preamble(
def plan_task_bindings(
tasks: list[tdag.CSLTask],
task_creation_behavior: tdag.TaskCreationBehavior,
disallowed_task_ids: Optional[set[int]] = None,
) -> TaskBindingPlan:
"""Compute a local-task binding plan for the generated CSL.

Expand All @@ -285,34 +286,37 @@ def plan_task_bindings(
local_task_indices = [i for i, task in enumerate(tasks) if task.task_type == 'local']
if not local_task_indices:
return TaskBindingPlan((), {}, {})
disallowed_task_ids = disallowed_task_ids or set()

allowed_local_task_ids = [t for t in constants.LOCAL_TASK_IDS if t not in disallowed_task_ids]

if task_creation_behavior in (
tdag.TaskCreationBehavior.FAIL_ON_OVERRUN,
tdag.TaskCreationBehavior.SYNCHRONOUS_ON_OVERRUN,
):
if len(local_task_indices) > len(constants.LOCAL_TASK_IDS):
if len(local_task_indices) > len(allowed_local_task_ids):
raise ValueError('Too many local tasks')
return _plan_unique_slots(local_task_indices)
return _plan_unique_slots(local_task_indices, allowed_local_task_ids)

if task_creation_behavior == tdag.TaskCreationBehavior.NO_TASKS:
return _plan_unique_slots(local_task_indices)
return _plan_unique_slots(local_task_indices, allowed_local_task_ids)

if len(local_task_indices) <= len(constants.LOCAL_TASK_IDS):
return _plan_unique_slots(local_task_indices)
if len(local_task_indices) <= len(allowed_local_task_ids):
return _plan_unique_slots(local_task_indices, allowed_local_task_ids)

max_colors = len(constants.LOCAL_TASK_IDS)
max_colors = len(allowed_local_task_ids)
conflict_graph = _build_conflict_graph(tasks, local_task_indices)
coloring = greedy_coloring(conflict_graph, local_task_indices, max_colors=max_colors, load_balance=True)
if coloring is None:
raise ValueError('Too many concurrently-live local tasks for state-machine recycling')

return _build_plan_from_coloring(coloring)
return _build_plan_from_coloring(coloring, allowed_local_task_ids)


def _plan_unique_slots(local_task_indices: list[int]) -> TaskBindingPlan:
def _plan_unique_slots(local_task_indices: list[int], allowed_local_task_ids: list[int]) -> TaskBindingPlan:
"""Build the trivial plan where each local task receives its own slot."""
coloring = {task_index: color for color, task_index in enumerate(local_task_indices)}
return _build_plan_from_coloring(coloring)
return _build_plan_from_coloring(coloring, allowed_local_task_ids)


def _build_conflict_graph(tasks: list[tdag.CSLTask], local_task_indices: Iterable[int]) -> dict[int, set[int]]:
Expand Down Expand Up @@ -531,11 +535,11 @@ def greedy_coloring(



def _build_plan_from_coloring(coloring: dict[int, int]) -> TaskBindingPlan:
def _build_plan_from_coloring(coloring: dict[int, int], allowed_local_task_ids: list[int]) -> TaskBindingPlan:
"""Convert a graph coloring into the stable binding structures used by codegen.

Each color becomes one ``LocalTaskSlot`` backed by the corresponding
hardware ID in ``constants.LOCAL_TASK_IDS``. Within each slot, logical task
hardware ID in ``allowed_local_task_ids``. Within each slot, logical task
indices are sorted to make state assignment deterministic. The position of a
task inside that sorted tuple is its per-slot state number.
"""
Expand All @@ -548,7 +552,7 @@ def _build_plan_from_coloring(coloring: dict[int, int]) -> TaskBindingPlan:
task_to_local_state: dict[int, int] = {}
for slot_index, color in enumerate(sorted(color_to_tasks)):
task_indices = tuple(sorted(color_to_tasks[color]))
slot = LocalTaskSlot(slot_index, constants.LOCAL_TASK_IDS[slot_index], task_indices)
slot = LocalTaskSlot(slot_index, allowed_local_task_ids[slot_index], task_indices)
local_slots.append(slot)
for state_index, task_index in enumerate(task_indices):
task_to_local_slot[task_index] = slot_index
Expand Down
98 changes: 58 additions & 40 deletions spatialstencil/syntax/csl/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Contains a CSL task DAG representation and creation methods.
"""

from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, auto
Expand All @@ -16,6 +17,7 @@ class TaskCreationBehavior(Enum):
"""
Enumeration prescribing how tasks should be created.
"""

NO_TASKS = auto() # All statements in one task
FAIL_ON_OVERRUN = auto() # Error if too many tasks
STATE_MACHINE_ON_OVERRUN = auto() # Recycle task IDs with a state machine
Expand All @@ -26,6 +28,7 @@ class InterTaskEdge(Enum):
"""
Enumeration representing a task dependency edge type.
"""

UNSET = auto()
SEQUENCE = auto()
ACTIVATE = auto()
Expand All @@ -37,8 +40,9 @@ class CSLTask:
"""
Object representing a task DAG node.
"""

task_id: int
task_type: Literal['local', 'data'] # We do not generate control tasks at the moment
task_type: Literal["local", "data"] # We do not generate control tasks at the moment
statements: list[int] # Index is the statement's index from the completion DAG
outgoing: list[tuple[int, InterTaskEdge]] # For each statement, the next task ID and the dependency type
blocked: bool # Whether there is an unblock edge leading to this task
Expand All @@ -61,10 +65,11 @@ def should_be_asynchronous(dtypes: dict[spir.Identifier, spir.IRType], stmt: spi


def create_csl_tasks(
completion_dag: nx.DiGraph,
block: spir.ComputeBlock,
dtypes: dict[spir.Identifier, spir.IRType],
task_creation_behavior: TaskCreationBehavior = TaskCreationBehavior.FAIL_ON_OVERRUN) -> list[CSLTask]:
completion_dag: nx.DiGraph,
block: spir.ComputeBlock,
dtypes: dict[spir.Identifier, spir.IRType],
task_creation_behavior: TaskCreationBehavior = TaskCreationBehavior.FAIL_ON_OVERRUN,
) -> list[CSLTask]:
"""
Creates a list of CSL tasks. The nodes are tasks that contain a unique ID
and the list of statements to include; and the edges are the type of dependency across tasks.
Expand All @@ -77,7 +82,7 @@ def create_csl_tasks(
* Send and receive statements that can be lowered to a ``FabricDSD`` operation, in turn can (and should)
be nonblocking, or ``async`` in CSL terms. In this lowering pipeline, these live in CSL local tasks.
* Other statements (e.g. free assignments) are blocking and also live in local tasks.

Given that tasks can ``@activate`` and ``@unblock`` other tasks, and that ``FabricDSD`` operations can also
do the same, both a task terminator and a nonblocking statement can trigger other tasks. Given that there are
no other options to trigger tasks, we run a preprocessing pass on the graph to convert nodes with in-degree over 2
Expand All @@ -87,7 +92,7 @@ def create_csl_tasks(
new tasks based on a set of necessary rules in which a new task must be formed:

1. A node with no predecessors creates a new activated and unblocked task
2. A node with more than one incoming edge must start a new task
2. A node with more than one incoming edge must start a new task
(the conditions below thus apply to the case where a node has one predecessor)
3. If a node's predecessor represents one kind of CSL task (e.g., data) and this node represents another
4. Node pairs with ``wait->wait`` edges create a new task (this also fulfills the condition for the above
Expand Down Expand Up @@ -118,12 +123,15 @@ def create_csl_tasks(
for cnode in nx.topological_sort(completion_dag):
node = block.statements[cnode.statement_id]
# Figure out whether this task type is a local task or a data task
if (isinstance(node, spir.ForeachStatement) and dsd_ops.get_dsd_op(dtypes, node) is None and
cnode.optype == 'post'):
if (
isinstance(node, spir.ForeachStatement)
and dsd_ops.get_dsd_op(dtypes, node) is None
and cnode.optype == "post"
):
# Only if it is a complex task (i.e., not a DSD operation)
this_task_type = 'data'
this_task_type = "data"
else:
this_task_type = 'local'
this_task_type = "local"

task_id = None
# Look at incoming edges:
Expand All @@ -132,12 +140,12 @@ def create_csl_tasks(
pred: analysis.CompletionDAGNode
pred, _ = next(iter(completion_dag.in_edges(cnode)))
# If {wait,post}->post and there is one edge, and the previous task is a local task, inherit task ID
if result[statement_id_to_task_id[pred.statement_id]].task_type == 'local' and this_task_type == 'local':
if pred.optype == 'post' and cnode.optype == 'post':
if result[statement_id_to_task_id[pred.statement_id]].task_type == "local" and this_task_type == "local":
if pred.optype == "post" and cnode.optype == "post":
task_id = statement_id_to_task_id[pred.statement_id]
elif pred.optype == 'wait' and cnode.optype == 'post':
elif pred.optype == "wait" and cnode.optype == "post":
task_id = statement_id_to_task_id[pred.statement_id]
elif pred.optype == 'post' and cnode.optype == 'wait':
elif pred.optype == "post" and cnode.optype == "wait":
# ``post->wait`` node pairs where the post is a nonblocking operation creates a new task for the
# ``wait`` node, depending on task creation behavior
should_create_task = True
Expand All @@ -151,7 +159,7 @@ def create_csl_tasks(
if not should_create_task:
task_id = statement_id_to_task_id[pred.statement_id]
# wait->wait will create a new task
elif result[statement_id_to_task_id[pred.statement_id]].task_type == 'local' and this_task_type == 'data':
elif result[statement_id_to_task_id[pred.statement_id]].task_type == "local" and this_task_type == "data":
# An empty wait task before a data task can be contracted
if not result[statement_id_to_task_id[pred.statement_id]].statements:
task_id = statement_id_to_task_id[pred.statement_id]
Expand Down Expand Up @@ -182,15 +190,16 @@ def create_csl_tasks(
task_id = len(result)
cnode_to_task_id[cnode] = task_id
current_task = CSLTask(
task_id, this_task_type, [], [], blocked=((indeg > 1) or (this_task_type == 'data' and indeg > 0)))
task_id, this_task_type, [], [], blocked=((indeg > 1) or (this_task_type == "data" and indeg > 0))
)
result.append(current_task)
if this_task_type == 'local':
if this_task_type == "local":
num_local_tasks += 1
else:
num_data_tasks += 1
statement_id_to_task_id[cnode.statement_id] = task_id

if cnode.optype == 'wait':
if cnode.optype == "wait":
# Nothing to do within the task
pass
else: # 'post'
Expand All @@ -203,11 +212,11 @@ def create_csl_tasks(
# Determine edge types between task statements
for cnode in nx.topological_sort(completion_dag):
stmt_task = cnode_to_task_id[cnode]
if cnode.optype == 'post':
if cnode.optype == "post":
# Find matching "wait" successor
succ_task = None
for succ in completion_dag.successors(cnode):
if succ.optype == 'wait':
if succ.optype == "wait":
succ_task = cnode_to_task_id[succ]
break

Expand All @@ -216,7 +225,7 @@ def create_csl_tasks(
# Find outgoing index within task
ind = next(i for i, s in enumerate(result[stmt_task].statements) if s == cnode.statement_id)

elif cnode.optype == 'wait': # Set the next task after the await to begin sequentially
elif cnode.optype == "wait": # Set the next task after the await to begin sequentially
ind = next((i for i, s in enumerate(result[stmt_task].statements) if s == cnode.statement_id), None)
if ind is None: # Wait already omitted from task
continue
Expand All @@ -227,8 +236,10 @@ def create_csl_tasks(
succ_task = cnode_to_task_id[succ_task]
elif num_successors > 1:
node = block.statements[cnode.statement_id]
raise ValueError('Multiple successors for a wait task should not appear after canonicalization.\n In '
f'line {node.lineinfo}')
raise ValueError(
"Multiple successors for a wait task should not appear after canonicalization.\n In "
f"line {node.lineinfo}"
)
else: # No successors
continue

Expand All @@ -238,7 +249,7 @@ def create_csl_tasks(
etype = InterTaskEdge.SEQUENCE
else:
# Check if task already has an activate edge
if succ_task in task_has_activate or result[succ_task].task_type == 'data':
if succ_task in task_has_activate or result[succ_task].task_type == "data":
etype = InterTaskEdge.UNBLOCK
else:
etype = InterTaskEdge.ACTIVATE
Expand Down Expand Up @@ -269,7 +280,7 @@ def create_csl_tasks(
if task.task_type == "local":
task.statements.append("TERMINATOR")
# After contracting an empty trailing task, stmt_task may equal len(result), meaning exit.
succ_is_data = stmt_task < len(result) and result[stmt_task].task_type == 'data'
succ_is_data = stmt_task < len(result) and result[stmt_task].task_type == "data"
if stmt_task in task_has_activate or succ_is_data:
task.outgoing.append((stmt_task, InterTaskEdge.UNBLOCK))
else:
Expand All @@ -283,7 +294,7 @@ def create_csl_tasks(
task_id_to_data_id: dict[int, int] = {}
for task_id, task in enumerate(result):
# Increment the current task ID and add a new task with the specified type
if task.task_type == 'local':
if task.task_type == "local":
current_local_task_id += 1
task_id_to_local_id[task_id] = current_local_task_id
else: # 'data'
Expand Down Expand Up @@ -325,11 +336,12 @@ def create_csl_tasks(
t for i, t in enumerate(result) if not any(n != i for n, _ in t.outgoing) or -1 in set(n for n, _ in t.outgoing)
]
if len(sink_tasks) > 2:
raise ValueError('Too many sink tasks')
raise ValueError("Too many sink tasks")
for i, task in enumerate(sink_tasks):
edge_type = InterTaskEdge.ACTIVATE if i == 0 else InterTaskEdge.UNBLOCK
if task_creation_behavior == TaskCreationBehavior.SYNCHRONOUS_ON_OVERRUN and num_local_tasks >= len(
constants.LOCAL_TASK_IDS):
constants.LOCAL_TASK_IDS
):
edge_type = InterTaskEdge.SEQUENCE
elif task_creation_behavior == TaskCreationBehavior.NO_TASKS:
edge_type = InterTaskEdge.SEQUENCE
Expand Down Expand Up @@ -395,19 +407,25 @@ def _limit_indegree(dag: nx.DiGraph):
current_node = node
# Create intermediate wait nodes (the counter changes the statement ID because it has to be unique)
for u, _ in edges[1:]:
new_node = analysis.CompletionDAGNode('wait', counter)
new_node = analysis.CompletionDAGNode("wait", counter)
counter -= 1
dag.remove_edge(u, node)
dag.add_edge(new_node, current_node)
dag.add_edge(u, new_node)
current_node = new_node


def fuse_tasks(tasks: list[CSLTask], dsds: UniqueDSDDict, dtypes: dict[spir.Identifier, spir.IRType], rect,
use_memcpy_mode: bool, compute: spir.ComputeBlock) -> list[CSLTask]:
def fuse_tasks(
tasks: list[CSLTask],
dsds: UniqueDSDDict,
dtypes: dict[spir.Identifier, spir.IRType],
rect,
use_memcpy_mode: bool,
compute: spir.ComputeBlock,
) -> list[CSLTask]:
"""
Fuses tasks where possible to reduce the number of tasks.

:param tasks: The list of CSL tasks.
:param dsds: The unique DSD dictionary.
:param dtypes: The dictionary of identifier types.
Expand All @@ -421,7 +439,7 @@ def fuse_tasks(tasks: list[CSLTask], dsds: UniqueDSDDict, dtypes: dict[spir.Iden
for i, task in enumerate(tasks):
if i in fused or i in removed:
continue
if task.task_type == 'data':
if task.task_type == "data":
continue
if not task.statements:
# Remove task
Expand All @@ -436,7 +454,7 @@ def fuse_tasks(tasks: list[CSLTask], dsds: UniqueDSDDict, dtypes: dict[spir.Iden
if any(et == InterTaskEdge.UNBLOCK for t in tasks for n, et in t.outgoing if n == outgoing_id):
# Cannot fuse if there are multiple predecessors to next task
continue
if tasks[outgoing_id].task_type != 'local':
if tasks[outgoing_id].task_type != "local":
# Cannot fuse with data tasks
continue

Expand Down Expand Up @@ -525,7 +543,7 @@ def renumber_tasks(tasks: list[CSLTask], task_creation_behavior: TaskCreationBeh
task_id_to_data_id: dict[int, int] = {}
for task_id, task in enumerate(tasks):
# Increment the current task ID and add a new task with the specified type
if task.task_type == 'local':
if task.task_type == "local":
current_local_task_id += 1
task_id_to_local_id[task_id] = current_local_task_id
else: # 'data'
Expand All @@ -534,13 +552,13 @@ def renumber_tasks(tasks: list[CSLTask], task_creation_behavior: TaskCreationBeh

# Re-number task IDs and outgoing connections based on CSL IDs
for task_id, task in enumerate(tasks):
if task.task_type == 'data':
if task.task_type == "data":
tid = task_id_to_data_id[task_id]
if tid >= len(constants.DATA_TASK_IDS) and task_creation_behavior == TaskCreationBehavior.FAIL_ON_OVERRUN:
raise ValueError('Too many data tasks')
raise ValueError("Too many data tasks")
task.task_id = constants.DATA_TASK_IDS[tid]
elif task.task_type == 'local':
elif task.task_type == "local":
tid = task_id_to_local_id[task_id]
if tid >= len(constants.LOCAL_TASK_IDS) and task_creation_behavior == TaskCreationBehavior.FAIL_ON_OVERRUN:
raise ValueError('Too many local tasks')
raise ValueError("Too many local tasks")
task.task_id = constants.LOCAL_TASK_IDS[tid]
Loading
Loading