diff --git a/src/styx_compiler/control_flow.py b/src/styx_compiler/control_flow.py index a9a58d0..fd4bf21 100644 --- a/src/styx_compiler/control_flow.py +++ b/src/styx_compiler/control_flow.py @@ -37,7 +37,26 @@ class Ghost: instance: int -CfgNode = Node | Ghost +type CfgNode = Node | Ghost + + +@dataclass(frozen=True) +class Read: + pass + + +@dataclass(frozen=True) +class Write: + pass + + +type RWContext = Read | Write + + +def debug_print_cfg(cfg: dict[CfgNode, list[CfgNode]]) -> None: + for node_from, nodes_to in cfg.items(): + for node_to in nodes_to: + print(f"{node_from.index}_{node_from.instance} -> {node_to.index}_{node_to.instance}") class ComputeControlFlowGraph(cst.CSTVisitor): @@ -170,8 +189,10 @@ def _visit_statement( # RHS first if it exists if statement.value is not None: prev = self._visit_expression(statement.value, instance, prev) - # LHS - prev = self._visit_expression(statement.target, instance, prev) + # LHS reads + prev = self._visit_expression(statement.target, instance, prev, Write()) + # LHS write + prev = self._make_cfg_node(statement, instance, prev) # AnnAssign elif m.matches(statement, m.Assert()): statement: cst.Assert = cst.ensure_type(statement, cst.Assert) # test expression first? @@ -184,7 +205,7 @@ def _visit_statement( prev = self._visit_expression(statement.value, instance, prev) # then the multiple LHS, from left to right for target in statement.targets: - prev = self._visit_expression(target.target, instance, prev) + prev = self._visit_expression(target.target, instance, prev, Write()) prev = self._make_cfg_node(target, instance, prev) # AssignTarget elif m.matches(statement, m.AugAssign()): statement: cst.AugAssign = cst.ensure_type(statement, cst.AugAssign) @@ -208,7 +229,9 @@ def _visit_statement( prev = [] elif m.matches(statement, m.Del()): statement: cst.Del = cst.ensure_type(statement, cst.Del) - prev = self._visit_expression(statement.target, instance, prev) + prev = self._visit_expression(statement.target, instance, prev, Write()) + # write/del effect + prev = self._make_cfg_node(statement, instance, prev) # Del elif m.matches(statement, m.Expr()): statement: cst.Expr = cst.ensure_type(statement, cst.Expr) prev = self._visit_expression(statement.value, instance, prev) @@ -255,9 +278,11 @@ def _visit_statement( index = self._provider.get_metadata(IndexProvider, statement) for_loop_continue_target = Ghost(index, 0) prev = self._edge(prev, for_loop_continue_target) - loop_expr_prev = self._visit_expression(statement.iter, instance, prev) - prev = loop_expr_prev - prev = self._visit_expression(statement.target, instance, prev) + prev = self._visit_expression(statement.iter, instance, prev) + prev = self._visit_expression(statement.target, instance, prev, Write()) + # assignment effect, would be nice to have something other than the target to pin this to (as the target + # may be a Name, and we usually see a Name CfgNode as a Read effect) + prev = self._make_cfg_node(statement, instance, prev) prev = self._visit_loop( statement, instance, @@ -357,7 +382,6 @@ def wrap_in_finally(exit: CfgNode) -> CfgNode: handler_exits.append(handler_exit) if handler.name is not None: - handler_prev = self._visit_expression(handler.name.name, instance, handler_prev) handler_prev = self._make_cfg_node(handler.name, instance, handler_prev) # AsName handler_prev = self._visit_BaseSuite( handler.body, @@ -493,12 +517,20 @@ def _visit_loop( loop_continue_target=loop_continue_target, loop_break_target=loop_break_target, ) + self._edge(prev, this_loop_continue_target) return self._edge(prev, this_loop_break_target) - def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: list[CfgNode]) -> list[CfgNode]: + def _visit_expression( + self, + expression: cst.BaseExpression, + instance: int, + prev: list[CfgNode], + context: RWContext = Read(), # noqa: B008 + ) -> list[CfgNode]: ## Names and Object Attributes if m.matches(expression, m.Name()): - prev = self._make_cfg_node(expression, instance, prev) # Name + if context == Read(): + prev = self._make_cfg_node(expression, instance, prev) # Name elif m.matches(expression, m.Attribute()): expression: cst.Attribute = cst.ensure_type(expression, cst.Attribute) prev = self._visit_expression(expression.value, instance, prev) @@ -506,41 +538,41 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: ## Operations and Comparisons elif m.matches(expression, m.UnaryOperation()): expression: cst.UnaryOperation = cst.ensure_type(expression, cst.UnaryOperation) - prev = self._visit_expression(expression.expression, instance, prev) + prev = self._visit_expression(expression.expression, instance, prev, context) prev = self._make_cfg_node(expression.expression, instance, prev) # UnaryOperation elif m.matches(expression, m.BinaryOperation() | m.BooleanOperation()): expression: cst.BinaryOperation = cst.ensure_type(expression, cst.BinaryOperation) - prev = self._visit_expression(expression.left, instance, prev) - prev = self._visit_expression(expression.right, instance, prev) + prev = self._visit_expression(expression.left, instance, prev, context) + prev = self._visit_expression(expression.right, instance, prev, context) prev = self._make_cfg_node(expression, instance, prev) # BinaryOperation, BooleanOperation elif m.matches(expression, m.Comparison()): # noinspection DuplicatedCode expression: cst.Comparison = cst.ensure_type(expression, cst.Comparison) - prev = self._visit_expression(expression.left, instance, prev) + prev = self._visit_expression(expression.left, instance, prev, context) for comparison in expression.comparisons: - prev = self._visit_expression(comparison.comparator, instance, prev) + prev = self._visit_expression(comparison.comparator, instance, prev, context) prev = self._make_cfg_node(comparison, instance, prev) # ComparisonTarget ## Control Flow elif m.matches(expression, m.Await()): expression: cst.Await = cst.ensure_type(expression, cst.Await) - prev = self._visit_expression(expression.expression, instance, prev) + prev = self._visit_expression(expression.expression, instance, prev, context) prev = self._make_cfg_node(expression, instance, prev) # Await elif m.matches(expression, m.Yield()): expression: cst.Yield = cst.ensure_type(expression, cst.Yield) - prev = self._visit_expression(expression.value, instance, prev) + prev = self._visit_expression(expression.value, instance, prev, context) # yield is not like return. A later call to the generator will continue from the yield, so # in the CFG we're just going to represent it as a normal node and pretend the control did # not leave and re-enter because it probably doesn't matter for the analyses we want to do. prev = self._make_cfg_node(expression, instance, prev) # Yield elif m.matches(expression, m.From()): expression: cst.From = cst.ensure_type(expression, cst.From) - prev = self._visit_expression(expression.item, instance, prev) + prev = self._visit_expression(expression.item, instance, prev, context) prev = self._make_cfg_node(expression, instance, prev) # From elif m.matches(expression, m.IfExp()): expression: cst.IfExp = cst.ensure_type(expression, cst.IfExp) - prev = self._visit_expression(expression.test, instance, prev) - body = self._visit_expression(expression.body, instance, prev) - orelse = self._visit_expression(expression.orelse, instance, prev) + prev = self._visit_expression(expression.test, instance, prev, context) + body = self._visit_expression(expression.body, instance, prev, context) + orelse = self._visit_expression(expression.orelse, instance, prev, context) prev = [*body, *orelse] ## Lambdas and Function Calls elif m.matches(expression, m.Lambda()): @@ -548,9 +580,9 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: raise NotImplementedError(msg) elif m.matches(expression, m.Call()): expression: cst.Call = cst.ensure_type(expression, cst.Call) - prev = self._visit_expression(expression.func, instance, prev) + prev = self._visit_expression(expression.func, instance, prev, context) for arg in expression.args: - prev = self._visit_expression(arg.value, instance, prev) + prev = self._visit_expression(arg.value, instance, prev, context) prev = self._make_cfg_node(expression, instance, prev) # Call ## Literal Values elif m.matches(expression, m.Ellipsis()): @@ -564,25 +596,25 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: for part in expression.parts: if m.matches(part, m.FormattedStringExpression()): part: cst.FormattedStringExpression = cst.ensure_type(part, cst.FormattedStringExpression) # noqa: PLW2901 - prev = self._visit_expression(part.expression, instance, prev) + prev = self._visit_expression(part.expression, instance, prev, context) prev = self._make_cfg_node(part, instance, prev) # FormattedStringExpression prev = self._make_cfg_node(expression, instance, prev) # FormattedString ## Collections elif m.matches(expression, m.Tuple() | m.List() | m.Set()): # noinspection PyUnresolvedReferences - prev = self._visit_elements(expression.elements, instance, prev) + prev = self._visit_elements(expression.elements, instance, prev, context) prev = self._make_cfg_node(expression, instance, prev) # Tuple, List, Set elif m.matches(expression, m.Element() | m.StarredElement()): # noinspection PyUnresolvedReferences - prev = self._visit_expression(expression.value, instance, prev) + prev = self._visit_expression(expression.value, instance, prev, context) prev = self._make_cfg_node(expression, instance, prev) # Element, StarredElement elif m.matches(expression, m.Dict()): expression: cst.Dict = cst.ensure_type(expression, cst.Dict) for element in expression.elements: if m.matches(element, m.DictElement()): element: cst.DictElement = cst.ensure_type(element, cst.DictElement) # noqa: PLW2901 - prev = self._visit_expression(element.key, instance, prev) - prev = self._visit_expression(element.value, instance, prev) + prev = self._visit_expression(element.key, instance, prev, context) + prev = self._visit_expression(element.value, instance, prev, context) prev = self._make_cfg_node(element, instance, prev) # DictElement, StarredDictElement ## Comprehensions elif m.matches(expression, m.GeneratorExp() | m.ListComp() | m.SetComp()): @@ -596,7 +628,7 @@ def _visit_expression(self, expression: cst.BaseExpression, instance: int, prev: ## Subscripts and Slices elif m.matches(expression, m.Subscript()): expression: cst.Subscript = cst.ensure_type(expression, cst.Subscript) - prev = self._visit_expression(expression.value, instance, prev) + prev = self._visit_expression(expression.value, instance, prev, context) for element in expression.slice: if m.matches(element, m.Index()): element: cst.Index = cst.ensure_type(element, cst.Index) # noqa: PLW2901 @@ -623,10 +655,12 @@ def _visit_ImportAlias(self, import_alias: cst.ImportAlias, instance: int, prev: prev = self._make_cfg_node(import_alias.asname, instance, prev) # AsName return prev - def _visit_elements(self, elements: Sequence[cst.BaseElement], instance: int, prev: list[CfgNode]) -> list[CfgNode]: + def _visit_elements( + self, elements: Sequence[cst.BaseElement], instance: int, prev: list[CfgNode], context: RWContext + ) -> list[CfgNode]: for element in elements: if m.matches(element, m.Element() | m.StarredElement()): - prev = self._visit_expression(element.value, instance, prev) + prev = self._visit_expression(element.value, instance, prev, context) prev = self._make_cfg_node(element, instance, prev) # Element, StarredElement else: msg = f"Unknown element type {element}" @@ -640,9 +674,13 @@ def _visit_CompFor( elt: cst.BaseExpression | tuple[cst.BaseExpression, cst.BaseExpression], prev: list[CfgNode], ) -> list[CfgNode]: - exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0) + entry = Ghost(self._provider.get_metadata(IndexProvider, for_in), 0) + exit = Ghost(self._provider.get_metadata(IndexProvider, for_in), 1) + prev = self._edge(prev, entry) prev = self._visit_expression(for_in.iter, instance, prev) - prev = self._visit_expression(for_in.target, instance, prev) + prev = self._visit_expression(for_in.target, instance, prev, Write()) + # write effect + prev = self._make_cfg_node(for_in, instance, prev) # CompFor for compif in for_in.ifs: prev = self._visit_expression(compif.test, instance, prev) self._edge(prev, exit) @@ -655,6 +693,7 @@ def _visit_CompFor( prev = self._visit_expression(value, instance, prev) else: prev = self._visit_expression(elt, instance, prev) + self._edge(prev, entry) return self._edge(prev, exit) @property diff --git a/src/styx_compiler/data_flow.py b/src/styx_compiler/data_flow.py index 67ff43e..e60e9df 100644 --- a/src/styx_compiler/data_flow.py +++ b/src/styx_compiler/data_flow.py @@ -39,6 +39,8 @@ def strong_connect(node: Node): node_index[node] = index node_lowlink[node] = index index += 1 + + stack_set_size = len(node_on_stack) node_on_stack.add(node) # N.B. we don't add node to scc_stack here, only to the node_on_stack set. The set is used next and in the # recursive calls, so it doesn't affect the algorithm's correctness to postpone adding to scc_stack. @@ -54,9 +56,8 @@ def strong_connect(node: Node): scc_stack.append(node) if node_lowlink[node] == node_index[node]: - scc = [scc_stack.pop()] - node_on_stack.remove(scc[-1]) - while scc[-1] != node: + scc = [] + for _ in range(len(node_on_stack), stack_set_size, -1): scc.append(scc_stack.pop()) node_on_stack.remove(scc[-1]) result.append(list(reversed(scc))) diff --git a/src/styx_compiler/live_variables.py b/src/styx_compiler/live_variables.py index 99a1aaf..a26717b 100644 --- a/src/styx_compiler/live_variables.py +++ b/src/styx_compiler/live_variables.py @@ -1,5 +1,5 @@ from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Sequence import libcst as cst from libcst import matchers as m @@ -18,20 +18,27 @@ class CollectLiveVariablesTransferFunctions(cst.CSTVisitor): def __init__(self, provider: LiveVariablesDataflowPropertyProvider): super().__init__() self._provider = provider - self._active = False - self._tfs: dict[Node, Callable[[frozenset[str]], frozenset[str]]] = defaultdict(lambda: lambda x: x) + self._tfs: dict[Node, Callable[[frozenset[QualifiedName]], frozenset[QualifiedName]]] = defaultdict( + lambda: lambda x: x + ) + + def resolve_name(self, node: cst.CSTNode) -> Sequence[QualifiedName] | None: + name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, node) + if len(name_origin) == 1: + [qual_name] = name_origin + if qual_name.source == QualifiedNameSource.LOCAL: + return [qual_name] + return None def get_dataflow_property(self) -> DataflowProperty: return DataflowProperty(forward=False, initial=frozenset(), transfer_func=self._tfs, lattice=MaySet()) - def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: + def _get_lhs_names(self, target: cst.BaseExpression) -> Sequence[QualifiedName]: if m.matches(target, m.Attribute()): target: cst.Attribute = cst.ensure_type(target, cst.Attribute) - name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, target.attr) - if len(name_origin) == 1: - [qual_name] = name_origin - if qual_name.source == QualifiedNameSource.LOCAL: - return [target.attr.value] + result = self.resolve_name(target) + if result is not None: + return result if m.matches(target, m.Subscript()): target: cst.Subscript = cst.ensure_type(target, cst.Subscript) return self._get_lhs_names(target.value) @@ -39,11 +46,9 @@ def _get_lhs_names(self, target: cst.BaseExpression) -> list[str]: return self._get_lhs_names(target.value) if m.matches(target, m.Name()): target: cst.Name = cst.ensure_type(target, cst.Name) - name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, target) - if len(name_origin) == 1: - [qual_name] = name_origin - if qual_name.source == QualifiedNameSource.LOCAL: - return [target.value] + result = self.resolve_name(target) + if result is not None: + return result if m.matches(target, m.List() | m.Tuple()): # noinspection PyUnresolvedReferences return [name for el in target.elements for name in self._get_lhs_names(el)] @@ -53,68 +58,52 @@ def leave_Module(self, module: cst.Module) -> None: self._provider.set_metadata(module, self.get_dataflow_property()) def visit_Param(self, node: cst.Param) -> bool | None: - if self._active: - index = self._provider.get_metadata(IndexProvider, node) - name = node.name.value - self._tfs[Node(index, 0)] = lambda lives, name=name: lives.difference([name]) - return False - return None + index = self._provider.get_metadata(IndexProvider, node) + result = self.resolve_name(node.name) + if result is not None: + self._tfs[Node(index, 0)] = lambda lives, names=result: lives.difference(names) - # noinspection PyDefaultArgument def visit_AnnAssign(self, node: cst.AnnAssign) -> bool | None: - if self._active: - index = self._provider.get_metadata(IndexProvider, node.target) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - # noinspection PyDefaultArgument - def visit_Assign(self, node: cst.Assign) -> bool | None: - if self._active: - for target in node.targets: - index = self._provider.get_metadata(IndexProvider, target) - names = self._get_lhs_names(target.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - - # noinspection PyDefaultArgument - def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: - if self._active: - index = self._provider.get_metadata(IndexProvider, node) - names = self._get_lhs_names(node.target) - self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - def visit_Name(self, node: cst.Name) -> bool | None: - if self._active: - name_origin: set[QualifiedName] = self._provider.get_metadata(QualifiedNameProvider, node) - if len(name_origin) == 1: - [qual_name] = name_origin - if qual_name.source == QualifiedNameSource.LOCAL: - index = self._provider.get_metadata(IndexProvider, node) - name = node.value - self._tfs[Node(index, 0)] = lambda lives, name=name: lives.union([name]) + def visit_AssignTarget(self, node: cst.AssignTarget) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - def visit_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self._active = True - - def leave_FunctionDef_params(self, _node: cst.FunctionDef) -> None: - self._active = False - - def visit_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self._active = True + def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - def leave_FunctionDef_body(self, _node: cst.FunctionDef) -> None: - self._active = False + def visit_Del(self, node: cst.Del) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - def visit_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self._active = False + def visit_For(self, node: cst.For) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.target) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - def leave_AnnAssign_annotation(self, _node: cst.FunctionDef) -> None: - self._active = True + def visit_AsName(self, node: cst.AsName) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + names = self._get_lhs_names(node.name) + self._tfs[Node(index, 0)] = lambda lives, names=names: lives.difference(names) - def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self._active = False + def visit_Name(self, node: cst.Name) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + result = self.resolve_name(node) + if result is not None: + self._tfs[Node(index, 0)] = lambda lives, names=result: lives.union(names) - def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: - self._active = True + def visit_Attribute(self, node: cst.Attribute) -> bool | None: + index = self._provider.get_metadata(IndexProvider, node) + result = self.resolve_name(node) + if result is not None: + self._tfs[Node(index, 0)] = lambda lives, names=result: lives.union(names) class LiveVariablesDataflowPropertyProvider(cst.BatchableMetadataProvider[DataflowProperty]): @@ -126,11 +115,13 @@ def visit_Module(self, module: cst.Module) -> None: class LiveVariablesVisitor(cst.CSTVisitor): def __init__( - self, provider: LiveVariablesProvider, live_vars: dict[Node, tuple[TB[frozenset[str]], TB[frozenset[str]]]] + self, + provider: LiveVariablesProvider, + live_vars: dict[Node, tuple[TB[frozenset[QualifiedName]], TB[frozenset[QualifiedName]]]], ): super().__init__() self._provider: LiveVariablesProvider = provider - self.live_vars: dict[Node, tuple[TB[frozenset[str]], TB[frozenset[str]]]] = live_vars + self.live_vars: dict[Node, tuple[TB[frozenset[QualifiedName]], TB[frozenset[QualifiedName]]]] = live_vars def on_visit(self, node: cst.CSTNode) -> bool: if m.matches(node, m.SimpleWhitespace() | m.TrailingWhitespace()): @@ -142,7 +133,9 @@ def on_visit(self, node: cst.CSTNode) -> bool: return True -class LiveVariablesProvider(cst.BatchableMetadataProvider[tuple[frozenset[TB[str]], frozenset[TB[str]]]]): +class LiveVariablesProvider( + cst.BatchableMetadataProvider[tuple[frozenset[TB[QualifiedName]], frozenset[TB[QualifiedName]]]] +): METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider, LiveVariablesDataflowPropertyProvider) def visit_Module(self, node: cst.Module) -> bool | None: diff --git a/tests/test_cfg.py b/tests/test_cfg.py index 6bd598d..10b769b 100644 --- a/tests/test_cfg.py +++ b/tests/test_cfg.py @@ -1,8 +1,9 @@ """Tests for styx_compiler.control_flow.""" import libcst as cst +import libcst.matchers as m -from styx_compiler.control_flow import ControlFlowGraphProvider, Node +from styx_compiler.control_flow import CfgNode, ControlFlowGraphProvider, Node from styx_compiler.metadata_providers import IndexProvider add_fundef = """ @@ -11,18 +12,6 @@ def add(a: int, b: int) -> int: """ -# def test_add_fundef_cfg(): -# source_tree = cst.parse_module(add_fundef) -# wrapper = cst.MetadataWrapper(source_tree) -# cfgp = ControlFlowGraphProvider() -# ccfg = ComputeControlFlowGraph(cfgp) -# assert len(ccfg._cfg) == 0 -# wrapper.visit(ccfg) -# print(ccfg._cfg) -# assert len(ccfg._cfg) > 0 -# assert len(ccfg._start_end) == 1 - - user_item = """ @entity class Item: @@ -49,16 +38,6 @@ def __key__(self): """ -# def test_multi_def_cfg(): -# source_tree = cst.parse_module(user_item) -# wrapper = cst.MetadataWrapper(source_tree) -# cfgp = ControlFlowGraphProvider() -# ccfg = ComputeControlFlowGraph(cfgp) -# wrapper.visit(ccfg) -# print(ccfg._cfg) -# assert len(ccfg._start_end) == 5 - - nested_try = """ def nested_try(a: int) -> int: try: @@ -86,18 +65,6 @@ def nested_try(a: int) -> int: """ -# def test_nested_try_cfg(): -# source_tree = cst.parse_module(nested_try) -# wrapper = cst.MetadataWrapper(source_tree) -# cfgp = ControlFlowGraphProvider() -# ccfg = ComputeControlFlowGraph(cfgp) -# assert len(ccfg._cfg) == 0 -# wrapper.visit(ccfg) -# print(ccfg._cfg) -# assert len(ccfg._cfg) > 0 -# print(sum(1 for v in ccfg._cfg.values() if len(v) > 1)) - - def test_node_existence(): node_existence(add_fundef) node_existence(user_item) @@ -302,3 +269,142 @@ def visit_Attribute_attr(self, _node: cst.FunctionDef) -> None: def leave_Attribute_attr(self, _node: cst.FunctionDef) -> None: self.active = True + + +loop_test = """ +@entity +class Something: + def loop_test(self, cart: list[Item]) -> int: + val = 0 + + for item in cart: + attr_1 = item.get_price() + val += attr_1 + + temp = 3 + + val += temp + + return val +""" + + +def test_loop_test_cfg(): + module = cst.parse_module(loop_test) + wrapper = cst.MetadataWrapper(module) + ltct = LoopTestCfgTester() + wrapper.visit(ltct) + + +class LoopTestCfgTester(cst.CSTVisitor): + METADATA_DEPENDENCIES = (IndexProvider, ControlFlowGraphProvider) + + def __init__(self): + super().__init__() + self.cfg: dict[CfgNode, set[CfgNode]] | None = None + self.start_end: list[tuple[CfgNode, CfgNode]] | None = None + + def is_edge(self, from_node: CfgNode, to_node: CfgNode) -> bool: + return from_node in self.cfg and to_node in self.cfg[from_node] + + def assert_node(self, node: cst.CSTNode, prev: CfgNode) -> CfgNode: + cfg_node = Node(self.get_metadata(IndexProvider, node), 0) + assert cfg_node in self.cfg + assert self.is_edge(prev, cfg_node) + return cfg_node + + def visit_Module(self, node: cst.Module) -> bool | None: + self.cfg, self.start_end = self.get_metadata(ControlFlowGraphProvider, node) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: + assert node.name.value == "loop_test" + + # unpack + + init_val: cst.Assign + for_loop: cst.For + init_temp: cst.Assign + update_val: cst.AugAssign + return_val: cst.Return + + init_val, for_loop, init_temp, update_val, return_val = ( + cst.ensure_type(ssl, cst.SimpleStatementLine).body[0] if m.matches(ssl, m.SimpleStatementLine()) else ssl + for ssl in node.body.body + ) + + for_iter_cart: cst.Name = cst.ensure_type(for_loop.iter, cst.Name) + + for_init_attr_1: cst.Assign + for_update_val: cst.AugAssign + + for_init_attr_1, for_update_val = ( + cst.ensure_type(ssl, cst.SimpleStatementLine).body[0] for ssl in for_loop.body.body + ) + + for_init_attr_1_call: cst.Call = cst.ensure_type(for_init_attr_1.value, cst.Call) + for_init_attr_1_call_expr: cst.Attribute = cst.ensure_type(for_init_attr_1_call.func, cst.Attribute) + for_init_attr_1_call_item: cst.Name = cst.ensure_type(for_init_attr_1_call_expr.value, cst.Name) + + # FunctionDef + + index = self.get_metadata(IndexProvider, node) + start = Node(index, 0) + end = Node(index, 1) + assert start in self.cfg + + assert (start, end) in self.start_end + + param_names = ["self", "cart"] + prev = start + for param, param_name in zip(node.params.params, param_names, strict=True): + assert param.name.value == param_name + prev = self.assert_node(param, prev) + + # FunctionDef Assign + + prev = self.assert_node(init_val.value, prev) + prev = self.assert_node(init_val.targets[0], prev) + + # FunctionDef For + + prev = self.assert_node(for_iter_cart, prev) + prev = self.assert_node(for_loop, prev) + + # FunctionDef For Assign Call Attribute + + prev = self.assert_node(for_init_attr_1_call_item, prev) + prev = self.assert_node(for_init_attr_1_call_expr, prev) + + # FunctionDef For Assign Call + + prev = self.assert_node(for_init_attr_1_call, prev) + + # FunctionDef For Assign + + prev = self.assert_node(for_init_attr_1.targets[0], prev) + + # FunctionDef For AugAssign + + prev = self.assert_node(for_update_val.target, prev) + prev = self.assert_node(for_update_val.value, prev) + prev = self.assert_node(for_update_val, prev) + + # FunctionDef For (back-edge) + + assert self.is_edge(prev, Node(self.get_metadata(IndexProvider, for_iter_cart), 0)) + + # FunctionDef Assign + + prev = self.assert_node(init_temp.value, prev) + prev = self.assert_node(init_temp.targets[0], prev) + + # FunctionDef AugAssign + + prev = self.assert_node(update_val.target, prev) + prev = self.assert_node(update_val.value, prev) + prev = self.assert_node(update_val, prev) + + # FunctionDef Return + + prev = self.assert_node(return_val.value, prev) + assert self.is_edge(prev, end) diff --git a/tests/test_live_variables.py b/tests/test_live_variables.py index 458eadc..49006f5 100644 --- a/tests/test_live_variables.py +++ b/tests/test_live_variables.py @@ -1,11 +1,14 @@ """Tests for styx_compiler.live_variables and indirectly for styx_compiler.data_flow.""" import libcst as cst +from libcst.metadata import QualifiedName, QualifiedNameSource +from styx_compiler.control_flow import ControlFlowGraphProvider, Node from styx_compiler.live_variables import ( LiveVariablesProvider, ) -from tests.test_cfg import add_fundef +from styx_compiler.metadata_providers import IndexProvider +from tests.test_cfg import add_fundef, loop_test # def test_add_fundef_live_vars(): # source_tree = cst.parse_module(add_fundef) @@ -49,5 +52,77 @@ class LiveVariablesTester1(cst.CSTVisitor): METADATA_DEPENDENCIES = (LiveVariablesProvider,) + @staticmethod + def local(name: str) -> QualifiedName: + return QualifiedName(name=f"add..{name}", source=QualifiedNameSource.LOCAL) + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool | None: - assert self.get_metadata(LiveVariablesProvider, node, None) == (frozenset(), frozenset()) + empty = frozenset() + ab = frozenset([self.local("a"), self.local("b")]) + a = frozenset([self.local("a")]) + assert self.get_metadata(LiveVariablesProvider, node.params.params[1], None) == (a, ab) + assert self.get_metadata(LiveVariablesProvider, node.params.params[0], None) == (empty, a) + assert self.get_metadata(LiveVariablesProvider, node, None) == (empty, empty) + + +def test_loop_test_live_vars_provider(): + source_tree = cst.parse_module(loop_test) + wrapper = cst.MetadataWrapper(source_tree) + lvt1 = LiveVariablesTester2() + wrapper.visit(lvt1) + + +class LiveVariablesTester2(cst.CSTVisitor): + """ + Checks that each kind of CST Node that should have a corresponding CFG node has one + """ + + METADATA_DEPENDENCIES = (LiveVariablesProvider, ControlFlowGraphProvider, IndexProvider) + + def __init__(self): + super().__init__() + self._cfg = None + + def is_edge(self, from_node: cst.CSTNode, to_node: cst.CSTNode) -> bool: + from_idx: int | None = self.get_metadata(IndexProvider, from_node, None) + if from_idx is None: + return False + from_cfg = Node(from_idx, 0) + if from_cfg not in self._cfg: + return False + to_idx: int | None = self.get_metadata(IndexProvider, to_node, None) + if to_idx is None: + return False + to_cfg = Node(to_idx, 0) + return to_cfg in self._cfg[from_cfg] + + @staticmethod + def local(name: str) -> QualifiedName: + return QualifiedName(name=f"Something.loop_test..{name}", source=QualifiedNameSource.LOCAL) + + def visit_Module(self, node: cst.Module) -> bool | None: + self._cfg, _ = self.get_metadata(ControlFlowGraphProvider, node) + + def visit_Call(self, node: cst.Call) -> bool | None: + assert self.get_metadata(LiveVariablesProvider, node, None)[1] == frozenset( + [self.local("val"), self.local("item.get_price"), self.local("cart")] + ) + + # def visit_AugAssign(self, node: cst.AugAssign) -> bool | None: + # print(node) + # assert self.is_edge(node.target, node.value), ( + # "AugAssign target should have an edge to AugAssign value is both are cst.Name" + # ) + # print(self.get_metadata(LiveVariablesProvider, node.target, None)) + # print(self.get_metadata(LiveVariablesProvider, node.value, None)) + # print(self.get_metadata(LiveVariablesProvider, node, None)) + # + # def visit_Assign(self, node: cst.Assign) -> bool | None: + # print(node) + # print(self.get_metadata(LiveVariablesProvider, node.value, None)) + # for target in node.targets: + # print(self.get_metadata(LiveVariablesProvider, target, None)) + # + # def visit_Return(self, node: cst.Return) -> bool | None: + # print(node) + # print(self.get_metadata(LiveVariablesProvider, node.value, None))