diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a37e517..c8edb81 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,12 +12,13 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 - hooks: - - id: debug-statements - - id: name-tests-test - args: ["--pytest-test-first"] +# This hook cannot handle newer Python syntax like type aliases yet +# - repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v6.0.0 +# hooks: +# - id: debug-statements +# - id: name-tests-test +# args: ["--pytest-test-first"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.13 diff --git a/pyproject.toml b/pyproject.toml index 31a92ff..02cb42d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,13 @@ dependencies = [ [dependency-groups] dev = [ - "hatch>=1.13.0,<2", - "prek>=0.3.6,<2", - "pytest>=9.0.2,<10", - "pytest-cov>=7.0.0,<8", - "mkdocstrings[python]>=1.0.3,<2", - "mkdocs-material>=9.7.5,<10", - "setuptools-scm>=9.2.2,<10", + "hatch>=1.13.0", + "prek>=0.3.6", + "pytest>=9.0.2", + "pytest-cov>=7.0.0", + "mkdocstrings[python]>=1.0.3", + "mkdocs-material>=9.7.5", + "setuptools-scm>=9.2.2", ] [build-system] diff --git a/src/styx_compiler/control_flow.py b/src/styx_compiler/control_flow.py index 140dcb3..a9a58d0 100644 --- a/src/styx_compiler/control_flow.py +++ b/src/styx_compiler/control_flow.py @@ -45,30 +45,24 @@ class ComputeControlFlowGraph(cst.CSTVisitor): Computes the control-flow graph of the code, expressed in indices from the IndexProvider """ - METADATA_DEPENDENCIES = (IndexProvider,) - - def __init__(self): + def __init__(self, provider: ControlFlowGraphProvider): super().__init__() + self._provider = provider self._cfg: dict[CfgNode, set[CfgNode]] = {} self._start_end: list[tuple[CfgNode, CfgNode]] = [] def _edge(self, prev: list[CfgNode], cur: CfgNode) -> list[CfgNode]: for p in prev: - if p not in self._cfg: - self._cfg[p] = set() - self._cfg[p].add(cur) + self._cfg.setdefault(p, set()).add(cur) return [cur] def _edges(self, prev: list[CfgNode], tos: list[CfgNode]) -> list[CfgNode]: for p in prev: - if p not in self._cfg: - self._cfg[p] = set() - for to in tos: - self._cfg[p].add(to) + self._cfg.setdefault(p, set()).update(tos) return tos def _make_cfg_node(self, cst_node: cst.CSTNode, instance: int, prev: list[CfgNode]) -> list[CfgNode]: - cur = Node(self.get_metadata(IndexProvider, cst_node), instance) + cur = Node(self._provider.get_metadata(IndexProvider, cst_node), instance) return self._edge(prev, cur) def _clean_up_cfg_ghosts(self, start: CfgNode) -> None: @@ -96,7 +90,7 @@ def _clean_up_cfg_ghosts(self, start: CfgNode) -> None: seen.add(next_node) def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - index = self.get_metadata(IndexProvider, node) + index = self._provider.get_metadata(IndexProvider, node) start = Node(index, 0) end = Node(index, 1) self._start_end.append((start, end)) @@ -114,7 +108,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self._clean_up_cfg_ghosts(start) - def leave_Module(self, node: cst.Module) -> None: + def leave_Module(self, module: cst.Module) -> None: # Remove unreachable parts of the CFG (e.g. unused finally clause instantiations, dead code after a return) reachable = set() workstack = [] @@ -136,6 +130,8 @@ def leave_Module(self, node: cst.Module) -> None: for k in to_remove: del self._cfg[k] + self._provider.set_metadata(module, (self._cfg, self._start_end)) + def _visit_BaseSuite( self, statements: cst.BaseSuite | cst.SimpleStatementLine, @@ -188,15 +184,16 @@ def _visit_statement( prev = self._visit_expression(statement.value, instance, prev) # then the multiple LHS, from left to right for target in statement.targets: + prev = self._visit_expression(target.target, instance, prev) prev = self._make_cfg_node(target, instance, prev) # AssignTarget elif m.matches(statement, m.AugAssign()): statement: cst.AugAssign = cst.ensure_type(statement, cst.AugAssign) - # note we're making the AugAssign a node first to represent reading the value from the target - prev = self._make_cfg_node(statement, instance, prev) # AugAssign + # note we're visiting LHS first to represent reading the value from the target + prev = self._visit_expression(statement.target, instance, prev) # then we visit the RHS expression to find more reads prev = self._visit_expression(statement.value, instance, prev) - # finally we write to the LHS - prev = self._visit_expression(statement.target, instance, prev) + # finally we write to the LHS, represented by a node of the whole assignment + prev = self._make_cfg_node(statement, instance, prev) # AugAssign elif m.matches(statement, m.Break()): if loop_break_target is None: msg = "Found break outside of loop" @@ -255,7 +252,7 @@ def _visit_statement( raise NotImplementedError(msg) elif m.matches(statement, m.For()): statement: cst.For = cst.ensure_type(statement, cst.For) - index = self.get_metadata(IndexProvider, statement) + index = self._provider.get_metadata(IndexProvider, statement) for_loop_continue_target = Ghost(index, 0) prev = self._edge(prev, for_loop_continue_target) loop_expr_prev = self._visit_expression(statement.iter, instance, prev) @@ -319,7 +316,7 @@ def _visit_statement( def wrap_in_finally(exit: CfgNode) -> CfgNode: nonlocal statement, finally_number, fn_end, exception_target, loop_continue_target, loop_break_target if statement.finalbody is not None: - entry = Ghost(self.get_metadata(IndexProvider, statement.finalbody), finally_number) + entry = Ghost(self._provider.get_metadata(IndexProvider, statement.finalbody), finally_number) finalbody: cst.Finally = cst.ensure_type(statement.finalbody, cst.Finally) prev = self._visit_BaseSuite( finalbody.body, @@ -348,7 +345,7 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: # to the next conditional for handler in statement.handlers: handler: cst.ExceptHandler = cst.ensure_type(handler, cst.ExceptHandler) # noqa: PLW2901 - handler_index = self.get_metadata(IndexProvider, handler) + handler_index = self._provider.get_metadata(IndexProvider, handler) handler_entry = Ghost(handler_index, 0) handler_exit = Ghost(handler_index, 1) handler_entries.append(handler_entry) @@ -360,6 +357,7 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: handler_exits.append(handler_exit) if handler.name is not None: + handler_prev = self._visit_expression(handler.name.name, instance, handler_prev) handler_prev = self._make_cfg_node(handler.name, instance, handler_prev) # AsName handler_prev = self._visit_BaseSuite( handler.body, @@ -404,14 +402,14 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: loop_break_target=local_loop_break_target, ) # Ghost node for exiting the finally clause normally - try_exit = Ghost(self.get_metadata(IndexProvider, statement), 0) + try_exit = Ghost(self._provider.get_metadata(IndexProvider, statement), 0) finally_entry = wrap_in_finally(try_exit) # The normal entry into a normal finally clause at the end of the body/else or handler self._edge([*prev, *handler_exits], finally_entry) prev = [try_exit] elif m.matches(statement, m.While()): statement: cst.While = cst.ensure_type(statement, cst.While) - index = self.get_metadata(IndexProvider, statement) + index = self._provider.get_metadata(IndexProvider, statement) while_loop_continue_target = Ghost(index, 0) prev = self._edge(prev, while_loop_continue_target) prev = self._visit_expression(statement.test, instance, prev) @@ -553,6 +551,7 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: prev = self._visit_expression(expression.func, instance, prev) for arg in expression.args: prev = self._visit_expression(arg.value, instance, prev) + prev = self._make_cfg_node(expression, instance, prev) # Call ## Literal Values elif m.matches(expression, m.Ellipsis()): pass @@ -641,7 +640,7 @@ def _visit_CompFor( elt: cst.BaseExpression | tuple[cst.BaseExpression, cst.BaseExpression], prev: list[CfgNode], ) -> list[CfgNode]: - exit = Ghost(self.get_metadata(IndexProvider, for_in), 0) + exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0) prev = self._visit_expression(for_in.iter, instance, prev) prev = self._visit_expression(for_in.target, instance, prev) for compif in for_in.ifs: @@ -658,193 +657,19 @@ def _visit_CompFor( prev = self._visit_expression(elt, instance, prev) return self._edge(prev, exit) + @property + def cfg(self): + return self._cfg -class CfgNodeTester(cst.CSTVisitor): - """ - Checks that each kind of CST Node that should have a corresponding CFG node has one - """ - - METADATA_DEPENDENCIES = (IndexProvider,) - - def __init__(self, cfg: dict[Node, set[Node]]): - super().__init__() - self.cfg = cfg - self.active = False - - def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: - """ - Tests if the CSTNode has a corresponding CFG node with outgoing edges - """ - n = Node(self.get_metadata(IndexProvider, node), instance) - return n in self.cfg - - def visit_Param(self, node: cst.Param) -> bool | None: - if self.active: - assert self._has_node(node) - # TODO: should we visit deeper into Param too? - return False - return None - - def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: - if self.active: - assert self._has_node(node) - # TODO: should we visit deeper into AssignTarget too? - return False - return None - - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_NameItem(self, node: cst.NameItem) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Attribute(self, node: cst.Attribute) -> bool | None: - if self.active: - assert self._has_node(node) - # TODO: should we visit deeper into Attribute too? - return False - return None - - def visit_Name(self, node: cst.Name) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_AsName(self, node: cst.AsName) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_UnaryOperation(self, node: cst.UnaryOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_BinaryOperation(self, node: cst.BinaryOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_BooleanOperation(self, node: cst.BooleanOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Await(self, node: cst.Await) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Yield(self, node: cst.Yield) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_From(self, node: cst.From) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Integer(self, node: cst.Integer) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Float(self, node: cst.Float) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Imaginary(self, node: cst.Imaginary) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_SimpleString(self, node: cst.SimpleString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FormattedStringExpression(self, node: cst.FormattedStringExpression) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FormattedString(self, node: cst.FormattedString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Tuple(self, node: cst.Tuple) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_List(self, node: cst.List) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Set(self, node: cst.Set) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Element(self, node: cst.Element) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_StarredElement(self, node: cst.StarredElement) -> bool | None: - if self.active: - assert self._has_node(node) + @property + def start_end(self): + return self._start_end - def visit_DictElement(self, node: cst.DictElement) -> bool | None: - if self.active: - assert self._has_node(node) - def visit_StarredDictElement(self, node: cst.StarredDictElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ListComp(self, node: cst.ListComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_SetComp(self, node: cst.SetComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_DictComp(self, node: cst.DictComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Index(self, node: cst.Index) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Slice(self, node: cst.Slice) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Subscript(self, node: cst.Subscript) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - assert self._has_node(node) - # We're not testing for instance 1, which is a final node and will not have outgoing edges - - def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = True - - def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = False - - def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = True - - def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = False - - def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = False +class ControlFlowGraphProvider( + cst.BatchableMetadataProvider[tuple[dict[CfgNode, set[CfgNode]], list[tuple[CfgNode, CfgNode]]]] +): + METADATA_DEPENDENCIES = (IndexProvider,) - def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = True + def visit_Module(self, node: cst.Module) -> bool | None: + node.visit(ComputeControlFlowGraph(self)) diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py new file mode 100644 index 0000000..67ff43e --- /dev/null +++ b/src/styx_compiler/data_flow.py @@ -0,0 +1,193 @@ +""" +Data-flow analysis engine as described in "FlowSpec: A declarative specification language for intra-procedural +flow-sensitive data-flow analysis" by Smits, Wachsmuth and Visser (https://doi.org/10.1016/j.cola.2019.100924). +""" + +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass + +from styx_compiler.control_flow import Node + +type Cfg = dict[Node, set[Node]] + + +def compute_sccs(cfg: Cfg, extremals: list[Node]) -> list[list[Node]]: + """ + Tarjan's Strongly Connected Component algorithm, with a slight modification to force the order of nodes within an + SCC into a postorder traversal. See section 5.3.1 / figure 26. + + You can reverse the control-flow graphs and give the list of end nodes and this will work just as well. + + Parameters: + cfg (Cfg): one or more control-flow graphs + extremals (list[Node]): a list of start nodes to the control-flow graphs + + Returns: + list[list[Node]]: list of SCCs in topological order (use as a stack for topo order), where each SCC + in reverse postorder over the depth-first spanning tree of the SCC. + """ + index = 0 + scc_stack = [] + result = [] + node_index = {} + node_lowlink = {} + node_on_stack = set() + + def strong_connect(node: Node): + nonlocal index, scc_stack, result, node_index, node_lowlink, node_on_stack + node_index[node] = index + node_lowlink[node] = index + index += 1 + node_on_stack.add(node) + # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the + # recursive calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. + + for next_node in cfg.get(node, set()): + if next_node not in node_index: + strong_connect(next_node) + node_lowlink[node] = min(node_lowlink[node], node_lowlink[next_node]) + elif next_node in node_on_stack: + node_lowlink[node] = min(node_lowlink[node], node_index[next_node]) + + # Now we add the node to scc_stack in postorder + scc_stack.append(node) + + if node_lowlink[node] == node_index[node]: + scc = [scc_stack.pop()] + node_on_stack.remove(scc[-1]) + while scc[-1] != node: + scc.append(scc_stack.pop()) + node_on_stack.remove(scc[-1]) + result.append(list(reversed(scc))) + + for ext in extremals: + if ext not in node_index: + strong_connect(ext) + + return list(reversed(result)) + + +@dataclass(frozen=True) +class SymbolicTop: + pass + + +@dataclass(frozen=True) +class SymbolicBottom: + pass + + +type TB[T] = T | SymbolicTop | SymbolicBottom + + +class Lattice[T](ABC): + def __init__(self): + super().__init__() + self.top = SymbolicTop() + self.bottom = SymbolicBottom() + + def nleq(self, left: TB[T], right: TB[T]) -> bool: + if isinstance(left, SymbolicBottom) or isinstance(right, SymbolicTop): + return False + if isinstance(left, SymbolicTop) or isinstance(right, SymbolicBottom): + return True + return self._nleq_helper(left, right) + + @abstractmethod + def _nleq_helper(self, left: T, right: T) -> bool: + raise NotImplementedError() + + def join(self, left: TB[T], right: TB[T]) -> TB[T]: + if isinstance(left, SymbolicTop) or isinstance(right, SymbolicTop): + return SymbolicTop() + if isinstance(left, SymbolicBottom): + return right + if isinstance(right, SymbolicBottom): + return left + return self._join_helper(left, right) + + @abstractmethod + def _join_helper(self, left: T, right: T) -> T: + raise NotImplementedError() + + +@dataclass(frozen=True) +class DataflowProperty[T]: + forward: bool + initial: T + transfer_func: dict[Node, Callable[[T], T]] + lattice: Lattice[T] + + +def compute_dataflow_property[T]( + cfg: Cfg, + start_end: list[tuple[Node, Node]], + df_property: DataflowProperty[T], +) -> dict[Node, tuple[TB[T], TB[T]]]: + """ + Compute a single dataflow property. We're not doing dependent dataflow properties like the paper. See section 5.3.2 + and figure 27. + TODO: filter the CFG to efficiently handle all the identity function transfer functions. + """ + prop = {} + for node, nexts in cfg.items(): + prop[node] = df_property.lattice.bottom + for next_node in nexts: + prop[next_node] = df_property.lattice.bottom + + if df_property.forward: + extremals = [start for start, _ in start_end] + else: + rev_cfg = {} + for node, nexts in cfg.items(): + for next_node in nexts: + rev_cfg.setdefault(next_node, set()).add(node) + cfg = rev_cfg + extremals = [end for _, end in start_end] + + for node in extremals: + prop[node] = df_property.initial + + sccs = compute_sccs(cfg, extremals) + + for scc in sccs: + done = False + while not done: + done = True + for node in scc: + for next_node in cfg.get(node, set()): + assert not isinstance(prop[node], SymbolicTop | SymbolicBottom) + step = df_property.transfer_func[node](prop[node]) + if df_property.lattice.nleq(step, prop[next_node]): + prop[next_node] = df_property.lattice.join(step, prop[next_node]) + if next_node in scc: + done = False + + if df_property.forward: + return {node: (p, df_property.transfer_func[node](p)) for node, p in prop.items()} + return {node: (df_property.transfer_func[node](p), p) for node, p in prop.items()} + + +class MaySet[T](Lattice[frozenset[T]]): + def __init__(self): + super().__init__() + self.bottom = frozenset() + + def _nleq_helper(self, left: frozenset[T], right: frozenset[T]) -> bool: + return not (left <= right) + + def _join_helper(self, left: frozenset[T], right: frozenset[T]) -> frozenset[T]: + return left | right + + +class MustSet[T](Lattice[frozenset[T]]): + def __init__(self): + super().__init__() + self.top = frozenset() + + def _nleq_helper(self, left: frozenset[T], right: frozenset[T]) -> bool: + return not (left >= right) + + def _join_helper(self, left: frozenset[T], right: frozenset[T]) -> frozenset[T]: + return left & right diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py new file mode 100644 index 0000000..99a1aaf --- /dev/null +++ b/src/styx_compiler/live_variables.py @@ -0,0 +1,152 @@ +from collections import defaultdict +from collections.abc import Callable + +import libcst as cst +from libcst import matchers as m +from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource + +from styx_compiler.control_flow import ControlFlowGraphProvider, Node +from styx_compiler.data_flow import TB, DataflowProperty, MaySet, compute_dataflow_property +from styx_compiler.metadata_providers import IndexProvider + + +class CollectLiveVariablesTransferFunctions(cst.CSTVisitor): + """ + Computes the live variable analysis transfer functions for control-flow graph nodes + """ + + def __init__(self, provider: LiveVariablesDataflowPropertyProvider): + super().__init__() + self._provider = provider + self._active = False + self._tfs: dict[Node, Callable[[frozenset[str]], frozenset[str]]] = defaultdict(lambda: lambda x: x) + + def get_dataflow_property(self) -> DataflowProperty: + return DataflowProperty(forward=False, initial=frozenset(), transfer_func=self._tfs, lattice=MaySet()) + + def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: + if m.matches(target, m.Attribute()): + target: cst.Attribute = cst.ensure_type(target, cst.Attribute) + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, target.attr) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + return [target.attr.value] + if m.matches(target, m.Subscript()): + target: cst.Subscript = cst.ensure_type(target, cst.Subscript) + return self._get_lhs_names(target.value) + if m.matches(target, m.StarredElement() | m.Element()): + return self._get_lhs_names(target.value) + if m.matches(target, m.Name()): + target: cst.Name = cst.ensure_type(target, cst.Name) + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, target) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + return [target.value] + if m.matches(target, m.List() | m.Tuple()): + # noinspection PyUnresolvedReferences + return [name for el in target.elements for name in self._get_lhs_names(el)] + return [] + + def leave_Module(self, module: cst.Module) -> None: + self._provider.set_metadata(module, self.get_dataflow_property()) + + def visit_Param(self, node: cst.Param) -> bool | None: + if self._active: + index = self._provider.get_metadata(IndexProvider, node) + name = node.name.value + self._tfs[Node(index, 0)] = lambda lives, name=name: lives.difference([name]) + return False + return None + + # noinspection PyDefaultArgument + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: + if self._active: + index = self._provider.get_metadata(IndexProvider, node.target) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) + + # noinspection PyDefaultArgument + def visit_Assign(self, node: cst.Assign) -> bool | None: + if self._active: + for target in node.targets: + index = self._provider.get_metadata(IndexProvider, target) + names = self._get_lhs_names(target.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) + + # noinspection PyDefaultArgument + def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: + if self._active: + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) + + def visit_Name(self, node: cst.Name) -> bool | None: + if self._active: + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, node) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + index = self._provider.get_metadata(IndexProvider, node) + name = node.value + self._tfs[Node(index, 0)] = lambda lives, name=name: lives.union([name]) + + def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self._active = True + + def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self._active = False + + def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self._active = True + + def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self._active = False + + def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self._active = False + + def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self._active = True + + def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self._active = False + + def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self._active = True + + +class LiveVariablesDataflowPropertyProvider(cst.BatchableMetadataProvider[DataflowProperty]): + METADATA_DEPENDENCIES = (IndexProvider, QualifiedNameProvider) + + def visit_Module(self, module: cst.Module) -> None: + module.visit(CollectLiveVariablesTransferFunctions(self)) + + +class LiveVariablesVisitor(cst.CSTVisitor): + def __init__( + self, provider: LiveVariablesProvider, live_vars: dict[Node, tuple[TB[frozenset[str]], TB[frozenset[str]]]] + ): + super().__init__() + self._provider: LiveVariablesProvider = provider + self.live_vars: dict[Node, tuple[TB[frozenset[str]], TB[frozenset[str]]]] = live_vars + + def on_visit(self, node: cst.CSTNode) -> bool: + if m.matches(node, m.SimpleWhitespace() | m.TrailingWhitespace()): + return False + if self._provider.get_metadata(IndexProvider, node, None) is not None: + cfg_node = Node(self._provider.get_metadata(IndexProvider, node), 0) + if cfg_node in self.live_vars: + self._provider.set_metadata(node, self.live_vars[cfg_node]) + return True + + +class LiveVariablesProvider(cst.BatchableMetadataProvider[tuple[frozenset[TB[str]], frozenset[TB[str]]]]): + METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider, LiveVariablesDataflowPropertyProvider) + + def visit_Module(self, node: cst.Module) -> bool | None: + cfg, start_end = self.get_metadata(ControlFlowGraphProvider, node) + lv_prop = self.get_metadata(LiveVariablesDataflowPropertyProvider, node) + lv_result = compute_dataflow_property(cfg, start_end, lv_prop) + node.visit(LiveVariablesVisitor(self, lv_result)) diff --git a/src/styx_compiler/metadata_providers.py b/src/styx_compiler/metadata_providers.py index cdea90e..9696467 100644 --- a/src/styx_compiler/metadata_providers.py +++ b/src/styx_compiler/metadata_providers.py @@ -19,22 +19,3 @@ def on_visit(self, node: cst.CSTNode) -> bool: self._index += 1 return True return False - - -class ComputeLiveVariables(cst.CSTVisitor): - """ - Computes the live variables, using the indices from IndexProvider to navigate the control-flow graph - """ - - METADATA_DEPENDENCIES = (IndexProvider,) - - -class LiveVariablesProvider(cst.VisitorMetadataProvider[list[cst.Name]]): - METADATA_DEPENDENCIES = (IndexProvider,) - - def __init__(self, live_variables_entry: list[list[cst.Name]], live_variables_exit: list[list[cst.Name]]): - super().__init__() - self.live_variables_entry = live_variables_entry - self.live_variables_exit = live_variables_exit - - # visit the same things as the IndexProvider and look up the list of live variables from the index. diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..1b66e3d --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Testing module for styx_compiler +""" diff --git a/tests/test_cfg.py b/tests/test_cfg.py index 38ca0f5..6bd598d 100644 --- a/tests/test_cfg.py +++ b/tests/test_cfg.py @@ -1,8 +1,9 @@ -"""Tests for the styx_compiler.main.""" +"""Tests for styx_compiler.control_flow.""" import libcst as cst -from styx_compiler.control_flow import CfgNodeTester, ComputeControlFlowGraph +from styx_compiler.control_flow import ControlFlowGraphProvider, Node +from styx_compiler.metadata_providers import IndexProvider add_fundef = """ def add(a: int, b: int) -> int: @@ -10,15 +11,16 @@ def add(a: int, b: int) -> int: """ -def test_add_fundef_cfg(): - source_tree = cst.parse_module(add_fundef) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - assert len(ccfg._cfg) == 0 - wrapper.visit(ccfg) - print(ccfg._cfg) - assert len(ccfg._cfg) > 0 - assert len(ccfg._start_end) == 1 +# def test_add_fundef_cfg(): +# source_tree = cst.parse_module(add_fundef) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# assert len(ccfg._cfg) == 0 +# wrapper.visit(ccfg) +# print(ccfg._cfg) +# assert len(ccfg._cfg) > 0 +# assert len(ccfg._start_end) == 1 user_item = """ @@ -47,13 +49,14 @@ def __key__(self): """ -def test_multi_def_cfg(): - source_tree = cst.parse_module(user_item) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - wrapper.visit(ccfg) - print(ccfg._cfg) - assert len(ccfg._start_end) == 5 +# def test_multi_def_cfg(): +# source_tree = cst.parse_module(user_item) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# wrapper.visit(ccfg) +# print(ccfg._cfg) +# assert len(ccfg._start_end) == 5 nested_try = """ @@ -83,15 +86,16 @@ def nested_try(a: int) -> int: """ -def test_nested_try_cfg(): - source_tree = cst.parse_module(nested_try) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - assert len(ccfg._cfg) == 0 - wrapper.visit(ccfg) - print(ccfg._cfg) - assert len(ccfg._cfg) > 0 - print(sum(1 for v in ccfg._cfg.values() if len(v) > 1)) +# def test_nested_try_cfg(): +# source_tree = cst.parse_module(nested_try) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# assert len(ccfg._cfg) == 0 +# wrapper.visit(ccfg) +# print(ccfg._cfg) +# assert len(ccfg._cfg) > 0 +# print(sum(1 for v in ccfg._cfg.values() if len(v) > 1)) def test_node_existence(): @@ -103,7 +107,198 @@ def test_node_existence(): def node_existence(source_string: str): source_tree = cst.parse_module(source_string) wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - wrapper.visit(ccfg) - cnt = CfgNodeTester(ccfg._cfg) + cnt = CfgNodeTester() wrapper.visit(cnt) + + +class CfgNodeTester(cst.CSTVisitor): + """ + Checks that each kind of CST Node that should have a corresponding CFG node has one + """ + + METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider) + + def __init__(self): + super().__init__() + self.cfg = None + self.active = False + + def visit_Module(self, node: cst.Module) -> bool | None: + self.cfg, _start_end = self.get_metadata(ControlFlowGraphProvider, node) + + def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: + """ + Tests if the CSTNode has a corresponding CFG node with outgoing edges + """ + n = Node(self.get_metadata(IndexProvider, node), instance) + return n in self.cfg + + def visit_Param(self, node: cst.Param) -> bool | None: + if self.active: + assert self._has_node(node) + return False + return None + + def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_NameItem(self, node: cst.NameItem) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Attribute(self, node: cst.Attribute) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Name(self, node: cst.Name) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_AsName(self, node: cst.AsName) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_UnaryOperation(self, node: cst.UnaryOperation) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_BinaryOperation(self, node: cst.BinaryOperation) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_BooleanOperation(self, node: cst.BooleanOperation) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Await(self, node: cst.Await) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Yield(self, node: cst.Yield) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_From(self, node: cst.From) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Integer(self, node: cst.Integer) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Float(self, node: cst.Float) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Imaginary(self, node: cst.Imaginary) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_SimpleString(self, node: cst.SimpleString) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_FormattedStringExpression(self, node: cst.FormattedStringExpression) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_FormattedString(self, node: cst.FormattedString) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Tuple(self, node: cst.Tuple) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_List(self, node: cst.List) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Set(self, node: cst.Set) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Element(self, node: cst.Element) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_StarredElement(self, node: cst.StarredElement) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_DictElement(self, node: cst.DictElement) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_StarredDictElement(self, node: cst.StarredDictElement) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_ListComp(self, node: cst.ListComp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_SetComp(self, node: cst.SetComp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_DictComp(self, node: cst.DictComp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Index(self, node: cst.Index) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Slice(self, node: cst.Slice) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Subscript(self, node: cst.Subscript) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + assert self._has_node(node) + # We're not testing for instance 1, which is a final node and will not have outgoing edges + + def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self.active = True + + def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self.active = False + + def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self.active = True + + def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self.active = False + + def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self.active = True + + def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = True diff --git a/tests/test_live_variables.py b/tests/test_live_variables.py new file mode 100644 index 0000000..458eadc --- /dev/null +++ b/tests/test_live_variables.py @@ -0,0 +1,53 @@ +"""Tests for styx_compiler.live_variables and indirectly for styx_compiler.data_flow.""" + +import libcst as cst + +from styx_compiler.live_variables import ( + LiveVariablesProvider, +) +from tests.test_cfg import add_fundef + +# def test_add_fundef_live_vars(): +# source_tree = cst.parse_module(add_fundef) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# wrapper.visit(ccfg) +# lvdpp = LiveVariablesDataflowPropertyProvider() +# clvtf = CollectLiveVariablesTransferFunctions(lvdpp) +# wrapper.visit(clvtf) +# lv_prop = clvtf.get_dataflow_property() +# lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) +# print(lv_result) +# +# +# def test_user_item_live_vars(): +# source_tree = cst.parse_module(user_item) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# wrapper.visit(ccfg) +# lvdpp = LiveVariablesDataflowPropertyProvider() +# clvtf = CollectLiveVariablesTransferFunctions(lvdpp) +# wrapper.visit(clvtf) +# lv_prop = clvtf.get_dataflow_property() +# lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) +# print(lv_result) + + +def test_add_fundef_live_vars_provider(): + source_tree = cst.parse_module(add_fundef) + wrapper = cst.MetadataWrapper(source_tree) + lvt1 = LiveVariablesTester1() + wrapper.visit(lvt1) + + +class LiveVariablesTester1(cst.CSTVisitor): + """ + Checks that each kind of CST Node that should have a corresponding CFG node has one + """ + + METADATA_DEPENDENCIES = (LiveVariablesProvider,) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + assert self.get_metadata(LiveVariablesProvider, node, None) == (frozenset(), frozenset()) diff --git a/uv.lock b/uv.lock index daf4711..2229ccf 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,10 @@ version = 1 revision = 3 requires-python = ">=3.14" +[options] +exclude-newer = "2026-03-31T14:19:26.944856Z" +exclude-newer-span = "P7D" + [[package]] name = "anyio" version = "4.12.1" @@ -1064,13 +1068,13 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "hatch", specifier = ">=1.13.0,<2" }, - { name = "mkdocs-material", specifier = ">=9.7.5,<10" }, - { name = "mkdocstrings", extras = ["python"], specifier = ">=1.0.3,<2" }, - { name = "prek", specifier = ">=0.3.6,<2" }, - { name = "pytest", specifier = ">=9.0.2,<10" }, - { name = "pytest-cov", specifier = ">=7.0.0,<8" }, - { name = "setuptools-scm", specifier = ">=9.2.2,<10" }, + { name = "hatch", specifier = ">=1.13.0" }, + { name = "mkdocs-material", specifier = ">=9.7.5" }, + { name = "mkdocstrings", extras = ["python"], specifier = ">=1.0.3" }, + { name = "prek", specifier = ">=0.3.6" }, + { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "setuptools-scm", specifier = ">=9.2.2" }, ] [[package]]