Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 73 additions & 34 deletions src/styx_compiler/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,26 @@ class Ghost:
instance: int


CfgNode = Node | Ghost
type CfgNode = Node | Ghost


@dataclass(frozen=True)
class Read:
pass


@dataclass(frozen=True)
class Write:
pass


type RWContext = Read | Write


def debug_print_cfg(cfg: dict[CfgNode, list[CfgNode]]) -> None:
for node_from, nodes_to in cfg.items():
for node_to in nodes_to:
print(f"{node_from.index}_{node_from.instance} -> {node_to.index}_{node_to.instance}")


class ComputeControlFlowGraph(cst.CSTVisitor):
Expand Down Expand Up @@ -170,8 +189,10 @@ def _visit_statement(
# RHS first if it exists
if statement.value is not None:
prev = self._visit_expression(statement.value, instance, prev)
# LHS
prev = self._visit_expression(statement.target, instance, prev)
# LHS reads
prev = self._visit_expression(statement.target, instance, prev, Write())
# LHS write
prev = self._make_cfg_node(statement, instance, prev) # AnnAssign
elif m.matches(statement, m.Assert()):
statement: cst.Assert = cst.ensure_type(statement, cst.Assert)
# test expression first?
Expand All @@ -184,7 +205,7 @@ def _visit_statement(
prev = self._visit_expression(statement.value, instance, prev)
# then the multiple LHS, from left to right
for target in statement.targets:
prev = self._visit_expression(target.target, instance, prev)
prev = self._visit_expression(target.target, instance, prev, Write())
prev = self._make_cfg_node(target, instance, prev) # AssignTarget
elif m.matches(statement, m.AugAssign()):
statement: cst.AugAssign = cst.ensure_type(statement, cst.AugAssign)
Expand All @@ -208,7 +229,9 @@ def _visit_statement(
prev = []
elif m.matches(statement, m.Del()):
statement: cst.Del = cst.ensure_type(statement, cst.Del)
prev = self._visit_expression(statement.target, instance, prev)
prev = self._visit_expression(statement.target, instance, prev, Write())
# write/del effect
prev = self._make_cfg_node(statement, instance, prev) # Del
elif m.matches(statement, m.Expr()):
statement: cst.Expr = cst.ensure_type(statement, cst.Expr)
prev = self._visit_expression(statement.value, instance, prev)
Expand Down Expand Up @@ -255,9 +278,11 @@ def _visit_statement(
index = self._provider.get_metadata(IndexProvider, statement)
for_loop_continue_target = Ghost(index, 0)
prev = self._edge(prev, for_loop_continue_target)
loop_expr_prev = self._visit_expression(statement.iter, instance, prev)
prev = loop_expr_prev
prev = self._visit_expression(statement.target, instance, prev)
prev = self._visit_expression(statement.iter, instance, prev)
prev = self._visit_expression(statement.target, instance, prev, Write())
# assignment effect, would be nice to have something other than the target to pin this to (as the target
# may be a Name, and we usually see a Name CfgNode as a Read effect)
prev = self._make_cfg_node(statement, instance, prev)
prev = self._visit_loop(
statement,
instance,
Expand Down Expand Up @@ -357,7 +382,6 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode:
handler_exits.append(handler_exit)

if handler.name is not None:
handler_prev = self._visit_expression(handler.name.name, instance, handler_prev)
handler_prev = self._make_cfg_node(handler.name, instance, handler_prev) # AsName
handler_prev = self._visit_BaseSuite(
handler.body,
Expand Down Expand Up @@ -493,64 +517,72 @@ def _visit_loop(
loop_continue_target=loop_continue_target,
loop_break_target=loop_break_target,
)
self._edge(prev, this_loop_continue_target)
return self._edge(prev, this_loop_break_target)

def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: list[CfgNode]) -> list[CfgNode]:
def _visit_expression(
self,
expression: cst.BaseExpression,
instance: int,
prev: list[CfgNode],
context: RWContext = Read(), # noqa: B008
) -> list[CfgNode]:
## Names and Object Attributes
if m.matches(expression, m.Name()):
prev = self._make_cfg_node(expression, instance, prev) # Name
if context == Read():
prev = self._make_cfg_node(expression, instance, prev) # Name
elif m.matches(expression, m.Attribute()):
expression: cst.Attribute = cst.ensure_type(expression, cst.Attribute)
prev = self._visit_expression(expression.value, instance, prev)
prev = self._make_cfg_node(expression, instance, prev) # Attribute
## Operations and Comparisons
elif m.matches(expression, m.UnaryOperation()):
expression: cst.UnaryOperation = cst.ensure_type(expression, cst.UnaryOperation)
prev = self._visit_expression(expression.expression, instance, prev)
prev = self._visit_expression(expression.expression, instance, prev, context)
prev = self._make_cfg_node(expression.expression, instance, prev) # UnaryOperation
elif m.matches(expression, m.BinaryOperation() | m.BooleanOperation()):
expression: cst.BinaryOperation = cst.ensure_type(expression, cst.BinaryOperation)
prev = self._visit_expression(expression.left, instance, prev)
prev = self._visit_expression(expression.right, instance, prev)
prev = self._visit_expression(expression.left, instance, prev, context)
prev = self._visit_expression(expression.right, instance, prev, context)
prev = self._make_cfg_node(expression, instance, prev) # BinaryOperation, BooleanOperation
elif m.matches(expression, m.Comparison()):
# noinspection DuplicatedCode
expression: cst.Comparison = cst.ensure_type(expression, cst.Comparison)
prev = self._visit_expression(expression.left, instance, prev)
prev = self._visit_expression(expression.left, instance, prev, context)
for comparison in expression.comparisons:
prev = self._visit_expression(comparison.comparator, instance, prev)
prev = self._visit_expression(comparison.comparator, instance, prev, context)
prev = self._make_cfg_node(comparison, instance, prev) # ComparisonTarget
## Control Flow
elif m.matches(expression, m.Await()):
expression: cst.Await = cst.ensure_type(expression, cst.Await)
prev = self._visit_expression(expression.expression, instance, prev)
prev = self._visit_expression(expression.expression, instance, prev, context)
prev = self._make_cfg_node(expression, instance, prev) # Await
elif m.matches(expression, m.Yield()):
expression: cst.Yield = cst.ensure_type(expression, cst.Yield)
prev = self._visit_expression(expression.value, instance, prev)
prev = self._visit_expression(expression.value, instance, prev, context)
# yield is not like return. A later call to the generator will continue from the yield, so
# in the CFG we're just going to represent it as a normal node and pretend the control did
# not leave and re-enter because it probably doesn't matter for the analyses we want to do.
prev = self._make_cfg_node(expression, instance, prev) # Yield
elif m.matches(expression, m.From()):
expression: cst.From = cst.ensure_type(expression, cst.From)
prev = self._visit_expression(expression.item, instance, prev)
prev = self._visit_expression(expression.item, instance, prev, context)
prev = self._make_cfg_node(expression, instance, prev) # From
elif m.matches(expression, m.IfExp()):
expression: cst.IfExp = cst.ensure_type(expression, cst.IfExp)
prev = self._visit_expression(expression.test, instance, prev)
body = self._visit_expression(expression.body, instance, prev)
orelse = self._visit_expression(expression.orelse, instance, prev)
prev = self._visit_expression(expression.test, instance, prev, context)
body = self._visit_expression(expression.body, instance, prev, context)
orelse = self._visit_expression(expression.orelse, instance, prev, context)
prev = [*body, *orelse]
## Lambdas and Function Calls
elif m.matches(expression, m.Lambda()):
msg = "Lambdas are not yet supported"
raise NotImplementedError(msg)
elif m.matches(expression, m.Call()):
expression: cst.Call = cst.ensure_type(expression, cst.Call)
prev = self._visit_expression(expression.func, instance, prev)
prev = self._visit_expression(expression.func, instance, prev, context)
for arg in expression.args:
prev = self._visit_expression(arg.value, instance, prev)
prev = self._visit_expression(arg.value, instance, prev, context)
prev = self._make_cfg_node(expression, instance, prev) # Call
## Literal Values
elif m.matches(expression, m.Ellipsis()):
Expand All @@ -564,25 +596,25 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev:
for part in expression.parts:
if m.matches(part, m.FormattedStringExpression()):
part: cst.FormattedStringExpression = cst.ensure_type(part, cst.FormattedStringExpression) # noqa: PLW2901
prev = self._visit_expression(part.expression, instance, prev)
prev = self._visit_expression(part.expression, instance, prev, context)
prev = self._make_cfg_node(part, instance, prev) # FormattedStringExpression
prev = self._make_cfg_node(expression, instance, prev) # FormattedString
## Collections
elif m.matches(expression, m.Tuple() | m.List() | m.Set()):
# noinspection PyUnresolvedReferences
prev = self._visit_elements(expression.elements, instance, prev)
prev = self._visit_elements(expression.elements, instance, prev, context)
prev = self._make_cfg_node(expression, instance, prev) # Tuple, List, Set
elif m.matches(expression, m.Element() | m.StarredElement()):
# noinspection PyUnresolvedReferences
prev = self._visit_expression(expression.value, instance, prev)
prev = self._visit_expression(expression.value, instance, prev, context)
prev = self._make_cfg_node(expression, instance, prev) # Element, StarredElement
elif m.matches(expression, m.Dict()):
expression: cst.Dict = cst.ensure_type(expression, cst.Dict)
for element in expression.elements:
if m.matches(element, m.DictElement()):
element: cst.DictElement = cst.ensure_type(element, cst.DictElement) # noqa: PLW2901
prev = self._visit_expression(element.key, instance, prev)
prev = self._visit_expression(element.value, instance, prev)
prev = self._visit_expression(element.key, instance, prev, context)
prev = self._visit_expression(element.value, instance, prev, context)
prev = self._make_cfg_node(element, instance, prev) # DictElement, StarredDictElement
## Comprehensions
elif m.matches(expression, m.GeneratorExp() | m.ListComp() | m.SetComp()):
Expand All @@ -596,7 +628,7 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev:
## Subscripts and Slices
elif m.matches(expression, m.Subscript()):
expression: cst.Subscript = cst.ensure_type(expression, cst.Subscript)
prev = self._visit_expression(expression.value, instance, prev)
prev = self._visit_expression(expression.value, instance, prev, context)
for element in expression.slice:
if m.matches(element, m.Index()):
element: cst.Index = cst.ensure_type(element, cst.Index) # noqa: PLW2901
Expand All @@ -623,10 +655,12 @@ def _visit_ImportAlias(self, import_alias: cst.ImportAlias, instance: int, prev:
prev = self._make_cfg_node(import_alias.asname, instance, prev) # AsName
return prev

def _visit_elements(self, elements: Sequence[cst.BaseElement], instance: int, prev: list[CfgNode]) -> list[CfgNode]:
def _visit_elements(
self, elements: Sequence[cst.BaseElement], instance: int, prev: list[CfgNode], context: RWContext
) -> list[CfgNode]:
for element in elements:
if m.matches(element, m.Element() | m.StarredElement()):
prev = self._visit_expression(element.value, instance, prev)
prev = self._visit_expression(element.value, instance, prev, context)
prev = self._make_cfg_node(element, instance, prev) # Element, StarredElement
else:
msg = f"Unknown element type {element}"
Expand All @@ -640,9 +674,13 @@ def _visit_CompFor(
elt: cst.BaseExpression | tuple[cst.BaseExpression, cst.BaseExpression],
prev: list[CfgNode],
) -> list[CfgNode]:
exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0)
entry = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0)
exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 1)
prev = self._edge(prev, entry)
prev = self._visit_expression(for_in.iter, instance, prev)
prev = self._visit_expression(for_in.target, instance, prev)
prev = self._visit_expression(for_in.target, instance, prev, Write())
# write effect
prev = self._make_cfg_node(for_in, instance, prev) # CompFor
for compif in for_in.ifs:
prev = self._visit_expression(compif.test, instance, prev)
self._edge(prev, exit)
Expand All @@ -655,6 +693,7 @@ def _visit_CompFor(
prev = self._visit_expression(value, instance, prev)
else:
prev = self._visit_expression(elt, instance, prev)
self._edge(prev, entry)
return self._edge(prev, exit)

@property
Expand Down
7 changes: 4 additions & 3 deletions src/styx_compiler/data_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def strong_connect(node: Node):
node_index[node] = index
node_lowlink[node] = index
index += 1

stack_set_size = len(node_on_stack)
node_on_stack.add(node)
# N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the
# recursive calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack.
Expand All @@ -54,9 +56,8 @@ def strong_connect(node: Node):
scc_stack.append(node)

if node_lowlink[node] == node_index[node]:
scc = [scc_stack.pop()]
node_on_stack.remove(scc[-1])
while scc[-1] != node:
scc = []
for _ in range(len(node_on_stack), stack_set_size, -1):
scc.append(scc_stack.pop())
node_on_stack.remove(scc[-1])
result.append(list(reversed(scc)))
Expand Down
Loading
Loading