diff --git a/pyproject.toml b/pyproject.toml index 02cb42d..df2a39b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ readme = "README.md" requires-python = ">=3.14" dependencies = [ "libcst>=1.8.6", + "libcst-dfa>=0.0.1", "libcst-mypy>=0.1.0", "mypy>=1.19.1", ] diff --git a/src/styx_compiler/control_flow.py b/src/styx_compiler/control_flow.py deleted file mode 100644 index 8a42e9b..0000000 --- a/src/styx_compiler/control_flow.py +++ /dev/null @@ -1,730 +0,0 @@ -""" -A control-flow graph consists of a start CfgNode, an end CfgNode, and some Node CfgNodes in between. -It looks like ``dict[CfgNode, list[CfgNode]]``, along with a start and end. You can reuse the dict to -contain multiple graphs. Each CfgNode has an index, an integer that uniquely identifies a CST node. -""" - -from collections.abc import Sequence -from dataclasses import dataclass - -import libcst as cst -from libcst import matchers as m - -from styx_compiler.metadata_providers import IndexProvider - - -@dataclass(frozen=True) -class Node: - """ - A node in the control flow graph - - Uses an index from the IndexProvider to tie it to the CST, and an instance number to make multiple unique instances - """ - - index: int - instance: int - - -@dataclass(frozen=True) -class Ghost: - """ - Not a real node, just a construction device that gets removed later - - Uses an index from the IndexProvider to tie it to the CST, and an instance number to make multiple unique instances - """ - - index: int - instance: int - - -type CfgNode = Node | Ghost - - -@dataclass(frozen=True) -class Read: - pass - - -@dataclass(frozen=True) -class Write: - pass - - -type RWContext = Read | Write - - -def debug_print_cfg(cfg: dict[CfgNode, list[CfgNode]]) -> None: - for node_from, nodes_to in cfg.items(): - for node_to in nodes_to: - print(f"{node_from.index}_{node_from.instance} -> {node_to.index}_{node_to.instance}") - - -class ComputeControlFlowGraph(cst.CSTVisitor): - """ - Computes the control-flow graph of the code, expressed in indices from the IndexProvider - """ - - def __init__(self, provider: ControlFlowGraphProvider): - super().__init__() - self._provider = provider - self._cfg: dict[CfgNode, set[CfgNode]] = {} - self._start_end: list[tuple[CfgNode, CfgNode]] = [] - - def _edge(self, prev: list[CfgNode], cur: CfgNode) -> list[CfgNode]: - for p in prev: - self._cfg.setdefault(p, set()).add(cur) - return [cur] - - def _edges(self, prev: list[CfgNode], tos: list[CfgNode]) -> list[CfgNode]: - for p in prev: - self._cfg.setdefault(p, set()).update(tos) - return tos - - def _make_cfg_node(self, cst_node: cst.CSTNode, instance: int, prev: list[CfgNode]) -> list[CfgNode]: - cur = Node(self._provider.get_metadata(IndexProvider, cst_node), instance) - return self._edge(prev, cur) - - def _clean_up_cfg_ghosts(self, start: CfgNode) -> None: - seen: set[CfgNode] = set() - workstack: list[CfgNode] = [start] - seen.add(start) - while len(workstack) > 0: - node = workstack.pop() - ghost_workstack: list[Ghost] = [] - for next_node in self._cfg.get(node, set()): - if isinstance(next_node, Ghost): - ghost_workstack.append(next_node) - elif next_node not in seen: - workstack.append(next_node) - seen.add(next_node) - while len(ghost_workstack) > 0: - next_node = ghost_workstack.pop() - self._cfg[node].remove(next_node) - for next_next_node in self._cfg[next_node]: - self._cfg[node].add(next_next_node) - if isinstance(next_next_node, Ghost): - ghost_workstack.append(next_next_node) - elif next_node not in seen: - workstack.append(next_node) - seen.add(next_node) - - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - index = self._provider.get_metadata(IndexProvider, node) - start = Node(index, 0) - end = Node(index, 1) - self._start_end.append((start, end)) - - prev = [start] - - instance = 0 - - for param in node.params.params: - prev = self._make_cfg_node(param, instance, prev) # Param - - prev = self._visit_BaseSuite(node.body, instance, prev, fn_end=end, exception_target=end) - - self._edge(prev, end) - - self._clean_up_cfg_ghosts(start) - - def leave_Module(self, module: cst.Module) -> None: - # Remove unreachable parts of the CFG (e.g. unused finally clause instantiations, dead code after a return) - reachable = set() - workstack = [] - for start, _ in self._start_end: - reachable.add(start) - workstack.append(start) - - while len(workstack) > 0: - node = workstack.pop() - for to in self._cfg.get(node, []): - if to not in reachable: - reachable.add(to) - workstack.append(to) - - to_remove = [] - for k in self._cfg: - if k not in reachable: - to_remove.append(k) - for k in to_remove: - del self._cfg[k] - - self._provider.set_metadata(module, (self._cfg, self._start_end)) - - def _visit_BaseSuite( - self, - statements: cst.BaseSuite | cst.SimpleStatementLine, - instance: int, - prev: list[CfgNode], - fn_end: CfgNode, - exception_target: CfgNode, - loop_continue_target: CfgNode | None = None, - loop_break_target: CfgNode | None = None, - ) -> list[CfgNode]: - for statement in statements.body: - prev = self._visit_statement( - statement, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - return prev - - def _visit_statement( - self, - statement: cst.BaseStatement | cst.BaseSmallStatement, - instance: int, - prev: list[CfgNode], - fn_end: CfgNode, - exception_target: CfgNode, - loop_continue_target: CfgNode | None = None, - loop_break_target: CfgNode | None = None, - ) -> list[CfgNode]: - ## Simple Statements - if m.matches(statement, m.AnnAssign()): - statement: cst.AnnAssign = cst.ensure_type(statement, cst.AnnAssign) - # RHS first if it exists - if statement.value is not None: - prev = self._visit_expression(statement.value, instance, prev) - # LHS reads - prev = self._visit_expression(statement.target, instance, prev, Write()) - # LHS write - prev = self._make_cfg_node(statement, instance, prev) # AnnAssign - elif m.matches(statement, m.Assert()): - statement: cst.Assert = cst.ensure_type(statement, cst.Assert) - # test expression first? - prev = self._visit_expression(statement.test, instance, prev) - # then message - prev = self._visit_expression(statement.msg, instance, prev) - elif m.matches(statement, m.Assign()): - statement: cst.Assign = cst.ensure_type(statement, cst.Assign) - # RHS first - prev = self._visit_expression(statement.value, instance, prev) - # then the multiple LHS, from left to right - for target in statement.targets: - prev = self._visit_expression(target.target, instance, prev, Write()) - prev = self._make_cfg_node(target, instance, prev) # AssignTarget - elif m.matches(statement, m.AugAssign()): - statement: cst.AugAssign = cst.ensure_type(statement, cst.AugAssign) - # note we're visiting LHS first to represent reading the value from the target - prev = self._visit_expression(statement.target, instance, prev) - # then we visit the RHS expression to find more reads - prev = self._visit_expression(statement.value, instance, prev) - # finally we write to the LHS, represented by a node of the whole assignment - prev = self._make_cfg_node(statement, instance, prev) # AugAssign - elif m.matches(statement, m.Break()): - if loop_break_target is None: - msg = "Found break outside of loop" - raise RuntimeError(msg) - self._edge(prev, loop_break_target) - prev = [] - elif m.matches(statement, m.Continue()): - if loop_continue_target is None: - msg = "Found break outside of loop" - raise RuntimeError(msg) - self._edge(prev, loop_continue_target) - prev = [] - elif m.matches(statement, m.Del()): - statement: cst.Del = cst.ensure_type(statement, cst.Del) - prev = self._visit_expression(statement.target, instance, prev, Write()) - # write/del effect - prev = self._make_cfg_node(statement, instance, prev) # Del - elif m.matches(statement, m.Expr()): - statement: cst.Expr = cst.ensure_type(statement, cst.Expr) - prev = self._visit_expression(statement.value, instance, prev) - elif m.matches(statement, m.Global()): - statement: cst.Global = cst.ensure_type(statement, cst.Global) - for name in statement.names: - prev = self._make_cfg_node(name, instance, prev) # NameItem - elif m.matches(statement, m.Import()): - statement: cst.Import = cst.ensure_type(statement, cst.Import) - for name in statement.names: - prev = self._visit_ImportAlias(name, instance, prev) - elif m.matches(statement, m.ImportFrom()): - statement: cst.ImportFrom = cst.ensure_type(statement, cst.ImportFrom) - if statement.module is not None: - prev = self._make_cfg_node(statement.module, instance, prev) # Attribute | Name - if not m.matches(statement.names, m.ImportStar()): - for name in statement.names: - prev = self._visit_ImportAlias(name, instance, prev) - elif m.matches(statement, m.Nonlocal()): - statement: cst.Nonlocal = cst.ensure_type(statement, cst.Nonlocal) - for name in statement.names: - prev = self._make_cfg_node(name, instance, prev) # NameItem - elif m.matches(statement, m.Pass()): - pass - elif m.matches(statement, m.Raise()): - statement: cst.Raise = cst.ensure_type(statement, cst.Raise) - if statement.exc is not None: - prev = self._visit_expression(statement.exc, instance, prev) - if statement.cause is not None: - prev = self._visit_expression(statement.cause.item, instance, prev) - self._edge(prev, exception_target) - prev = [] - elif m.matches(statement, m.Return()): - statement: cst.Return = cst.ensure_type(statement, cst.Return) - prev = self._visit_expression(statement.value, instance, prev) - self._edge(prev, fn_end) - prev = [] - ## Compound Statements - elif m.matches(statement, m.ClassDef()): - msg = "Inline class definition is not yet supported" - raise NotImplementedError(msg) - elif m.matches(statement, m.For()): - statement: cst.For = cst.ensure_type(statement, cst.For) - index = self._provider.get_metadata(IndexProvider, statement) - for_loop_continue_target = Ghost(index, 0) - prev = self._edge(prev, for_loop_continue_target) - prev = self._visit_expression(statement.iter, instance, prev) - prev = self._visit_expression(statement.target, instance, prev, Write()) - # assignment effect, would be nice to have something other than the target to pin this to (as the target - # may be a Name, and we usually see a Name CfgNode as a Read effect) - prev = self._make_cfg_node(statement, instance, prev) - prev = self._visit_loop( - statement, - instance, - prev, - index, - for_loop_continue_target, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - elif m.matches(statement, m.FunctionDef()): - msg = "Inline function definition is not yet supported" - raise NotImplementedError(msg) - elif m.matches(statement, m.If()): - statement: cst.If = cst.ensure_type(statement, cst.If) - prev = self._visit_expression(statement.test, instance, prev) - body = self._visit_BaseSuite( - statement.body, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - if statement.orelse is None: - pass - elif m.matches(statement.orelse, m.Else()): - orelse: cst.Else = cst.ensure_type(statement.orelse, cst.Else) - prev = self._visit_BaseSuite( - orelse.body, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - else: - orelse: cst.If = cst.ensure_type(statement.orelse, cst.If) - prev = self._visit_statement( - orelse, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - prev = [*body, *prev] - elif m.matches(statement, m.Try()): - statement: cst.Try = cst.ensure_type(statement, cst.Try) - finally_number = instance - - def wrap_in_finally(exit: CfgNode) -> CfgNode: - nonlocal statement, finally_number, fn_end, exception_target, loop_continue_target, loop_break_target - if statement.finalbody is not None: - entry = Ghost(self._provider.get_metadata(IndexProvider, statement.finalbody), finally_number) - finalbody: cst.Finally = cst.ensure_type(statement.finalbody, cst.Finally) - prev = self._visit_BaseSuite( - finalbody.body, - finally_number, - [entry], - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - self._edge(prev, exit) - finally_number += 1 - return entry - return exit - - # Install instantiation of finally clause before different ways you can exit a try body or handler body. - local_fn_end = wrap_in_finally(fn_end) - local_exception_target = local_fn_end if fn_end == exception_target else wrap_in_finally(exception_target) - local_loop_continue_target = wrap_in_finally(loop_continue_target) - local_loop_break_target = wrap_in_finally(loop_break_target) - - handler_entries = [] - handler_cond = [] - handler_exits = [] - # Build the chain of exception handlers, each is modeled with a conditional going into the handler body or - # to the next conditional - for handler in statement.handlers: - handler: cst.ExceptHandler = cst.ensure_type(handler, cst.ExceptHandler) # noqa: PLW2901 - handler_index = self._provider.get_metadata(IndexProvider, handler) - handler_entry = Ghost(handler_index, 0) - handler_exit = Ghost(handler_index, 1) - handler_entries.append(handler_entry) - if len(handler_cond) > 0: - self._edge(handler_cond[-1], handler_entry) - handler_prev = self._visit_expression(handler.type, instance, [handler_entry]) - handler_cond.append(handler_prev) - self._edge(handler_prev, handler_exit) - handler_exits.append(handler_exit) - - if handler.name is not None: - handler_prev = self._make_cfg_node(handler.name, instance, handler_prev) # AsName - handler_prev = self._visit_BaseSuite( - handler.body, - instance, - handler_prev, - fn_end=local_fn_end, - exception_target=local_exception_target, - loop_continue_target=local_loop_continue_target, - loop_break_target=local_loop_break_target, - ) - self._edge(handler_prev, handler_exit) - # Try body first, using the handler chain as exception target and the finally-wrapped other targets - prev = self._visit_BaseSuite( - statement.body, - instance, - prev, - fn_end=local_fn_end, - exception_target=handler_entries[0] if len(handler_entries) > 0 else local_exception_target, - loop_continue_target=local_loop_continue_target, - loop_break_target=local_loop_break_target, - ) - # If we have handlers, we go into them after the try body too in case of an exception that wasn't - # explicitly raised - if len(handler_entries) > 0: - self._edge(prev, handler_entries[0]) - # From the final handler cond we can go to the finally-wrapped outside exception target if none of our - # local handlers matched against the raised exception. - self._edge(handler_cond[-1], local_exception_target) - else: - # If there are no handlers, we might go to the finally-wrapped outside exception target. - self._edge(prev, local_exception_target) - # If no exception was raised, we go into the else clause if it exists - if statement.orelse is not None: - orelse: cst.Else = cst.ensure_type(statement.orelse, cst.Else) - prev = self._visit_BaseSuite( - orelse.body, - instance, - prev, - fn_end=local_fn_end, - exception_target=local_exception_target, - loop_continue_target=local_loop_continue_target, - loop_break_target=local_loop_break_target, - ) - # Ghost node for exiting the finally clause normally - try_exit = Ghost(self._provider.get_metadata(IndexProvider, statement), 0) - finally_entry = wrap_in_finally(try_exit) - # The normal entry into a normal finally clause at the end of the body/else or handler - self._edge([*prev, *handler_exits], finally_entry) - prev = [try_exit] - elif m.matches(statement, m.While()): - statement: cst.While = cst.ensure_type(statement, cst.While) - index = self._provider.get_metadata(IndexProvider, statement) - while_loop_continue_target = Ghost(index, 0) - prev = self._edge(prev, while_loop_continue_target) - prev = self._visit_expression(statement.test, instance, prev) - prev = self._visit_loop( - statement, - instance, - prev, - index, - while_loop_continue_target, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - elif m.matches(statement, m.With()): - statement: cst.With = cst.ensure_type(statement, cst.With) - for item in statement.items: - prev = self._visit_expression(item.item, instance, prev) - if item.asname is not None: - prev = self._make_cfg_node(item.asname, instance, prev) # AsName - prev = self._visit_BaseSuite( - statement.body, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - ## Statement Blocks - elif m.matches(statement, m.SimpleStatementLine() | m.SimpleStatementSuite() | m.IndentedBlock()): - if m.matches(statement, m.SimpleStatementLine()): - statement: cst.SimpleStatementLine = cst.ensure_type(statement, cst.SimpleStatementLine) - else: - statement: cst.BaseSuite = cst.ensure_type(statement, cst.BaseSuite) - prev = self._visit_BaseSuite( - statement, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - else: - msg = f"Unknown statement type {statement}" - raise RuntimeError(msg) - - return prev - - def _visit_loop( - self, - statement: cst.For | cst.While, - instance: int, - prev: list[CfgNode], - index: int, - this_loop_continue_target: Ghost, - fn_end: CfgNode, - exception_target: CfgNode, - loop_continue_target: CfgNode | None, - loop_break_target: CfgNode | None, - ) -> list[CfgNode]: - this_loop_break_target = Ghost(index, 1) - prev = self._visit_BaseSuite( - statement.body, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=this_loop_continue_target, - loop_break_target=this_loop_break_target, - ) - if statement.orelse is not None: - orelse: cst.Else = cst.ensure_type(statement.orelse, cst.Else) - prev = self._visit_BaseSuite( - orelse.body, - instance, - prev, - fn_end=fn_end, - exception_target=exception_target, - loop_continue_target=loop_continue_target, - loop_break_target=loop_break_target, - ) - self._edge(prev, this_loop_continue_target) - return self._edge(prev, this_loop_break_target) - - def _visit_expression( - self, - expression: cst.BaseExpression, - instance: int, - prev: list[CfgNode], - context: RWContext = Read(), # noqa: B008 - ) -> list[CfgNode]: - ## Names and Object Attributes - if m.matches(expression, m.Name()): - if context == Read(): - prev = self._make_cfg_node(expression, instance, prev) # Name - elif m.matches(expression, m.Attribute()): - expression: cst.Attribute = cst.ensure_type(expression, cst.Attribute) - prev = self._visit_expression(expression.value, instance, prev) - prev = self._make_cfg_node(expression, instance, prev) # Attribute - ## Operations and Comparisons - elif m.matches(expression, m.UnaryOperation()): - expression: cst.UnaryOperation = cst.ensure_type(expression, cst.UnaryOperation) - prev = self._visit_expression(expression.expression, instance, prev, context) - prev = self._make_cfg_node(expression.expression, instance, prev) # UnaryOperation - elif m.matches(expression, m.BinaryOperation()): - expression: cst.BinaryOperation = cst.ensure_type(expression, cst.BinaryOperation) - prev = self._visit_expression(expression.left, instance, prev, context) - prev = self._visit_expression(expression.right, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # BinaryOperation - elif m.matches(expression, m.BooleanOperation()): - expression: cst.BooleanOperation = cst.ensure_type(expression, cst.BooleanOperation) - prev = self._visit_expression(expression.left, instance, prev, context) - prev = self._visit_expression(expression.right, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # BooleanOperation - elif m.matches(expression, m.Comparison()): - # noinspection DuplicatedCode - expression: cst.Comparison = cst.ensure_type(expression, cst.Comparison) - prev = self._visit_expression(expression.left, instance, prev, context) - for comparison in expression.comparisons: - prev = self._visit_expression(comparison.comparator, instance, prev, context) - prev = self._make_cfg_node(comparison, instance, prev) # ComparisonTarget - ## Control Flow - elif m.matches(expression, m.Await()): - expression: cst.Await = cst.ensure_type(expression, cst.Await) - prev = self._visit_expression(expression.expression, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # Await - elif m.matches(expression, m.Yield()): - expression: cst.Yield = cst.ensure_type(expression, cst.Yield) - prev = self._visit_expression(expression.value, instance, prev, context) - # yield is not like return. A later call to the generator will continue from the yield, so - # in the CFG we're just going to represent it as a normal node and pretend the control did - # not leave and re-enter because it probably doesn't matter for the analyses we want to do. - prev = self._make_cfg_node(expression, instance, prev) # Yield - elif m.matches(expression, m.From()): - expression: cst.From = cst.ensure_type(expression, cst.From) - prev = self._visit_expression(expression.item, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # From - elif m.matches(expression, m.IfExp()): - expression: cst.IfExp = cst.ensure_type(expression, cst.IfExp) - prev = self._visit_expression(expression.test, instance, prev, context) - body = self._visit_expression(expression.body, instance, prev, context) - orelse = self._visit_expression(expression.orelse, instance, prev, context) - prev = [*body, *orelse] - ## Lambdas and Function Calls - elif m.matches(expression, m.Lambda()): - msg = "Lambdas are not yet supported" - raise NotImplementedError(msg) - elif m.matches(expression, m.Call()): - expression: cst.Call = cst.ensure_type(expression, cst.Call) - prev = self._visit_expression(expression.func, instance, prev, context) - for arg in expression.args: - prev = self._visit_expression(arg.value, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # Call - ## Literal Values - elif m.matches(expression, m.Ellipsis()): - pass - elif m.matches(expression, m.Integer() | m.Float() | m.Imaginary() | m.SimpleString()): - prev = self._make_cfg_node(expression, instance, prev) # Integer, Float, Imaginary, SimpleString - elif m.matches(expression, m.ConcatenatedString()): - # Implicit concatenation like `"a" f"{x}"`. Visit left and right so - # any embedded expressions inside FormattedString parts are seen as reads. - expression: cst.ConcatenatedString = cst.ensure_type(expression, cst.ConcatenatedString) - prev = self._visit_expression(expression.left, instance, prev, context) - prev = self._visit_expression(expression.right, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # ConcatenatedString - elif m.matches(expression, m.FormattedString()): - expression: cst.FormattedString = cst.ensure_type(expression, cst.FormattedString) - for part in expression.parts: - if m.matches(part, m.FormattedStringExpression()): - part: cst.FormattedStringExpression = cst.ensure_type(part, cst.FormattedStringExpression) # noqa: PLW2901 - prev = self._visit_expression(part.expression, instance, prev, context) - prev = self._make_cfg_node(part, instance, prev) # FormattedStringExpression - prev = self._make_cfg_node(expression, instance, prev) # FormattedString - ## Collections - elif m.matches(expression, m.Tuple() | m.List() | m.Set()): - # noinspection PyUnresolvedReferences - prev = self._visit_elements(expression.elements, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # Tuple, List, Set - elif m.matches(expression, m.Element() | m.StarredElement()): - # noinspection PyUnresolvedReferences - prev = self._visit_expression(expression.value, instance, prev, context) - prev = self._make_cfg_node(expression, instance, prev) # Element, StarredElement - elif m.matches(expression, m.Dict()): - expression: cst.Dict = cst.ensure_type(expression, cst.Dict) - for element in expression.elements: - if m.matches(element, m.DictElement()): - element: cst.DictElement = cst.ensure_type(element, cst.DictElement) # noqa: PLW2901 - prev = self._visit_expression(element.key, instance, prev, context) - prev = self._visit_expression(element.value, instance, prev, context) - prev = self._make_cfg_node(element, instance, prev) # DictElement, StarredDictElement - ## Comprehensions - elif m.matches(expression, m.GeneratorExp() | m.ListComp() | m.SetComp()): - expression: cst.BaseSimpleComp = cst.ensure_type(expression, cst.BaseSimpleComp) - prev = self._visit_CompFor(expression.for_in, instance, expression.elt, prev) - prev = self._make_cfg_node(expression, instance, prev) # GeneratorExp, ListComp, SetComp - elif m.matches(expression, m.DictComp()): - expression: cst.DictComp = cst.ensure_type(expression, cst.DictComp) - prev = self._visit_CompFor(expression.for_in, instance, (expression.key, expression.value), prev) - prev = self._make_cfg_node(expression, instance, prev) # DictComp - ## Subscripts and Slices - elif m.matches(expression, m.Subscript()): - expression: cst.Subscript = cst.ensure_type(expression, cst.Subscript) - prev = self._visit_expression(expression.value, instance, prev, context) - for element in expression.slice: - baseslice: cst.BaseSlice = cst.ensure_type(element, cst.SubscriptElement).slice - if m.matches(baseslice, m.Index()): - element: cst.Index = cst.ensure_type(baseslice, cst.Index) # noqa: PLW2901 - prev = self._visit_expression(element.value, instance, prev) - prev = self._make_cfg_node(element, instance, prev) # Index - elif m.matches(baseslice, m.Slice()): - element: cst.Slice = cst.ensure_type(baseslice, cst.Slice) # noqa: PLW2901 - if element.lower is not None: - prev = self._visit_expression(element.lower, instance, prev) - if element.upper is not None: - prev = self._visit_expression(element.upper, instance, prev) - if element.step is not None: - prev = self._visit_expression(element.step, instance, prev) - prev = self._make_cfg_node(element, instance, prev) # Slice - else: - msg = f"Unknown BaseSlice type {baseslice}" - raise RuntimeError(msg) - prev = self._make_cfg_node(expression, instance, prev) # Subscript - elif expression is None: - return prev - else: - msg = f"Unknown expression type {expression}" - raise RuntimeError(msg) - return prev - - def _visit_ImportAlias(self, import_alias: cst.ImportAlias, instance: int, prev: list[CfgNode]) -> list[CfgNode]: - prev = self._make_cfg_node(import_alias.name, instance, prev) # Attribute, Name - if import_alias.asname is not None: - prev = self._make_cfg_node(import_alias.asname, instance, prev) # AsName - return prev - - def _visit_elements( - self, elements: Sequence[cst.BaseElement], instance: int, prev: list[CfgNode], context: RWContext - ) -> list[CfgNode]: - for element in elements: - if m.matches(element, m.Element() | m.StarredElement()): - prev = self._visit_expression(element.value, instance, prev, context) - prev = self._make_cfg_node(element, instance, prev) # Element, StarredElement - else: - msg = f"Unknown element type {element}" - raise RuntimeError(msg) - return prev - - def _visit_CompFor( - self, - for_in: cst.CompFor, - instance: int, - elt: cst.BaseExpression | tuple[cst.BaseExpression, cst.BaseExpression], - prev: list[CfgNode], - ) -> list[CfgNode]: - entry = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0) - exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 1) - prev = self._edge(prev, entry) - prev = self._visit_expression(for_in.iter, instance, prev) - prev = self._visit_expression(for_in.target, instance, prev, Write()) - # write effect - prev = self._make_cfg_node(for_in, instance, prev) # CompFor - for compif in for_in.ifs: - prev = self._visit_expression(compif.test, instance, prev) - self._edge(prev, exit) - if for_in.inner_for_in is not None: - prev = self._visit_CompFor(for_in.inner_for_in, instance, elt, prev) - else: - if isinstance(elt, tuple): - key, value = elt - prev = self._visit_expression(key, instance, prev) - prev = self._visit_expression(value, instance, prev) - else: - prev = self._visit_expression(elt, instance, prev) - self._edge(prev, entry) - return self._edge(prev, exit) - - @property - def cfg(self): - return self._cfg - - @property - def start_end(self): - return self._start_end - - -class ControlFlowGraphProvider( - cst.BatchableMetadataProvider[tuple[dict[CfgNode, set[CfgNode]], list[tuple[CfgNode, CfgNode]]]] -): - METADATA_DEPENDENCIES = (IndexProvider,) - - def visit_Module(self, node: cst.Module) -> bool | None: - node.visit(ComputeControlFlowGraph(self)) diff --git a/src/styx_compiler/core.py b/src/styx_compiler/core.py index c32325c..4b80bef 100644 --- a/src/styx_compiler/core.py +++ b/src/styx_compiler/core.py @@ -12,12 +12,12 @@ import libcst.matchers as m import mypy.api from libcst import CSTNode, FlattenSentinel, FunctionDef, Module, RemovalSentinel +from libcst_dfa.live_variables import LiveVariablesProvider from libcst_mypy import MypyTypeInferenceProvider from libcst_mypy.utils import MypyType from styx_compiler.comprehension_expander import ComprehensionExpander from styx_compiler.config import N_PARTITIONS -from styx_compiler.live_variables import LiveVariablesProvider from styx_compiler.processor import FunctionProcessor from styx_compiler.transformers import ( EntityTypeReplacer, diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py deleted file mode 100644 index e60e9df..0000000 --- a/src/styx_compiler/data_flow.py +++ /dev/null @@ -1,194 +0,0 @@ -""" -Data-flow analysis engine as described in "FlowSpec: A declarative specification language for intra-procedural -flow-sensitive data-flow analysis" by Smits, Wachsmuth and Visser (https://doi.org/10.1016/j.cola.2019.100924). -""" - -from abc import ABC, abstractmethod -from collections.abc import Callable -from dataclasses import dataclass - -from styx_compiler.control_flow import Node - -type Cfg = dict[Node, set[Node]] - - -def compute_sccs(cfg: Cfg, extremals: list[Node]) -> list[list[Node]]: - """ - Tarjan's Strongly Connected Component algorithm, with a slight modification to force the order of nodes within an - SCC into a postorder traversal. See section 5.3.1 / figure 26. - - You can reverse the control-flow graphs and give the list of end nodes and this will work just as well. - - Parameters: - cfg (Cfg): one or more control-flow graphs - extremals (list[Node]): a list of start nodes to the control-flow graphs - - Returns: - list[list[Node]]: list of SCCs in topological order (use as a stack for topo order), where each SCC - in reverse postorder over the depth-first spanning tree of the SCC. - """ - index = 0 - scc_stack = [] - result = [] - node_index = {} - node_lowlink = {} - node_on_stack = set() - - def strong_connect(node: Node): - nonlocal index, scc_stack, result, node_index, node_lowlink, node_on_stack - node_index[node] = index - node_lowlink[node] = index - index += 1 - - stack_set_size = len(node_on_stack) - node_on_stack.add(node) - # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the - # recursive calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. - - for next_node in cfg.get(node, set()): - if next_node not in node_index: - strong_connect(next_node) - node_lowlink[node] = min(node_lowlink[node], node_lowlink[next_node]) - elif next_node in node_on_stack: - node_lowlink[node] = min(node_lowlink[node], node_index[next_node]) - - # Now we add the node to scc_stack in postorder - scc_stack.append(node) - - if node_lowlink[node] == node_index[node]: - scc = [] - for _ in range(len(node_on_stack), stack_set_size, -1): - scc.append(scc_stack.pop()) - node_on_stack.remove(scc[-1]) - result.append(list(reversed(scc))) - - for ext in extremals: - if ext not in node_index: - strong_connect(ext) - - return list(reversed(result)) - - -@dataclass(frozen=True) -class SymbolicTop: - pass - - -@dataclass(frozen=True) -class SymbolicBottom: - pass - - -type TB[T] = T | SymbolicTop | SymbolicBottom - - -class Lattice[T](ABC): - def __init__(self): - super().__init__() - self.top = SymbolicTop() - self.bottom = SymbolicBottom() - - def nleq(self, left: TB[T], right: TB[T]) -> bool: - if isinstance(left, SymbolicBottom) or isinstance(right, SymbolicTop): - return False - if isinstance(left, SymbolicTop) or isinstance(right, SymbolicBottom): - return True - return self._nleq_helper(left, right) - - @abstractmethod - def _nleq_helper(self, left: T, right: T) -> bool: - raise NotImplementedError() - - def join(self, left: TB[T], right: TB[T]) -> TB[T]: - if isinstance(left, SymbolicTop) or isinstance(right, SymbolicTop): - return SymbolicTop() - if isinstance(left, SymbolicBottom): - return right - if isinstance(right, SymbolicBottom): - return left - return self._join_helper(left, right) - - @abstractmethod - def _join_helper(self, left: T, right: T) -> T: - raise NotImplementedError() - - -@dataclass(frozen=True) -class DataflowProperty[T]: - forward: bool - initial: T - transfer_func: dict[Node, Callable[[T], T]] - lattice: Lattice[T] - - -def compute_dataflow_property[T]( - cfg: Cfg, - start_end: list[tuple[Node, Node]], - df_property: DataflowProperty[T], -) -> dict[Node, tuple[TB[T], TB[T]]]: - """ - Compute a single dataflow property. We're not doing dependent dataflow properties like the paper. See section 5.3.2 - and figure 27. - TODO: filter the CFG to efficiently handle all the identity function transfer functions. - """ - prop = {} - for node, nexts in cfg.items(): - prop[node] = df_property.lattice.bottom - for next_node in nexts: - prop[next_node] = df_property.lattice.bottom - - if df_property.forward: - extremals = [start for start, _ in start_end] - else: - rev_cfg = {} - for node, nexts in cfg.items(): - for next_node in nexts: - rev_cfg.setdefault(next_node, set()).add(node) - cfg = rev_cfg - extremals = [end for _, end in start_end] - - for node in extremals: - prop[node] = df_property.initial - - sccs = compute_sccs(cfg, extremals) - - for scc in sccs: - done = False - while not done: - done = True - for node in scc: - for next_node in cfg.get(node, set()): - assert not isinstance(prop[node], SymbolicTop | SymbolicBottom) - step = df_property.transfer_func[node](prop[node]) - if df_property.lattice.nleq(step, prop[next_node]): - prop[next_node] = df_property.lattice.join(step, prop[next_node]) - if next_node in scc: - done = False - - if df_property.forward: - return {node: (p, df_property.transfer_func[node](p)) for node, p in prop.items()} - return {node: (df_property.transfer_func[node](p), p) for node, p in prop.items()} - - -class MaySet[T](Lattice[frozenset[T]]): - def __init__(self): - super().__init__() - self.bottom = frozenset() - - def _nleq_helper(self, left: frozenset[T], right: frozenset[T]) -> bool: - return not (left <= right) - - def _join_helper(self, left: frozenset[T], right: frozenset[T]) -> frozenset[T]: - return left | right - - -class MustSet[T](Lattice[frozenset[T]]): - def __init__(self): - super().__init__() - self.top = frozenset() - - def _nleq_helper(self, left: frozenset[T], right: frozenset[T]) -> bool: - return not (left >= right) - - def _join_helper(self, left: frozenset[T], right: frozenset[T]) -> frozenset[T]: - return left & right diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py deleted file mode 100644 index e225fd1..0000000 --- a/src/styx_compiler/live_variables.py +++ /dev/null @@ -1,140 +0,0 @@ -from collections import defaultdict -from collections.abc import Callable, Sequence - -import libcst as cst -from libcst import matchers as m -from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource - -from styx_compiler.control_flow import ControlFlowGraphProvider, Node -from styx_compiler.data_flow import TB, DataflowProperty, MaySet, compute_dataflow_property -from styx_compiler.metadata_providers import IndexProvider - - -class CollectLiveVariablesTransferFunctions(cst.CSTVisitor): - """ - Computes the live variable analysis transfer functions for control-flow graph nodes - """ - - def __init__(self, provider: LiveVariablesDataflowPropertyProvider): - super().__init__() - self._provider = provider - self._tfs: dict[Node, Callable[[frozenset[QualifiedName]], frozenset[QualifiedName]]] = defaultdict( - lambda: lambda x: x - ) - - def resolve_name(self, node: cst.CSTNode) -> Sequence[QualifiedName] | None: - name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, node) - if len(name_origin) == 1: - [qual_name] = name_origin - if qual_name.source == QualifiedNameSource.LOCAL: - return [qual_name] - return None - - def get_dataflow_property(self) -> DataflowProperty: - return DataflowProperty(forward=False, initial=frozenset(), transfer_func=self._tfs, lattice=MaySet()) - - def _get_lhs_names(self, target: cst.BaseExpression) -> Sequence[QualifiedName]: - if m.matches(target, m.Attribute()): - target: cst.Attribute = cst.ensure_type(target, cst.Attribute) - result = self.resolve_name(target) - if result is not None: - return result - if m.matches(target, m.Subscript()): - # target: cst.Subscript = cst.ensure_type(target, cst.Subscript) - # Don't return the name of the target.value, since only the subscripted part is written to - return [] - if m.matches(target, m.StarredElement() | m.Element()): - return self._get_lhs_names(target.value) - if m.matches(target, m.Name()): - target: cst.Name = cst.ensure_type(target, cst.Name) - result = self.resolve_name(target) - if result is not None: - return result - if m.matches(target, m.List() | m.Tuple()): - # noinspection PyUnresolvedReferences - return [name for el in target.elements for name in self._get_lhs_names(el)] - return [] - - def leave_Module(self, module: cst.Module) -> None: - self._provider.set_metadata(module, self.get_dataflow_property()) - - def visit_Param(self, node: cst.Param) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - result = self.resolve_name(node.name) - if result is not None: - self._tfs[Node(index, 0)] = lambda lives, names=result: lives.difference(names) - - def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - def visit_Del(self, node: cst.Del) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - def visit_For(self, node: cst.For) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - def visit_AsName(self, node: cst.AsName) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.name) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - def visit_Name(self, node: cst.Name) -> bool | None: - index = self._provider.get_metadata(IndexProvider, node) - result = self.resolve_name(node) - if result is not None: - self._tfs[Node(index, 0)] = lambda lives, names=result: lives.union(names) - - -class LiveVariablesDataflowPropertyProvider(cst.BatchableMetadataProvider[DataflowProperty]): - METADATA_DEPENDENCIES = (IndexProvider, QualifiedNameProvider) - - def visit_Module(self, module: cst.Module) -> None: - module.visit(CollectLiveVariablesTransferFunctions(self)) - - -class LiveVariablesVisitor(cst.CSTVisitor): - def __init__( - self, - provider: LiveVariablesProvider, - live_vars: dict[Node, tuple[TB[frozenset[QualifiedName]], TB[frozenset[QualifiedName]]]], - ): - super().__init__() - self._provider: LiveVariablesProvider = provider - self.live_vars: dict[Node, tuple[TB[frozenset[QualifiedName]], TB[frozenset[QualifiedName]]]] = live_vars - - def on_visit(self, node: cst.CSTNode) -> bool: - if m.matches(node, m.SimpleWhitespace() | m.TrailingWhitespace()): - return False - if self._provider.get_metadata(IndexProvider, node, None) is not None: - cfg_node = Node(self._provider.get_metadata(IndexProvider, node), 0) - if cfg_node in self.live_vars: - self._provider.set_metadata(node, self.live_vars[cfg_node]) - return True - - -class LiveVariablesProvider( - cst.BatchableMetadataProvider[tuple[frozenset[TB[QualifiedName]], frozenset[TB[QualifiedName]]]] -): - METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider, LiveVariablesDataflowPropertyProvider) - - def visit_Module(self, node: cst.Module) -> bool | None: - cfg, start_end = self.get_metadata(ControlFlowGraphProvider, node) - lv_prop = self.get_metadata(LiveVariablesDataflowPropertyProvider, node) - lv_result = compute_dataflow_property(cfg, start_end, lv_prop) - node.visit(LiveVariablesVisitor(self, lv_result)) diff --git a/src/styx_compiler/liveness.py b/src/styx_compiler/liveness.py index d04f9c6..aa65d6a 100644 --- a/src/styx_compiler/liveness.py +++ b/src/styx_compiler/liveness.py @@ -8,6 +8,7 @@ import libcst as cst from libcst import matchers as m +from libcst_dfa.data_flow import ImmutableSet class LivenessHelper: @@ -22,16 +23,16 @@ def simple_name(v) -> str: return str(v.name).split(".")[-1] return str(v).split(".")[-1] if "." in str(v) else str(v) - def _live_set(self, node: cst.CSTNode | None, kind: str) -> frozenset | None: + def _live_set(self, node: cst.CSTNode | None, kind: str) -> ImmutableSet | None: if not self.live_vars or node is None: return None data = self.live_vars.get(node) if data is None: return None val = data[0] if kind == "in" else data[1] - return val if isinstance(val, frozenset) else None + return val if isinstance(val, ImmutableSet) else None - def live_out(self, stmt: cst.CSTNode) -> frozenset | None: + def live_out(self, stmt: cst.CSTNode) -> ImmutableSet | None: """Live-out set for a statement, or None if no liveness data.""" if isinstance(stmt, cst.SimpleStatementLine) and stmt.body: element = stmt.body[0] diff --git a/src/styx_compiler/metadata_providers.py b/src/styx_compiler/metadata_providers.py deleted file mode 100644 index 9696467..0000000 --- a/src/styx_compiler/metadata_providers.py +++ /dev/null @@ -1,21 +0,0 @@ -import libcst as cst -import libcst.matchers as m - - -class IndexProvider(cst.VisitorMetadataProvider[int]): - """ - Gives each CST node a number to refer to them uniquely - """ - - def __init__(self, index: int = 0): - super().__init__() - self._index = index - - def on_visit(self, node: cst.CSTNode) -> bool: - if m.matches(node, m.SimpleWhitespace() | m.TrailingWhitespace()): - return False - if not self.get_metadata(type(self), node, False): - self.set_metadata(node, self._index) - self._index += 1 - return True - return False diff --git a/tests/test_cfg.py b/tests/test_cfg.py deleted file mode 100644 index 10b769b..0000000 --- a/tests/test_cfg.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Tests for styx_compiler.control_flow.""" - -import libcst as cst -import libcst.matchers as m - -from styx_compiler.control_flow import CfgNode, ControlFlowGraphProvider, Node -from styx_compiler.metadata_providers import IndexProvider - -add_fundef = """ -def add(a: int, b: int) -> int: - return a + b -""" - - -user_item = """ -@entity -class Item: - def __init__(self, item_name: str, price: int): - self.item_name: str = item_name - self.stock: int = 0 - self.price: int = price - - def get_price(self) -> int: - return self.price - - def get_stock(self) -> int: - return self.stock - - def update_stock(self, amount: int) -> bool: - if (self.stock + amount) < 0: - raise OutOfStock("Not enough stock to update.") - - self.stock += amount - return True - - def __key__(self): - return self.item_name -""" - - -nested_try = """ -def nested_try(a: int) -> int: - try: - if a < 0: - return 0 - elif a > 10: - raise RuntimeError - elif a == 10: - raise OutOfStock("Not enough stock to update.") - a += 1 - except OutOfStock: - if a < 0: - return 0 - elif a > 10: - raise RuntimeError - a += 1 - else: - if a > 10: - raise OutOfStock("Not enough stock to update.") - a += 1 - finally: - if a < 10: - return 9001 - return a + 42 -""" - - -def test_node_existence(): - node_existence(add_fundef) - node_existence(user_item) - node_existence(nested_try) - - -def node_existence(source_string: str): - source_tree = cst.parse_module(source_string) - wrapper = cst.MetadataWrapper(source_tree) - cnt = CfgNodeTester() - wrapper.visit(cnt) - - -class CfgNodeTester(cst.CSTVisitor): - """ - Checks that each kind of CST Node that should have a corresponding CFG node has one - """ - - METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider) - - def __init__(self): - super().__init__() - self.cfg = None - self.active = False - - def visit_Module(self, node: cst.Module) -> bool | None: - self.cfg, _start_end = self.get_metadata(ControlFlowGraphProvider, node) - - def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: - """ - Tests if the CSTNode has a corresponding CFG node with outgoing edges - """ - n = Node(self.get_metadata(IndexProvider, node), instance) - return n in self.cfg - - def visit_Param(self, node: cst.Param) -> bool | None: - if self.active: - assert self._has_node(node) - return False - return None - - def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_NameItem(self, node: cst.NameItem) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Attribute(self, node: cst.Attribute) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Name(self, node: cst.Name) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_AsName(self, node: cst.AsName) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_UnaryOperation(self, node: cst.UnaryOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_BinaryOperation(self, node: cst.BinaryOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_BooleanOperation(self, node: cst.BooleanOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Await(self, node: cst.Await) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Yield(self, node: cst.Yield) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_From(self, node: cst.From) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Integer(self, node: cst.Integer) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Float(self, node: cst.Float) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Imaginary(self, node: cst.Imaginary) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_SimpleString(self, node: cst.SimpleString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FormattedStringExpression(self, node: cst.FormattedStringExpression) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FormattedString(self, node: cst.FormattedString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Tuple(self, node: cst.Tuple) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_List(self, node: cst.List) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Set(self, node: cst.Set) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Element(self, node: cst.Element) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_StarredElement(self, node: cst.StarredElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_DictElement(self, node: cst.DictElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_StarredDictElement(self, node: cst.StarredDictElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ListComp(self, node: cst.ListComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_SetComp(self, node: cst.SetComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_DictComp(self, node: cst.DictComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Index(self, node: cst.Index) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Slice(self, node: cst.Slice) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Subscript(self, node: cst.Subscript) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - assert self._has_node(node) - # We're not testing for instance 1, which is a final node and will not have outgoing edges - - def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = True - - def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = False - - def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = True - - def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = False - - def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = False - - def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = True - - def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self.active = False - - def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self.active = True - - -loop_test = """ -@entity -class Something: - def loop_test(self, cart: list[Item]) -> int: - val = 0 - - for item in cart: - attr_1 = item.get_price() - val += attr_1 - - temp = 3 - - val += temp - - return val -""" - - -def test_loop_test_cfg(): - module = cst.parse_module(loop_test) - wrapper = cst.MetadataWrapper(module) - ltct = LoopTestCfgTester() - wrapper.visit(ltct) - - -class LoopTestCfgTester(cst.CSTVisitor): - METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider) - - def __init__(self): - super().__init__() - self.cfg: dict[CfgNode, set[CfgNode]] | None = None - self.start_end: list[tuple[CfgNode, CfgNode]] | None = None - - def is_edge(self, from_node: CfgNode, to_node: CfgNode) -> bool: - return from_node in self.cfg and to_node in self.cfg[from_node] - - def assert_node(self, node: cst.CSTNode, prev: CfgNode) -> CfgNode: - cfg_node = Node(self.get_metadata(IndexProvider, node), 0) - assert cfg_node in self.cfg - assert self.is_edge(prev, cfg_node) - return cfg_node - - def visit_Module(self, node: cst.Module) -> bool | None: - self.cfg, self.start_end = self.get_metadata(ControlFlowGraphProvider, node) - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - assert node.name.value == "loop_test" - - # unpack - - init_val: cst.Assign - for_loop: cst.For - init_temp: cst.Assign - update_val: cst.AugAssign - return_val: cst.Return - - init_val, for_loop, init_temp, update_val, return_val = ( - cst.ensure_type(ssl, cst.SimpleStatementLine).body[0] if m.matches(ssl, m.SimpleStatementLine()) else ssl - for ssl in node.body.body - ) - - for_iter_cart: cst.Name = cst.ensure_type(for_loop.iter, cst.Name) - - for_init_attr_1: cst.Assign - for_update_val: cst.AugAssign - - for_init_attr_1, for_update_val = ( - cst.ensure_type(ssl, cst.SimpleStatementLine).body[0] for ssl in for_loop.body.body - ) - - for_init_attr_1_call: cst.Call = cst.ensure_type(for_init_attr_1.value, cst.Call) - for_init_attr_1_call_expr: cst.Attribute = cst.ensure_type(for_init_attr_1_call.func, cst.Attribute) - for_init_attr_1_call_item: cst.Name = cst.ensure_type(for_init_attr_1_call_expr.value, cst.Name) - - # FunctionDef - - index = self.get_metadata(IndexProvider, node) - start = Node(index, 0) - end = Node(index, 1) - assert start in self.cfg - - assert (start, end) in self.start_end - - param_names = ["self", "cart"] - prev = start - for param, param_name in zip(node.params.params, param_names, strict=True): - assert param.name.value == param_name - prev = self.assert_node(param, prev) - - # FunctionDef Assign - - prev = self.assert_node(init_val.value, prev) - prev = self.assert_node(init_val.targets[0], prev) - - # FunctionDef For - - prev = self.assert_node(for_iter_cart, prev) - prev = self.assert_node(for_loop, prev) - - # FunctionDef For Assign Call Attribute - - prev = self.assert_node(for_init_attr_1_call_item, prev) - prev = self.assert_node(for_init_attr_1_call_expr, prev) - - # FunctionDef For Assign Call - - prev = self.assert_node(for_init_attr_1_call, prev) - - # FunctionDef For Assign - - prev = self.assert_node(for_init_attr_1.targets[0], prev) - - # FunctionDef For AugAssign - - prev = self.assert_node(for_update_val.target, prev) - prev = self.assert_node(for_update_val.value, prev) - prev = self.assert_node(for_update_val, prev) - - # FunctionDef For (back-edge) - - assert self.is_edge(prev, Node(self.get_metadata(IndexProvider, for_iter_cart), 0)) - - # FunctionDef Assign - - prev = self.assert_node(init_temp.value, prev) - prev = self.assert_node(init_temp.targets[0], prev) - - # FunctionDef AugAssign - - prev = self.assert_node(update_val.target, prev) - prev = self.assert_node(update_val.value, prev) - prev = self.assert_node(update_val, prev) - - # FunctionDef Return - - prev = self.assert_node(return_val.value, prev) - assert self.is_edge(prev, end) diff --git a/tests/test_live_variables.py b/tests/test_live_variables.py deleted file mode 100644 index 22e612d..0000000 --- a/tests/test_live_variables.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Tests for styx_compiler.live_variables and indirectly for styx_compiler.data_flow.""" - -import libcst as cst -import libcst.matchers as m -from libcst.metadata import QualifiedName, QualifiedNameSource - -from styx_compiler.control_flow import ControlFlowGraphProvider, Node -from styx_compiler.live_variables import ( - LiveVariablesProvider, -) -from styx_compiler.metadata_providers import IndexProvider -from tests.test_cfg import add_fundef, loop_test - -# def test_add_fundef_live_vars(): -# source_tree = cst.parse_module(add_fundef) -# wrapper = cst.MetadataWrapper(source_tree) -# cfgp = ControlFlowGraphProvider() -# ccfg = ComputeControlFlowGraph(cfgp) -# wrapper.visit(ccfg) -# lvdpp = LiveVariablesDataflowPropertyProvider() -# clvtf = CollectLiveVariablesTransferFunctions(lvdpp) -# wrapper.visit(clvtf) -# lv_prop = clvtf.get_dataflow_property() -# lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) -# print(lv_result) -# -# -# def test_user_item_live_vars(): -# source_tree = cst.parse_module(user_item) -# wrapper = cst.MetadataWrapper(source_tree) -# cfgp = ControlFlowGraphProvider() -# ccfg = ComputeControlFlowGraph(cfgp) -# wrapper.visit(ccfg) -# lvdpp = LiveVariablesDataflowPropertyProvider() -# clvtf = CollectLiveVariablesTransferFunctions(lvdpp) -# wrapper.visit(clvtf) -# lv_prop = clvtf.get_dataflow_property() -# lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) -# print(lv_result) - - -def test_add_fundef_live_vars_provider(): - source_tree = cst.parse_module(add_fundef) - wrapper = cst.MetadataWrapper(source_tree) - lvt1 = LiveVariablesTester1() - wrapper.visit(lvt1) - - -class LiveVariablesTester1(cst.CSTVisitor): - """ - Checks that each kind of CST Node that should have a corresponding CFG node has one - """ - - METADATA_DEPENDENCIES = (LiveVariablesProvider,) - - @staticmethod - def local(name: str) -> QualifiedName: - return QualifiedName(name=f"add..{name}", source=QualifiedNameSource.LOCAL) - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - empty = frozenset() - ab = frozenset([self.local("a"), self.local("b")]) - a = frozenset([self.local("a")]) - assert self.get_metadata(LiveVariablesProvider, node.params.params[1], None) == (a, ab) - assert self.get_metadata(LiveVariablesProvider, node.params.params[0], None) == (empty, a) - assert self.get_metadata(LiveVariablesProvider, node, None) == (empty, empty) - - -def test_loop_test_live_vars_provider(): - source_tree = cst.parse_module(loop_test) - wrapper = cst.MetadataWrapper(source_tree) - lvt1 = LiveVariablesTester2() - wrapper.visit(lvt1) - - -class LiveVariablesTester2(cst.CSTVisitor): - """ - Checks that each kind of CST Node that should have a corresponding CFG node has one - """ - - METADATA_DEPENDENCIES = (LiveVariablesProvider, ControlFlowGraphProvider, IndexProvider) - - def __init__(self): - super().__init__() - self._cfg = None - - def is_edge(self, from_node: cst.CSTNode, to_node: cst.CSTNode) -> bool: - from_idx: int | None = self.get_metadata(IndexProvider, from_node, None) - if from_idx is None: - return False - from_cfg = Node(from_idx, 0) - if from_cfg not in self._cfg: - return False - to_idx: int | None = self.get_metadata(IndexProvider, to_node, None) - if to_idx is None: - return False - to_cfg = Node(to_idx, 0) - return to_cfg in self._cfg[from_cfg] - - @staticmethod - def local(name: str) -> QualifiedName: - return QualifiedName(name=f"Something.loop_test..{name}", source=QualifiedNameSource.LOCAL) - - def visit_Module(self, node: cst.Module) -> bool | None: - self._cfg, _ = self.get_metadata(ControlFlowGraphProvider, node) - - def visit_Call(self, node: cst.Call) -> bool | None: - assert self.get_metadata(LiveVariablesProvider, node, None)[1] == frozenset( - [self.local("val"), self.local("cart")] - ) - - # def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - # print(node) - # assert self.is_edge(node.target, node.value), ( - # "AugAssign target should have an edge to AugAssign value is both are cst.Name" - # ) - # print(self.get_metadata(LiveVariablesProvider, node.target, None)) - # print(self.get_metadata(LiveVariablesProvider, node.value, None)) - # print(self.get_metadata(LiveVariablesProvider, node, None)) - # - # def visit_Assign(self, node: cst.Assign) -> bool | None: - # print(node) - # print(self.get_metadata(LiveVariablesProvider, node.value, None)) - # for target in node.targets: - # print(self.get_metadata(LiveVariablesProvider, target, None)) - # - # def visit_Return(self, node: cst.Return) -> bool | None: - # print(node) - # print(self.get_metadata(LiveVariablesProvider, node.value, None)) - - -subscript_assign = """ -def subscript_assign(lst, i, val): - lst[i] = val - return lst -""" - - -def test_subscript_assign_keeps_list_live(): - """lst[i] = val must not kill lst — lst is mutated, not redefined.""" - source_tree = cst.parse_module(subscript_assign) - wrapper = cst.MetadataWrapper(source_tree) - wrapper.visit(SubscriptAssignTester()) - - -class SubscriptAssignTester(cst.CSTVisitor): - METADATA_DEPENDENCIES = (LiveVariablesProvider,) - - @staticmethod - def local(name: str) -> QualifiedName: - return QualifiedName(name=f"subscript_assign..{name}", source=QualifiedNameSource.LOCAL) - - def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: - live_in, live_out = self.get_metadata(LiveVariablesProvider, node) - assert self.local("lst") in live_out - assert self.local("lst") in live_in - - -aug_subscript = """ -def aug_subscript(lst, i): - lst[i] += 1 - return lst -""" - - -def test_aug_subscript_keeps_list_live(): - source_tree = cst.parse_module(aug_subscript) - wrapper = cst.MetadataWrapper(source_tree) - wrapper.visit(AugSubscriptTester()) - - -class AugSubscriptTester(cst.CSTVisitor): - METADATA_DEPENDENCIES = (LiveVariablesProvider,) - - @staticmethod - def local(name: str) -> QualifiedName: - return QualifiedName(name=f"aug_subscript..{name}", source=QualifiedNameSource.LOCAL) - - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - live_in, live_out = self.get_metadata(LiveVariablesProvider, node) - assert self.local("lst") in live_out - assert self.local("lst") in live_in - - -method_call_source = """ -def method_call_func(items, item): - items.append(item) - return items -""" - - -def test_method_call_does_not_add_method_name(): - """items.append(item) — 'append' is a method, not a live variable.""" - source_tree = cst.parse_module(method_call_source) - wrapper = cst.MetadataWrapper(source_tree) - wrapper.visit(MethodCallLiveVarsTester()) - - -class MethodCallLiveVarsTester(cst.CSTVisitor): - METADATA_DEPENDENCIES = (LiveVariablesProvider,) - - def visit_Attribute(self, node: cst.Attribute) -> bool | None: - metadata = self.get_metadata(LiveVariablesProvider, node, None) - if metadata is None: - return None - live_in, _ = metadata - items_append = QualifiedName( - name="method_call_func..items.append", - source=QualifiedNameSource.LOCAL, - ) - assert items_append not in live_in - - -# Advanced test: nested loops with an if condition. -# Expected live sets computed by hand via backward fixed-point analysis: -# -# Lout (live at outer-loop header xs Name) = {xs, ys, limit, total} -# Lin (live at inner-loop header ys Name) = {xs, ys, x, limit, total} -# -# Outer For node kills x → live_in misses x, live_out has it. -# Inner For node kills y → live_in misses y, live_out has it. -# ComparisonTarget → all variables x, y, xs, ys, limit, total live. -# AugAssign kills total → live_in misses total, live_out has it (needed by -# the next iteration and the return). - -nested_loops_if = """ -def nested_loops_if(xs, ys, limit): - total = 0 - for x in xs: - for y in ys: - if x < limit: - total += y - return total -""" - - -def test_nested_loops_if_live_vars(): - source_tree = cst.parse_module(nested_loops_if) - wrapper = cst.MetadataWrapper(source_tree) - wrapper.visit(NestedLoopsIfTester()) - - -class NestedLoopsIfTester(cst.CSTVisitor): - METADATA_DEPENDENCIES = (LiveVariablesProvider,) - - @staticmethod - def local(name: str) -> QualifiedName: - return QualifiedName(name=f"nested_loops_if..{name}", source=QualifiedNameSource.LOCAL) - - def visit_For(self, node: cst.For) -> bool | None: - live_in, live_out = self.get_metadata(LiveVariablesProvider, node) - xs, ys, x, y, limit, total = (self.local(n) for n in ("xs", "ys", "x", "y", "limit", "total")) - - if m.matches(node.iter, m.Name(value="xs")): - assert live_in == frozenset([xs, ys, limit, total]) - assert live_out == frozenset([xs, ys, x, limit, total]) - - elif m.matches(node.iter, m.Name(value="ys")): - assert live_in == frozenset([xs, ys, x, limit, total]) - assert live_out == frozenset([xs, ys, x, y, limit, total]) - - def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None: - live_in, live_out = self.get_metadata(LiveVariablesProvider, node) - expected = frozenset(self.local(n) for n in ("xs", "ys", "x", "y", "limit", "total")) - assert live_in == expected - assert live_out == expected - - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - live_in, live_out = self.get_metadata(LiveVariablesProvider, node) - xs, ys, x, limit, total = (self.local(n) for n in ("xs", "ys", "x", "limit", "total")) - assert live_out == frozenset([xs, ys, x, limit, total]) - assert live_in == frozenset([xs, ys, x, limit]) diff --git a/uv.lock b/uv.lock index e179143..97c2a41 100644 --- a/uv.lock +++ b/uv.lock @@ -341,6 +341,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] +[[package]] +name = "immutables" +version = "0.21" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/69/41/0ccaa6ef9943c0609ec5aa663a3b3e681c1712c1007147b84590cec706a0/immutables-0.21.tar.gz", hash = "sha256:b55ffaf0449790242feb4c56ab799ea7af92801a0a43f9e2f4f8af2ab24dfc4a", size = 89008, upload-time = "2024-10-10T00:55:01.434Z" } + [[package]] name = "iniconfig" version = "2.3.0" @@ -448,6 +454,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/5d6a790a02eb0d9d36c4aed4f41b277497e6178900b2fa29c35353aa45ed/libcst-1.8.6-cp314-cp314t-win_arm64.whl", hash = "sha256:819c8081e2948635cab60c603e1bbdceccdfe19104a242530ad38a36222cb88f", size = 2065000, upload-time = "2025-11-03T22:33:16.257Z" }, ] +[[package]] +name = "libcst-dfa" +version = "0.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "immutables" }, + { name = "libcst" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/11/624a418d6ba2e686a158459dc31416ff191d94308c254c5e7219a1555898/libcst_dfa-0.0.1.tar.gz", hash = "sha256:6054255c951cf7a3813e23ae179591767ea9ae3dcadf656dd26f4336acfdfe7d", size = 87271, upload-time = "2026-06-09T14:31:22.958Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/39/b62afe5d64c3ee346c19e1dba56ee3ddfd9b28a6bfb6fe24835b2f4069e8/libcst_dfa-0.0.1-py3-none-any.whl", hash = "sha256:056bcec83ef8022de554a9392b5b0553b0c99a8edc99cc8dfd28612b262603e6", size = 15813, upload-time = "2026-06-09T14:31:24.081Z" }, +] + [[package]] name = "libcst-mypy" version = "0.1.0" @@ -1040,6 +1059,7 @@ name = "styx-compiler" source = { editable = "." } dependencies = [ { name = "libcst" }, + { name = "libcst-dfa" }, { name = "libcst-mypy" }, { name = "mypy" }, ] @@ -1058,6 +1078,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "libcst", specifier = ">=1.8.6" }, + { name = "libcst-dfa", specifier = ">=0.0.1" }, { name = "libcst-mypy", specifier = ">=0.1.0" }, { name = "mypy", specifier = ">=1.19.1" }, ]