From 13b3450356e2a5d39cc51b8fdc6a90291dca3ddc Mon Sep 17 00:00:00 2001 From: Jeff Smits Date: Tue, 31 Mar 2026 13:59:45 +0200 Subject: [PATCH 1/6] Add FlowSpec's dataflow engine --- src/styx_compiler/data_flow.py | 154 +++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 src/styx_compiler/data_flow.py diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py new file mode 100644 index 0000000..910091c --- /dev/null +++ b/src/styx_compiler/data_flow.py @@ -0,0 +1,154 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass + +from styx_compiler.control_flow import Node + +type Cfg = dict[Node, set[Node]] + + +def compute_sccs(cfg: Cfg, extremals: list[Node]) -> list[list[Node]]: + """ + Tarjan's Strongly Connected Component algorithm, with a slight modification to force the order of nodes within an + SCC into a postorder traversal. See https://doi.org/10.1016/j.cola.2019.100924, Section 5.3.1 / Figure 26. + + You can reverse the control-flow graphs and give the list of end nodes and this will work just as well. + + Parameters: + cfg (Cfg): one or more control-flow graphs + extremals (list[Node]): a list of start nodes to the control-flow graphs + + Returns: + list[list[Node]]: list of SCCs in topological order (use as a stack for topo order), where each SCC + in reverse postorder over the depth-first spanning tree of the SCC. + """ + index = 0 + scc_stack = [] + result = [] + node_index = {} + node_lowlink = {} + node_on_stack = set() + + def strong_connect(node: Node): + nonlocal index, scc_stack, result, node_index, node_lowlink, node_on_stack + node_index[node] = index + node_lowlink[node] = index + index += 1 + node_on_stack.add(node) + # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the recursive + # calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. + + for next_node in cfg[node]: + if next_node not in node_index: + strong_connect(next_node) + node_lowlink[node] = min(node_lowlink[node], node_lowlink[next_node]) + elif next_node in node_on_stack: + node_lowlink[node] = min(node_lowlink[node], node_index[next_node]) + + # Now we add the node to scc_stack in postorder + scc_stack.append(node) + + if node_lowlink[node] == node_index[node]: + scc = [scc_stack.pop()] + node_on_stack.remove(scc[-1]) + while scc[-1] != node: + scc.append(scc_stack.pop()) + node_on_stack.remove(scc[-1]) + result.append(list(reversed(scc))) + + for ext in extremals: + if ext not in node_index: + strong_connect(ext) + + return list(reversed(result)) + + +@dataclass(frozen=True) +class SymbolicTop: + pass + + +@dataclass(frozen=True) +class SymbolicBottom: + pass + + +type TB[T] = T | SymbolicTop | SymbolicBottom + + +class Lattice[T](ABC): + top = SymbolicTop() + bottom = SymbolicBottom() + + def nleq(self, left: TB[T], right: TB[T]) -> bool: + if isinstance(left, SymbolicBottom) or isinstance(right, SymbolicTop): + return False + if isinstance(left, SymbolicTop) or isinstance(right, SymbolicBottom): + return True + return self._nleq_helper(left, right) + + @abstractmethod + def _nleq_helper(self, left: T, right: T) -> bool: + raise NotImplementedError() + + def join(self, left: TB[T], right: TB[T]) -> TB[T]: + if isinstance(left, SymbolicTop) or isinstance(right, SymbolicTop): + return SymbolicTop() + if isinstance(left, SymbolicBottom): + return right + if isinstance(right, SymbolicBottom): + return left + return self._join_helper(left, right) + + @abstractmethod + def _join_helper(self, left: T, right: T) -> T: + raise NotImplementedError() + + +@dataclass(frozen=True) +class DataflowProperty[T]: + forward: bool + initial: T + transfer_func: Callable[[Node, TB[T]], TB[T]] + lattice: Lattice[T] + + +def compute_dataflow_property[T]( + cfg: Cfg, + start_end: list[tuple[Node, Node]], + df_property: DataflowProperty[T], +) -> dict[Node, tuple[TB[T], TB[T]]]: + prop = {} + for node, nexts in cfg: + prop[node] = df_property.lattice.bottom + for next_node in nexts: + prop[next_node] = df_property.lattice.bottom + + if df_property.forward: + extremals = [start for start, _ in start_end] + else: + rev_cfg = {} + for node, nexts in cfg: + for next_node in nexts: + rev_cfg[next_node] = cfg[node] + cfg = rev_cfg + extremals = [end for _, end in start_end] + + for node in extremals: + prop[node] = df_property.initial + + sccs = compute_sccs(cfg, extremals) + + for scc in sccs: + done = False + while not done: + done = True + for node in scc: + for next_node in cfg[node]: + step = df_property.transfer_func(node, prop[node]) + if df_property.lattice.nleq(step, prop[next_node]): + prop[next_node] = df_property.lattice.join(step, prop[next_node]) + if next_node in scc: + done = False + + return {node: (p, df_property.transfer_func(node, p)) for node, p in prop} From ddc9fa57639c5fcb44ecb6fd3d9644d180871cfc Mon Sep 17 00:00:00 2001 From: Jeff Smits Date: Wed, 1 Apr 2026 14:10:41 +0200 Subject: [PATCH 2/6] Add Live Variables transfer functions (updating some of the CFG shape_ --- src/styx_compiler/control_flow.py | 23 +++---- src/styx_compiler/data_flow.py | 49 +++++++++++++-- src/styx_compiler/live_variables.py | 96 +++++++++++++++++++++++++++++ uv.lock | 4 ++ 4 files changed, 155 insertions(+), 17 deletions(-) create mode 100644 src/styx_compiler/live_variables.py diff --git a/src/styx_compiler/control_flow.py b/src/styx_compiler/control_flow.py index 140dcb3..049d9b9 100644 --- a/src/styx_compiler/control_flow.py +++ b/src/styx_compiler/control_flow.py @@ -188,15 +188,16 @@ def _visit_statement( prev = self._visit_expression(statement.value, instance, prev) # then the multiple LHS, from left to right for target in statement.targets: + prev = self._visit_expression(target.target, instance, prev) prev = self._make_cfg_node(target, instance, prev) # AssignTarget elif m.matches(statement, m.AugAssign()): statement: cst.AugAssign = cst.ensure_type(statement, cst.AugAssign) - # note we're making the AugAssign a node first to represent reading the value from the target - prev = self._make_cfg_node(statement, instance, prev) # AugAssign + # note we're visiting LHS first to represent reading the value from the target + prev = self._visit_expression(statement.target, instance, prev) # then we visit the RHS expression to find more reads prev = self._visit_expression(statement.value, instance, prev) - # finally we write to the LHS - prev = self._visit_expression(statement.target, instance, prev) + # finally we write to the LHS, represented by a node of the whole assignment + prev = self._make_cfg_node(statement, instance, prev) # AugAssign elif m.matches(statement, m.Break()): if loop_break_target is None: msg = "Found break outside of loop" @@ -360,6 +361,7 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: handler_exits.append(handler_exit) if handler.name is not None: + handler_prev = self._visit_expression(handler.name.name, instance, handler_prev) handler_prev = self._make_cfg_node(handler.name, instance, handler_prev) # AsName handler_prev = self._visit_BaseSuite( handler.body, @@ -681,16 +683,12 @@ def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: def visit_Param(self, node: cst.Param) -> bool | None: if self.active: assert self._has_node(node) - # TODO: should we visit deeper into Param too? return False return None def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: if self.active: assert self._has_node(node) - # TODO: should we visit deeper into AssignTarget too? - return False - return None def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: if self.active: @@ -703,9 +701,6 @@ def visit_NameItem(self, node: cst.NameItem) -> bool | None: def visit_Attribute(self, node: cst.Attribute) -> bool | None: if self.active: assert self._has_node(node) - # TODO: should we visit deeper into Attribute too? - return False - return None def visit_Name(self, node: cst.Name) -> bool | None: if self.active: @@ -848,3 +843,9 @@ def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: self.active = True + + def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = True diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py index 910091c..6ac5456 100644 --- a/src/styx_compiler/data_flow.py +++ b/src/styx_compiler/data_flow.py @@ -1,3 +1,8 @@ +""" +Data-flow analysis engine as described in "FlowSpec: A declarative specification language for intra-procedural +flow-sensitive data-flow analysis" by Smits, Wachsmuth and Visser (https://doi.org/10.1016/j.cola.2019.100924). +""" + from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass @@ -10,7 +15,7 @@ def compute_sccs(cfg: Cfg, extremals: list[Node]) -> list[list[Node]]: """ Tarjan's Strongly Connected Component algorithm, with a slight modification to force the order of nodes within an - SCC into a postorder traversal. See https://doi.org/10.1016/j.cola.2019.100924, Section 5.3.1 / Figure 26. + SCC into a postorder traversal. See section 5.3.1 / figure 26. You can reverse the control-flow graphs and give the list of end nodes and this will work just as well. @@ -77,8 +82,10 @@ class SymbolicBottom: class Lattice[T](ABC): - top = SymbolicTop() - bottom = SymbolicBottom() + def __init__(self): + super().__init__() + self.top = SymbolicTop() + self.bottom = SymbolicBottom() def nleq(self, left: TB[T], right: TB[T]) -> bool: if isinstance(left, SymbolicBottom) or isinstance(right, SymbolicTop): @@ -109,7 +116,7 @@ def _join_helper(self, left: T, right: T) -> T: class DataflowProperty[T]: forward: bool initial: T - transfer_func: Callable[[Node, TB[T]], TB[T]] + transfer_func: dict[Node, Callable[[T], T]] lattice: Lattice[T] @@ -118,6 +125,11 @@ def compute_dataflow_property[T]( start_end: list[tuple[Node, Node]], df_property: DataflowProperty[T], ) -> dict[Node, tuple[TB[T], TB[T]]]: + """ + Compute a single dataflow property. We're not doing dependent dataflow properties like the paper. See section 5.3.2 + and figure 27. + TODO: filter the CFG to efficiently handle all the identity function transfer functions. + """ prop = {} for node, nexts in cfg: prop[node] = df_property.lattice.bottom @@ -145,10 +157,35 @@ def compute_dataflow_property[T]( done = True for node in scc: for next_node in cfg[node]: - step = df_property.transfer_func(node, prop[node]) + assert not isinstance(prop[node], SymbolicTop | SymbolicBottom) + step = df_property.transfer_func[node](prop[node]) if df_property.lattice.nleq(step, prop[next_node]): prop[next_node] = df_property.lattice.join(step, prop[next_node]) if next_node in scc: done = False - return {node: (p, df_property.transfer_func(node, p)) for node, p in prop} + return {node: (p, df_property.transfer_func[node](p)) for node, p in prop} + + +class MaySet[T](Lattice[frozenset[T]]): + def __init__(self): + super().__init__() + self.bottom = frozenset() + + def _nleq_helper(self, left: frozenset[T], right: frozenset[T]) -> bool: + return not (left <= right) + + def _join_helper(self, left: frozenset[T], right: frozenset[T]) -> frozenset[T]: + return left | right + + +class MustSet[T](Lattice[frozenset[T]]): + def __init__(self): + super().__init__() + self.top = frozenset() + + def _nleq_helper(self, left: frozenset[T], right: frozenset[T]) -> bool: + return not (left >= right) + + def _join_helper(self, left: frozenset[T], right: frozenset[T]) -> frozenset[T]: + return left & right diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py new file mode 100644 index 0000000..1f94e96 --- /dev/null +++ b/src/styx_compiler/live_variables.py @@ -0,0 +1,96 @@ +from collections import defaultdict +from collections.abc import Callable + +import libcst as cst +from libcst import matchers as m + +from styx_compiler.control_flow import Node +from styx_compiler.data_flow import DataflowProperty, MaySet +from styx_compiler.metadata_providers import IndexProvider + + +class CollectLiveVariablesTransferFunctions(cst.CSTVisitor): + """ + Computes the live variable analysis transfer functions for control-flow graph nodes + """ + + METADATA_DEPENDENCIES = (IndexProvider,) + + def __init__(self): + super().__init__() + self.active = False + self._tfs: dict[Node, Callable[[frozenset[str]], frozenset[str]]] = defaultdict(lambda: lambda x: x) + + def get_dataflow_property(self) -> DataflowProperty: + return DataflowProperty(forward=False, initial=frozenset(), transfer_func=self._tfs, lattice=MaySet()) + + def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: + if m.matches(target, m.Attribute()): + target: cst.Attribute = cst.ensure_type(target, cst.Attribute) + return [target.attr.value] + if m.matches(target, m.Subscript()): + target: cst.Subscript = cst.ensure_type(target, cst.Subscript) + return self._get_lhs_names(target.value) + if m.matches(target, m.StarredElement() | m.Element()): + return self._get_lhs_names(target.value) + if m.matches(target, m.Name()): + target: cst.Name = cst.ensure_type(target, cst.Name) + return [target.value] + if m.matches(target, m.List() | m.Tuple()): + return [name for el in target.elements for name in self._get_lhs_names(el)] + return [] + + def visit_Param(self, node: cst.Param) -> bool | None: + if self.active: + index = self.get_metadata(IndexProvider, node) + self._tfs[Node(index, 0)] = lambda lives: lives.difference([node.name.value]) + return False + return None + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: + if self.active: + index = self.get_metadata(IndexProvider, node.target) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) + + def visit_Assign(self, node: cst.Assign) -> bool | None: + if self.active: + for target in node.targets: + index = self.get_metadata(IndexProvider, target) + names = self._get_lhs_names(target.target) + self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) + + def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: + if self.active: + index = self.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) + + def visit_Name(self, node: cst.Name) -> bool | None: + if self.active: + index = self.get_metadata(IndexProvider, node) + self._tfs[Node(index, 0)] = lambda lives: lives.union([node.value]) + + def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self.active = True + + def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self.active = False + + def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self.active = True + + def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self.active = False + + def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self.active = True + + def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = True diff --git a/uv.lock b/uv.lock index daf4711..20668dc 100644 --- a/uv.lock +++ b/uv.lock @@ -2,6 +2,10 @@ version = 1 revision = 3 requires-python = ">=3.14" +[options] +exclude-newer = "2026-03-25T09:17:12.685524Z" +exclude-newer-span = "P7D" + [[package]] name = "anyio" version = "4.12.1" From 6d595900f7b8d19f85f3591cf24612622d97f4f6 Mon Sep 17 00:00:00 2001 From: Jeff Smits Date: Wed, 1 Apr 2026 14:43:20 +0200 Subject: [PATCH 3/6] Add basic smoke tests for live variables, fix bugs in data_flow, live_variables --- src/styx_compiler/control_flow.py | 9 ++------ src/styx_compiler/data_flow.py | 14 +++++++------ src/styx_compiler/live_variables.py | 23 ++++++++++++++++----- tests/test_cfg.py | 2 +- tests/test_live_variables.py | 32 +++++++++++++++++++++++++++++ 5 files changed, 61 insertions(+), 19 deletions(-) create mode 100644 tests/test_live_variables.py diff --git a/src/styx_compiler/control_flow.py b/src/styx_compiler/control_flow.py index 049d9b9..1b869ed 100644 --- a/src/styx_compiler/control_flow.py +++ b/src/styx_compiler/control_flow.py @@ -54,17 +54,12 @@ def __init__(self): def _edge(self, prev: list[CfgNode], cur: CfgNode) -> list[CfgNode]: for p in prev: - if p not in self._cfg: - self._cfg[p] = set() - self._cfg[p].add(cur) + self._cfg.setdefault(p, set()).add(cur) return [cur] def _edges(self, prev: list[CfgNode], tos: list[CfgNode]) -> list[CfgNode]: for p in prev: - if p not in self._cfg: - self._cfg[p] = set() - for to in tos: - self._cfg[p].add(to) + self._cfg.setdefault(p, set()).update(tos) return tos def _make_cfg_node(self, cst_node: cst.CSTNode, instance: int, prev: list[CfgNode]) -> list[CfgNode]: diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py index 6ac5456..a544367 100644 --- a/src/styx_compiler/data_flow.py +++ b/src/styx_compiler/data_flow.py @@ -43,7 +43,7 @@ def strong_connect(node: Node): # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the recursive # calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. - for next_node in cfg[node]: + for next_node in cfg.get(node, set()): if next_node not in node_index: strong_connect(next_node) node_lowlink[node] = min(node_lowlink[node], node_lowlink[next_node]) @@ -131,7 +131,7 @@ def compute_dataflow_property[T]( TODO: filter the CFG to efficiently handle all the identity function transfer functions. """ prop = {} - for node, nexts in cfg: + for node, nexts in cfg.items(): prop[node] = df_property.lattice.bottom for next_node in nexts: prop[next_node] = df_property.lattice.bottom @@ -140,9 +140,9 @@ def compute_dataflow_property[T]( extremals = [start for start, _ in start_end] else: rev_cfg = {} - for node, nexts in cfg: + for node, nexts in cfg.items(): for next_node in nexts: - rev_cfg[next_node] = cfg[node] + rev_cfg.setdefault(next_node, set()).add(node) cfg = rev_cfg extremals = [end for _, end in start_end] @@ -156,7 +156,7 @@ def compute_dataflow_property[T]( while not done: done = True for node in scc: - for next_node in cfg[node]: + for next_node in cfg.get(node, set()): assert not isinstance(prop[node], SymbolicTop | SymbolicBottom) step = df_property.transfer_func[node](prop[node]) if df_property.lattice.nleq(step, prop[next_node]): @@ -164,7 +164,9 @@ def compute_dataflow_property[T]( if next_node in scc: done = False - return {node: (p, df_property.transfer_func[node](p)) for node, p in prop} + if df_property.forward: + return {node: (p, df_property.transfer_func[node](p)) for node, p in prop.items()} + return {node: (df_property.transfer_func[node](p), p) for node, p in prop.items()} class MaySet[T](Lattice[frozenset[T]]): diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py index 1f94e96..8721449 100644 --- a/src/styx_compiler/live_variables.py +++ b/src/styx_compiler/live_variables.py @@ -3,6 +3,7 @@ import libcst as cst from libcst import matchers as m +from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource from styx_compiler.control_flow import Node from styx_compiler.data_flow import DataflowProperty, MaySet @@ -14,7 +15,7 @@ class CollectLiveVariablesTransferFunctions(cst.CSTVisitor): Computes the live variable analysis transfer functions for control-flow graph nodes """ - METADATA_DEPENDENCIES = (IndexProvider,) + METADATA_DEPENDENCIES = (IndexProvider, QualifiedNameProvider) def __init__(self): super().__init__() @@ -27,7 +28,11 @@ def get_dataflow_property(self) -> DataflowProperty: def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: if m.matches(target, m.Attribute()): target: cst.Attribute = cst.ensure_type(target, cst.Attribute) - return [target.attr.value] + name_origin: set[QualifiedName] = self.get_metadata(QualifiedNameProvider, target.attr) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + return [target.attr.value] if m.matches(target, m.Subscript()): target: cst.Subscript = cst.ensure_type(target, cst.Subscript) return self._get_lhs_names(target.value) @@ -35,7 +40,11 @@ def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: return self._get_lhs_names(target.value) if m.matches(target, m.Name()): target: cst.Name = cst.ensure_type(target, cst.Name) - return [target.value] + name_origin: set[QualifiedName] = self.get_metadata(QualifiedNameProvider, target) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + return [target.value] if m.matches(target, m.List() | m.Tuple()): return [name for el in target.elements for name in self._get_lhs_names(el)] return [] @@ -68,8 +77,12 @@ def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: def visit_Name(self, node: cst.Name) -> bool | None: if self.active: - index = self.get_metadata(IndexProvider, node) - self._tfs[Node(index, 0)] = lambda lives: lives.union([node.value]) + name_origin: set[QualifiedName] = self.get_metadata(QualifiedNameProvider, node) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + index = self.get_metadata(IndexProvider, node) + self._tfs[Node(index, 0)] = lambda lives: lives.union([node.value]) def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: self.active = True diff --git a/tests/test_cfg.py b/tests/test_cfg.py index 38ca0f5..fe4deab 100644 --- a/tests/test_cfg.py +++ b/tests/test_cfg.py @@ -1,4 +1,4 @@ -"""Tests for the styx_compiler.main.""" +"""Tests for styx_compiler.control_flow.""" import libcst as cst diff --git a/tests/test_live_variables.py b/tests/test_live_variables.py new file mode 100644 index 0000000..48c7e62 --- /dev/null +++ b/tests/test_live_variables.py @@ -0,0 +1,32 @@ +"""Tests for styx_compiler.live_variables and indirectly for styx_compiler.data_flow.""" + +import libcst as cst +from tests.test_cfg import add_fundef, user_item + +from styx_compiler.control_flow import ComputeControlFlowGraph +from styx_compiler.data_flow import compute_dataflow_property +from styx_compiler.live_variables import CollectLiveVariablesTransferFunctions + + +def test_add_fundef_live_vars(): + source_tree = cst.parse_module(add_fundef) + wrapper = cst.MetadataWrapper(source_tree) + ccfg = ComputeControlFlowGraph() + wrapper.visit(ccfg) + clvtf = CollectLiveVariablesTransferFunctions() + wrapper.visit(clvtf) + lv_prop = clvtf.get_dataflow_property() + lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) + print(lv_result) + + +def test_user_item_live_vars(): + source_tree = cst.parse_module(user_item) + wrapper = cst.MetadataWrapper(source_tree) + ccfg = ComputeControlFlowGraph() + wrapper.visit(ccfg) + clvtf = CollectLiveVariablesTransferFunctions() + wrapper.visit(clvtf) + lv_prop = clvtf.get_dataflow_property() + lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) + print(lv_result) From fcd7f87624f8306e6ac22676c026c782bc9ba21f Mon Sep 17 00:00:00 2001 From: Jeff Smits Date: Thu, 2 Apr 2026 16:15:39 +0200 Subject: [PATCH 4/6] Change all visitors to take a provider to set the metadata on, so we can have each depend on another --- src/styx_compiler/control_flow.py | 223 +++--------------------- src/styx_compiler/live_variables.py | 91 +++++++--- src/styx_compiler/metadata_providers.py | 19 -- tests/test_cfg.py | 193 +++++++++++++++++++- tests/test_live_variables.py | 20 ++- 5 files changed, 301 insertions(+), 245 deletions(-) diff --git a/src/styx_compiler/control_flow.py b/src/styx_compiler/control_flow.py index 1b869ed..a9a58d0 100644 --- a/src/styx_compiler/control_flow.py +++ b/src/styx_compiler/control_flow.py @@ -45,10 +45,9 @@ class ComputeControlFlowGraph(cst.CSTVisitor): Computes the control-flow graph of the code, expressed in indices from the IndexProvider """ - METADATA_DEPENDENCIES = (IndexProvider,) - - def __init__(self): + def __init__(self, provider: ControlFlowGraphProvider): super().__init__() + self._provider = provider self._cfg: dict[CfgNode, set[CfgNode]] = {} self._start_end: list[tuple[CfgNode, CfgNode]] = [] @@ -63,7 +62,7 @@ def _edges(self, prev: list[CfgNode], tos: list[CfgNode]) -> list[CfgNode]: return tos def _make_cfg_node(self, cst_node: cst.CSTNode, instance: int, prev: list[CfgNode]) -> list[CfgNode]: - cur = Node(self.get_metadata(IndexProvider, cst_node), instance) + cur = Node(self._provider.get_metadata(IndexProvider, cst_node), instance) return self._edge(prev, cur) def _clean_up_cfg_ghosts(self, start: CfgNode) -> None: @@ -91,7 +90,7 @@ def _clean_up_cfg_ghosts(self, start: CfgNode) -> None: seen.add(next_node) def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - index = self.get_metadata(IndexProvider, node) + index = self._provider.get_metadata(IndexProvider, node) start = Node(index, 0) end = Node(index, 1) self._start_end.append((start, end)) @@ -109,7 +108,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self._clean_up_cfg_ghosts(start) - def leave_Module(self, node: cst.Module) -> None: + def leave_Module(self, module: cst.Module) -> None: # Remove unreachable parts of the CFG (e.g. unused finally clause instantiations, dead code after a return) reachable = set() workstack = [] @@ -131,6 +130,8 @@ def leave_Module(self, node: cst.Module) -> None: for k in to_remove: del self._cfg[k] + self._provider.set_metadata(module, (self._cfg, self._start_end)) + def _visit_BaseSuite( self, statements: cst.BaseSuite | cst.SimpleStatementLine, @@ -251,7 +252,7 @@ def _visit_statement( raise NotImplementedError(msg) elif m.matches(statement, m.For()): statement: cst.For = cst.ensure_type(statement, cst.For) - index = self.get_metadata(IndexProvider, statement) + index = self._provider.get_metadata(IndexProvider, statement) for_loop_continue_target = Ghost(index, 0) prev = self._edge(prev, for_loop_continue_target) loop_expr_prev = self._visit_expression(statement.iter, instance, prev) @@ -315,7 +316,7 @@ def _visit_statement( def wrap_in_finally(exit: CfgNode) -> CfgNode: nonlocal statement, finally_number, fn_end, exception_target, loop_continue_target, loop_break_target if statement.finalbody is not None: - entry = Ghost(self.get_metadata(IndexProvider, statement.finalbody), finally_number) + entry = Ghost(self._provider.get_metadata(IndexProvider, statement.finalbody), finally_number) finalbody: cst.Finally = cst.ensure_type(statement.finalbody, cst.Finally) prev = self._visit_BaseSuite( finalbody.body, @@ -344,7 +345,7 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: # to the next conditional for handler in statement.handlers: handler: cst.ExceptHandler = cst.ensure_type(handler, cst.ExceptHandler) # noqa: PLW2901 - handler_index = self.get_metadata(IndexProvider, handler) + handler_index = self._provider.get_metadata(IndexProvider, handler) handler_entry = Ghost(handler_index, 0) handler_exit = Ghost(handler_index, 1) handler_entries.append(handler_entry) @@ -401,14 +402,14 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: loop_break_target=local_loop_break_target, ) # Ghost node for exiting the finally clause normally - try_exit = Ghost(self.get_metadata(IndexProvider, statement), 0) + try_exit = Ghost(self._provider.get_metadata(IndexProvider, statement), 0) finally_entry = wrap_in_finally(try_exit) # The normal entry into a normal finally clause at the end of the body/else or handler self._edge([*prev, *handler_exits], finally_entry) prev = [try_exit] elif m.matches(statement, m.While()): statement: cst.While = cst.ensure_type(statement, cst.While) - index = self.get_metadata(IndexProvider, statement) + index = self._provider.get_metadata(IndexProvider, statement) while_loop_continue_target = Ghost(index, 0) prev = self._edge(prev, while_loop_continue_target) prev = self._visit_expression(statement.test, instance, prev) @@ -550,6 +551,7 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: prev = self._visit_expression(expression.func, instance, prev) for arg in expression.args: prev = self._visit_expression(arg.value, instance, prev) + prev = self._make_cfg_node(expression, instance, prev) # Call ## Literal Values elif m.matches(expression, m.Ellipsis()): pass @@ -638,7 +640,7 @@ def _visit_CompFor( elt: cst.BaseExpression | tuple[cst.BaseExpression, cst.BaseExpression], prev: list[CfgNode], ) -> list[CfgNode]: - exit = Ghost(self.get_metadata(IndexProvider, for_in), 0) + exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0) prev = self._visit_expression(for_in.iter, instance, prev) prev = self._visit_expression(for_in.target, instance, prev) for compif in for_in.ifs: @@ -655,192 +657,19 @@ def _visit_CompFor( prev = self._visit_expression(elt, instance, prev) return self._edge(prev, exit) + @property + def cfg(self): + return self._cfg -class CfgNodeTester(cst.CSTVisitor): - """ - Checks that each kind of CST Node that should have a corresponding CFG node has one - """ - - METADATA_DEPENDENCIES = (IndexProvider,) - - def __init__(self, cfg: dict[Node, set[Node]]): - super().__init__() - self.cfg = cfg - self.active = False - - def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: - """ - Tests if the CSTNode has a corresponding CFG node with outgoing edges - """ - n = Node(self.get_metadata(IndexProvider, node), instance) - return n in self.cfg - - def visit_Param(self, node: cst.Param) -> bool | None: - if self.active: - assert self._has_node(node) - return False - return None - - def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_NameItem(self, node: cst.NameItem) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Attribute(self, node: cst.Attribute) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Name(self, node: cst.Name) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_AsName(self, node: cst.AsName) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_UnaryOperation(self, node: cst.UnaryOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_BinaryOperation(self, node: cst.BinaryOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_BooleanOperation(self, node: cst.BooleanOperation) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Await(self, node: cst.Await) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Yield(self, node: cst.Yield) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_From(self, node: cst.From) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Integer(self, node: cst.Integer) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Float(self, node: cst.Float) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Imaginary(self, node: cst.Imaginary) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_SimpleString(self, node: cst.SimpleString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FormattedStringExpression(self, node: cst.FormattedStringExpression) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FormattedString(self, node: cst.FormattedString) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Tuple(self, node: cst.Tuple) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_List(self, node: cst.List) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Set(self, node: cst.Set) -> bool | None: - if self.active: - assert self._has_node(node) + @property + def start_end(self): + return self._start_end - def visit_Element(self, node: cst.Element) -> bool | None: - if self.active: - assert self._has_node(node) - def visit_StarredElement(self, node: cst.StarredElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_DictElement(self, node: cst.DictElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_StarredDictElement(self, node: cst.StarredDictElement) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_ListComp(self, node: cst.ListComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_SetComp(self, node: cst.SetComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_DictComp(self, node: cst.DictComp) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Index(self, node: cst.Index) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Slice(self, node: cst.Slice) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_Subscript(self, node: cst.Subscript) -> bool | None: - if self.active: - assert self._has_node(node) - - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - assert self._has_node(node) - # We're not testing for instance 1, which is a final node and will not have outgoing edges - - def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = True - - def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = False - - def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = True - - def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = False - - def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = False - - def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = True - - def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self.active = False +class ControlFlowGraphProvider( + cst.BatchableMetadataProvider[tuple[dict[CfgNode, set[CfgNode]], list[tuple[CfgNode, CfgNode]]]] +): + METADATA_DEPENDENCIES = (IndexProvider,) - def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self.active = True + def visit_Module(self, node: cst.Module) -> bool | None: + node.visit(ComputeControlFlowGraph(self)) diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py index 8721449..b97f7d5 100644 --- a/src/styx_compiler/live_variables.py +++ b/src/styx_compiler/live_variables.py @@ -5,8 +5,8 @@ from libcst import matchers as m from libcst.metadata import QualifiedName, QualifiedNameProvider, QualifiedNameSource -from styx_compiler.control_flow import Node -from styx_compiler.data_flow import DataflowProperty, MaySet +from styx_compiler.control_flow import ControlFlowGraphProvider, Node +from styx_compiler.data_flow import TB, DataflowProperty, MaySet, compute_dataflow_property from styx_compiler.metadata_providers import IndexProvider @@ -15,11 +15,10 @@ class CollectLiveVariablesTransferFunctions(cst.CSTVisitor): Computes the live variable analysis transfer functions for control-flow graph nodes """ - METADATA_DEPENDENCIES = (IndexProvider, QualifiedNameProvider) - - def __init__(self): + def __init__(self, provider: LiveVariablesDataflowPropertyProvider): super().__init__() - self.active = False + self._provider = provider + self._active = False self._tfs: dict[Node, Callable[[frozenset[str]], frozenset[str]]] = defaultdict(lambda: lambda x: x) def get_dataflow_property(self) -> DataflowProperty: @@ -28,7 +27,7 @@ def get_dataflow_property(self) -> DataflowProperty: def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: if m.matches(target, m.Attribute()): target: cst.Attribute = cst.ensure_type(target, cst.Attribute) - name_origin: set[QualifiedName] = self.get_metadata(QualifiedNameProvider, target.attr) + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, target.attr) if len(name_origin) == 1: [qual_name] = name_origin if qual_name.source == QualifiedNameSource.LOCAL: @@ -40,7 +39,7 @@ def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: return self._get_lhs_names(target.value) if m.matches(target, m.Name()): target: cst.Name = cst.ensure_type(target, cst.Name) - name_origin: set[QualifiedName] = self.get_metadata(QualifiedNameProvider, target) + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, target) if len(name_origin) == 1: [qual_name] = name_origin if qual_name.source == QualifiedNameSource.LOCAL: @@ -49,61 +48,99 @@ def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: return [name for el in target.elements for name in self._get_lhs_names(el)] return [] + def leave_Module(self, module: cst.Module) -> None: + self._provider.set_metadata(module, self.get_dataflow_property()) + def visit_Param(self, node: cst.Param) -> bool | None: - if self.active: - index = self.get_metadata(IndexProvider, node) + if self._active: + index = self._provider.get_metadata(IndexProvider, node) self._tfs[Node(index, 0)] = lambda lives: lives.difference([node.name.value]) return False return None def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: - if self.active: - index = self.get_metadata(IndexProvider, node.target) + if self._active: + index = self._provider.get_metadata(IndexProvider, node.target) names = self._get_lhs_names(node.target) self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) def visit_Assign(self, node: cst.Assign) -> bool | None: - if self.active: + if self._active: for target in node.targets: - index = self.get_metadata(IndexProvider, target) + index = self._provider.get_metadata(IndexProvider, target) names = self._get_lhs_names(target.target) self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - if self.active: - index = self.get_metadata(IndexProvider, node) + if self._active: + index = self._provider.get_metadata(IndexProvider, node) names = self._get_lhs_names(node.target) self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) def visit_Name(self, node: cst.Name) -> bool | None: - if self.active: - name_origin: set[QualifiedName] = self.get_metadata(QualifiedNameProvider, node) + if self._active: + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, node) if len(name_origin) == 1: [qual_name] = name_origin if qual_name.source == QualifiedNameSource.LOCAL: - index = self.get_metadata(IndexProvider, node) + index = self._provider.get_metadata(IndexProvider, node) self._tfs[Node(index, 0)] = lambda lives: lives.union([node.value]) def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = True + self._active = True def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self.active = False + self._active = False def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = True + self._active = True def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self.active = False + self._active = False def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = False + self._active = False def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self.active = True + self._active = True def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self.active = False + self._active = False def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self.active = True + self._active = True + + +class LiveVariablesDataflowPropertyProvider(cst.BatchableMetadataProvider[DataflowProperty]): + METADATA_DEPENDENCIES = (IndexProvider, QualifiedNameProvider) + + def visit_Module(self, module: cst.Module) -> None: + module.visit(CollectLiveVariablesTransferFunctions(self)) + + +class LiveVariablesVisitor(cst.CSTVisitor): + def __init__( + self, provider: LiveVariablesProvider, live_vars: dict[Node, tuple[TB[frozenset[str]], TB[frozenset[str]]]] + ): + super().__init__() + self._provider: LiveVariablesProvider = provider + self.live_vars: dict[Node, tuple[TB[frozenset[str]], TB[frozenset[str]]]] = live_vars + + def on_visit(self, node: cst.CSTNode) -> bool: + if m.matches(node, m.SimpleWhitespace() | m.TrailingWhitespace()): + return False + if self._provider.get_metadata(IndexProvider, node, None) is not None: + cfg_node = Node(self._provider.get_metadata(IndexProvider, node), 0) + if cfg_node in self.live_vars: + self._provider.set_metadata(node, self.live_vars[cfg_node]) + return True + + +class LiveVariablesProvider(cst.BatchableMetadataProvider[tuple[frozenset[TB[str]], frozenset[TB[str]]]]): + METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider, LiveVariablesDataflowPropertyProvider) + + def visit_Module(self, node: cst.Module) -> bool | None: + cfg, start_end = self.get_metadata(ControlFlowGraphProvider, node) + lv_prop = self.get_metadata(LiveVariablesDataflowPropertyProvider, node) + lv_result = compute_dataflow_property(cfg, start_end, lv_prop) + node.visit(LiveVariablesVisitor(self, lv_result)) diff --git a/src/styx_compiler/metadata_providers.py b/src/styx_compiler/metadata_providers.py index cdea90e..9696467 100644 --- a/src/styx_compiler/metadata_providers.py +++ b/src/styx_compiler/metadata_providers.py @@ -19,22 +19,3 @@ def on_visit(self, node: cst.CSTNode) -> bool: self._index += 1 return True return False - - -class ComputeLiveVariables(cst.CSTVisitor): - """ - Computes the live variables, using the indices from IndexProvider to navigate the control-flow graph - """ - - METADATA_DEPENDENCIES = (IndexProvider,) - - -class LiveVariablesProvider(cst.VisitorMetadataProvider[list[cst.Name]]): - METADATA_DEPENDENCIES = (IndexProvider,) - - def __init__(self, live_variables_entry: list[list[cst.Name]], live_variables_exit: list[list[cst.Name]]): - super().__init__() - self.live_variables_entry = live_variables_entry - self.live_variables_exit = live_variables_exit - - # visit the same things as the IndexProvider and look up the list of live variables from the index. diff --git a/tests/test_cfg.py b/tests/test_cfg.py index fe4deab..82c8a41 100644 --- a/tests/test_cfg.py +++ b/tests/test_cfg.py @@ -2,7 +2,8 @@ import libcst as cst -from styx_compiler.control_flow import CfgNodeTester, ComputeControlFlowGraph +from styx_compiler.control_flow import ComputeControlFlowGraph, Node +from styx_compiler.metadata_providers import IndexProvider add_fundef = """ def add(a: int, b: int) -> int: @@ -107,3 +108,193 @@ def node_existence(source_string: str): wrapper.visit(ccfg) cnt = CfgNodeTester(ccfg._cfg) wrapper.visit(cnt) + + +class CfgNodeTester(cst.CSTVisitor): + """ + Checks that each kind of CST Node that should have a corresponding CFG node has one + """ + + METADATA_DEPENDENCIES = (IndexProvider,) + + def __init__(self, cfg: dict[Node, set[Node]]): + super().__init__() + self.cfg = cfg + self.active = False + + def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: + """ + Tests if the CSTNode has a corresponding CFG node with outgoing edges + """ + n = Node(self.get_metadata(IndexProvider, node), instance) + return n in self.cfg + + def visit_Param(self, node: cst.Param) -> bool | None: + if self.active: + assert self._has_node(node) + return False + return None + + def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_NameItem(self, node: cst.NameItem) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Attribute(self, node: cst.Attribute) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Name(self, node: cst.Name) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_AsName(self, node: cst.AsName) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_UnaryOperation(self, node: cst.UnaryOperation) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_BinaryOperation(self, node: cst.BinaryOperation) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_BooleanOperation(self, node: cst.BooleanOperation) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_ComparisonTarget(self, node: cst.ComparisonTarget) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Await(self, node: cst.Await) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Yield(self, node: cst.Yield) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_From(self, node: cst.From) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Integer(self, node: cst.Integer) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Float(self, node: cst.Float) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Imaginary(self, node: cst.Imaginary) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_SimpleString(self, node: cst.SimpleString) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_FormattedStringExpression(self, node: cst.FormattedStringExpression) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_FormattedString(self, node: cst.FormattedString) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Tuple(self, node: cst.Tuple) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_List(self, node: cst.List) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Set(self, node: cst.Set) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Element(self, node: cst.Element) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_StarredElement(self, node: cst.StarredElement) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_DictElement(self, node: cst.DictElement) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_StarredDictElement(self, node: cst.StarredDictElement) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_GeneratorExp(self, node: cst.GeneratorExp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_ListComp(self, node: cst.ListComp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_SetComp(self, node: cst.SetComp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_DictComp(self, node: cst.DictComp) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Index(self, node: cst.Index) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Slice(self, node: cst.Slice) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_Subscript(self, node: cst.Subscript) -> bool | None: + if self.active: + assert self._has_node(node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + assert self._has_node(node) + # We're not testing for instance 1, which is a final node and will not have outgoing edges + + def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self.active = True + + def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: + self.active = False + + def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self.active = True + + def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: + self.active = False + + def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: + self.active = True + + def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = False + + def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: + self.active = True diff --git a/tests/test_live_variables.py b/tests/test_live_variables.py index 48c7e62..336385a 100644 --- a/tests/test_live_variables.py +++ b/tests/test_live_variables.py @@ -5,7 +5,7 @@ from styx_compiler.control_flow import ComputeControlFlowGraph from styx_compiler.data_flow import compute_dataflow_property -from styx_compiler.live_variables import CollectLiveVariablesTransferFunctions +from styx_compiler.live_variables import CollectLiveVariablesTransferFunctions, LiveVariablesProvider def test_add_fundef_live_vars(): @@ -30,3 +30,21 @@ def test_user_item_live_vars(): lv_prop = clvtf.get_dataflow_property() lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) print(lv_result) + + +def test_add_fundef_live_vars_provider(): + source_tree = cst.parse_module(add_fundef) + wrapper = cst.MetadataWrapper(source_tree) + lvt1 = LiveVariablesTester1() + wrapper.visit(lvt1) + + +class LiveVariablesTester1(cst.CSTVisitor): + """ + Checks that each kind of CST Node that should have a corresponding CFG node has one + """ + + METADATA_DEPENDENCIES = (LiveVariablesProvider,) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + assert self.get_metadata(LiveVariablesProvider, node, None) == (frozenset(), frozenset()) From 747d31d2d0854e53f11ea95a75b1578e92f23b83 Mon Sep 17 00:00:00 2001 From: Jeff Smits Date: Tue, 7 Apr 2026 16:24:49 +0200 Subject: [PATCH 5/6] Fix bugs in lambdas, turn of pre-commit hook that chokes on type alias syntax, remove upper bounds on dependency versions as this no longer considered good practice by the Python community --- .pre-commit-config.yaml | 13 +++++++------ pyproject.toml | 14 +++++++------- src/styx_compiler/data_flow.py | 4 ++-- src/styx_compiler/live_variables.py | 16 +++++++++++----- uv.lock | 16 ++++++++-------- 5 files changed, 35 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a37e517..c8edb81 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,12 +12,13 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v6.0.0 - hooks: - - id: debug-statements - - id: name-tests-test - args: ["--pytest-test-first"] +# This hook cannot handle newer Python syntax like type aliases yet +# - repo: https://github.com/pre-commit/pre-commit-hooks +# rev: v6.0.0 +# hooks: +# - id: debug-statements +# - id: name-tests-test +# args: ["--pytest-test-first"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.14.13 diff --git a/pyproject.toml b/pyproject.toml index 31a92ff..02cb42d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,13 @@ dependencies = [ [dependency-groups] dev = [ - "hatch>=1.13.0,<2", - "prek>=0.3.6,<2", - "pytest>=9.0.2,<10", - "pytest-cov>=7.0.0,<8", - "mkdocstrings[python]>=1.0.3,<2", - "mkdocs-material>=9.7.5,<10", - "setuptools-scm>=9.2.2,<10", + "hatch>=1.13.0", + "prek>=0.3.6", + "pytest>=9.0.2", + "pytest-cov>=7.0.0", + "mkdocstrings[python]>=1.0.3", + "mkdocs-material>=9.7.5", + "setuptools-scm>=9.2.2", ] [build-system] diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py index a544367..67ff43e 100644 --- a/src/styx_compiler/data_flow.py +++ b/src/styx_compiler/data_flow.py @@ -40,8 +40,8 @@ def strong_connect(node: Node): node_lowlink[node] = index index += 1 node_on_stack.add(node) - # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the recursive - # calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. + # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the + # recursive calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. for next_node in cfg.get(node, set()): if next_node not in node_index: diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py index b97f7d5..99a1aaf 100644 --- a/src/styx_compiler/live_variables.py +++ b/src/styx_compiler/live_variables.py @@ -45,6 +45,7 @@ def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: if qual_name.source == QualifiedNameSource.LOCAL: return [target.value] if m.matches(target, m.List() | m.Tuple()): + # noinspection PyUnresolvedReferences return [name for el in target.elements for name in self._get_lhs_names(el)] return [] @@ -54,28 +55,32 @@ def leave_Module(self, module: cst.Module) -> None: def visit_Param(self, node: cst.Param) -> bool | None: if self._active: index = self._provider.get_metadata(IndexProvider, node) - self._tfs[Node(index, 0)] = lambda lives: lives.difference([node.name.value]) + name = node.name.value + self._tfs[Node(index, 0)] = lambda lives, name=name: lives.difference([name]) return False return None + # noinspection PyDefaultArgument def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: if self._active: index = self._provider.get_metadata(IndexProvider, node.target) names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) + # noinspection PyDefaultArgument def visit_Assign(self, node: cst.Assign) -> bool | None: if self._active: for target in node.targets: index = self._provider.get_metadata(IndexProvider, target) names = self._get_lhs_names(target.target) - self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) + # noinspection PyDefaultArgument def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: if self._active: index = self._provider.get_metadata(IndexProvider, node) names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives: lives.difference(names) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) def visit_Name(self, node: cst.Name) -> bool | None: if self._active: @@ -84,7 +89,8 @@ def visit_Name(self, node: cst.Name) -> bool | None: [qual_name] = name_origin if qual_name.source == QualifiedNameSource.LOCAL: index = self._provider.get_metadata(IndexProvider, node) - self._tfs[Node(index, 0)] = lambda lives: lives.union([node.value]) + name = node.value + self._tfs[Node(index, 0)] = lambda lives, name=name: lives.union([name]) def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: self._active = True diff --git a/uv.lock b/uv.lock index 20668dc..2229ccf 100644 --- a/uv.lock +++ b/uv.lock @@ -3,7 +3,7 @@ revision = 3 requires-python = ">=3.14" [options] -exclude-newer = "2026-03-25T09:17:12.685524Z" +exclude-newer = "2026-03-31T14:19:26.944856Z" exclude-newer-span = "P7D" [[package]] @@ -1068,13 +1068,13 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "hatch", specifier = ">=1.13.0,<2" }, - { name = "mkdocs-material", specifier = ">=9.7.5,<10" }, - { name = "mkdocstrings", extras = ["python"], specifier = ">=1.0.3,<2" }, - { name = "prek", specifier = ">=0.3.6,<2" }, - { name = "pytest", specifier = ">=9.0.2,<10" }, - { name = "pytest-cov", specifier = ">=7.0.0,<8" }, - { name = "setuptools-scm", specifier = ">=9.2.2,<10" }, + { name = "hatch", specifier = ">=1.13.0" }, + { name = "mkdocs-material", specifier = ">=9.7.5" }, + { name = "mkdocstrings", extras = ["python"], specifier = ">=1.0.3" }, + { name = "prek", specifier = ">=0.3.6" }, + { name = "pytest", specifier = ">=9.0.2" }, + { name = "pytest-cov", specifier = ">=7.0.0" }, + { name = "setuptools-scm", specifier = ">=9.2.2" }, ] [[package]] From a2d3d57a6ccec3c2497c7c60578ddd5da50bb903 Mon Sep 17 00:00:00 2001 From: Jeff Smits Date: Tue, 7 Apr 2026 16:39:32 +0200 Subject: [PATCH 6/6] Old tests don't work due to provider rewrite, need a little rewrite later --- tests/__init__.py | 3 ++ tests/test_cfg.py | 68 +++++++++++++++++++----------------- tests/test_live_variables.py | 59 ++++++++++++++++--------------- 3 files changed, 70 insertions(+), 60 deletions(-) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..1b66e3d --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Testing module for styx_compiler +""" diff --git a/tests/test_cfg.py b/tests/test_cfg.py index 82c8a41..6bd598d 100644 --- a/tests/test_cfg.py +++ b/tests/test_cfg.py @@ -2,7 +2,7 @@ import libcst as cst -from styx_compiler.control_flow import ComputeControlFlowGraph, Node +from styx_compiler.control_flow import ControlFlowGraphProvider, Node from styx_compiler.metadata_providers import IndexProvider add_fundef = """ @@ -11,15 +11,16 @@ def add(a: int, b: int) -> int: """ -def test_add_fundef_cfg(): - source_tree = cst.parse_module(add_fundef) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - assert len(ccfg._cfg) == 0 - wrapper.visit(ccfg) - print(ccfg._cfg) - assert len(ccfg._cfg) > 0 - assert len(ccfg._start_end) == 1 +# def test_add_fundef_cfg(): +# source_tree = cst.parse_module(add_fundef) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# assert len(ccfg._cfg) == 0 +# wrapper.visit(ccfg) +# print(ccfg._cfg) +# assert len(ccfg._cfg) > 0 +# assert len(ccfg._start_end) == 1 user_item = """ @@ -48,13 +49,14 @@ def __key__(self): """ -def test_multi_def_cfg(): - source_tree = cst.parse_module(user_item) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - wrapper.visit(ccfg) - print(ccfg._cfg) - assert len(ccfg._start_end) == 5 +# def test_multi_def_cfg(): +# source_tree = cst.parse_module(user_item) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# wrapper.visit(ccfg) +# print(ccfg._cfg) +# assert len(ccfg._start_end) == 5 nested_try = """ @@ -84,15 +86,16 @@ def nested_try(a: int) -> int: """ -def test_nested_try_cfg(): - source_tree = cst.parse_module(nested_try) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - assert len(ccfg._cfg) == 0 - wrapper.visit(ccfg) - print(ccfg._cfg) - assert len(ccfg._cfg) > 0 - print(sum(1 for v in ccfg._cfg.values() if len(v) > 1)) +# def test_nested_try_cfg(): +# source_tree = cst.parse_module(nested_try) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# assert len(ccfg._cfg) == 0 +# wrapper.visit(ccfg) +# print(ccfg._cfg) +# assert len(ccfg._cfg) > 0 +# print(sum(1 for v in ccfg._cfg.values() if len(v) > 1)) def test_node_existence(): @@ -104,9 +107,7 @@ def test_node_existence(): def node_existence(source_string: str): source_tree = cst.parse_module(source_string) wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - wrapper.visit(ccfg) - cnt = CfgNodeTester(ccfg._cfg) + cnt = CfgNodeTester() wrapper.visit(cnt) @@ -115,13 +116,16 @@ class CfgNodeTester(cst.CSTVisitor): Checks that each kind of CST Node that should have a corresponding CFG node has one """ - METADATA_DEPENDENCIES = (IndexProvider,) + METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider) - def __init__(self, cfg: dict[Node, set[Node]]): + def __init__(self): super().__init__() - self.cfg = cfg + self.cfg = None self.active = False + def visit_Module(self, node: cst.Module) -> bool | None: + self.cfg, _start_end = self.get_metadata(ControlFlowGraphProvider, node) + def _has_node(self, node: cst.CSTNode, instance: int = 0) -> bool: """ Tests if the CSTNode has a corresponding CFG node with outgoing edges diff --git a/tests/test_live_variables.py b/tests/test_live_variables.py index 336385a..458eadc 100644 --- a/tests/test_live_variables.py +++ b/tests/test_live_variables.py @@ -1,35 +1,38 @@ """Tests for styx_compiler.live_variables and indirectly for styx_compiler.data_flow.""" import libcst as cst -from tests.test_cfg import add_fundef, user_item -from styx_compiler.control_flow import ComputeControlFlowGraph -from styx_compiler.data_flow import compute_dataflow_property -from styx_compiler.live_variables import CollectLiveVariablesTransferFunctions, LiveVariablesProvider - - -def test_add_fundef_live_vars(): - source_tree = cst.parse_module(add_fundef) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - wrapper.visit(ccfg) - clvtf = CollectLiveVariablesTransferFunctions() - wrapper.visit(clvtf) - lv_prop = clvtf.get_dataflow_property() - lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) - print(lv_result) - - -def test_user_item_live_vars(): - source_tree = cst.parse_module(user_item) - wrapper = cst.MetadataWrapper(source_tree) - ccfg = ComputeControlFlowGraph() - wrapper.visit(ccfg) - clvtf = CollectLiveVariablesTransferFunctions() - wrapper.visit(clvtf) - lv_prop = clvtf.get_dataflow_property() - lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) - print(lv_result) +from styx_compiler.live_variables import ( + LiveVariablesProvider, +) +from tests.test_cfg import add_fundef + +# def test_add_fundef_live_vars(): +# source_tree = cst.parse_module(add_fundef) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# wrapper.visit(ccfg) +# lvdpp = LiveVariablesDataflowPropertyProvider() +# clvtf = CollectLiveVariablesTransferFunctions(lvdpp) +# wrapper.visit(clvtf) +# lv_prop = clvtf.get_dataflow_property() +# lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) +# print(lv_result) +# +# +# def test_user_item_live_vars(): +# source_tree = cst.parse_module(user_item) +# wrapper = cst.MetadataWrapper(source_tree) +# cfgp = ControlFlowGraphProvider() +# ccfg = ComputeControlFlowGraph(cfgp) +# wrapper.visit(ccfg) +# lvdpp = LiveVariablesDataflowPropertyProvider() +# clvtf = CollectLiveVariablesTransferFunctions(lvdpp) +# wrapper.visit(clvtf) +# lv_prop = clvtf.get_dataflow_property() +# lv_result = compute_dataflow_property(ccfg._cfg, ccfg._start_end, lv_prop) +# print(lv_result) def test_add_fundef_live_vars_provider():