@@ -152,9 +152,10 @@ def __init__(
152152 super ().__init__ ()
153153
154154 import pytato as pt
155- self ._freeze_prg_cache : dict [pt .DictOfNamedArrays , lp .TranslationUnit ] = {}
155+ self ._freeze_prg_cache : dict [
156+ pt .AbstractResultWithNamedArrays , lp .TranslationUnit ] = {}
156157 self ._dag_transform_cache : dict [
157- pt .DictOfNamedArrays ,
158+ pt .AbstractResultWithNamedArrays ,
158159 tuple [pt .AbstractResultWithNamedArrays , str ]] = {}
159160
160161 if compile_trace_callback is None :
@@ -600,12 +601,10 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
600601 dag = Deduplicator ()(dag )
601602
602603 # FIXME: Remove this if/when _normalize_pt_expr gets support for functions
603- dag = pt .tag_all_calls_to_be_inlined (
604- dag )
604+ dag = pt .tag_all_calls_to_be_inlined (dag )
605605 dag = pt .inline_calls (dag )
606606
607- normalized_expr , bound_arguments = _normalize_pt_expr (
608- dag )
607+ normalized_expr , bound_arguments = _normalize_pt_expr (dag )
609608
610609 try :
611610 pt_prg = self ._freeze_prg_cache [normalized_expr ]
@@ -760,10 +759,10 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
760759 def transform_dag (self , dag : pytato .AbstractResultWithNamedArrays
761760 ) -> pytato .AbstractResultWithNamedArrays :
762761 import pytato as pt
763- tdag = pt .tag_all_calls_to_be_inlined (dag )
764- tdag = pt .inline_calls (tdag )
765- tdag = pt .transform .materialize_with_mpms (tdag )
766- return tdag
762+ dag = pt .tag_all_calls_to_be_inlined (dag )
763+ dag = pt .inline_calls (dag )
764+ dag = pt .transform .materialize_with_mpms (dag )
765+ return dag
767766
768767 def einsum (self , spec , * args , arg_names = None , tagged = ()):
769768 import pytato as pt
0 commit comments