Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: debug-statements
- id: name-tests-test
args: ["--pytest-test-first"]
# This hook cannot handle newer Python syntax like type aliases yet
# - repo: https://github.com/pre-commit/pre-commit-hooks
# rev: v6.0.0
# hooks:
# - id: debug-statements
# - id: name-tests-test
# args: ["--pytest-test-first"]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.13
Expand Down
14 changes: 7 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ dependencies = [

[dependency-groups]
dev = [
"hatch>=1.13.0,<2",
"prek>=0.3.6,<2",
"pytest>=9.0.2,<10",
"pytest-cov>=7.0.0,<8",
"mkdocstrings[python]>=1.0.3,<2",
"mkdocs-material>=9.7.5,<10",
"setuptools-scm>=9.2.2,<10",
"hatch>=1.13.0",
"prek>=0.3.6",
"pytest>=9.0.2",
"pytest-cov>=7.0.0",
"mkdocstrings[python]>=1.0.3",
"mkdocs-material>=9.7.5",
"setuptools-scm>=9.2.2",
]

[build-system]
Expand Down
243 changes: 34 additions & 209 deletions src/styx_compiler/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,30 +45,24 @@ class ComputeControlFlowGraph(cst.CSTVisitor):
Computes the control-flow graph of the code, expressed in indices from the IndexProvider
"""

METADATA_DEPENDENCIES = (IndexProvider,)

def __init__(self):
def __init__(self, provider: ControlFlowGraphProvider):
super().__init__()
self._provider = provider
self._cfg: dict[CfgNode, set[CfgNode]] = {}
self._start_end: list[tuple[CfgNode, CfgNode]] = []

def _edge(self, prev: list[CfgNode], cur: CfgNode) -> list[CfgNode]:
for p in prev:
if p not in self._cfg:
self._cfg[p] = set()
self._cfg[p].add(cur)
self._cfg.setdefault(p, set()).add(cur)
return [cur]

def _edges(self, prev: list[CfgNode], tos: list[CfgNode]) -> list[CfgNode]:
for p in prev:
if p not in self._cfg:
self._cfg[p] = set()
for to in tos:
self._cfg[p].add(to)
self._cfg.setdefault(p, set()).update(tos)
return tos

def _make_cfg_node(self, cst_node: cst.CSTNode, instance: int, prev: list[CfgNode]) -> list[CfgNode]:
cur = Node(self.get_metadata(IndexProvider, cst_node), instance)
cur = Node(self._provider.get_metadata(IndexProvider, cst_node), instance)
return self._edge(prev, cur)

def _clean_up_cfg_ghosts(self, start: CfgNode) -> None:
Expand Down Expand Up @@ -96,7 +90,7 @@ def _clean_up_cfg_ghosts(self, start: CfgNode) -> None:
seen.add(next_node)

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
index = self.get_metadata(IndexProvider, node)
index = self._provider.get_metadata(IndexProvider, node)
start = Node(index, 0)
end = Node(index, 1)
self._start_end.append((start, end))
Expand All @@ -114,7 +108,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:

self._clean_up_cfg_ghosts(start)

def leave_Module(self, node: cst.Module) -> None:
def leave_Module(self, module: cst.Module) -> None:
# Remove unreachable parts of the CFG (e.g. unused finally clause instantiations, dead code after a return)
reachable = set()
workstack = []
Expand All @@ -136,6 +130,8 @@ def leave_Module(self, node: cst.Module) -> None:
for k in to_remove:
del self._cfg[k]

self._provider.set_metadata(module, (self._cfg, self._start_end))

def _visit_BaseSuite(
self,
statements: cst.BaseSuite | cst.SimpleStatementLine,
Expand Down Expand Up @@ -188,15 +184,16 @@ def _visit_statement(
prev = self._visit_expression(statement.value, instance, prev)
# then the multiple LHS, from left to right
for target in statement.targets:
prev = self._visit_expression(target.target, instance, prev)
prev = self._make_cfg_node(target, instance, prev) # AssignTarget
elif m.matches(statement, m.AugAssign()):
statement: cst.AugAssign = cst.ensure_type(statement, cst.AugAssign)
# note we're making the AugAssign a node first to represent reading the value from the target
prev = self._make_cfg_node(statement, instance, prev) # AugAssign
# note we're visiting LHS first to represent reading the value from the target
prev = self._visit_expression(statement.target, instance, prev)
# then we visit the RHS expression to find more reads
prev = self._visit_expression(statement.value, instance, prev)
# finally we write to the LHS
prev = self._visit_expression(statement.target, instance, prev)
# finally we write to the LHS, represented by a node of the whole assignment
prev = self._make_cfg_node(statement, instance, prev) # AugAssign
elif m.matches(statement, m.Break()):
if loop_break_target is None:
msg = "Found break outside of loop"
Expand Down Expand Up @@ -255,7 +252,7 @@ def _visit_statement(
raise NotImplementedError(msg)
elif m.matches(statement, m.For()):
statement: cst.For = cst.ensure_type(statement, cst.For)
index = self.get_metadata(IndexProvider, statement)
index = self._provider.get_metadata(IndexProvider, statement)
for_loop_continue_target = Ghost(index, 0)
prev = self._edge(prev, for_loop_continue_target)
loop_expr_prev = self._visit_expression(statement.iter, instance, prev)
Expand Down Expand Up @@ -319,7 +316,7 @@ def _visit_statement(
def wrap_in_finally(exit: CfgNode) -> CfgNode:
nonlocal statement, finally_number, fn_end, exception_target, loop_continue_target, loop_break_target
if statement.finalbody is not None:
entry = Ghost(self.get_metadata(IndexProvider, statement.finalbody), finally_number)
entry = Ghost(self._provider.get_metadata(IndexProvider, statement.finalbody), finally_number)
finalbody: cst.Finally = cst.ensure_type(statement.finalbody, cst.Finally)
prev = self._visit_BaseSuite(
finalbody.body,
Expand Down Expand Up @@ -348,7 +345,7 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode:
# to the next conditional
for handler in statement.handlers:
handler: cst.ExceptHandler = cst.ensure_type(handler, cst.ExceptHandler) # noqa: PLW2901
handler_index = self.get_metadata(IndexProvider, handler)
handler_index = self._provider.get_metadata(IndexProvider, handler)
handler_entry = Ghost(handler_index, 0)
handler_exit = Ghost(handler_index, 1)
handler_entries.append(handler_entry)
Expand All @@ -360,6 +357,7 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode:
handler_exits.append(handler_exit)

if handler.name is not None:
handler_prev = self._visit_expression(handler.name.name, instance, handler_prev)
handler_prev = self._make_cfg_node(handler.name, instance, handler_prev) # AsName
handler_prev = self._visit_BaseSuite(
handler.body,
Expand Down Expand Up @@ -404,14 +402,14 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode:
loop_break_target=local_loop_break_target,
)
# Ghost node for exiting the finally clause normally
try_exit = Ghost(self.get_metadata(IndexProvider, statement), 0)
try_exit = Ghost(self._provider.get_metadata(IndexProvider, statement), 0)
finally_entry = wrap_in_finally(try_exit)
# The normal entry into a normal finally clause at the end of the body/else or handler
self._edge([*prev, *handler_exits], finally_entry)
prev = [try_exit]
elif m.matches(statement, m.While()):
statement: cst.While = cst.ensure_type(statement, cst.While)
index = self.get_metadata(IndexProvider, statement)
index = self._provider.get_metadata(IndexProvider, statement)
while_loop_continue_target = Ghost(index, 0)
prev = self._edge(prev, while_loop_continue_target)
prev = self._visit_expression(statement.test, instance, prev)
Expand Down Expand Up @@ -553,6 +551,7 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev:
prev = self._visit_expression(expression.func, instance, prev)
for arg in expression.args:
prev = self._visit_expression(arg.value, instance, prev)
prev = self._make_cfg_node(expression, instance, prev) # Call
## Literal Values
elif m.matches(expression, m.Ellipsis()):
pass
Expand Down Expand Up @@ -641,7 +640,7 @@ def _visit_CompFor(
elt: cst.BaseExpression | tuple[cst.BaseExpression, cst.BaseExpression],
prev: list[CfgNode],
) -> list[CfgNode]:
exit = Ghost(self.get_metadata(IndexProvider, for_in), 0)
exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0)
prev = self._visit_expression(for_in.iter, instance, prev)
prev = self._visit_expression(for_in.target, instance, prev)
for compif in for_in.ifs:
Expand All @@ -658,193 +657,19 @@ def _visit_CompFor(
prev = self._visit_expression(elt, instance, prev)
return self._edge(prev, exit)

@property
def cfg(self):
return self._cfg

class CfgNodeTester(cst.CSTVisitor):
"""
Checks that each kind of CST Node that should have a corresponding CFG node has one
"""

METADATA_DEPENDENCIES = (IndexProvider,)

def __init__(self, cfg: dict[Node, set[Node]]):
super().__init__()
self.cfg = cfg
self.active = False

def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool:
"""
Tests if the CSTNode has a corresponding CFG node with outgoing edges
"""
n = Node(self.get_metadata(IndexProvider, node), instance)
return n in self.cfg

def visit_Param(self, node: cst.Param) -> bool | None:
if self.active:
assert self._has_node(node)
# TODO: should we visit deeper into Param too?
return False
return None

def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None:
if self.active:
assert self._has_node(node)
# TODO: should we visit deeper into AssignTarget too?
return False
return None

def visit_AugAssign(self, node: cst.AugAssign) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_NameItem(self, node: cst.NameItem) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Attribute(self, node: cst.Attribute) -> bool | None:
if self.active:
assert self._has_node(node)
# TODO: should we visit deeper into Attribute too?
return False
return None

def visit_Name(self, node: cst.Name) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_AsName(self, node: cst.AsName) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_UnaryOperation(self, node: cst.UnaryOperation) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_BinaryOperation(self, node: cst.BinaryOperation) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_BooleanOperation(self, node: cst.BooleanOperation) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Await(self, node: cst.Await) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Yield(self, node: cst.Yield) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_From(self, node: cst.From) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Integer(self, node: cst.Integer) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Float(self, node: cst.Float) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Imaginary(self, node: cst.Imaginary) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_SimpleString(self, node: cst.SimpleString) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_FormattedStringExpression(self, node: cst.FormattedStringExpression) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_FormattedString(self, node: cst.FormattedString) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Tuple(self, node: cst.Tuple) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_List(self, node: cst.List) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Set(self, node: cst.Set) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Element(self, node: cst.Element) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_StarredElement(self, node: cst.StarredElement) -> bool | None:
if self.active:
assert self._has_node(node)
@property
def start_end(self):
return self._start_end

def visit_DictElement(self, node: cst.DictElement) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_StarredDictElement(self, node: cst.StarredDictElement) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_ListComp(self, node: cst.ListComp) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_SetComp(self, node: cst.SetComp) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_DictComp(self, node: cst.DictComp) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Index(self, node: cst.Index) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Slice(self, node: cst.Slice) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_Subscript(self, node: cst.Subscript) -> bool | None:
if self.active:
assert self._has_node(node)

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None:
assert self._has_node(node)
# We're not testing for instance 1, which is a final node and will not have outgoing edges

def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None:
self.active = True

def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None:
self.active = False

def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None:
self.active = True

def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None:
self.active = False

def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None:
self.active = False
class ControlFlowGraphProvider(
cst.BatchableMetadataProvider[tuple[dict[CfgNode, set[CfgNode]], list[tuple[CfgNode, CfgNode]]]]
):
METADATA_DEPENDENCIES = (IndexProvider,)

def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None:
self.active = True
def visit_Module(self, node: cst.Module) -> bool | None:
node.visit(ComputeControlFlowGraph(self))
Loading
Loading